You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I really like this repo, thank you to all the contributors!
I was wondering if there are plans to add the LCM scheduler to this repo. From my understanding, it wouldn't require major modifications in the code, and it would allow much faster inference.
I already started working on it, but I think both my understanding of the scheduler and my Swift skills are too low to complete the task all by myself.
Below is my current code. Although it is not working, I think it could be a good starting point for anyone wanting to contribute.
// MARK: - LCMScheduler
/// This implementation matches:
/// [Hugging Face Diffusers LCMScheduler](https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/schedulers/scheduling_lcm.py)
@available(iOS 16.2, macOS 13.1, *)
public final class LCMScheduler: Scheduler {
public var trainStepCount: Int
public var inferenceStepCount: Int
public var origStepCount: Int
public var timeStepScaling: Int
public var clipSampleRange: Int
public var betas: [Float]
public var alphas: [Float]
public var alphasCumProd: [Float]
public var timeSteps: [Int]
// Internal state
var currentSample: MLShapedArray<Float32>?
var stepIndex: Int?
// var finalAlphaCumProd: Float
// var initNoiseSigma: Float = 1.0
// var customTimesteps: Bool = false
// Initialize with similar parameters as in Python
public init(
stepCount: Int = 4,
trainStepCount: Int = 1000,
betaSchedule: BetaSchedule = .scaledLinear,
betaStart: Float = 0.00085,
betaEnd: Float = 0.012,
origStepCount: Int = 50,
timeStepScaling: Int = 10
) {
self.trainStepCount = trainStepCount
self.inferenceStepCount = stepCount
self.origStepCount = origStepCount
self.timeStepScaling = timeStepScaling
self.clipSampleRange = 1
switch betaSchedule {
case .linear:
self.betas = linspace(betaStart, betaEnd, trainStepCount)
case .scaledLinear:
self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 })
}
// // Optionally rescale betas for zero terminal SNR
// if rescaleBetasZeroSNR {
// // TODO: implement rescale_zero_terminal_snr equivalent in Swift
// self.betas = rescaleBetasForZeroSNR(betas: self.betas)
// }
self.alphas = betas.map({ 1.0 - $0 })
var alphasCumProd = self.alphas
for i in 1..<alphasCumProd.count {
alphasCumProd[i] *= alphasCumProd[i - 1]
}
self.alphasCumProd = alphasCumProd
let stepsOffset = 1 // For stable diffusion
let stepRatio = Float(trainStepCount / stepCount )
let forwardSteps = (0..<stepCount).map {
Int((Float($0) * stepRatio).rounded()) + stepsOffset
}
// Initialize timeSteps to an empty array; it will be populated by setTimesteps
self.timeSteps = []
// Call setTimesteps to initialize the timeSteps property based on the provided parameters
self.setTimesteps(stepCount: stepCount, origStepCount: origStepCount)
self.currentSample = nil
}
}
@available(iOS 16.2, macOS 13.1, *)
extension LCMScheduler {
func setTimesteps(
stepCount: Int? = nil,
origStepCount: Int? = nil,
strength: Float = 1.0
) {
let origSteps = origStepCount ?? self.origStepCount
guard origSteps <= self.trainStepCount else {
fatalError("`origSteps`: \(origSteps) cannot be larger than `trainStepCount`: \(self.trainStepCount).")
}
let k = self.trainStepCount / origSteps
let lcmOriginTimesteps = (1...Int(Float(origSteps) * strength)).map { $0 * k - 1 }
let finalStepCount = stepCount ?? self.inferenceStepCount
guard finalStepCount <= self.trainStepCount else {
fatalError("`stepCount`: \(finalStepCount) cannot be larger than `trainStepCount`: \(self.trainStepCount).")
}
let skippingStep = lcmOriginTimesteps.count / finalStepCount
guard skippingStep >= 1 else {
fatalError("The combination of `origSteps x strength`: \(origSteps) x \(strength) is smaller than `stepCount`: \(finalStepCount).")
}
guard finalStepCount <= origSteps else {
fatalError("`stepCount`: \(finalStepCount) cannot be larger than `origStepCount`: \(origSteps).")
}
self.inferenceStepCount = finalStepCount
// Calculate the actual timesteps to use for inference
let inferenceIndices = stride(from: 0, to: lcmOriginTimesteps.count, by: skippingStep).map { lcmOriginTimesteps[$0] }
self.timeSteps = Array(inferenceIndices.prefix(finalStepCount))
// Reset internal state related to timestep tracking
self.stepIndex = nil
}
}
@available(iOS 16.2, macOS 13.1, *)
extension LCMScheduler {
/// Compute sample (denoised image) at previous step given a current time step
///
/// - Parameters:
/// - sample: The current input to the model x_t
/// - timeStep: The current time step t
/// - prevStep: The previous time step t−δ
/// - modelOutput: Predicted noise residual the current time step e_θ(x_t, t)
/// - Returns: Computes previous sample x_(t−δ)
public func step(
output: MLShapedArray<Float32>,
timeStep t: Int,
sample s: MLShapedArray<Float32>
) -> MLShapedArray<Float32> {
// Check if inferenceStepCount has been initialized
guard inferenceStepCount != nil else {
fatalError("Number of inference steps is 'nil', you need to run 'setTimesteps' after creating the scheduler")
}
// Initialize stepIndex if it hasn't been set
if self.stepIndex == nil {
// Find the index(es) in timeSteps that match the current timeStep
let indexCandidates = self.timeSteps.enumerated().filter { $0.element == t }.map { $0.offset }
// Determine the step_index based on indexCandidates
if indexCandidates.count > 1 {
self.stepIndex = indexCandidates[1]
} else if let firstIndex = indexCandidates.first {
self.stepIndex = firstIndex
} else {
fatalError("Current timeStep not found in timeSteps")
}
}
// 1. Compute the index for the previous timestep based on the current step index.
let prevStepIndex = stepIndex! + 1
let prevTimeStep = prevStepIndex < self.timeSteps.count ? self.timeSteps[prevStepIndex] : t
// 2. compute alphas, betas
let alphaProdT = self.alphasCumProd[t]
let alphaProdTPrev = self.alphasCumProd[max(0, prevTimeStep)]
let betaProdT = 1 - alphaProdT
let betaProdTPrev = 1 - alphaProdTPrev
// 3. Get scalings for boundary conditions
let scaledTimeStep = Float32(t) * Float32(self.timeStepScaling)
let sigmaDataSquared = Float32(0.5 * 0.5) // Assuming sigma_data is always 0.5 as per your setup
let cSkip = sigmaDataSquared / (scaledTimeStep * scaledTimeStep + sigmaDataSquared)
let cOut = scaledTimeStep / sqrt(scaledTimeStep * scaledTimeStep + sigmaDataSquared)
// 4. Compute predicted original sample x_0 based on the model parameterization + apply clamping
let betaProdTSqrt = sqrt(betaProdT)
let alphaProdTSqrt = sqrt(alphaProdT)
let predictedOrigSampleElements = zip(s.scalars, output.scalars).map { sampleElement, outputElement in
(sampleElement - outputElement * betaProdTSqrt) / alphaProdTSqrt
}.map { element in
// Clamping each element
min(max(element, -self.clipSampleRange), self.clipSampleRange)
}
let predictedOrigSample = try! MLShapedArray<Float32>(shape: s.shape, scalars: predictedOrigSampleElements)
// 6. Denoise model output using boundary conditions
var denoised = clampedPredictedOrigSampleArray * cOut + s * cSkip
// 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference, if not the final timestep
if stepIndex != self.inferenceStepCount - 1 {
let noise = MLShapedArray<Float32>(randomNormalShape: output.shape, mean: 0, stddev: 1)
let prevSample = denoised.multiplying(by: sqrt(alphaProdTPrev)).adding(noise.multiplying(by: sqrt(betaProdTPrev)))
denoised = prevSample
}
// Update the step index for the next call to `step`
self.stepIndex = stepIndex + 1
return denoised
}
}
Any help is welcome! Thank you 😃
The text was updated successfully, but these errors were encountered:
Hello,
I really like this repo, thank you to all the contributors!
I was wondering if there are plans to add the LCM scheduler to this repo. From my understanding, it wouldn't require major modifications in the code, and it would allow much faster inference.
I already started working on it, but I think both my understanding of the scheduler and my Swift skills are too low to complete the task all by myself.
Below is my current code. Although it is not working, I think it could be a good starting point for anyone wanting to contribute.
Any help is welcome! Thank you 😃
The text was updated successfully, but these errors were encountered: