diff --git a/test/backend.py b/test/backend.py index 8ab30b0893..674f9dae55 100755 --- a/test/backend.py +++ b/test/backend.py @@ -32,6 +32,12 @@ def _test_onnx_iterate(): address = netron.serve(file, model, verbosity='quiet') netron.stop(address) +def _test_torchscript(file): + torch = __import__('torch') + model = torch.load(os.path.join(test_data_dir, 'pytorch', file)) + torch._C._jit_pass_inline(model.graph) # pylint: disable=protected-access + netron.serve(file, model) + def _test_torchscript_transformer(): torch = __import__('torch') model = torch.nn.Transformer(nhead=16, num_encoder_layers=12) @@ -60,48 +66,15 @@ def _test_torchscript_quantized(): torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access netron.serve('d2go', trace) -def _test_torchscript_inception_v3(): - torch = __import__('torch') - trace = torch.jit.load(os.path.join(test_data_dir, 'pytorch', 'inception_v3_traced.pt')) - torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access - netron.serve('inception_v3', trace) - -def _test_torchscript_scalar(): - torch = __import__('torch') - trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'netron_issue_920.pt')) - # trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'UNet.pt')) - torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access - netron.serve('inception_v3', trace) - -def _test_torchscript_tuple(): - torch = __import__('torch') - __import__('torchvision') - trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'fasterrcnn_resnet50_fpn.pt')) - torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access - netron.serve('inception_v3', trace) - -def _test_torchscript_nnapi(): - torch = __import__('torch') - trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'mobilenetv2-quant_full-nnapi.pt')) - torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access - netron.serve('inception_v3', trace) - -def _test_torchscript_alexnet(): - torch = __import__('torch') - trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'alexnet.pt')) - torch._C._jit_pass_inline(trace.graph) # pylint: disable=protected-access - netron.serve('alexnet', trace) - - # _test_onnx() # _test_onnx_iterate() -# _test_torchscript() +# _test_torchscript('alexnet.pt') +_test_torchscript('gpt2.pt') +# _test_torchscript('inception_v3_traced.pt') +# _test_torchscript('netron_issue_920.pt') # scalar +# _test_torchscript('fasterrcnn_resnet50_fpn.pt') # tuple +# _test_torchscript('mobilenetv2-quant_full-nnapi.pt') # nnapi # _test_torchscript_quantized() # _test_torchscript_resnet34() -# _test_torchscript_inception_v3() -# _test_torchscript_scalar() -# _test_torchscript_tuple() -# _test_torchscript_nnapi() # _test_torchscript_transformer() -_test_torchscript_alexnet()