summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py16
1 files changed, 0 insertions, 16 deletions
diff --git a/training/functional.py b/training/functional.py
index db46766..27a43c2 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -254,7 +254,6 @@ def loss_step(
254 text_encoder: CLIPTextModel, 254 text_encoder: CLIPTextModel,
255 with_prior_preservation: bool, 255 with_prior_preservation: bool,
256 prior_loss_weight: float, 256 prior_loss_weight: float,
257 perlin_strength: float,
258 seed: int, 257 seed: int,
259 step: int, 258 step: int,
260 batch: dict[str, Any], 259 batch: dict[str, Any],
@@ -277,19 +276,6 @@ def loss_step(
277 generator=generator 276 generator=generator
278 ) 277 )
279 278
280 if perlin_strength != 0:
281 noise += perlin_strength * perlin_noise(
282 latents.shape[0],
283 latents.shape[1],
284 latents.shape[2],
285 latents.shape[3],
286 res=1,
287 octaves=4,
288 dtype=latents.dtype,
289 device=latents.device,
290 generator=generator
291 )
292
293 # Sample a random timestep for each image 279 # Sample a random timestep for each image
294 timesteps = torch.randint( 280 timesteps = torch.randint(
295 0, 281 0,
@@ -574,7 +560,6 @@ def train(
574 global_step_offset: int = 0, 560 global_step_offset: int = 0,
575 with_prior_preservation: bool = False, 561 with_prior_preservation: bool = False,
576 prior_loss_weight: float = 1.0, 562 prior_loss_weight: float = 1.0,
577 perlin_strength: float = 0.1,
578 **kwargs, 563 **kwargs,
579): 564):
580 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 565 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -609,7 +594,6 @@ def train(
609 text_encoder, 594 text_encoder,
610 with_prior_preservation, 595 with_prior_preservation,
611 prior_loss_weight, 596 prior_loss_weight,
612 perlin_strength,
613 seed, 597 seed,
614 ) 598 )
615 599