summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py
index 1c38635..db46766 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -23,6 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe
23from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
24from models.clip.tokenizer import MultiCLIPTokenizer 24from models.clip.tokenizer import MultiCLIPTokenizer
25from training.util import AverageMeter 25from training.util import AverageMeter
26from util.noise import perlin_noise
26 27
27 28
28def const(result=None): 29def const(result=None):
@@ -253,6 +254,7 @@ def loss_step(
253 text_encoder: CLIPTextModel, 254 text_encoder: CLIPTextModel,
254 with_prior_preservation: bool, 255 with_prior_preservation: bool,
255 prior_loss_weight: float, 256 prior_loss_weight: float,
257 perlin_strength: float,
256 seed: int, 258 seed: int,
257 step: int, 259 step: int,
258 batch: dict[str, Any], 260 batch: dict[str, Any],
@@ -275,6 +277,19 @@ def loss_step(
275 generator=generator 277 generator=generator
276 ) 278 )
277 279
280 if perlin_strength != 0:
281 noise += perlin_strength * perlin_noise(
282 latents.shape[0],
283 latents.shape[1],
284 latents.shape[2],
285 latents.shape[3],
286 res=1,
287 octaves=4,
288 dtype=latents.dtype,
289 device=latents.device,
290 generator=generator
291 )
292
278 # Sample a random timestep for each image 293 # Sample a random timestep for each image
279 timesteps = torch.randint( 294 timesteps = torch.randint(
280 0, 295 0,
@@ -559,6 +574,7 @@ def train(
559 global_step_offset: int = 0, 574 global_step_offset: int = 0,
560 with_prior_preservation: bool = False, 575 with_prior_preservation: bool = False,
561 prior_loss_weight: float = 1.0, 576 prior_loss_weight: float = 1.0,
577 perlin_strength: float = 0.1,
562 **kwargs, 578 **kwargs,
563): 579):
564 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 580 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -593,6 +609,7 @@ def train(
593 text_encoder, 609 text_encoder,
594 with_prior_preservation, 610 with_prior_preservation,
595 prior_loss_weight, 611 prior_loss_weight,
612 perlin_strength,
596 seed, 613 seed,
597 ) 614 )
598 615