Skip to content

Commit

Permalink
Add ONNX text test file (#884)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Mar 31, 2024
1 parent 90a7513 commit 19e4123
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 87 deletions.
152 changes: 67 additions & 85 deletions source/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -346,15 +346,18 @@ onnx.Value = class {
onnx.Node = class {

constructor(context, node, inputs, outputs) {
const op_type = node.op_type;
const domain = node.domain || 'ai.onnx';
const op_type = node.op_type;
const overload = node.overload || '';
const attributes = node.attribute || [];
const metadata_props = node.metadata_props || [];
this.type = context.type(op_type, domain) || { name: op_type, module: domain };
if (this.type.module !== domain && !(this.type instanceof onnx.Function)) {
this.type = context.type(domain, op_type, overload);
if (!this.type || (this.type.module !== domain && !(this.type instanceof onnx.Function))) {
this.type = Object.assign({}, this.type);
this.type.name = op_type;
this.type.module = domain;
this.type.overload = overload;
this.type.identifier = overload ? `${op_type}:${overload}` : `${op_type}`;
}
this.name = node.name || '';
this.description = node.doc_string || '';
Expand All @@ -377,7 +380,7 @@ onnx.Node = class {
type = 'float32';
break;
case onnx.AttributeType.INT:
value = attribute.i;
value = BigInt(attribute.i);
type = 'int64';
break;
case onnx.AttributeType.STRING:
Expand All @@ -397,7 +400,7 @@ onnx.Node = class {
type = 'float32[]';
break;
case onnx.AttributeType.INTS:
value = ArrayBuffer.isView(attribute.ints) ? Array.from(attribute.ints) : attribute.ints;
value = ArrayBuffer.isView(attribute.ints) ? Array.from(attribute.ints) : attribute.ints.map((value) => BigInt(value));
type = 'int64[]';
break;
case onnx.AttributeType.STRINGS:
Expand Down Expand Up @@ -431,18 +434,20 @@ onnx.Node = class {
default:
throw new onnx.Error(`Unsupported attribute type '${attribute.type}'.`);
}
const metadata = context.attribute(op_type, domain, attribute.name);
const metadata = context.attribute(domain, op_type, overload, attribute.name);
if (metadata) {
if (Object.prototype.hasOwnProperty.call(metadata, 'default') && value === metadata.default) {
visible = false;
if (metadata.default !== undefined) {
const defaultValue = type === 'int64' ? BigInt(metadata.default) : metadata.default;
if (value === defaultValue) {
visible = false;
}
}
if (metadata.type === 'DataType') {
type = metadata.type;
value = context.createDataType(value);
}
}
}
// (context, op_type, domain, attribute)
return new onnx.Argument(name, value, type, attribute.doc_string, visible);
});
this.metadata = metadata_props.map((metadata) => {
Expand Down Expand Up @@ -833,64 +838,35 @@ onnx.OptionalType = class {
onnx.Function = class {

constructor(context, func) {
this._name = func.name;
this._domain = func.domain;
this._description = func.doc_string;
this._inputs = [];
this._outputs = [];
this._attributes = func.attribute.map((attribtue) => {
this.type = 'function';
this.name = func.name;
this.module = func.domain;
this.overload = func.overload || '';
this.identifier = this.overload ? `${this.name}:${this.overload}` : this.name;
this.description = func.doc_string;
this.inputs = [];
this.outputs = [];
this.attributes = func.attribute.map((attribtue) => {
return { name: attribtue };
});
context = new onnx.Context.Graph(context, func);
func.input = func.input.map((input) => context.tensor(input));
func.output = func.output.map((output) => context.tensor(output));
context.push(func.node, func.input, func.output);
this._nodes = context.pop();
this.nodes = context.pop();
for (const input of func.input) {
const value = context.value(input.name);
if (!value.initializer) {
this._inputs.push(new onnx.Argument(input.name, [value]));
this.inputs.push(new onnx.Argument(input.name, [value]));
}
}
for (const output of func.output) {
const value = context.value(output.name);
if (!value.initializer) {
this._outputs.push(new onnx.Argument(output.name, [value]));
this.outputs.push(new onnx.Argument(output.name, [value]));
}
}
}

get type() {
return 'function';
}

get name() {
return this._name;
}

get module() {
return this._domain;
}

get description() {
return this._description;
}

get inputs() {
return this._inputs;
}

get outputs() {
return this._outputs;
}

get attributes() {
return this._attributes;
}

get nodes() {
return this._nodes;
}
};

onnx.Context = class {};
Expand All @@ -902,18 +878,18 @@ onnx.Context.Model = class {
this._locations = locations;
this._imageFormat = imageFormat;
this._imports = imports;
this._cache = new Map();
this._types = new Map();
this._attributes = new Map();
this._functions = new Map();
for (const func of functions || []) {
if (!this._functions.has(func.domain)) {
this._functions.set(func.domain, new Map());
const domain = func.domain;
const name = func.name;
const overload = func.overload;
const key = overload ? `${domain}:${name}:${overload}` : `${domain}:${name}`;
if (this._functions.has(key)) {
throw new onnx.Error(`Duplicate function identifier '${key}'.`);
}
const module = this._functions.get(func.domain);
if (module.has(func.name)) {
throw new onnx.Error(`Duplicate function identifier '${func.domain}.${func.name}'.`);
}
module.set(func.name, func);
this._functions.set(key, func);
}
}

Expand Down Expand Up @@ -943,36 +919,34 @@ onnx.Context.Model = class {
return null;
}

type(name, domain) {
const key = `${domain}:${name}`;
if (!this._cache.has(key)) {
let value = this._metadata.type(name, domain, this._imports);
if (!value) {
if (this._functions.has(domain)) {
const module = this._functions.get(domain);
if (module.has(name)) {
value = module.get(name);
if (value.domain !== undefined) {
value = new onnx.Function(this, value);
module.set(name, value);
}

}
type(domain, name, overload) {
const key = overload ? `${domain}:${name}:${overload}` : `${domain}:${name}`;
if (!this._types.has(key)) {
let value = null;
if (this._functions.has(key)) {
value = this._functions.get(key);
if (value.domain !== undefined) {
value = new onnx.Function(this, value);
this._functions.set(key, value);
}
}
this._cache.set(key, value);
if (!value) {
value = this._metadata.type(domain, name, this._imports);
}
this._types.set(key, value);
}
return this._cache.get(key);
return this._types.get(key);
}

attribute(type, domain, name) {
const key = `${domain}:${type}:${name}`;
attribute(domain, type, overload, name) {
const key = overload ? `${domain}:${type}:${overload}::${name}` : `${domain}:${type}::${name}`;
if (!this._attributes.has(key)) {
this._attributes.set(key, null);
const metadata = this.type(type, domain);
const metadata = this.type(domain, type);
if (metadata && Array.isArray(metadata.attributes) && metadata.attributes.length > 0) {
for (const attribute of metadata.attributes) {
const key = `${domain}:${type}:${attribute.name}`;
const name = attribute.name;
const key = overload ? `${domain}:${type}:${overload}::${name}` : `${domain}:${type}::${name}`;
this._attributes.set(key, attribute);
}
}
Expand Down Expand Up @@ -1014,7 +988,7 @@ onnx.Metadata = class {
}
}

type(name, domain, imports) {
type(domain, name, imports) {
domain = domain || 'ai.onnx';
let current = null;
if (this._types.has(domain)) {
Expand Down Expand Up @@ -1182,12 +1156,12 @@ onnx.Context.Graph = class {
}
}

type(name, domain) {
return this._context.type(name, domain);
type(domain, name, overload) {
return this._context.type(domain, name, overload);
}

attribute(type, domain, name) {
return this._context.type(type, domain, name);
attribute(domain, type, overload, name) {
return this._context.attribute(domain, type, overload, name);
}

graph(value) {
Expand Down Expand Up @@ -1354,7 +1328,9 @@ onnx.Context.Graph = class {
});
for (let node of nodes) {
const domain = node.domain || 'ai.onnx';
const type = this._context.type(node.op_type, domain);
const op_type = node.op_type;
const overload = node.overload || '';
const type = this._context.type(domain, op_type, overload);
const inputs = [];
node.input = node.input || [];
for (let i = 0; i < node.input.length;) {
Expand Down Expand Up @@ -2205,6 +2181,9 @@ onnx.TextReader = class {
}
node.domain = domain;
node.op_type = identifier;
if (this._match(':')) {
node.overload = this._parseIdentifier();
}
node.attribute = this._parseAttributeList();
this._expect('(');
node.input = this._parseIdentifierList();
Expand Down Expand Up @@ -2485,6 +2464,9 @@ onnx.TextReader = class {
case 'doc_string':
func[keyword] = this._parseString();
break;
case 'overload':
func[keyword] = this._parseString();
break;
default:
this._throw(`Unknown keyword '${keyword}'.`);
break;
Expand Down
9 changes: 7 additions & 2 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,12 @@ view.View = class {
}
this.showDefinition(this._stack[0]);
});
const name = graph && graph.name ? graph.name : '';
let name = '';
if (graph && graph.identifier) {
name = graph.identifier;
} else if (graph && graph.name) {
name = graph.name;
}
if (name.length > 24) {
element.setAttribute('title', name);
element.innerHTML = `&hellip;${name.substring(name.length - 24, name.length)}`;
Expand Down Expand Up @@ -5399,7 +5404,7 @@ view.ModelFactoryService = class {
this._factories = [];
this.register('./server', ['.netron']);
this.register('./pytorch', ['.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt', '.ff', '.ptmf', '.jit', '.pte', '.bin.index.json', 'serialized_exported_program.json'], ['.model', '.pt2']);
this.register('./onnx', ['.onnx', '.onn', '.pb', '.onnxtxt', '.pbtxt', '.prototxt', '.txt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx', 'onnxmodel', 'ngf', 'json']);
this.register('./onnx', ['.onnx', '.onn', '.pb', '.onnxtxt', '.pbtxt', '.prototxt', '.txt', '.model', '.pt', '.pth', '.pkl', '.ort', '.ort.onnx', 'onnxmodel', '.ngf', '.json', '.bin']);
this.register('./mxnet', ['.json', '.params'], ['.mar']);
this.register('./coreml', ['.mlmodel', '.bin', 'manifest.json', 'metadata.json', 'featuredescriptions.json', '.pb', '.pbtxt'], ['.mlpackage']);
this.register('./caffe', ['.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt']);
Expand Down
17 changes: 17 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -3669,13 +3669,15 @@
"type": "onnx",
"target": "candy.json",
"source": "https://github.com/lutzroeder/netron/files/12329067/candy.json.zip[candy.json]",
"assert": [ "model.graphs[0].nodes[2].attributes[1].visible == false" ],
"format": "ONNX JSON v3",
"link": "https://github.com/lutzroeder/netron/issues/6"
},
{
"type": "onnx",
"target": "candy.onnx",
"source": "https://raw.githubusercontent.com/Microsoft/Windows-Machine-Learning/master/Samples/FNSCandyStyleTransfer/UWP/cs/Assets/candy.onnx",
"assert": [ "model.graphs[0].nodes[2].attributes[1].visible == false" ],
"format": "ONNX v3",
"tags": "validation",
"link": "https://github.com/Microsoft/Windows-Machine-Learning/tree/master/Samples/FNSCandyStyleTransfer/UWP/cs/Assets"
Expand Down Expand Up @@ -3995,6 +3997,21 @@
"format": "ONNX v6",
"link": "https://github.com/lutzroeder/netron/issues/6"
},
{
"type": "onnx",
"target": "test_mi_overloaded_function.onnx",
"source": "https://github.com/lutzroeder/netron/files/14812772/test_mi_overloaded_function.zip[test_mi_overloaded_function.onnx]",
"assert": [ "model.graphs[0].nodes[0].type.identifier == 'cast:to_int32'" ],
"format": "ONNX v10",
"link": "https://github.com/lutzroeder/netron/issues/884"
},
{
"type": "onnx",
"target": "test_mi_overloaded_function.txt",
"source": "https://github.com/lutzroeder/netron/files/14812772/test_mi_overloaded_function.zip[test_mi_overloaded_function.txt]",
"format": "ONNX Text v10",
"link": "https://github.com/lutzroeder/netron/issues/884"
},
{
"type": "onnx",
"target": "mnist.onnx",
Expand Down

0 comments on commit 19e4123

Please sign in to comment.