Skip to content

Commit

Permalink
[AutoDiff] Fix adjoints for loop-local active values
Browse files Browse the repository at this point in the history
  • Loading branch information
kovdan01 committed Dec 19, 2024
1 parent f2ad9f3 commit 1590270
Show file tree
Hide file tree
Showing 4 changed files with 428 additions and 24 deletions.
73 changes: 73 additions & 0 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2456,6 +2456,79 @@ bool PullbackCloner::Implementation::run() {
// Visit original blocks in post-order and perform differentiation
// in corresponding pullback blocks. If errors occurred, back out.
else {
LLVM_DEBUG(getADDebugStream()
<< "Begin search for adjoints of loop-local active values\n");
llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
loopLocalActiveValues;
for (auto *bb : originalBlocks) {
const SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb);
if (loop == nullptr)
continue;
SILBasicBlock *loopHeader = loop->getHeader();
SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader);
LLVM_DEBUG(getADDebugStream()
<< "Original bb" << bb->getDebugID()
<< " belongs to a loop, original header bb"
<< loopHeader->getDebugID() << ", pullback header bb"
<< pbLoopHeader->getDebugID() << '\n');
builder.setInsertionPoint(pbLoopHeader);
auto bbActiveValuesIt = activeValues.find(bb);
if (bbActiveValuesIt == activeValues.end())
continue;
const auto &bbActiveValues = bbActiveValuesIt->second;
for (SILValue bbActiveValue : bbActiveValues) {
if (vjpCloner.getLoopInfo()->getLoopFor(
bbActiveValue->getParentBlock()) != loop) {
LLVM_DEBUG(
getADDebugStream()
<< "The following active value is NOT loop-local, skipping: "
<< bbActiveValue);
continue;
}
auto [_, wasInserted] =
loopLocalActiveValues[loop].insert(bbActiveValue);
LLVM_DEBUG(getADDebugStream()
<< "The following active value is loop-local, ");
if (!wasInserted) {
LLVM_DEBUG(llvm::dbgs() << "but it was already processed, skipping: "
<< bbActiveValue);
continue;
}
if (getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Object) {
LLVM_DEBUG(llvm::dbgs()
<< "zeroing its adjoint value in loop header: "
<< bbActiveValue);
setAdjointValue(bb, bbActiveValue,
makeZeroAdjointValue(getRemappedTangentType(
bbActiveValue->getType())));
} else {
assert(getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Address);
// Adjoint for address projections are handled automatically:
// 1. If the source address is loop-local, it's adjoint buffer will
// be zeroed, and we'll have zero adjoint projection to it.
// 2. Otherwise, we do not need to zero the projection buffer.
// Consider '%X = begin_access [modify] [static] %Y' where %Y is
// not loop-local - adjoint buffer for projection %X should not
// be zeroed.
if (!getAdjointProjection(bb, bbActiveValue)) {
LLVM_DEBUG(llvm::dbgs()
<< "zeroing its adjoint buffer in loop header: "
<< bbActiveValue);
builder.emitZeroIntoBuffer(
pbLoc, getAdjointBuffer(bb, bbActiveValue), IsInitialization);
} else {
LLVM_DEBUG(llvm::dbgs()
<< "but it is an address projection, skipping: "
<< bbActiveValue);
}
}
}
}
LLVM_DEBUG(getADDebugStream()
<< "End search for adjoints of loop-local active values\n");

for (auto *bb : originalBlocks) {
visitSILBasicBlock(bb);
if (errorOccurred)
Expand Down
24 changes: 12 additions & 12 deletions test/AutoDiff/SILOptimizer/pullback_generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -182,19 +182,19 @@ func f4(a: NonTrivial) -> Float {
}

// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
// CHECK: %88 = alloc_stack $NonTrivial
// CHECK: bb5(%[[#ARG0:]] : @owned $NonTrivial, %[[#]] : $Float, %[[#]] : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
// CHECK: %[[#T0:]] = alloc_stack $NonTrivial

// Non-trivial value must be copied or there will be a
// double consume when all owned parameters are destroyed
// at the end of the basic block.
// CHECK: %89 = copy_value %67 : $NonTrivial

// CHECK: store %89 to [init] %88 : $*NonTrivial
// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
// CHECK: %92 = alloc_stack $Float
// CHECK: store %86 to [trivial] %92 : $*Float
// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %95 = metatype $@thick Float.Type
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_value %67 : $NonTrivial
// CHECK: %[[#T1:]] = copy_value %[[#ARG0]] : $NonTrivial

// CHECK: store %[[#T1]] to [init] %[[#T0]] : $*NonTrivial
// CHECK: %[[#T2:]] = struct_element_addr %[[#T0]] : $*NonTrivial, #NonTrivial.x
// CHECK: %[[#T3:]] = alloc_stack $Float
// CHECK: store %[[#T4:]] to [trivial] %[[#T3]] : $*Float
// CHECK: %[[#T5:]] = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %[[#T6:]] = metatype $@thick Float.Type
// CHECK: %[[#]] = apply %[[#T5]]<Float>(%[[#T2]], %[[#T3]], %[[#T6]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_value %[[#ARG0]] : $NonTrivial
Loading

0 comments on commit 1590270

Please sign in to comment.