diff --git a/pts/modules/gaussian_diffusion.py b/pts/modules/gaussian_diffusion.py index ed9d363..95ddac8 100644 --- a/pts/modules/gaussian_diffusion.py +++ b/pts/modules/gaussian_diffusion.py @@ -28,17 +28,17 @@ def noise_like(shape, device, repeat=False): return repeat_noise() if repeat else noise() -def cosine_beta_schedule(timesteps, s = 0.008): +def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 - x = torch.linspace(0, timesteps, steps) - alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + x = np.linspace(0, timesteps, steps) + alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return torch.clip(betas, 0, 0.999) + return np.clip(betas, 0, 0.999) class GaussianDiffusion(nn.Module):