Skip to content

Commit

Permalink
[AutoDiff] Fix adjoints for loop-local active values
Browse files Browse the repository at this point in the history
  • Loading branch information
kovdan01 committed Dec 27, 2024
1 parent 55189ba commit 77cc3f9
Show file tree
Hide file tree
Showing 4 changed files with 558 additions and 37 deletions.
160 changes: 147 additions & 13 deletions lib/SILOptimizer/Differentiation/PullbackCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,23 @@ PullbackCloner::PullbackCloner(VJPCloner &vjpCloner)

PullbackCloner::~PullbackCloner() { delete &impl; }

static SILValue getArrayValue(ApplyInst *ai) {
SILValue arrayValue;
for (auto use : ai->getUses()) {
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
if (!dti)
continue;
assert(!arrayValue && "Array value already found");
// The first `destructure_tuple` result is the `Array` value.
arrayValue = dti->getResult(0);
#ifdef NDEBUG
break;
#endif
}
assert(arrayValue);
return arrayValue;
}

//--------------------------------------------------------------------------//
// Entry point
//--------------------------------------------------------------------------//
Expand Down Expand Up @@ -2456,6 +2473,133 @@ bool PullbackCloner::Implementation::run() {
// Visit original blocks in post-order and perform differentiation
// in corresponding pullback blocks. If errors occurred, back out.
else {
LLVM_DEBUG(getADDebugStream()
<< "Begin search for adjoints of loop-local active values\n");
llvm::DenseMap<const SILLoop *, llvm::DenseSet<SILValue>>
loopLocalActiveValues;
for (auto *bb : originalBlocks) {
const SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb);
if (loop == nullptr)
continue;
SILBasicBlock *loopHeader = loop->getHeader();
SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader);
LLVM_DEBUG(getADDebugStream()
<< "Original bb" << bb->getDebugID()
<< " belongs to a loop, original header bb"
<< loopHeader->getDebugID() << ", pullback header bb"
<< pbLoopHeader->getDebugID() << '\n');
builder.setInsertionPoint(pbLoopHeader);
auto bbActiveValuesIt = activeValues.find(bb);
if (bbActiveValuesIt == activeValues.end())
continue;
const auto &bbActiveValues = bbActiveValuesIt->second;
for (SILValue bbActiveValue : bbActiveValues) {
if (vjpCloner.getLoopInfo()->getLoopFor(
bbActiveValue->getParentBlock()) != loop) {
LLVM_DEBUG(
getADDebugStream()
<< "The following active value is NOT loop-local, skipping: "
<< bbActiveValue);
continue;
}

auto [_, wasInserted] =
loopLocalActiveValues[loop].insert(bbActiveValue);
LLVM_DEBUG(getADDebugStream()
<< "The following active value is loop-local, ");
if (!wasInserted) {
LLVM_DEBUG(llvm::dbgs() << "but it was already processed, skipping: "
<< bbActiveValue);
continue;
}

if (getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Object) {
LLVM_DEBUG(llvm::dbgs()
<< "zeroing its adjoint value in loop header: "
<< bbActiveValue);
setAdjointValue(bb, bbActiveValue,
makeZeroAdjointValue(getRemappedTangentType(
bbActiveValue->getType())));
continue;
}

assert(getTangentValueCategory(bbActiveValue) ==
SILValueCategory::Address);

// getAdjointProjection might call materializeAdjointDirect which
// writes to debug output, emit \n.
LLVM_DEBUG(llvm::dbgs()
<< "checking if it's adjoint is a projection\n");

if (!getAdjointProjection(bb, bbActiveValue)) {
LLVM_DEBUG(getADDebugStream()
<< "Adjoint for the following value is NOT a projection, "
"zeroing its adjoint buffer in loop header: "
<< bbActiveValue);

builder.emitZeroIntoBuffer(pbLoc, getAdjointBuffer(bb, bbActiveValue),
IsInitialization);

continue;
}

LLVM_DEBUG(getADDebugStream()
<< "Adjoint for the following value is a projection, ");

// If Projection::isAddressProjection(v) is true for a value v, it
// is not added to active values list (see recordValueIfActive).
//
// Ensure that only the following value types conforming to
// getAdjointProjection but not conforming to
// Projection::isAddressProjection can go here.
//
// Instructions conforming to Projection::isAddressProjection and
// thus never corresponding to an active value do not need any
// handling, because only active values can have adjoints from
// previous iterations propagated via BB arguments.
do {
// Consider '%X = begin_access [modify] [static] %Y'.
// 1. If %Y is loop-local, it's adjoint buffer will
// be zeroed, and we'll have zero adjoint projection to it.
// 2. Otherwise, we do not need to zero the projection buffer.
// Thus, we can just skip.
if (dyn_cast<BeginAccessInst>(bbActiveValue)) {
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
break;
}

// Consider the following sequence:
// %1 = function_ref @allocUninitArray
// %2 = apply %1<Float>(%0)
// (%3, %4) = destructure_tuple %2
// %5 = mark_dependence %4 on %3
// %6 = pointer_to_address %6 to [strict] $*Float
// Since %6 is active, %3 (which is an array) must also be active.
// Thus, adjoint for %3 will be zeroed if needed. Ensure that expected
// invariants hold and then skip.
if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress(
bbActiveValue)) {
#ifndef NDEBUG
assert(isa<PointerToAddressInst>(bbActiveValue));

SILValue arrayValue = getArrayValue(ai);
assert(llvm::find(bbActiveValues, arrayValue) !=
bbActiveValues.end());
assert(vjpCloner.getLoopInfo()->getLoopFor(
arrayValue->getParentBlock()) == loop);
#endif
LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue);
break;
}

assert(false);
} while (false);
}
}
LLVM_DEBUG(getADDebugStream()
<< "End search for adjoints of loop-local active values\n");

