Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: bitwise-ops-for-tensors #2498

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

quinton11
Copy link
Contributor

@quinton11 quinton11 commented Nov 16, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#2234
Blocked by CubeCL

Changes

Bitwise Operations for Tensors

  • Bitwise_And
  • Bitwise_Or
  • Bitwise_Xor
  • Bitwise_And_Scalar
  • Bitwise_Or_Scalar
  • Bitwise_Xor_Scalar
  • Bitwise_Not
  • Bitwise_left_shift
  • Bitwise_right_shift

Testing

The corresponding tests for the ops were included under the burn_tensor/tensor/tests directory.
Candle seems to not have bitwise operations so as it stands the implementation for the candle backend
is replaced with the todo macro.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start! Some minor comments but here is my preliminary review:

I think the bitwise ops should only be added as an int tensor operation, it doesn't make sense for floats.

We could eventually extend the ops to boolean tensors with their logical counterpart (would be applied on a single bit represented by the bool), but this can be left for another PR.

We can leave the candle ops as unimplemented, but for the JIT backends we should wait to merge once it's implemented with cubecl.

crates/burn-candle/src/ops/int_tensor.rs Outdated Show resolved Hide resolved
crates/burn-router/src/runner.rs Outdated Show resolved Hide resolved
crates/burn-tch/src/ops/base.rs Outdated Show resolved Hide resolved
Comment on lines 518 to 552
/// Operation corresponding to:
///
/// Float => [add](crate::ops::FloatTensorOps::float_add).
/// Int => [add](crate::ops::IntTensorOps::int_add).
BitwiseAnd(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar).
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar).
BitwiseAndScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [add](crate::ops::FloatTensorOps::float_add).
/// Int => [add](crate::ops::IntTensorOps::int_add).
BitwiseOr(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar).
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar).
BitwiseOrScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [add](crate::ops::FloatTensorOps::float_add).
/// Int => [add](crate::ops::IntTensorOps::int_add).
BitwiseXor(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar).
/// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar).
BitwiseXorScalar(ScalarOperationDescription<E>),
/// Operation corresponding to:
///
/// Float => [powf](crate::ops::FloatTensorOps::float_powf).
/// Int => [powf](crate::ops::IntTensorOps::int_powf).
BitwiseNot(UnaryOperationDescription),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy/pasta docstrings should be fixed.

Also, not sure if this should be a numeric op since it won't be implemented for float (and doesn't really make sense either). I think it should got in the int operations only (and possible add the logical operations that are equivalent for bool).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

crates/burn-tensor/src/tensor/api/numeric.rs Outdated Show resolved Hide resolved
crates/burn-tensor/src/tests/ops/mod.rs Outdated Show resolved Hide resolved
@quinton11
Copy link
Contributor Author

The idea behind the whole float implementation was to convert to int before computing and back to float after, you're right, doesn't really make sense, will make the changes.

Raised an issue in the cubecl repo for adding the bitwise op support, will look at that first before implementing here

Copy link
Contributor

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added the stale The issue or pr has been open for too long label Dec 21, 2024
@0x7CFE
Copy link

0x7CFE commented Dec 24, 2024

Amazing! This is the PR I was looking for!

Is there a way to support bit counting on integer tensors as well?

P.S.: Also, what is the status of the PR? Does somebody work on it? Maybe I should pick it up?

@github-actions github-actions bot removed the stale The issue or pr has been open for too long label Dec 24, 2024
@quinton11
Copy link
Contributor Author

Amazing! This is the PR I was looking for!

Is there a way to support bit counting on integer tensors as well?

P.S.: Also, what is the status of the PR? Does somebody work on it? Maybe I should pick it up?

@0x7CFE Yes, this implementation will allow for bit counting. Unless you mean having direct tensor op methods where you can just call and have the bits counted, ex: tensor::count_leading_ones or tensor::count_zeros. With that, existing tensor libraries like torch or numpy don't have such direct methods. But this implementation allows for using rust's bitwise operators with the tensors to acheive whatever bit counting you want to do. @laggui what do you think, should there be direct methods?

Also this PR is pending a release in cubecl with this implementation Cubecl

@0x7CFE
Copy link

0x7CFE commented Dec 27, 2024

With that, existing tensor libraries like torch or numpy don't have such direct methods.

Just in case, numpy has bitwise_count, wgsl has countOneBits, CUDA has __popc. Apparently, pytorch does not have a ready solution, but there are some tricks to do the counting efficiently given that boolean ops are available.

So I believe it would be possible to efficiently implement this at least for some of the backends.

P.S.: Since the blocker mentioned above was already resolved, what are the next steps for this PR?

@quinton11
Copy link
Contributor Author

quinton11 commented Dec 28, 2024

So I believe it would be possible to efficiently implement this at least for some of the backends.

