From f9673beafd9c2a3d1127b22e3447d4ebe3f4586e Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Thu, 7 Nov 2024 16:50:33 +0300 Subject: [PATCH] Ensure we are adding T : Differentiable conformance from protocol conditional conformance. Fixes #75711 --- lib/SILGen/SILGenPoly.cpp | 8 ++++++ lib/Sema/TypeCheckProtocol.cpp | 26 +++++++++++++++++++ ...nditional-differentiable-conformance.swift | 22 ++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 test/AutoDiff/compiler_crashers_fixed/issue-75711-conditional-differentiable-conformance.swift diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index a53f4bf4e87f9..5b686c8f9e174 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -7043,6 +7043,14 @@ void SILGenFunction::emitProtocolWitness( FullExpr scope(Cleanups, cleanupLoc); FormalEvaluationScope formalEvalScope(*this); + // The protocol conditional conformance itself might bring some T : + // Differentiable conformances. They are already added to the derivative + // generic signature. Update witness substitution map generic signature to + // have them as well. + if (auto *derivativeId = witness.getDerivativeFunctionIdentifier()) + witnessSubs = SubstitutionMap::get(derivativeId->getDerivativeGenericSignature(), + witnessSubs); + auto thunkTy = F.getLoweredFunctionType(); SmallVector origParams; diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index c302209230901..a024a928b06cd 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -294,6 +294,29 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness, return witness; } +static GenericSignature maybeAddDifferentiableFromContext(DeclContext *dc, + GenericSignature derivativeGenSig) { + auto conformanceGenSig = dc->getGenericSignatureOfContext(); + if (!conformanceGenSig) + return derivativeGenSig; + + // The protocol conditional conformance itself might bring some T : + // Differentiable conformances. Add them the the derivative generic signature. + SmallVector diffRequirements; + llvm::copy_if(conformanceGenSig.getRequirements(), + std::back_inserter(diffRequirements), + [](const Requirement &requirement) { + if (requirement.getKind() != RequirementKind::Conformance) + return false; + + auto protoKind = requirement.getProtocolDecl()->getKnownProtocolKind(); + return protoKind && *protoKind == KnownProtocolKind::Differentiable; + }); + + return buildGenericSignature(dc->getASTContext(), derivativeGenSig, + {}, std::move(diffRequirements), /*allowInverses=*/true); +} + /// Given a witness, a requirement, and an existing `RequirementMatch` result, /// check if the requirement's `@differentiable` attributes are met by the /// witness. @@ -425,6 +448,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, auto derivativeGenSig = witnessAFD->getGenericSignature(); if (supersetConfig) derivativeGenSig = supersetConfig->derivativeGenericSignature; + + derivativeGenSig = maybeAddDifferentiableFromContext(dc, derivativeGenSig); + // Use source location of the witness declaration as the source location // of the implicit `@differentiable` attribute. auto *newAttr = DifferentiableAttr::create( diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-75711-conditional-differentiable-conformance.swift b/test/AutoDiff/compiler_crashers_fixed/issue-75711-conditional-differentiable-conformance.swift new file mode 100644 index 0000000000000..30b766898c0bb --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/issue-75711-conditional-differentiable-conformance.swift @@ -0,0 +1,22 @@ +// RUN: %target-swift-frontend -emit-sil -verify %s + +// https://github.com/swiftlang/swift/issues/75711 + +// Ensure we propagate T : Differentiable conditional conformance + +import _Differentiation + +struct Wrapper { + func read(_ t: T) -> T { + return t + } +} + +protocol P { + associatedtype T: Differentiable + + @differentiable(reverse) + func read(_: T) -> T +} + +extension Wrapper: P where T: Differentiable {}