-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature add new one hot function meeting multi-dimensions (ranks) #2613
base: main
Are you sure you want to change the base?
Conversation
add one hot test
modify format add tests
modify comments
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2613 +/- ##
==========================================
- Coverage 81.86% 81.85% -0.02%
==========================================
Files 833 838 +5
Lines 106450 107528 +1078
==========================================
+ Hits 87146 88016 +870
- Misses 19304 19512 +208 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll take some time to look at this later, but just a couple comments before reviewing the actual code.
- We don't need to have all of the ops comply with the ONNX spec.
- Introducing this as a new operation means that we now have multiple definitions for one-hot. One definition should take over, otherwise it makes everything cluttered.
- Regarding the motivation, do you actually need this one-hot definition? Or is it simply for ONNX conversion? If only the latter, than it can probably just live in the ONNX import code.
@laggui Thank you for your comments. The existing
For current float version one hot, I do not come up with any use case. In PyTorch, for example, the function is minimally designed to support multiple dimensions, and this aspect is something that needs improvement in our framework as well.
Furthermore, another major framework, TensorFlow, not only supports multiple dimensions but also provides flexibility with parameters such as
Further usecases
The ability to configure multiple dimensions, Regarding the concern about having multiple definitions, I also have the same sentiment and agree that unification is necessary. My proposed new function is designed to support both I look forward to your feedback and hope for your support in making this improvement. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, makes sense! Thanks for the detailed response.
I think it is especially useful for the rank > 1 use cases, the rest of the configurable stuff seems less relevant to me. But I understand that there could be value in supporting the broad spec.
See my comments below 🙂
Regarding the multiple definitions, I think I would deprecate the other definitions since this can do it all. Just make sure to adapt the existing tests.
@laggui I modified codes, please review them again (maybe after your Christmas vacation, enjoy!). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hope you had a nice holiday break! Thanks for addressing my previous comments 🙂
I have some follow-up changes. Mostly form over content.
@@ -157,6 +157,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | |||
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | | |||
| `tensor.not_equal(other)` | `x != y` | | |||
| `tensor.permute(axes)` | `tensor.permute(axes)` | | |||
| `tensor.one_hot_fill(depth, on_value, off_value, axis)` | N/A | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one_hot_fill
is implemented for numeric (int + float), not all kinds (including bool). We should move it to the correct documentation table.
@@ -258,7 +259,7 @@ Those operations are only available for `Float` tensors. | |||
|
|||
| Burn API | PyTorch Equivalent | | |||
| --------------------------------------------- | ---------------------------------- | | |||
| `Tensor::one_hot(index, num_classes, device)` | N/A | | |||
| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of having one_hot
defined for int and float, we should have only one definition (and doc) in numeric.
pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self { | ||
check!(TensorCheck::one_hot_index(index, num_classes)); | ||
|
||
let mut dims = [1; D]; | ||
dims[D - 1] = num_classes; | ||
let shape = Shape::new(dims); | ||
let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); | ||
let tensor = Tensor::zeros(shape, device); | ||
let mut ranges: [core::ops::Range<usize>; D] = ranges.try_into().unwrap(); | ||
ranges[D - 1] = index..index + 1; | ||
|
||
tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function definition doesn't clash with the new "multi-purpose" definition of one_hot
. To correctly deprecate this I think we should simply mark it for deprecation and remove it at a later time.
#[deprecated(
since = "0.16.0",
note = "`Tensor::one_hot(...)` will be removed in the future, please use the new `tensor.one_hot(...)` method instead"
)]
We could also change the docstring to include the equivalent call for the example given.
pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2> { | ||
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); | ||
self.one_hot_fill(num_classes, 1.0, 0.0, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(as per my previous comment)
Instead of having one_hot
defined for int and float, we should have only one definition (and doc) in numeric.
pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, Int> { | ||
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); | ||
self.one_hot_fill(num_classes, 1.0, 0.0, -1) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same comment)
Instead of having one_hot
defined for int and float, we should have only one definition (and doc) in numeric.
Target of This Pull Request
First, I attempted to implement a one-hot operation for ONNX. However, I realized that the existing one-hot function did not meet the requirements and, in fact, did not support multidimensional inputs at all. As I explored solutions, including the ONNX specifications, Pytorch, Tensorflow, I concluded that it was necessary to implement a new one-hot function. This led to the creation of this implementation, which I am now submitting as a pull request.
(Pytorch also does not implement complet one hot function, though.)
Hope this will work for burn and community.
Checklist
run-checks all
script has been executed.Related Issues/PRs
Indirectly related to onnx issues #1714
Changes
Newly implemented one hot method for numeric tensor. The reason it should belong to numeric is the return value should be defined by on_value and off_value, not tensor itself. So, the output can take either types of int and float.
This function comprehensively covers all aspects defined by ONNX, including depth, on_value, off_value, and axis, and complies with the one-hot operator specifications introduced in ONNX version 11 and later. By developing this, I believe it becomes possible to handle multidimensional one-hot encoding while also providing a concise and efficient implementation of the ONNX operator. For these reasons, I deemed it essential to create this function.
I considered removing and updating the existing one-hot method, but I decided to take a more conservative approach by leaving the existing method as it is and adding a new one instead.
Testing
Adding tests on
crates/burn-tensor/src/tests/ops/one_hot.rs
and passingrun-checks all
.