Possible yes, what bit counting ops exactly are you looking to include? As for if its something the burn team would like to move forward with @laggui would have to comment on that

P.S.: Since the blocker mentioned above was already resolved, what are the next steps for this PR?

Waiting on an official release for CubeCl containing that change so we can implement it in burn

@0x7CFE
Copy link

0x7CFE commented Dec 28, 2024

what bit counting ops exactly are you looking to include?

I described my use case in #2641, so essentially I need uXX::count_ones().

@wingertge
Copy link
Contributor

I described my use case in #2641, so essentially I need uXX::count_ones().

I opened a PR to add the required backend ops to cubecl (tracel-ai/cubecl#391), so this should be easy enough to add at least for CPU and JIT backends.

@laggui
Copy link
Member

laggui commented Jan 2, 2025

Sorry all I didn't access my computer during the holidays 😅 github app is not as great for reviews and long discussions

But this implementation allows for using rust's bitwise operators with the tensors to acheive whatever bit counting you want to do. @laggui what do you think, should there be direct methods?

I think something like bitwise_count

Also this PR is pending a release in cubecl with this implementation Cubecl

The main branch of burn is actually kept up to date as much as possible with the latest cubecl revision, so your changes are already available 🙂

@laggui
Copy link
Member

laggui commented Jan 2, 2025

P.S.: Since the blocker mentioned above was already resolved, what are the next steps for this PR?

I just need to review the requested changes for the current PR, but this one should land pretty soon.

Regarding your request for bit counting, I think once the linked cubecl PR lands (thanks Genna 🙏) we could add a bitwise_count op. Feel free to open a draft PR if you want to get the ball rolling 🙂

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the changes 🙂

I have a couple of comments, most of them just pointing out some commented code you forgot to remove.

Also, bitwise tests are failing but that's probably just because you're not using the correct operations (as pointed out in the review below).

Comment on lines +288 to +289
// not implemented
//todo!()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover comment 🙂 it's implemented

Comment on lines +294 to +295
// not implemented
//todo!()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +300 to +301
// not implemented
//todo!()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +306 to +307
// not implemented
//todo!()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +312 to +313
// not implemented
//todo!()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@@ -75,7 +75,7 @@ pub enum OperationDescription {
/// Operation specific to a bool tensor.
Bool(BoolOperationDescription),
/// Operation specific to an int tensor.
Int(IntOperationDescription),
Int(DType, IntOperationDescription<i32>),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the current motivation behind adding the DType to Int operations?

Is it only because we moved the bitwise ops from NumericInt, which required the dtype?

Comment on lines +3354 to +3374

// /// Applies logical `and` operation element-wise between two tensors.
// fn bitwise_and(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;

// /// Applies logical `and` operation element-wise between a tensor and a scalar.
// fn bitwise_and_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;

// /// Applies logical `or` operation element-wise between two tensors.
// fn bitwise_or(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;

// /// Applies logical `or` operation element-wise between a tensor and a scalar.
// fn bitwise_or_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;

// /// Applies logical `xor` operation element-wise between two tensors.
// fn bitwise_xor(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;

// /// Applies logical `xor` operation element-wise between a tensor and a scalar.
// fn bitwise_xor_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;

// /// Applies logical `not` operation element-wise on a tensor.
// fn bitwise_not(tensor: Self::Primitive) -> Self::Primitive;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead code

Comment on lines +1210 to +1216
// fn int_bitwise_left_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;

// fn int_bitwise_left_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;

// fn int_bitwise_right_shift(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B>;

// fn int_bitwise_right_shift_scalar(lhs: IntTensor<B>, rhs: IntElem<B>) -> IntTensor<B>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to implement these in the current PR? Otherwise, we can remove the commented function definitions.

Comment on lines +313 to +319
// burn_tensor::testgen_bitwise_and!();
// burn_tensor::testgen_bitwise_or!();
// burn_tensor::testgen_bitwise_xor!();
// burn_tensor::testgen_bitwise_and_scalar!();
// burn_tensor::testgen_bitwise_or_scalar!();
// burn_tensor::testgen_bitwise_xor_scalar!();
// burn_tensor::testgen_bitwise_not!();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead commented code

Comment on lines +69 to +96
#[cube]
impl<N: Numeric> BinaryOp<N> for BitwiseAndOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
//lhs + rhs
lhs + rhs
}
}

#[cube]
impl<N: Numeric> BinaryOp<N> for BitwiseOrOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs + rhs
}
}

#[cube]
impl<N: Numeric> BinaryOp<N> for BitwiseXorOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs + rhs
}
}

#[cube]
impl<N: Numeric> BinaryOp<N> for BitwiseNotOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
lhs + rhs
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're always using the addition instead of the correct line operations.. that probably explains the failing tests 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it was intentional, because the bitwise line ops I did in cubecl wasn't merged at the time, it was to get past compilation issues, will patch since its available

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh makes sense! Sounds good 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants