From b17b3c3eedcf77e00ae035f94918aa0336aa18e9 Mon Sep 17 00:00:00 2001 From: Daniil Kovalev Date: Thu, 19 Dec 2024 02:26:06 +0300 Subject: [PATCH] [AutoDiff] Fix adjoints for loop-local active values Fixes #78264 --- .../Differentiation/PullbackCloner.cpp | 50 ++++ .../SILOptimizer/pullback_generation.swift | 20 +- .../pullback_generation_loop_adjoints.swift | 260 ++++++++++++++++++ .../validation-test/control_flow.swift | 95 ++++++- 4 files changed, 403 insertions(+), 22 deletions(-) create mode 100644 test/AutoDiff/SILOptimizer/pullback_generation_loop_adjoints.swift diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index 1ea2cc3ebbef6..822399bc5aed1 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -2456,6 +2456,56 @@ bool PullbackCloner::Implementation::run() { // Visit original blocks in post-order and perform differentiation // in corresponding pullback blocks. If errors occurred, back out. else { + LLVM_DEBUG(getADDebugStream() + << "Begin search for adjoints of loop-local active values\n"); + llvm::DenseMap> loopLocalActiveValues; + for (auto *bb : originalBlocks) { + SILLoop *loop = vjpCloner.getLoopInfo()->getLoopFor(bb); + if (loop == nullptr) + continue; + SILBasicBlock *loopHeader = loop->getHeader(); + SILBasicBlock *pbLoopHeader = getPullbackBlock(loopHeader); + LLVM_DEBUG(getADDebugStream() + << "Original bb" << bb->getDebugID() + << " belongs to a loop, original header bb" + << loopHeader->getDebugID() << ", pullback header bb" + << pbLoopHeader->getDebugID() << '\n'); + builder.setInsertionPoint(pbLoopHeader); + auto &bbActiveValues = activeValues[bb]; + for (SILValue bbActiveValue : bbActiveValues) { + if (vjpCloner.getLoopInfo()->getLoopFor( + bbActiveValue->getParentBlock()) != loop) { + LLVM_DEBUG( + getADDebugStream() + << "The following active value is NOT loop-local, skipping: " + << bbActiveValue); + continue; + } + auto [_, wasInserted] = + loopLocalActiveValues[loop].insert(bbActiveValue); + LLVM_DEBUG(getADDebugStream() + << "The following active value is loop-local, " + << (wasInserted ? "zeroing its adjoint in loop header: " + : "but its adjoint was already zeroed in " + "loop header, skipping: ") + << bbActiveValue); + if (!wasInserted) + continue; + if (getTangentValueCategory(bbActiveValue) == + SILValueCategory::Object) { + setAdjointValue(bb, bbActiveValue, + makeZeroAdjointValue(getRemappedTangentType( + bbActiveValue->getType()))); + } else { + assert(getTangentValueCategory(bbActiveValue) == + SILValueCategory::Address); + getAdjointBuffer(bb, bbActiveValue); + } + } + } + LLVM_DEBUG(getADDebugStream() + << "End search for adjoints of loop-local active values\n"); + for (auto *bb : originalBlocks) { visitSILBasicBlock(bb); if (errorOccurred) diff --git a/test/AutoDiff/SILOptimizer/pullback_generation.swift b/test/AutoDiff/SILOptimizer/pullback_generation.swift index b92403eff7a54..6247640cde867 100644 --- a/test/AutoDiff/SILOptimizer/pullback_generation.swift +++ b/test/AutoDiff/SILOptimizer/pullback_generation.swift @@ -183,18 +183,18 @@ func f4(a: NonTrivial) -> Float { // CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial { // CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)): -// CHECK: %88 = alloc_stack $NonTrivial +// CHECK: %74 = alloc_stack $NonTrivial // Non-trivial value must be copied or there will be a // double consume when all owned parameters are destroyed // at the end of the basic block. -// CHECK: %89 = copy_value %67 : $NonTrivial - -// CHECK: store %89 to [init] %88 : $*NonTrivial -// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x -// CHECK: %92 = alloc_stack $Float -// CHECK: store %86 to [trivial] %92 : $*Float -// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () -// CHECK: %95 = metatype $@thick Float.Type -// CHECK: %96 = apply %94(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () +// CHECK: %75 = copy_value %67 : $NonTrivial + +// CHECK: store %75 to [init] %74 : $*NonTrivial +// CHECK: %77 = struct_element_addr %74 : $*NonTrivial, #NonTrivial.x +// CHECK: %78 = alloc_stack $Float +// CHECK: store %72 to [trivial] %78 : $*Float +// CHECK: %80 = witness_method $Float, #AdditiveArithmetic."+=" : (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () +// CHECK: %81 = metatype $@thick Float.Type +// CHECK: %82 = apply %80(%77, %78, %81) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // CHECK: destroy_value %67 : $NonTrivial diff --git a/test/AutoDiff/SILOptimizer/pullback_generation_loop_adjoints.swift b/test/AutoDiff/SILOptimizer/pullback_generation_loop_adjoints.swift new file mode 100644 index 0000000000000..7b7c0dac3e1d3 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/pullback_generation_loop_adjoints.swift @@ -0,0 +1,260 @@ +// RUN: %target-swift-frontend -emit-sil -Xllvm --sil-print-after=differentiation %s 2>&1 | %FileCheck %s +// RUN: %target-swift-frontend -emit-sil -Xllvm --debug-only=differentiation %s 2>&1 | %FileCheck --check-prefix=DEBUG %s + +// Needed for '--debug-only' +// REQUIRES: asserts + +import _Differentiation + +@differentiable(reverse) +func repeat_while_loop(x: Float) -> Float { + var result : Float + repeat { + result = x + 0 + } while 0 == 1 + return result +} + +// DEBUG-LABEL: [AD] Running PullbackCloner on +// DEBUG-NEXT: // repeat_while_loop +// DEBUG: [AD] Begin search for adjoints of loop-local active values + +// CHECK-LABEL: // pullback of repeat_while_loop(x:) + +// DEBUG-NEXT: [AD] Original bb1 belongs to a loop, original header bb1, pullback header bb3 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %10 = apply %9(%0, %8, %4) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] Setting adjoint value for %10 = apply %9(%0, %8, %4) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %11 = begin_access [modify] [static] %2 : $*Float + +// CHECK: bb3(%[[ARG31:[0-9]+]] : $Float, %[[ARG32:[0-9]+]] : $Float, %[[ARG33:[0-9]+]] : @owned $(predecessor: _AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb1__Pred__src_0_wrt_0, @callee_guaranteed (Float) -> Float)): +// CHECK: (%[[T01:[0-9]+]], %[[T02:[0-9]+]]) = destructure_tuple %[[ARG33]] +// CHECK: %[[T03:[0-9]+]] = load [trivial] %[[V1:[0-9]+]] + +/// Ensure that we do not add adjoint from the previous iteration +/// The incorrect SIL looked like the following: +/// %[[T03]] = load [trivial] %[[V1]] +/// store %52 to [trivial] %58 // <-- we check absence of this +/// store %[[T03]] to [trivial] %59 +/// %62 = witness_method $Float, #AdditiveArithmetic."+" +/// %63 = metatype $@thick Float.Type +/// %64 = apply %62(%57, %59, %58, %63) +// CHECK-NOT: store %[[ARG32]] to [trivial] %[[#]] + +// CHECK: %[[T04:[0-9]+]] = witness_method $Float, #AdditiveArithmetic.zero!getter : (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 +// CHECK: %[[T05:[0-9]+]] = metatype $@thick Float.Type +// CHECK: %[[#]] = apply %[[T04]](%[[V1]], %[[T05]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 + +/// It's crucial that we call pullback with T03 which does not contain adjoints from previous iterations +// CHECK: %[[T07:[0-9]+]] = apply %[[T02]](%[[T03]]) : $@callee_guaranteed (Float) -> Float + +// CHECK: destroy_value %[[T02]] +// CHECK: %[[T08:[0-9]+]] = alloc_stack $Float +// CHECK: %[[T09:[0-9]+]] = alloc_stack $Float +// CHECK: %[[T10:[0-9]+]] = alloc_stack $Float +// CHECK: store %[[ARG31]] to [trivial] %[[T09]] +// CHECK: store %[[T07]] to [trivial] %[[T10]] +// CHECK: %[[T11:[0-9]+]] = witness_method $Float, #AdditiveArithmetic."+" : (Self.Type) -> (Self, Self) -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 +// CHECK: %[[T12:[0-9]+]] = metatype $@thick Float.Type +// CHECK: %[[#]] = apply %[[T11]](%[[T08]], %[[T10]], %[[T09]], %[[T12]]) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> @out τ_0_0 +// CHECK: destroy_addr %[[T09]] +// CHECK: destroy_addr %[[T10]] +// CHECK: dealloc_stack %[[T10]] +// CHECK: dealloc_stack %[[T09]] +// CHECK: %[[T14:[0-9]+]] = load [trivial] %[[T08]] +// CHECK: dealloc_stack %[[T08]] +// CHECK: debug_value %[[T14]], let, name "x", argno 1 +// CHECK: copy_addr %[[V1]] to %[[V2:[0-9]+]] +// CHECK: debug_value %[[T14]], let, name "x", argno 1 +// CHECK: copy_addr %[[V1]] to %[[V3:[0-9]+]] +// CHECK: switch_enum %[[T01]], case #_AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb1__Pred__src_0_wrt_0.bb2!enumelt: bb4, case #_AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb1__Pred__src_0_wrt_0.bb0!enumelt: bb6, forwarding: @owned + +// CHECK: bb4(%[[ARG41:[0-9]+]] : $Builtin.RawPointer): +// CHECK: %[[T15:[0-9]+]] = pointer_to_address %[[ARG41]] to [strict] $*(predecessor: _AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb2__Pred__src_0_wrt_0) +// CHECK: %[[T16:[0-9]+]] = load [trivial] %[[T15]] +// CHECK: br bb5(%[[T14]], %[[T03]], %[[T16]]) + +// DEBUG-NEXT: [AD] Original bb2 belongs to a loop, original header bb1, pullback header bb3 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %10 = apply %9(%0, %8, %4) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %11 = begin_access [modify] [static] %2 : $*Float + +// CHECK: bb5(%[[ARG51:[0-9]+]] : $Float, %[[ARG52:[0-9]+]] : $Float, %[[ARG53:[0-9]+]] : $(predecessor: _AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb2__Pred__src_0_wrt_0)): + +/// Ensure that we do not zero adjoints for non-loop-local active values +// CHECK-NOT: witness_method $Float, #AdditiveArithmetic.zero!getter + +// CHECK: %[[T17:[0-9]+]] = destructure_tuple %[[ARG53]] +// CHECK: debug_value %[[ARG51]], let, name "x", argno 1 +// CHECK: copy_addr %[[V2]] to %[[V1]] +// CHECK: switch_enum %[[T17]], case #_AD__$s33pullback_generation_loop_adjoints013repeat_while_C01xS2f_tF_bb2__Pred__src_0_wrt_0.bb1!enumelt: bb1, forwarding: @owned + +// DEBUG-NEXT: [AD] End search for adjoints of loop-local active values + +@differentiable(reverse) +func repeat_while_loop_nested(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + var temp = x + var j = 0 + repeat { + temp = temp * x + j += 1 + } while j < 2 + result = result * temp + i += 1 + } while i < 2 + return result +} + +// DEBUG-LABEL: [AD] Running PullbackCloner on +// DEBUG-NEXT: // repeat_while_loop_nested +// DEBUG: [AD] Begin search for adjoints of loop-local active values + +// DEBUG-NEXT: [AD] Original bb4 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %11 = alloc_stack [var_decl] $Float, var, name "temp", type $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %21 = begin_access [read] [static] %11 : $*Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %22 = load [trivial] %21 : $*Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %25 = apply %24(%22, %0, %20) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %26 = begin_access [modify] [static] %11 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %52 = begin_access [read] [static] %2 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %53 = load [trivial] %52 : $*Float +// DEBUG-NEXT: [AD] Setting adjoint value for %53 = load [trivial] %52 : $*Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %55 = begin_access [read] [static] %11 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %56 = load [trivial] %55 : $*Float +// DEBUG-NEXT: [AD] Setting adjoint value for %56 = load [trivial] %55 : $*Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %59 = apply %58(%53, %56, %51) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] Setting adjoint value for %59 = apply %58(%53, %56, %51) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %60 = begin_access [modify] [static] %2 : $*Float + +// DEBUG-NEXT: [AD] Original bb2 belongs to a loop, original header bb2, pullback header bb6 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %11 = alloc_stack [var_decl] $Float, var, name "temp", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %21 = begin_access [read] [static] %11 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %22 = load [trivial] %21 : $*Float +// DEBUG-NEXT: [AD] Setting adjoint value for %22 = load [trivial] %21 : $*Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %25 = apply %24(%22, %0, %20) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] Setting adjoint value for %25 = apply %24(%22, %0, %20) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %26 = begin_access [modify] [static] %11 : $*Float + +// DEBUG-NEXT: [AD] Original bb3 belongs to a loop, original header bb2, pullback header bb6 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %11 = alloc_stack [var_decl] $Float, var, name "temp", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %21 = begin_access [read] [static] %11 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %22 = load [trivial] %21 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %25 = apply %24(%22, %0, %20) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %26 = begin_access [modify] [static] %11 : $*Float + +// DEBUG-NEXT: [AD] Original bb1 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %11 = alloc_stack [var_decl] $Float, var, name "temp", type $Float + +// DEBUG-NEXT: [AD] Original bb5 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %11 = alloc_stack [var_decl] $Float, var, name "temp", type $Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %21 = begin_access [read] [static] %11 : $*Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %22 = load [trivial] %21 : $*Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %25 = apply %24(%22, %0, %20) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %26 = begin_access [modify] [static] %11 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %52 = begin_access [read] [static] %2 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %53 = load [trivial] %52 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %55 = begin_access [read] [static] %11 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %56 = load [trivial] %55 : $*Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %59 = apply %58(%53, %56, %51) : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// DEBUG-NEXT: [AD] The following active value is loop-local, but its adjoint was already zeroed in loop header, skipping: %60 = begin_access [modify] [static] %2 : $*Float + +// DEBUG-NEXT: [AD] End search for adjoints of loop-local active values + +@differentiable(reverse) +func repeat_while_loop_condition(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + if i == 2 { + break + } + if i == 0 { + result = result * x + } else { + result = result * result + } + i += 1 + } while i < 10 + return result +} + +// DEBUG-LABEL: [AD] Running PullbackCloner on +// DEBUG-NEXT: // repeat_while_loop_condition +// DEBUG: [AD] Begin search for adjoints of loop-local active values + +// DEBUG-NEXT: [AD] Original bb6 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float // users: %41, %3, %1 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float // users: %86, %3, %37, %42, %47, %50, %55, %82 + +// DEBUG-NEXT: [AD] Original bb1 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float // users: %41, %3, %1 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float // users: %86, %3, %37, %42, %47, %50, %55, %82 + +// DEBUG-NEXT: [AD] Original bb5 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float // users: %41, %3, %1 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float // users: %86, %3, %37, %42, %47, %50, %55, %82 +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %47 = begin_access [read] [static] %2 : $*Float // users: %49, %48 +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %48 = load [trivial] %47 : $*Float // user: %54 +// DEBUG-NEXT: [AD] Setting adjoint value for %48 = load [trivial] %47 : $*Float // user: %54 +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %50 = begin_access [read] [static] %2 : $*Float // users: %52, %51 +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %51 = load [trivial] %50 : $*Float // user: %54 +// DEBUG-NEXT: [AD] Setting adjoint value for %51 = load [trivial] %50 : $*Float // user: %54 +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %54 = apply %53(%48, %51, %46) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %56 +// DEBUG-NEXT: [AD] Setting adjoint value for %54 = apply %53(%48, %51, %46) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %56 +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %55 = begin_access [modify] [static] %2 : $*Float // users: %56, %57 + +// DEBUG-NEXT: [AD] Original bb4 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float // users: %41, %3, %1 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float // users: %86, %3, %37, %42, %47, %50, %55, %82 +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %37 = begin_access [read] [static] %2 : $*Float // users: %39, %38 +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %38 = load [trivial] %37 : $*Float // user: %41 +// DEBUG-NEXT: [AD] Setting adjoint value for %38 = load [trivial] %37 : $*Float // user: %41 +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %41 = apply %40(%38, %0, %36) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %43 +// DEBUG-NEXT: [AD] Setting adjoint value for %41 = apply %40(%38, %0, %36) : $@convention(method) (Float, Float, @thin Float.Type) -> Float // user: %43 +// DEBUG-NEXT: [AD] No debug variable found. +// DEBUG-NEXT: [AD] The new adjoint value, replacing the existing one, is: Zero[$Float] +// DEBUG-NEXT: [AD] The following active value is loop-local, zeroing its adjoint in loop header: %42 = begin_access [modify] [static] %2 : $*Float // users: %43, %44 + +// DEBUG-NEXT: [AD] Original bb7 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float // users: %41, %3, %1 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float // users: %86, %3, %37, %42, %47, %50, %55, %82 + +// DEBUG-NEXT: [AD] Original bb3 belongs to a loop, original header bb1, pullback header bb10 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %0 = argument of bb0 : $Float // users: %41, %3, %1 +// DEBUG-NEXT: [AD] The following active value is NOT loop-local, skipping: %2 = alloc_stack [var_decl] $Float, var, name "result", type $Float // users: %86, %3, %37, %42, %47, %50, %55, %82 + +// DEBUG-NEXT: [AD] End search for adjoints of loop-local active values diff --git a/test/AutoDiff/validation-test/control_flow.swift b/test/AutoDiff/validation-test/control_flow.swift index 915049b0a5236..70f5260aa7310 100644 --- a/test/AutoDiff/validation-test/control_flow.swift +++ b/test/AutoDiff/validation-test/control_flow.swift @@ -628,12 +628,8 @@ ControlFlowTests.test("Loops") { } while i < 2 return result } - // FIXME(TF-584): Investigate incorrect (too big) gradient values for - // repeat-while loops. - // expectEqual((8, 12), valueWithGradient(at: 2, of: repeat_while_loop)) - // expectEqual((27, 27), valueWithGradient(at: 3, of: repeat_while_loop)) - expectEqual((8, 18), valueWithGradient(at: 2, of: repeat_while_loop)) - expectEqual((27, 36), valueWithGradient(at: 3, of: repeat_while_loop)) + expectEqual((8, 12), valueWithGradient(at: 2, of: repeat_while_loop)) + expectEqual((27, 27), valueWithGradient(at: 3, of: repeat_while_loop)) func repeat_while_loop_nonactive_initial_value(_ x: Float) -> Float { var result: Float = 1 @@ -644,12 +640,87 @@ ControlFlowTests.test("Loops") { } while i < 2 return result } - // FIXME(TF-584): Investigate incorrect (too big) gradient values for - // repeat-while loops. - // expectEqual((4, 4), valueWithGradient(at: 2, of: repeat_while_loop_nonactive_initial_value)) - // expectEqual((9, 6), valueWithGradient(at: 3, of: repeat_while_loop_nonactive_initial_value)) - expectEqual((4, 5), valueWithGradient(at: 2, of: repeat_while_loop_nonactive_initial_value)) - expectEqual((9, 7), valueWithGradient(at: 3, of: repeat_while_loop_nonactive_initial_value)) + expectEqual((4, 4), valueWithGradient(at: 2, of: repeat_while_loop_nonactive_initial_value)) + expectEqual((9, 6), valueWithGradient(at: 3, of: repeat_while_loop_nonactive_initial_value)) + + @differentiable(reverse) + func repeat_while_loop_nested_repeat(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + var temp = x + var j = 0 + repeat { + temp = temp * x + j += 1 + } while j < 2 + result = result * temp + i += 1 + } while i < 2 + return result + } + + expectEqual((128, 448), valueWithGradient(at: 2, of: repeat_while_loop_nested_repeat)) + expectEqual((2187, 5103), valueWithGradient(at: 3, of: repeat_while_loop_nested_repeat)) + + @differentiable(reverse) + func repeat_while_loop_nested_while(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + var temp = x + var j = 0 + while j < 2 { + temp = temp * x + j += 1 + } + result = result * temp + i += 1 + } while i < 2 + return result + } + + expectEqual((128, 448), valueWithGradient(at: 2, of: repeat_while_loop_nested_while)) + expectEqual((2187, 5103), valueWithGradient(at: 3, of: repeat_while_loop_nested_while)) + + @differentiable(reverse) + func repeat_while_loop_nested_for(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + var temp = x + for j in 0..<2 { + temp = temp * x + } + result = result * temp + i += 1 + } while i < 2 + return result + } + + expectEqual((128, 448), valueWithGradient(at: 2, of: repeat_while_loop_nested_for)) + expectEqual((2187, 5103), valueWithGradient(at: 3, of: repeat_while_loop_nested_for)) + + @differentiable(reverse) + func repeat_while_loop_condition(_ x: Float) -> Float { + var result = x + var i = 0 + repeat { + if i == 2 { + break + } + if i == 0 { + result = result * x + } else { + result = result * result + } + i += 1 + } while i < 10 + return result + } + + expectEqual((16, 32), valueWithGradient(at: 2, of: repeat_while_loop_condition)) + expectEqual((81, 108), valueWithGradient(at: 3, of: repeat_while_loop_condition)) func loop_continue(_ x: Float) -> Float { var result = x