for (auto *bb : originalBlocks) {
visitSILBasicBlock(bb);
if (errorOccurred)
Expand Down Expand Up @@ -3371,19 +3515,9 @@ SILValue PullbackCloner::Implementation::getAdjointProjection(
eltIndex = ili->getValue().getLimitedValue();
}
// Get the array adjoint value.
SILValue arrayAdjoint;
assert(ai && "Expected `array.uninitialized_intrinsic` application");
for (auto use : ai->getUses()) {
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
if (!dti)
continue;
assert(!arrayAdjoint && "Array adjoint already found");
// The first `destructure_tuple` result is the `Array` value.
auto arrayValue = dti->getResult(0);
arrayAdjoint = materializeAdjointDirect(
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
}
assert(arrayAdjoint && "Array does not have adjoint value");
SILValue arrayValue = getArrayValue(ai);
SILValue arrayAdjoint = materializeAdjointDirect(
getAdjointValue(origBB, arrayValue), definingInst->getLoc());
// Apply `Array.TangentVector.subscript` to get array element adjoint value.
auto *eltAdjBuffer =
getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc());
Expand Down
24 changes: 12 additions & 12 deletions test/AutoDiff/SILOptimizer/pullback_generation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -182,19 +182,19 @@ func f4(a: NonTrivial) -> Float {
}

// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
// CHECK: %88 = alloc_stack $NonTrivial
// CHECK: bb5(%[[#ARG0:]] : @owned $NonTrivial, %[[#]] : $Float, %[[#]] : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
// CHECK: %[[#T0:]] = alloc_stack $NonTrivial

// Non-trivial value must be copied or there will be a
// double consume when all owned parameters are destroyed
// at the end of the basic block.
// CHECK: %89 = copy_value %67 : $NonTrivial

// CHECK: store %89 to [init] %88 : $*NonTrivial
// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
// CHECK: %92 = alloc_stack $Float
// CHECK: store %86 to [trivial] %92 : $*Float
// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %95 = metatype $@thick Float.Type
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_value %67 : $NonTrivial
// CHECK: %[[#T1:]] = copy_value %[[#ARG0]] : $NonTrivial

// CHECK: store %[[#T1]] to [init] %[[#T0]] : $*NonTrivial
// CHECK: %[[#T2:]] = struct_element_addr %[[#T0]] : $*NonTrivial, #NonTrivial.x
// CHECK: %[[#T3:]] = alloc_stack $Float
// CHECK: store %[[#T4:]] to [trivial] %[[#T3]] : $*Float
// CHECK: %[[#T5:]] = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: %[[#T6:]] = metatype $@thick Float.Type
// CHECK: %[[#]] = apply %[[#T5]]<Float>(%[[#T2]], %[[#T3]], %[[#T6]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
// CHECK: destroy_value %[[#ARG0]] : $NonTrivial
Loading

0 comments on commit 77cc3f9

Please sign in to comment.