Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/Split ONNX Import #2568

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft

Conversation

agelas
Copy link
Contributor

@agelas agelas commented Nov 29, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#2440

Changes

Adds split to the list of supported ops that can be imported via ONNX.

Testing

tbd

@agelas
Copy link
Contributor Author

agelas commented Dec 14, 2024

@antimora For nodes that produce multiple outputs, is there a recommended pattern for how to assign and return these outputs in the forward method? I tried to do the reverse of concat, but still having a bit of trouble.

Also, in the generated ONNX graph, all three outputs of the graph are named "split1_out1", but the outputs of the node are named uniquely. How do we ensure that each output of a multi-output node is named uniquely (assuming they should be)? I keep getting this error when I try generating the IR:

thread 'main' panicked at crates/burn-import/src/burn/graph.rs:566:40:
Output type not found for split1_out1

I have a feeling that the non-unique names might be tripping this up, because the overall structure of the outputs is consistent with other IRs I've generated.

I also copied some of the generated graph to show you what I mean.

ParsedOnnxGraph(
    // omitting the constant node for brevity
            Node {
                // omitting a bunch of stuff here too
                outputs: [
                    Argument {
                        name: "split1_out1",
                        ty: Tensor(
                            TensorType {
                                elem_type: Int64,
                                dim: 2,
                                shape: None,
                            },
                        ),
                        value: None,
                        passed: false,
                    },
                    Argument {
                        name: "split1_out2",
                        ty: Tensor(
                            TensorType {
                                elem_type: Int64,
                                dim: 2,
                                shape: None,
                            },
                        ),
                        value: None,
                        passed: false,
                    },
                    Argument {
                        name: "split1_out3",
                        ty: Tensor(
                            TensorType {
                                elem_type: Int64,
                                dim: 2,
                                shape: None,
                            },
                        ),
                        value: None,
                        passed: false,
                    },
                ],
                attrs: {
                    "axis": Int64(
                        0,
                    ),
                },
            },
        ],
        inputs: [
           // ignoring this
        ],
        outputs: [
            Argument {
                name: "split1_out1", <--- this name is the same for everything in outputs[]
                ty: Tensor(
                    TensorType {
                        elem_type: Int64,
                        dim: 2,
                        shape: None,
                    },
                ),
                value: None,
                passed: false,
            },
            Argument {
                name: "split1_out1",
                ty: Tensor(
                    TensorType {
                        elem_type: Int64,
                        dim: 2,
                        shape: None,
                    },
                ),
                value: None,
                passed: false,
            },
            Argument {
                name: "split1_out1",
                ty: Tensor(
                    TensorType {
                        elem_type: Int64,
                        dim: 2,
                        shape: None,
                    },
                ),
                value: None,
                passed: false,
            },
        ],
    },
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant