Skip to content

Commit

Permalink
Upstream Euler-Smea-Dyn-Sampler (add 1 sampler)
Browse files Browse the repository at this point in the history
  • Loading branch information
Panchovix committed Nov 19, 2024
1 parent b618f30 commit 1fd5778
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
2 changes: 2 additions & 0 deletions modules/sd_samplers_kdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
('Euler SMEA Dy', 'sample_euler_smea_dy', ['k_euler_smea_dy'], {}),
('Euler Negative', 'sample_euler_negative', ['k_euler_negative'], {}),
('Euler Negative Dy', 'sample_euler_dy_negative', ['k_euler_dy_negative'], {}),
('Kohaku_LoNyu_Yog', 'sample_Kohaku_LoNyu_Yog', ['k_euler_dy_negative'], {}),
]
samplers_k_diffusion.extend(additional_samplers)

Expand All @@ -60,6 +61,7 @@
'sample_euler_smea_dy': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_euler_negative': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_euler_dy_negative': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_Kohaku_LoNyu_Yog': ["s_churn", "s_tmin", "s_tmax", "s_noise"],
})

k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
Expand Down
33 changes: 33 additions & 0 deletions modules/sd_samplers_kdiffusion_smea.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,36 @@ def sample_euler_dy_negative(model, x, sigmas, extra_args=None, callback=None, d
else:
x = x + d * dt
return x

@torch.no_grad()
def sample_Kohaku_LoNyu_Yog(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0.,
s_tmax=float('inf'), s_noise=1., noise_sampler=None, eta=1.):
"""Kohaku_LoNyu_Yog"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
noise_sampler = sampling.default_noise_sampler(x) if noise_sampler is None else noise_sampler
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = sampling.to_d(x, sigma_hat, denoised)
sigma_down, sigma_up = sampling.get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigma_down - sigmas[i]
if i <= (len(sigmas) - 1) / 2:
x2 = - x
denoised2 = model(x2, sigma_hat * s_in, **extra_args)
d2 = sampling.to_d(x2, sigma_hat, denoised2)
x3 = x + ((d + d2) / 2) * dt
denoised3 = model(x3, sigma_hat * s_in, **extra_args)
d3 = sampling.to_d(x3, sigma_hat, denoised3)
real_d = (d + d3) / 2
x = x + real_d * dt
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
else:
x = x + d * dt
return x

0 comments on commit 1fd5778

Please sign in to comment.