Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add new one hot function meeting multi-dimensions (ranks) #2613

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

tiruka
Copy link
Contributor

@tiruka tiruka commented Dec 15, 2024

Target of This Pull Request

First, I attempted to implement a one-hot operation for ONNX. However, I realized that the existing one-hot function did not meet the requirements and, in fact, did not support multidimensional inputs at all. As I explored solutions, including the ONNX specifications, Pytorch, Tensorflow, I concluded that it was necessary to implement a new one-hot function. This led to the creation of this implementation, which I am now submitting as a pull request.
(Pytorch also does not implement complet one hot function, though.)

Hope this will work for burn and community.

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Indirectly related to onnx issues #1714

Changes

Newly implemented one hot method for numeric tensor. The reason it should belong to numeric is the return value should be defined by on_value and off_value, not tensor itself. So, the output can take either types of int and float.
This function comprehensively covers all aspects defined by ONNX, including depth, on_value, off_value, and axis, and complies with the one-hot operator specifications introduced in ONNX version 11 and later. By developing this, I believe it becomes possible to handle multidimensional one-hot encoding while also providing a concise and efficient implementation of the ONNX operator. For these reasons, I deemed it essential to create this function.

I considered removing and updating the existing one-hot method, but I decided to take a more conservative approach by leaving the existing method as it is and adding a new one instead.

Testing

Adding tests on crates/burn-tensor/src/tests/ops/one_hot.rs and passing run-checks all.

Copy link

codecov bot commented Dec 15, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.85%. Comparing base (f1558ad) to head (ac060d1).
Report is 16 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2613      +/-   ##
==========================================
- Coverage   81.86%   81.85%   -0.02%     
==========================================
  Files         833      838       +5     
  Lines      106450   107528    +1078     
==========================================
+ Hits        87146    88016     +870     
- Misses      19304    19512     +208     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take some time to look at this later, but just a couple comments before reviewing the actual code.

  1. We don't need to have all of the ops comply with the ONNX spec.
  2. Introducing this as a new operation means that we now have multiple definitions for one-hot. One definition should take over, otherwise it makes everything cluttered.
  3. Regarding the motivation, do you actually need this one-hot definition? Or is it simply for ONNX conversion? If only the latter, than it can probably just live in the ONNX import code.

@tiruka
Copy link
Contributor Author

tiruka commented Dec 19, 2024

@laggui Thank you for your comments.

The existing one_hot function only operates on rank-1 tensors, which limits its usability.

impl<B> Tensor<B, 1, Int> {
  ...
  pub one_hot() {
  ...
  
  }
}  

For current float version one hot, I do not come up with any use case.

In PyTorch, for example, the function is minimally designed to support multiple dimensions, and this aspect is something that needs improvement in our framework as well.
Pytorch example

F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0]])

Furthermore, another major framework, TensorFlow, not only supports multiple dimensions but also provides flexibility with parameters such as axis and values.
Tensorflow example

indices = [[0, 2], [1, -1]]
depth = 3
tf.one_hot(indices, depth,
           on_value=1.0, off_value=0.0,
           axis=-1)  # output: [2 x 2 x 3]
# [[[1.0, 0.0, 0.0],   # one_hot(0)
#   [0.0, 0.0, 1.0]],  # one_hot(2)
#  [[0.0, 1.0, 0.0],   # one_hot(1)
#   [0.0, 0.0, 0.0]]]  # one_hot(-1)

Further usecases

The ability to configure multiple dimensions, axis, and values is an expected feature in popular frameworks, and I believe this would greatly benefit Burn users, myself included, by helping the framework stay aligned with modern expectations. This one hot function is not closed only for ONNX.

Regarding the concern about having multiple definitions, I also have the same sentiment and agree that unification is necessary. My proposed new function is designed to support both int and float types, making it closer to the one_hot definitions found in other frameworks. As such, I would advocate deprecating the existing implementation and unifying it with this new version. If there is agreement on this approach, I would be happy to submit changes either as part of this PR or in a separate PR to address these points.

I look forward to your feedback and hope for your support in making this improvement.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense! Thanks for the detailed response.

I think it is especially useful for the rank > 1 use cases, the rest of the configurable stuff seems less relevant to me. But I understand that there could be value in supporting the broad spec.

See my comments below 🙂

Regarding the multiple definitions, I think I would deprecate the other definitions since this can do it all. Just make sure to adapt the existing tests.

crates/burn-tensor/src/tensor/api/numeric.rs Outdated Show resolved Hide resolved
crates/burn-tensor/src/tensor/api/numeric.rs Outdated Show resolved Hide resolved
crates/burn-tensor/src/tensor/api/numeric.rs Outdated Show resolved Hide resolved
@tiruka
Copy link
Contributor Author

tiruka commented Dec 25, 2024

@laggui I modified codes, please review them again (maybe after your Christmas vacation, enjoy!).

@tiruka tiruka changed the title Feature add new one hot function meeting full requirements. Feature add new one hot function meeting multi-dimensions (ranks) Dec 26, 2024
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope you had a nice holiday break! Thanks for addressing my previous comments 🙂

I have some follow-up changes. Mostly form over content.

@@ -157,6 +157,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.one_hot_fill(depth, on_value, off_value, axis)` | N/A |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one_hot_fill is implemented for numeric (int + float), not all kinds (including bool). We should move it to the correct documentation table.

@@ -258,7 +259,7 @@ Those operations are only available for `Float` tensors.

| Burn API | PyTorch Equivalent |
| --------------------------------------------- | ---------------------------------- |
| `Tensor::one_hot(index, num_classes, device)` | N/A |
| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having one_hot defined for int and float, we should have only one definition (and doc) in numeric.

Comment on lines -192 to -203
pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self {
check!(TensorCheck::one_hot_index(index, num_classes));

let mut dims = [1; D];
dims[D - 1] = num_classes;
let shape = Shape::new(dims);
let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect();
let tensor = Tensor::zeros(shape, device);
let mut ranges: [core::ops::Range<usize>; D] = ranges.try_into().unwrap();
ranges[D - 1] = index..index + 1;

tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function definition doesn't clash with the new "multi-purpose" definition of one_hot. To correctly deprecate this I think we should simply mark it for deprecation and remove it at a later time.

#[deprecated(
    since = "0.16.0",
    note = "`Tensor::one_hot(...)` will be removed in the future, please use the new `tensor.one_hot(...)` method instead"
)]

We could also change the docstring to include the equivalent call for the example given.

Comment on lines +190 to +192
pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2> {
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
self.one_hot_fill(num_classes, 1.0, 0.0, -1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(as per my previous comment)

Instead of having one_hot defined for int and float, we should have only one definition (and doc) in numeric.

Comment on lines +125 to +128
pub fn one_hot<const D2: usize>(self, num_classes: usize) -> Tensor<B, D2, Int> {
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
self.one_hot_fill(num_classes, 1.0, 0.0, -1)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same comment)

Instead of having one_hot defined for int and float, we should have only one definition (and doc) in numeric.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants