Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequential Macro #2565

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Sequential Macro #2565

wants to merge 3 commits into from

Conversation

ImTheSquid
Copy link

@ImTheSquid ImTheSquid commented Nov 28, 2024

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Changes

Burn currently lacks an analog to nn.Sequential in PyTorch, so I created a sequential macro to generate a similar structure. I used a macro since there is no unifying trait amongst all modules, specifically with regards to how they are initialized. I created a relatively lenient system that classes modules into different bins depending on how they're initialized, then generate a structure with that information. This structure can be initialized like normal (using SequentialConfig, a struct containing an enum that allows for individual customization of each step) then executed with Sequential::forward.

Testing

I have tested various combinations of modules with the macro and it performs as expected. I also added a unit test to ensure that the macro properly generates on each new build.

Copy link

codecov bot commented Dec 1, 2024

Codecov Report

Attention: Patch coverage is 40.32258% with 37 lines in your changes missing coverage. Please review.

Project coverage is 82.35%. Comparing base (6f494e5) to head (0c019b2).
Report is 31 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-core/src/nn/sequential.rs 17.77% 37 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2565      +/-   ##
==========================================
- Coverage   82.37%   82.35%   -0.02%     
==========================================
  Files         825      828       +3     
  Lines      105643   105773     +130     
==========================================
+ Hits        87026    87112      +86     
- Misses      18617    18661      +44     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +10 to +20
/// gen_sequential! {
/// // No config
/// Relu,
/// Sigmoid;
/// // Has config
/// DropoutConfig => Dropout,
/// LeakyReluConfig => LeakyRelu;
/// // Requires a backend (<B>)
/// LinearConfig => Linear
/// }
/// ```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use pattern matching to differentiate modules that require a backend and config:

Sequential!(
    Relu, // Without config
    Dropout(DropoutConfig), // With config
    Linear(LinearConfig; B), // With config + Generics
    Custom(; B), // No config + Generics
    Custom2(config; B, A, C), // With config + many generics
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this would be my ideal method for doing this, it would require me to rewrite it as a proc macro since there isn't a great way to differentiate depending on whether a value is present or not (specifically the whole ($cfg$(; $($generic),+))?) in declarative macros. Separating into multiple blocks allows me to know that the structs in that block need to have .init() called if they have a config or .init(device) called if they have a backend-dependent config without needing actual Rust code to differentiate them. A proc macro would fix this but would also require a new crate just for the macro and some more advanced parsing techniques.

Additionally, how would multiple generics work? Do all generics need to be unique? If I define A as generic across Custom2 and Custom3, is it the same generic? I assume it would be. We would probably designate B as reserved and meaning "needs a device passed to it on initialization".

Should I look into rewriting this as a more complicated but more flexible proc macro or should I keep the simpler but slightly more rigid declarative macro?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I actually don't think the macro helps much, and I don't think we should implement a proc macro. Maybe the real solution would be to implement a trait Forward instead of simply having a method. We could then support tuples as sequential layers. The Forward trait would be totaly decoupled from the Module trait and only used to simplify composing multiple forward methods.

@antimora
Copy link
Collaborator

antimora commented Dec 2, 2024

Copy link
Contributor

github-actions bot commented Jan 2, 2025

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added the stale The issue or pr has been open for too long label Jan 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale The issue or pr has been open for too long
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants