Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

burn-import: add some tests for ConstantNode #2623

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

jameshiew
Copy link

@jameshiew jameshiew commented Dec 17, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

I've been trying to implement OneHot ONNX op (#1714) in a WIP draft branch. The ONNX model ends up containing a constant integer vector values=[0, 1] used by the OneHot op, this vector was causing issues when trying to test the model. I looked at ConstantNode and these are the tests so far I could get working while investigating.

Issues for ConstantNode
#2624 - constant tensors aren't populated with values
#2625 - generated code for const int tensors doesn't compile

Changes

  • added a helper method ConstantNode::tensor_ty_tokens for tests, but this PR otherwise shouldn't be changing how ConstantNode currently works
  • add codegen tests for ConstantNode (i32/64 + f32/64 scalar, tensors)
  • add ONNX model tests for i32/64 + f32/64 scalars - implicitly testing by adding the constant

Testing

Ran added tests

cargo xtask check all
cargo nextest run --manifest-path crates/burn-import/Cargo.toml
cargo nextest run --manifest-path crates/burn-import/onnx-tests/Cargo.toml

I checked the .onnx models contain the expected scalar constants using Netron

Screenshots f32 f64 i32 i64

let device = Default::default();
let model = constant_f64::Model::<Backend>::new(&device);
let input = TensorData::zeros::<f64, _>(Shape::from([2, 3, 4]));
let expected_output = TensorData::full(Shape::from([2, 3, 4]), 2f32);
Copy link
Author

@jameshiew jameshiew Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below). I wasn't sure how to get PyTorch to just forward the constant by itself so these tests are adding the constant to the input

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe by having the output return the constant only? But a simple constant addition works too.

In case you're curious, you could also manually define the onnx graph like the ConstOfShape script. PyTorch tends doesn't always have a 1-to-1 correspondence for ops, so in such cases it could be easier to define the graph manually.

I'm not sure if the addition is coercing f64 -> f32 somewhere (and i32 -> i64 below)

The floating point and integer data types are defined by the backend used. A model is not as statically defined like an ONNX graph. If you look at the other tests, the input(s) and output(s) are created using the Tensor methods, not from TensorData.

@jameshiew jameshiew marked this pull request as ready for review December 17, 2024 19:16
Copy link

codecov bot commented Jan 2, 2025

Codecov Report

Attention: Patch coverage is 99.74160% with 1 line in your changes missing coverage. Please review.

Project coverage is 82.15%. Comparing base (8a89293) to head (756fd3c).
Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-import/src/burn/node/constant.rs 99.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2623      +/-   ##
==========================================
+ Coverage   82.06%   82.15%   +0.09%     
==========================================
  Files         831      832       +1     
  Lines      106003   106749     +746     
==========================================
+ Hits        86990    87699     +709     
- Misses      19013    19050      +37     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants