mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
fixed cosine schedule
This commit is contained in:
@@ -28,17 +28,17 @@ def noise_like(shape, device, repeat=False):
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def cosine_beta_schedule(timesteps, s=0.008):
|
||||
def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = np.linspace(0, steps, steps)
|
||||
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
||||
x = torch.linspace(0, timesteps, steps)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return np.clip(betas, a_min=0, a_max=0.999)
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
|
||||
|
||||
class GaussianDiffusion(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user