-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: bitwise-ops-for-tensors #2498
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great start! Some minor comments but here is my preliminary review:
I think the bitwise ops should only be added as an int tensor operation, it doesn't make sense for floats.
We could eventually extend the ops to boolean tensors with their logical counterpart (would be applied on a single bit represented by the bool), but this can be left for another PR.
We can leave the candle ops as unimplemented, but for the JIT backends we should wait to merge once it's implemented with cubecl.
/// Operation corresponding to: | ||
/// | ||
/// Float => [add](crate::ops::FloatTensorOps::float_add). | ||
/// Int => [add](crate::ops::IntTensorOps::int_add). | ||
BitwiseAnd(BinaryOperationDescription), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar). | ||
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar). | ||
BitwiseAndScalar(ScalarOperationDescription<E>), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [add](crate::ops::FloatTensorOps::float_add). | ||
/// Int => [add](crate::ops::IntTensorOps::int_add). | ||
BitwiseOr(BinaryOperationDescription), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar). | ||
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar). | ||
BitwiseOrScalar(ScalarOperationDescription<E>), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [add](crate::ops::FloatTensorOps::float_add). | ||
/// Int => [add](crate::ops::IntTensorOps::int_add). | ||
BitwiseXor(BinaryOperationDescription), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar). | ||
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar). | ||
BitwiseXorScalar(ScalarOperationDescription<E>), | ||
/// Operation corresponding to: | ||
/// | ||
/// Float => [powf](crate::ops::FloatTensorOps::float_powf). | ||
/// Int => [powf](crate::ops::IntTensorOps::int_powf). | ||
BitwiseNot(UnaryOperationDescription), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy/pasta docstrings should be fixed.
Also, not sure if this should be a numeric op since it won't be implemented for float (and doesn't really make sense either). I think it should got in the int operations only (and possible add the logical operations that are equivalent for bool).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
The idea behind the whole float implementation was to convert to int before computing and back to float after, you're right, doesn't really make sense, will make the changes. Raised an issue in the cubecl repo for adding the bitwise op support, will look at that first before implementing here |
This PR has been marked as stale because it has not been updated for over a month |
Amazing! This is the PR I was looking for! Is there a way to support bit counting on integer tensors as well? P.S.: Also, what is the status of the PR? Does somebody work on it? Maybe I should pick it up? |
@0x7CFE Yes, this implementation will allow for bit counting. Unless you mean having direct tensor op methods where you can just call and have the bits counted, ex: Also this PR is pending a release in cubecl with this implementation Cubecl |
Just in case, numpy has bitwise_count, wgsl has countOneBits, CUDA has __popc. Apparently, pytorch does not have a ready solution, but there are some tricks to do the counting efficiently given that boolean ops are available. So I believe it would be possible to efficiently implement this at least for some of the backends. P.S.: Since the blocker mentioned above was already resolved, what are the next steps for this PR? |
Possible yes, what bit counting ops exactly are you looking to include? As for if its something the burn team would like to move forward with @laggui would have to comment on that
Waiting on an official release for |
I described my use case in #2641, so essentially I need |
I opened a PR to add the required backend ops to |
Sorry all I didn't access my computer during the holidays 😅 github app is not as great for reviews and long discussions
I think something like
The main branch of burn is actually kept up to date as much as possible with the latest cubecl revision, so your changes are already available 🙂 |
I just need to review the requested changes for the current PR, but this one should land pretty soon. Regarding your request for bit counting, I think once the linked cubecl PR lands (thanks Genna 🙏) we could add a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the changes 🙂
I have a couple of comments, most of them just pointing out some commented code you forgot to remove.
Also, bitwise tests are failing but that's probably just because you're not using the correct operations (as pointed out in the review below).
// not implemented | ||
//todo!() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leftover comment 🙂 it's implemented
// not implemented | ||
//todo!() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
// not implemented | ||
//todo!() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
// not implemented | ||
//todo!() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
// not implemented | ||
//todo!() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
@@ -75,7 +75,7 @@ pub enum OperationDescription { | |||
/// Operation specific to a bool tensor. | |||
Bool(BoolOperationDescription), | |||
/// Operation specific to an int tensor. | |||
Int(IntOperationDescription), | |||
Int(DType, IntOperationDescription<i32>), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the current motivation behind adding the DType
to Int
operations?
Is it only because we moved the bitwise ops from NumericInt
, which required the dtype?
|
||
// /// Applies logical `and` operation element-wise between two tensors. | ||
// fn bitwise_and(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; | ||
|
||
// /// Applies logical `and` operation element-wise between a tensor and a scalar. | ||
// fn bitwise_and_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive; | ||
|
||
// /// Applies logical `or` operation element-wise between two tensors. | ||
// fn bitwise_or(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; | ||
|
||
// /// Applies logical `or` operation element-wise between a tensor and a scalar. | ||
// fn bitwise_or_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive; | ||
|
||
// /// Applies logical `xor` operation element-wise between two tensors. | ||
// fn bitwise_xor(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive; | ||
|
||
// /// Applies logical `xor` operation element-wise between a tensor and a scalar. | ||
// fn bitwise_xor_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive; | ||
|
||
// /// Applies logical `not` operation element-wise on a tensor. | ||
// fn bitwise_not(tensor: Self::Primitive) -> Self::Primitive; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dead code
// fn int_bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
|
||
// fn int_bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; | ||
|
||
// fn int_bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>; | ||
|
||
// fn int_bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to implement these in the current PR? Otherwise, we can remove the commented function definitions.
// burn_tensor::testgen_bitwise_and!(); | ||
// burn_tensor::testgen_bitwise_or!(); | ||
// burn_tensor::testgen_bitwise_xor!(); | ||
// burn_tensor::testgen_bitwise_and_scalar!(); | ||
// burn_tensor::testgen_bitwise_or_scalar!(); | ||
// burn_tensor::testgen_bitwise_xor_scalar!(); | ||
// burn_tensor::testgen_bitwise_not!(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dead commented code
#[cube] | ||
impl<N: Numeric> BinaryOp<N> for BitwiseAndOp { | ||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> { | ||
//lhs + rhs | ||
lhs + rhs | ||
} | ||
} | ||
|
||
#[cube] | ||
impl<N: Numeric> BinaryOp<N> for BitwiseOrOp { | ||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> { | ||
lhs + rhs | ||
} | ||
} | ||
|
||
#[cube] | ||
impl<N: Numeric> BinaryOp<N> for BitwiseXorOp { | ||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> { | ||
lhs + rhs | ||
} | ||
} | ||
|
||
#[cube] | ||
impl<N: Numeric> BinaryOp<N> for BitwiseNotOp { | ||
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> { | ||
lhs + rhs | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're always using the addition instead of the correct line operations.. that probably explains the failing tests 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it was intentional, because the bitwise line ops I did in cubecl wasn't merged at the time, it was to get past compilation issues, will patch since its available
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh makes sense! Sounds good 👍
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
#2234
Blocked by CubeCL
Changes
Bitwise Operations for Tensors
Testing
The corresponding tests for the ops were included under the
burn_tensor/tensor/tests
directory.Candle seems to not have bitwise operations so as it stands the implementation for the candle backend
is replaced with the todo macro.