-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjoint for active values in loops are just wrong #78264
Comments
Tagging @kovdan01 @JaapWijnen @rxwei @dan-zheng |
Looks like one way of doing things is as follows:
Given that in the pullback loop header will be executed after corresponding loop body, we will effectively zero out adjoints of active values after each iteration. Values "reused" from previous iterations are propagated via phi-nodes and therefore will be unaffected, only those defined in loop body will be affected. |
Description
Kudos to @kovdan01 for initial analysis of this issue.
It turns out that adjoints for active values in loops are just plain wrong. Consider the reproducer. As one case see, the gradient for
repeat_while_loop
is wrong, while gradient forwhile_loop
is correct. Even more, if we'd replace the code in loop byresult *= x
thenrepeat_while_loop
case will start working.Why it is so?
The loop body for
repeat_while_loop
looks like as follows (removed loop condition calculation for brevity):The key thing here is active
%13
(which is essentially aresult
value), so we need to generate adjoint for it. AutoDiff code uses notion of "adjoint for value X in basic block Y. This is fine for code without loops. And is just plain wrong for values inside loops as it should be "adjoint for value X in basic block Y on loop iteration Z
. The values are different at different loop iterations. Thus their adjoints should be distinct as well. Without this we're ending with artificial adjoint accumulations (since single adjoint value is shared between loop iterations) and wrong results.So, when generating pullback for this loop body we need to ensure that initial value for adjoint of
%13
is zero on each iteration. And then perform the usual pullback cloning that involves adjoint value generation and accumulation. We don't do this, so essentially we're accumulating into adjoint from the previous loop iteration.Sure, if things are so broken, why we have not noticed this before? I would say: coincidence.
For the code like
result *= x
we do not have these extra active values,Float.*=
takes adjoint buffer as an inout argument and perform proper adjoint generation there.while / for
case is more interesting. Here the code looks like as follows:So, we're having loop header (
bb1
) first, then loop body and finally the code after loopbb3
. Now, we're having a code that propagates adjoints of active values into predecessor BBs while doing function traverse in reverse post-order. Here, we first visitbb3
, thenbb1
. Insidebb1
we're realizing that there are active values (%25
and%28
) in predecessorbb2
, so we are taking their adjoints inbb1
and propagating them intobb2
. Since no adjoints were defined before, they will be zero initialized and further propagated. And since coincidentally it is a loop header, we're ending into zero-initializing them in each loop iteration in a pullback as pullback to loop header will be executed after loop body. Everything magically works.The situation with repeat loop is in reverse, there is no "loop header" BB in the common sense, there is a "loop footer" instead fused into loop body. So, the adjoints for
%13
and%16
will be first zero-initialized in pullback block corresponding tobb3
and then further propagated tobb1
. So, no zero-initialization on each loop iteration, adjoint values will be reused from previous loop iteration, and wrong results will be provided.It seems to me that we need to perform explicit adjoint zeroing inside loop headers in pullback cloner
Reproduction
Expected behavior
Correct gradient calculation for both cases
Environment
Swift version 6.2-dev (LLVM e404f8897f17aff, Swift 5a68861)
Target: arm64-apple-macosx15.0
Additional information
No response
The text was updated successfully, but these errors were encountered: