fixed cosine schedule

This commit is contained in:
Kashif Rasul
2022-04-20 12:01:17 +02:00
committed by GitHub
parent 0c4b37a48c
commit ae0aed8339
+4 -4
View File
@@ -28,17 +28,17 @@ def noise_like(shape, device, repeat=False):
return repeat_noise() if repeat else noise()
def cosine_beta_schedule(timesteps, s=0.008):
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return np.clip(betas, a_min=0, a_max=0.999)
return torch.clip(betas, 0, 0.999)
class GaussianDiffusion(nn.Module):