From de1b516923934339b4c3abd0cbbb9803e273e510 Mon Sep 17 00:00:00 2001 From: Daniil Kovalev Date: Thu, 19 Dec 2024 02:26:06 +0300 Subject: [PATCH] [AutoDiff] Fix adjoints for loop-local active values Fixes #78264 --- .../Differentiation/PullbackCloner.cpp | 157 ++++++++- .../SILOptimizer/pullback_generation.swift | 24 +- .../pullback_generation_loop_adjoints.swift | 316 ++++++++++++++++++ .../validation-test/control_flow.swift | 95 +++++- 4 files changed, 555 insertions(+), 37 deletions(-) create mode 100644 test/AutoDiff/SILOptimizer/pullback_generation_loop_adjoints.swift diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 1ea2cc3ebbef6..66a3f09a93af1 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -2112,6 +2112,20 @@ PullbackCloner::PullbackCloner(VJPCloner &vjpCloner) PullbackCloner::~PullbackCloner() { delete &impl; } +static SILValue getArrayValue(ApplyInst *ai) { + SILValue arrayValue; + for (auto use : ai->getUses()) { + auto *dti = dyn_cast(use->getUser()); + if (!dti) + continue; + assert(!arrayValue && "Array value already found"); + // The first `destructure_tuple` result is the `Array` value. + arrayValue = dti->getResult(0); + } + assert(arrayValue); + return arrayValue; +} + //--------------------------------------------------------------------------// // Entry point //--------------------------------------------------------------------------// @@ -2456,6 +2470,133 @@ bool PullbackCloner::Implementation::run() { // Visit original blocks in post-order and perform differentiation // in corresponding pullback blocks. If errors occurred, back out. else { + LLVM_DEBUG(getADDebugStream() + << "Begin search for adjoints of loop-local active values\n"); + llvm::DenseMap> + loopLocalActiveValues; + for (auto *bb : originalBlocks) { + const SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb); + if (loop == nullptr) + continue; + SILBasicBlock *loopHeader = loop->getHeader(); + SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader); + LLVM_DEBUG(getADDebugStream() + << "Original bb" << bb->getDebugID() + << " belongs to a loop, original header bb" + << loopHeader->getDebugID() << ", pullback header bb" + << pbLoopHeader->getDebugID() << '\n'); + builder.setInsertionPoint(pbLoopHeader); + auto bbActiveValuesIt = activeValues.find(bb); + if (bbActiveValuesIt == activeValues.end()) + continue; + const auto &bbActiveValues = bbActiveValuesIt->second; + for (SILValue bbActiveValue : bbActiveValues) { + if (vjpCloner.getLoopInfo()->getLoopFor( + bbActiveValue->getParentBlock()) != loop) { + LLVM_DEBUG( + getADDebugStream() + << "The following active value is NOT loop-local, skipping: " + << bbActiveValue); + continue; + } + + auto [_, wasInserted] = + loopLocalActiveValues[loop].insert(bbActiveValue); + LLVM_DEBUG(getADDebugStream() + << "The following active value is loop-local, "); + if (!wasInserted) { + LLVM_DEBUG(llvm::dbgs() << "but it was already processed, skipping: " + << bbActiveValue); + continue; + } + + if (getTangentValueCategory(bbActiveValue) == + SILValueCategory::Object) { + LLVM_DEBUG(llvm::dbgs() + << "zeroing its adjoint value in loop header: " + << bbActiveValue); + setAdjointValue(bb, bbActiveValue, + makeZeroAdjointValue(getRemappedTangentType( + bbActiveValue->getType()))); + continue; + } + + assert(getTangentValueCategory(bbActiveValue) == + SILValueCategory::Address); + + // getAdjointProjection might call materializeAdjointDirect which + // writes to debug output, emit \n. + LLVM_DEBUG(llvm::dbgs() + << "checking if it's adjoint is a projection\n"); + + if (!getAdjointProjection(bb, bbActiveValue)) { + LLVM_DEBUG(getADDebugStream() + << "Adjoint for the following value is NOT a projection, " + "zeroing its adjoint buffer in loop header: " + << bbActiveValue); + + builder.emitZeroIntoBuffer(pbLoc, getAdjointBuffer(bb, bbActiveValue), + IsInitialization); + + continue; + } + + LLVM_DEBUG(getADDebugStream() + << "Adjoint for the following value is a projection, "); + + // If Projection::isAddressProjection(v) is true for a value v, it + // is not added to active values list (see recordValueIfActive). + // + // Ensure that only the following value types conforming to + // getAdjointProjection but not conforming to + // Projection::isAddressProjection can go here. + // + // Instructions conforming to Projection::isAddressProjection and + // thus never corresponding to an active value do not need any + // handling, because only active values can have adjoints from + // previous iterations propagated via BB arguments. + do { + // Consider '%X = begin_access [modify] [static] %Y'. + // 1. If %Y is loop-local, it's adjoint buffer will + // be zeroed, and we'll have zero adjoint projection to it. + // 2. Otherwise, we do not need to zero the projection buffer. + // Thus, we can just skip. + if (dyn_cast(bbActiveValue)) { + LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue); + break; + } + + // Consider the following sequence: + // %1 = function_ref @allocUninitArray + // %2 = apply %1(%0) + // (%3, %4) = destructure_tuple %2 + // %5 = mark_dependence %4 on %3 + // %6 = pointer_to_address %6 to [strict] $*Float + // Since %6 is active, %3 (which is an array) must also be active. + // Thus, adjoint for %3 will be zeroed if needed. Ensure that expected + // invariants hold and then skip. + if (auto *ai = getAllocateUninitializedArrayIntrinsicElementAddress( + bbActiveValue)) { +#ifndef NDEBUG + assert(isa(bbActiveValue)); + + SILValue arrayValue = getArrayValue(ai); + assert(llvm::find(bbActiveValues, arrayValue) != + bbActiveValues.end()); + assert(vjpCloner.getLoopInfo()->getLoopFor( + arrayValue->getParentBlock()) == loop); +#endif + LLVM_DEBUG(llvm::dbgs() << "skipping: " << bbActiveValue); + break; + } + + assert(false); + } while (false); + } + } + LLVM_DEBUG(getADDebugStream() + << "End search for adjoints of loop-local active values\n"); + for (auto *bb : originalBlocks) { visitSILBasicBlock(bb); if (errorOccurred) @@ -3371,19 +3512,9 @@ SILValue PullbackCloner::Implementation::getAdjointProjection( eltIndex = ili->getValue().getLimitedValue(); } // Get the array adjoint value. - SILValue arrayAdjoint; - assert(ai && "Expected `array.uninitialized_intrinsic` application"); - for (auto use : ai->getUses()) { - auto *dti = dyn_cast(use->getUser()); - if (!dti) - continue; - assert(!arrayAdjoint && "Array adjoint already found"); - // The first `destructure_tuple` result is the `Array` value. - auto arrayValue = dti->getResult(0); - arrayAdjoint = materializeAdjointDirect( - getAdjointValue(origBB, arrayValue), definingInst->getLoc()); - } - assert(arrayAdjoint && "Array does not have adjoint value"); + SILValue arrayValue = getArrayValue(ai); + SILValue arrayAdjoint = materializeAdjointDirect( + getAdjointValue(origBB, arrayValue), definingInst->getLoc()); // Apply `Array.TangentVector.subscript` to get array element adjoint value. auto *eltAdjBuffer = getArrayAdjointElementBuffer(arrayAdjoint, eltIndex, ai->getLoc()); diff --git a/test/AutoDiff/SILOptimizer/pullback_generation.swift b/test/AutoDiff/SILOptimizer/pullback_generation.swift index b92403eff7a54..1151b4b786f84 100644 --- a/test/AutoDiff/SILOptimizer/pullback_generation.swift +++ b/test/AutoDiff/SILOptimizer/pullback_generation.swift @@ -182,19 +182,19 @@ func f4(a: NonTrivial) -> Float { } // CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial { -// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)): -// CHECK: %88 = alloc_stack $NonTrivial +// CHECK: bb5(%[[#ARG0:]] : @owned $NonTrivial, %[[#]] : $Float, %[[#]] : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)): +// CHECK: %[[#T0:]] = alloc_stack $NonTrivial // Non-trivial value must be copied or there will be a // double consume when all owned parameters are destroyed // at the end of the basic block. -// CHECK: %89 = copy_value %67 : $NonTrivial - -// CHECK: store %89 to [init] %88 : $*NonTrivial -// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x -// CHECK: %92 = alloc_stack $Float -// CHECK: store %86 to [trivial] %92 : $*Float -// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () -// CHECK: %95 = metatype $@thick Float.Type -// CHECK: %96 = apply %94(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () -// CHECK: destroy_value %67 : $NonTrivial +// CHECK: %[[#T1:]] = copy_value %[[#ARG0]] : $NonTrivial + +// CHECK: store %[[#T1]] to [init] %[[#T0]] : $*NonTrivial +// CHECK: %[[#T2:]] = struct_element_addr %[[#T0]] : $*NonTrivial, #NonTrivial.x +// CHECK: %[[#T3:]] = alloc_stack $Float +// CHECK: store %[[#T4:]] to [trivial] %[[#T3]] : $*Float +// CHECK: %[[#T5:]] = witness_method $Float, #AdditiveArithmetic."+=" : (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () +// CHECK: %[[#T6:]] = metatype $@thick Float.Type +// CHECK: %[[#]] = apply %[[#T5]](%[[#T2]], %[[#T3]], %[[#T6]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () +// CHECK: destroy_value %[[#ARG0]] : $NonTrivial diff --git a/test/AutoDiff/SILOptimizer/pullback_generation_loop_adjoints.swift b/test/AutoDiff/SILOptimizer/pullback_generation_loop_adjoints.swift new file mode 100644 index 0000000000000..c8cc1af2026b9 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/pullback_generation_loop_adjoints.swift @@ -0,0 +1,316 @@ +// RUN: %target-swift-frontend -emit-sil -Xllvm --sil-print-after=differentiation %s 2>&1 | %FileCheck %s +// RUN: %target-swift-frontend -emit-sil -Xllvm --debug-only=differentiation %s 2>&1 | %FileCheck --check-prefix=DEBUG %s + +// Needed for '--debug-only' +// REQUIRES: asserts + +import _Differentiation + +@differentiable(reverse) +func repeat_while_loop(x: Float) -> Float { + var result : Float + repeat { + result = x + 0 + } while 0 == 1 + return result +} + +// DEBUG-LABEL: [AD] Running PullbackCloner on +// DEBUG-NEXT: // repeat_while_loop +// DEBUG: [AD] Begin search for adjoints of loop-local active values + +// CHECK-LABEL: // pullback of repeat_while_loop(x:) + +// DEBUG-NEXT: [AD] Original bb1 belongs to a loop, original header bb1, pullback header bb3 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#AARG0:]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#A0:]] = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#A1:]] = apply %[[#A2:]](%[[#AARG0]], %[[#A3:]], %[[#A4:]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#A1]] = apply %[[#A2]](%[[#AARG0]], %[[#A3]], %[[#A4]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#A5:]] = begin_access [modify] [static] %[[#A0]] : $*Float + +// CHECK: bb3(%[[ARG31:[0-9]+]] : $Float, %[[ARG32:[0-9]+]] : $Float, %[[ARG33:[0-9]+]] : @owned $(predecessor: _AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float)): +// CHECK: (%[[T01:[0-9]+]], %[[T02:[0-9]+]]) = destructure_tuple %[[ARG33]] +// CHECK: %[[T03:[0-9]+]] = load [trivial] %[[V1:[0-9]+]] + +/// Ensure that we do not add adjoint from the previous iteration +/// The incorrect SIL looked like the following: +/// %[[T03]] = load [trivial] %[[V1]] +/// store %[[#B08]] to [trivial] %[[#B13]] // <-- we check absence of this +/// store %[[T03]] to [trivial] %[[#B12]] +/// %62 = witness_method $Float, #AdditiveArithmetic."+" +/// %63 = metatype $@thick Float.Type +/// %64 = apply %62(%[[#]], %[[#B12]], %[[#B13]], %63) +// CHECK-NOT: store %[[ARG32]] to [trivial] %[[#]] + +// CHECK: %[[T04:[0-9]+]] = witness_method $Float, #AdditiveArithmetic.zero!getter : (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 +// CHECK: %[[T05:[0-9]+]] = metatype $@thick Float.Type +// CHECK: %[[#]] = apply %[[T04]](%[[V1]], %[[T05]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 + +/// It's crucial that we call pullback with T03 which does not contain adjoints from previous iterations +// CHECK: %[[T07:[0-9]+]] = apply %[[T02]](%[[T03]]) : $@callee_guaranteed (Float) -> Float + +// CHECK: destroy_value %[[T02]] +// CHECK: %[[T08:[0-9]+]] = alloc_stack $Float +// CHECK: %[[T09:[0-9]+]] = alloc_stack $Float +// CHECK: %[[T10:[0-9]+]] = alloc_stack $Float +// CHECK: store %[[ARG31]] to [trivial] %[[T09]] +// CHECK: store %[[T07]] to [trivial] %[[T10]] +// CHECK: %[[T11:[0-9]+]] = witness_method $Float, #AdditiveArithmetic."+" : (Self.Type) -> (Self, Self) -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 +// CHECK: %[[T12:[0-9]+]] = metatype $@thick Float.Type +// CHECK: %[[#]] = apply %[[T11]](%[[T08]], %[[T10]], %[[T09]], %[[T12]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 +// CHECK: destroy_addr %[[T09]] +// CHECK: destroy_addr %[[T10]] +// CHECK: dealloc_stack %[[T10]] +// CHECK: dealloc_stack %[[T09]] +// CHECK: %[[T14:[0-9]+]] = load [trivial] %[[T08]] +// CHECK: dealloc_stack %[[T08]] +// CHECK: debug_value %[[T14]], let, name "x", argno 1 +// CHECK: copy_addr %[[V1]] to %[[V2:[0-9]+]] +// CHECK: debug_value %[[T14]], let, name "x", argno 1 +// CHECK: copy_addr %[[V1]] to %[[V3:[0-9]+]] +// CHECK: switch_enum %[[T01]], case #_AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb1__Pred__src_0_wrt_0.bb2!enumelt: bb4, case #_AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb1__Pred__src_0_wrt_0.bb0!enumelt: bb6, forwarding: @owned + +// CHECK: bb4(%[[ARG41:[0-9]+]] : $Builtin.RawPointer): +// CHECK: %[[T15:[0-9]+]] = pointer_to_address %[[ARG41]] to [strict] $*(predecessor: _AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb2__Pred__src_0_wrt_0) +// CHECK: %[[T16:[0-9]+]] = load [trivial] %[[T15]] +// CHECK: br bb5(%[[T14]], %[[T03]], %[[T16]]) + +// DEBUG-NEXT: [AD] Original bb2 belongs to a loop, original header bb1, pullback header bb3 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#AARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#A0]] = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#A1]] = apply %[[#A2]](%[[#AARG0]], %[[#A3]], %[[#A4]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#A5]] = begin_access [modify] [static] %[[#A0]] : $*Float + +// CHECK: bb5(%[[ARG51:[0-9]+]] : $Float, %[[ARG52:[0-9]+]] : $Float, %[[ARG53:[0-9]+]] : $(predecessor: _AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb2__Pred__src_0_wrt_0)): + +/// Ensure that we do not zero adjoints for non-loop-local active values +// CHECK-NOT: witness_method $Float, #AdditiveArithmetic.zero!getter + +// CHECK: %[[T17:[0-9]+]] = destructure_tuple %[[ARG53]] +// CHECK: debug_value %[[ARG51]], let, name "x", argno 1 +// CHECK: copy_addr %[[V2]] to %[[V1]] +// CHECK: switch_enum %[[T17]], case #_AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb2__Pred__src_0_wrt_0.bb1!enumelt: bb1, forwarding: @owned + +// DEBUG-NEXT: [AD] End search for adjoints of loop-local active values + +@differentiable(reverse) +func repeat_while_loop_nested(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + var temp = x + var j = 0 + repeat { + temp = temp * x + j += 1 + } while j < 2 + result = result * temp + i += 1 + } while i < 2 + return result +} + +// DEBUG-LABEL: [AD] Running PullbackCloner on +// DEBUG-NEXT: // repeat_while_loop_nested +// DEBUG: [AD] Begin search for adjoints of loop-local active values + +// DEBUG-NEXT: [AD] Original bb4 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#BARG0:]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B00:]] = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is NOT a projection, zeroing its adjoint buffer in loop header: %[[#B01:]] = alloc_stack [var_decl] $Float, var, name "temp", type $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B02:]] = begin_access [read] [static] %[[#B01]] : $*Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B03:]] = load [trivial] %[[#B02]] : $*Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B04:]] = apply %[[#B05:]](%[[#B03]], %[[#BARG0]], %[[#B06:]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B07:]] = begin_access [modify] [static] %[[#B01]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#B08:]] = begin_access [read] [static] %[[#B00]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#B09:]] = load [trivial] %[[#B08]] : $*Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#B09]] = load [trivial] %[[#B08]] : $*Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#B10:]] = begin_access [read] [static] %[[#B01]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#B11:]] = load [trivial] %[[#B10]] : $*Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#B11]] = load [trivial] %[[#B10]] : $*Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#B12:]] = apply %[[#B13:]](%[[#B09]], %[[#B11]], %[[#B14:]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#B12]] = apply %[[#B13]](%[[#B09]], %[[#B11]], %[[#B14]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#B15:]] = begin_access [modify] [static] %[[#B00]] : $*Float + +// DEBUG-NEXT: [AD] Original bb2 belongs to a loop, original header bb2, pullback header bb6 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#BARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B00]] = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B01]] = alloc_stack [var_decl] $Float, var, name "temp", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#B02]] = begin_access [read] [static] %[[#B01]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#B03]] = load [trivial] %[[#B02]] : $*Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#B03]] = load [trivial] %[[#B02]] : $*Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#B04]] = apply %[[#B05]](%[[#B03]], %[[#BARG0]], %[[#B06]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#B04]] = apply %[[#B05]](%[[#B03]], %[[#BARG0]], %[[#B06]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#B07]] = begin_access [modify] [static] %[[#B01]] : $*Float + +// DEBUG-NEXT: [AD] Original bb3 belongs to a loop, original header bb2, pullback header bb6 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#BARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B00]] = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B01]] = alloc_stack [var_decl] $Float, var, name "temp", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B02]] = begin_access [read] [static] %[[#B01]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B03]] = load [trivial] %[[#B02]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B04]] = apply %[[#B05]](%[[#B03]], %[[#BARG0]], %[[#B06]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B07]] = begin_access [modify] [static] %[[#B01]] : $*Float + +// DEBUG-NEXT: [AD] Original bb1 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#BARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B00]] = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B01]] = alloc_stack [var_decl] $Float, var, name "temp", type $Float + +// DEBUG-NEXT: [AD] Original bb5 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#BARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B00]] = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B01]] = alloc_stack [var_decl] $Float, var, name "temp", type $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B02]] = begin_access [read] [static] %[[#B01]] : $*Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B03]] = load [trivial] %[[#B02]] : $*Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B04]] = apply %[[#B05]](%[[#B03]], %[[#BARG0]], %[[#B06]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#B07]] = begin_access [modify] [static] %[[#B01]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B08]] = begin_access [read] [static] %[[#B00]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B09]] = load [trivial] %[[#B08]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B10]] = begin_access [read] [static] %[[#B01]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B11]] = load [trivial] %[[#B10]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B12]] = apply %[[#B13]](%[[#B09]], %[[#B11]], %[[#B14]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but it was already processed, skipping: %[[#B15]] = begin_access [modify] [static] %[[#B00]] : $*Float + +// DEBUG-NEXT: [AD] End search for adjoints of loop-local active values + +@differentiable(reverse) +func repeat_while_loop_condition(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + if i == 2 { + break + } + if i == 0 { + result = result * x + } else { + result = result * result + } + i += 1 + } while i < 10 + return result +} + +// DEBUG-LABEL: [AD] Running PullbackCloner on +// DEBUG-NEXT: // repeat_while_loop_condition +// DEBUG: [AD] Begin search for adjoints of loop-local active values + +// DEBUG-NEXT: [AD] Original bb6 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#CARG0:]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#C00:]] = alloc_stack [var_decl] $Float, var, name "result", type $Float + +// DEBUG-NEXT: [AD] Original bb1 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#CARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#C00]] = alloc_stack [var_decl] $Float, var, name "result", type $Float + +// DEBUG-NEXT: [AD] Original bb5 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#CARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#C00]] = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#C01:]] = begin_access [read] [static] %[[#C00]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#C02:]] = load [trivial] %[[#C01]] : $*Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#C02]] = load [trivial] %[[#C01]] : $*Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#C03:]] = begin_access [read] [static] %[[#C00]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#C04:]] = load [trivial] %[[#C03]] : $*Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#C04]] = load [trivial] %[[#C03]] : $*Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#C05:]] = apply %[[#C06:]](%[[#C02]], %[[#C04]], %[[#C07:]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#C05]] = apply %[[#C06]](%[[#C02]], %[[#C04]], %[[#C07]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#]] = begin_access [modify] [static] %[[#C00]] : $*Float + +// DEBUG-NEXT: [AD] Original bb4 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#CARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#C00]] = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#C08:]] = begin_access [read] [static] %[[#C00]] : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#C09:]] = load [trivial] %[[#C08]] : $*Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#C09]] = load [trivial] %[[#C08]] : $*Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#C10:]] = apply %[[#C11:]](%[[#C09]], %[[#CARG0]], %[[#C12:]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#C10]] = apply %[[#C11]](%[[#C09]], %[[#CARG0]], %[[#C12]]) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#]] = begin_access [modify] [static] %[[#C00]] : $*Float + +// DEBUG-NEXT: [AD] Original bb7 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#CARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#C00]] = alloc_stack [var_decl] $Float, var, name "result", type $Float + +// DEBUG-NEXT: [AD] Original bb3 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#CARG0]] = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#C00]] = alloc_stack [var_decl] $Float, var, name "result", type $Float + +// DEBUG-NEXT: [AD] End search for adjoints of loop-local active values + +typealias FloatArrayTan = Array.TangentVector + +func identity(_ array: [Float]) -> [Float] { + var results: [Float] = [] + for i in withoutDerivative(at: array.indices) { + results += [array[i]] + } + return results +} + +pullback(at: [1, 2, 3], of: identity)(FloatArrayTan([4, -5, 6])) + +// DEBUG-LABEL: [AD] Running PullbackCloner on +// DEBUG-NEXT: // identity +// DEBUG: [AD] Begin search for adjoints of loop-local active values + +// DEBUG-NEXT: [AD] Original bb1 belongs to a loop, original header bb1, pullback header bb3 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#DARG0:]] = argument of bb0 : $Array +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#D0:]] = alloc_stack [var_decl] $Array, var, name "results", type $Array + +// DEBUG-NEXT: [AD] Original bb2 belongs to a loop, original header bb1, pullback header bb3 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#DARG0:]] = argument of bb0 : $Array +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %[[#D0]] = alloc_stack [var_decl] $Array, var, name "results", type $Array +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#D1:]] = apply %[[#]](%[[#]]) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer) +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#D:]] = apply %[[#]](%[[#]]) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer) +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Array.DifferentiableView] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: (**%[[#D2:]]**, %[[#D3:]]) = destructure_tuple %[[#D1]] : $(Array, Builtin.RawPointer) +// DEBUG-NEXT: [AD] Setting adjoint value for (**%[[#D2]]**, %[[#D3]]) = destructure_tuple %[[#D1]] : $(Array, Builtin.RawPointer) +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Array.DifferentiableView] +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Materializing adjoint for Zero[$Array.DifferentiableView] +// DEBUG-NEXT: [AD] Recorded temporary %[[#]] = load [take] %[[#]] : $*Array.DifferentiableView +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#D4:]] = pointer_to_address %[[#D5:]] : $Builtin.RawPointer to [strict] $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint value in loop header: %[[#D6:]] = apply %[[#]](%[[#D2]]) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> +// DEBUG-NEXT: [AD] Setting adjoint value for %[[#D6]] = apply %[[#]](%[[#D2]]) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Array.DifferentiableView] +// DEBUG-NEXT: [AD] The following active value is loop-local, checking if it's adjoint is a projection +// DEBUG-NEXT: [AD] Adjoint for the following value is a projection, skipping: %[[#]] = begin_access [modify] [static] %[[#D0]] : $*Array + +// DEBUG-NEXT: [AD] End search for adjoints of loop-local active values diff --git a/test/AutoDiff/validation-test/control_flow.swift b/test/AutoDiff/validation-test/control_flow.swift index 915049b0a5236..70f5260aa7310 100644 --- a/test/AutoDiff/validation-test/control_flow.swift +++ b/test/AutoDiff/validation-test/control_flow.swift @@ -628,12 +628,8 @@ ControlFlowTests.test("Loops") { } while i < 2 return result } - // FIXME(TF-584): Investigate incorrect (too big) gradient values for - // repeat-while loops. - // expectEqual((8, 12), valueWithGradient(at: 2, of: repeat_while_loop)) - // expectEqual((27, 27), valueWithGradient(at: 3, of: repeat_while_loop)) - expectEqual((8, 18), valueWithGradient(at: 2, of: repeat_while_loop)) - expectEqual((27, 36), valueWithGradient(at: 3, of: repeat_while_loop)) + expectEqual((8, 12), valueWithGradient(at: 2, of: repeat_while_loop)) + expectEqual((27, 27), valueWithGradient(at: 3, of: repeat_while_loop)) func repeat_while_loop_nonactive_initial_value(_ x: Float) -> Float { var result: Float = 1 @@ -644,12 +640,87 @@ ControlFlowTests.test("Loops") { } while i < 2 return result } - // FIXME(TF-584): Investigate incorrect (too big) gradient values for - // repeat-while loops. - // expectEqual((4, 4), valueWithGradient(at: 2, of: repeat_while_loop_nonactive_initial_value)) - // expectEqual((9, 6), valueWithGradient(at: 3, of: repeat_while_loop_nonactive_initial_value)) - expectEqual((4, 5), valueWithGradient(at: 2, of: repeat_while_loop_nonactive_initial_value)) - expectEqual((9, 7), valueWithGradient(at: 3, of: repeat_while_loop_nonactive_initial_value)) + expectEqual((4, 4), valueWithGradient(at: 2, of: repeat_while_loop_nonactive_initial_value)) + expectEqual((9, 6), valueWithGradient(at: 3, of: repeat_while_loop_nonactive_initial_value)) + + @differentiable(reverse) + func repeat_while_loop_nested_repeat(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + var temp = x + var j = 0 + repeat { + temp = temp * x + j += 1 + } while j < 2 + result = result * temp + i += 1 + } while i < 2 + return result + } + + expectEqual((128, 448), valueWithGradient(at: 2, of: repeat_while_loop_nested_repeat)) + expectEqual((2187, 5103), valueWithGradient(at: 3, of: repeat_while_loop_nested_repeat)) + + @differentiable(reverse) + func repeat_while_loop_nested_while(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + var temp = x + var j = 0 + while j < 2 { + temp = temp * x + j += 1 + } + result = result * temp + i += 1 + } while i < 2 + return result + } + + expectEqual((128, 448), valueWithGradient(at: 2, of: repeat_while_loop_nested_while)) + expectEqual((2187, 5103), valueWithGradient(at: 3, of: repeat_while_loop_nested_while)) + + @differentiable(reverse) + func repeat_while_loop_nested_for(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + var temp = x + for j in 0..<2 { + temp = temp * x + } + result = result * temp + i += 1 + } while i < 2 + return result + } + + expectEqual((128, 448), valueWithGradient(at: 2, of: repeat_while_loop_nested_for)) + expectEqual((2187, 5103), valueWithGradient(at: 3, of: repeat_while_loop_nested_for)) + + @differentiable(reverse) + func repeat_while_loop_condition(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + if i == 2 { + break + } + if i == 0 { + result = result * x + } else { + result = result * result + } + i += 1 + } while i < 10 + return result + } + + expectEqual((16, 32), valueWithGradient(at: 2, of: repeat_while_loop_condition)) + expectEqual((81, 108), valueWithGradient(at: 3, of: repeat_while_loop_condition)) func loop_continue(_ x: Float) -> Float { var result = x