summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py31
-rw-r--r--training/util.py1
2 files changed, 10 insertions, 22 deletions
diff --git a/training/functional.py b/training/functional.py
index 2d582bf..36269f0 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -253,7 +253,7 @@ def loss_step(
253 text_encoder: CLIPTextModel, 253 text_encoder: CLIPTextModel,
254 with_prior_preservation: bool, 254 with_prior_preservation: bool,
255 prior_loss_weight: float, 255 prior_loss_weight: float,
256 low_freq_noise: float, 256 noise_offset: float,
257 seed: int, 257 seed: int,
258 step: int, 258 step: int,
259 batch: dict[str, Any], 259 batch: dict[str, Any],
@@ -268,30 +268,19 @@ def loss_step(
268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
269 269
270 # Sample noise that we'll add to the latents 270 # Sample noise that we'll add to the latents
271 noise = torch.randn( 271 offsets = noise_offset * torch.randn(
272 latents.shape, 272 latents.shape[0], 1, 1, 1,
273 dtype=latents.dtype, 273 dtype=latents.dtype,
274 layout=latents.layout, 274 layout=latents.layout,
275 device=latents.device, 275 device=latents.device,
276 generator=generator 276 generator=generator
277 ).expand(latents.shape)
278 noise = torch.normal(
279 mean=offsets,
280 std=1,
281 generator=generator,
277 ) 282 )
278 283
279 if low_freq_noise != 0:
280 low_freq_factor = low_freq_noise * torch.randn(
281 latents.shape[0], 1, 1, 1,
282 dtype=latents.dtype,
283 layout=latents.layout,
284 device=latents.device,
285 generator=generator
286 )
287 noise = noise * (1 - low_freq_factor) + low_freq_factor * torch.randn(
288 latents.shape[0], latents.shape[1], 1, 1,
289 dtype=latents.dtype,
290 layout=latents.layout,
291 device=latents.device,
292 generator=generator
293 )
294
295 # Sample a random timestep for each image 284 # Sample a random timestep for each image
296 timesteps = torch.randint( 285 timesteps = torch.randint(
297 0, 286 0,
@@ -576,7 +565,7 @@ def train(
576 global_step_offset: int = 0, 565 global_step_offset: int = 0,
577 with_prior_preservation: bool = False, 566 with_prior_preservation: bool = False,
578 prior_loss_weight: float = 1.0, 567 prior_loss_weight: float = 1.0,
579 low_freq_noise: float = 0.1, 568 noise_offset: float = 0.2,
580 **kwargs, 569 **kwargs,
581): 570):
582 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 571 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -611,7 +600,7 @@ def train(
611 text_encoder, 600 text_encoder,
612 with_prior_preservation, 601 with_prior_preservation,
613 prior_loss_weight, 602 prior_loss_weight,
614 low_freq_noise, 603 noise_offset,
615 seed, 604 seed,
616 ) 605 )
617 606
diff --git a/training/util.py b/training/util.py
index c8524de..8bd8a83 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,6 +1,5 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy
4from typing import Iterable, Any 3from typing import Iterable, Any
5from contextlib import contextmanager 4from contextlib import contextmanager
6 5