diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 31 | ||||
-rw-r--r-- | training/util.py | 1 |
2 files changed, 10 insertions, 22 deletions
diff --git a/training/functional.py b/training/functional.py index 2d582bf..36269f0 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -253,7 +253,7 @@ def loss_step( | |||
253 | text_encoder: CLIPTextModel, | 253 | text_encoder: CLIPTextModel, |
254 | with_prior_preservation: bool, | 254 | with_prior_preservation: bool, |
255 | prior_loss_weight: float, | 255 | prior_loss_weight: float, |
256 | low_freq_noise: float, | 256 | noise_offset: float, |
257 | seed: int, | 257 | seed: int, |
258 | step: int, | 258 | step: int, |
259 | batch: dict[str, Any], | 259 | batch: dict[str, Any], |
@@ -268,30 +268,19 @@ def loss_step( | |||
268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None |
269 | 269 | ||
270 | # Sample noise that we'll add to the latents | 270 | # Sample noise that we'll add to the latents |
271 | noise = torch.randn( | 271 | offsets = noise_offset * torch.randn( |
272 | latents.shape, | 272 | latents.shape[0], 1, 1, 1, |
273 | dtype=latents.dtype, | 273 | dtype=latents.dtype, |
274 | layout=latents.layout, | 274 | layout=latents.layout, |
275 | device=latents.device, | 275 | device=latents.device, |
276 | generator=generator | 276 | generator=generator |
277 | ).expand(latents.shape) | ||
278 | noise = torch.normal( | ||
279 | mean=offsets, | ||
280 | std=1, | ||
281 | generator=generator, | ||
277 | ) | 282 | ) |
278 | 283 | ||
279 | if low_freq_noise != 0: | ||
280 | low_freq_factor = low_freq_noise * torch.randn( | ||
281 | latents.shape[0], 1, 1, 1, | ||
282 | dtype=latents.dtype, | ||
283 | layout=latents.layout, | ||
284 | device=latents.device, | ||
285 | generator=generator | ||
286 | ) | ||
287 | noise = noise * (1 - low_freq_factor) + low_freq_factor * torch.randn( | ||
288 | latents.shape[0], latents.shape[1], 1, 1, | ||
289 | dtype=latents.dtype, | ||
290 | layout=latents.layout, | ||
291 | device=latents.device, | ||
292 | generator=generator | ||
293 | ) | ||
294 | |||
295 | # Sample a random timestep for each image | 284 | # Sample a random timestep for each image |
296 | timesteps = torch.randint( | 285 | timesteps = torch.randint( |
297 | 0, | 286 | 0, |
@@ -576,7 +565,7 @@ def train( | |||
576 | global_step_offset: int = 0, | 565 | global_step_offset: int = 0, |
577 | with_prior_preservation: bool = False, | 566 | with_prior_preservation: bool = False, |
578 | prior_loss_weight: float = 1.0, | 567 | prior_loss_weight: float = 1.0, |
579 | low_freq_noise: float = 0.1, | 568 | noise_offset: float = 0.2, |
580 | **kwargs, | 569 | **kwargs, |
581 | ): | 570 | ): |
582 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 571 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -611,7 +600,7 @@ def train( | |||
611 | text_encoder, | 600 | text_encoder, |
612 | with_prior_preservation, | 601 | with_prior_preservation, |
613 | prior_loss_weight, | 602 | prior_loss_weight, |
614 | low_freq_noise, | 603 | noise_offset, |
615 | seed, | 604 | seed, |
616 | ) | 605 | ) |
617 | 606 | ||
diff --git a/training/util.py b/training/util.py index c8524de..8bd8a83 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,6 +1,5 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | ||
4 | from typing import Iterable, Any | 3 | from typing import Iterable, Any |
5 | from contextlib import contextmanager | 4 | from contextlib import contextmanager |
6 | 5 | ||