summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-03 18:53:15 +0100
committerVolpeon <git@volpeon.ink>2023-03-03 18:53:15 +0100
commite32b4d4c04a31b22051740e5f26e16960464f787 (patch)
tree9fc842453e4974b936e64b4f012acfef726a8e51
parentLow freq noise with randomized strength (diff)
downloadtextual-inversion-diff-e32b4d4c04a31b22051740e5f26e16960464f787.tar.gz
textual-inversion-diff-e32b4d4c04a31b22051740e5f26e16960464f787.tar.bz2
textual-inversion-diff-e32b4d4c04a31b22051740e5f26e16960464f787.zip
Implemented different noise offset
-rw-r--r--environment.yaml2
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_lora.py2
-rw-r--r--train_ti.py4
-rw-r--r--training/functional.py31
-rw-r--r--training/util.py1
6 files changed, 16 insertions, 28 deletions
diff --git a/environment.yaml b/environment.yaml
index 4899709..018a9ab 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -24,4 +24,4 @@ dependencies:
24 - setuptools==65.6.3 24 - setuptools==65.6.3
25 - test-tube>=0.7.5 25 - test-tube>=0.7.5
26 - transformers==4.26.1 26 - transformers==4.26.1
27 - triton==2.0.0a2 27 - triton==2.0.0
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 6d699f3..8571dff 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -438,7 +438,7 @@ def main():
438 438
439 accelerator = Accelerator( 439 accelerator = Accelerator(
440 log_with=LoggerType.TENSORBOARD, 440 log_with=LoggerType.TENSORBOARD,
441 logging_dir=f"{output_dir}", 441 project_dir=f"{output_dir}",
442 mixed_precision=args.mixed_precision 442 mixed_precision=args.mixed_precision
443 ) 443 )
444 444
@@ -526,7 +526,7 @@ def main():
526 with_prior_preservation=args.num_class_images != 0, 526 with_prior_preservation=args.num_class_images != 0,
527 prior_loss_weight=args.prior_loss_weight, 527 prior_loss_weight=args.prior_loss_weight,
528 no_val=args.valid_set_size == 0, 528 no_val=args.valid_set_size == 0,
529 # low_freq_noise=0, 529 # noise_offset=0,
530 ) 530 )
531 531
532 checkpoint_output_dir = output_dir / "model" 532 checkpoint_output_dir = output_dir / "model"
diff --git a/train_lora.py b/train_lora.py
index 0a3d4c9..e213e3d 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -398,7 +398,7 @@ def main():
398 398
399 accelerator = Accelerator( 399 accelerator = Accelerator(
400 log_with=LoggerType.TENSORBOARD, 400 log_with=LoggerType.TENSORBOARD,
401 logging_dir=f"{output_dir}", 401 project_dir=f"{output_dir}",
402 mixed_precision=args.mixed_precision 402 mixed_precision=args.mixed_precision
403 ) 403 )
404 404
diff --git a/train_ti.py b/train_ti.py
index 394711f..bc9348d 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -517,7 +517,7 @@ def main():
517 517
518 accelerator = Accelerator( 518 accelerator = Accelerator(
519 log_with=LoggerType.TENSORBOARD, 519 log_with=LoggerType.TENSORBOARD,
520 logging_dir=f"{output_dir}", 520 project_dir=f"{output_dir}",
521 mixed_precision=args.mixed_precision 521 mixed_precision=args.mixed_precision
522 ) 522 )
523 523
@@ -607,7 +607,7 @@ def main():
607 with_prior_preservation=args.num_class_images != 0, 607 with_prior_preservation=args.num_class_images != 0,
608 prior_loss_weight=args.prior_loss_weight, 608 prior_loss_weight=args.prior_loss_weight,
609 no_val=args.valid_set_size == 0, 609 no_val=args.valid_set_size == 0,
610 # low_freq_noise=0, 610 noise_offset=0,
611 strategy=textual_inversion_strategy, 611 strategy=textual_inversion_strategy,
612 num_train_epochs=args.num_train_epochs, 612 num_train_epochs=args.num_train_epochs,
613 gradient_accumulation_steps=args.gradient_accumulation_steps, 613 gradient_accumulation_steps=args.gradient_accumulation_steps,
diff --git a/training/functional.py b/training/functional.py
index 2d582bf..36269f0 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -253,7 +253,7 @@ def loss_step(
253 text_encoder: CLIPTextModel, 253 text_encoder: CLIPTextModel,
254 with_prior_preservation: bool, 254 with_prior_preservation: bool,
255 prior_loss_weight: float, 255 prior_loss_weight: float,
256 low_freq_noise: float, 256 noise_offset: float,
257 seed: int, 257 seed: int,
258 step: int, 258 step: int,
259 batch: dict[str, Any], 259 batch: dict[str, Any],
@@ -268,30 +268,19 @@ def loss_step(
268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None 268 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
269 269
270 # Sample noise that we'll add to the latents 270 # Sample noise that we'll add to the latents
271 noise = torch.randn( 271 offsets = noise_offset * torch.randn(
272 latents.shape, 272 latents.shape[0], 1, 1, 1,
273 dtype=latents.dtype, 273 dtype=latents.dtype,
274 layout=latents.layout, 274 layout=latents.layout,
275 device=latents.device, 275 device=latents.device,
276 generator=generator 276 generator=generator
277 ).expand(latents.shape)
278 noise = torch.normal(
279 mean=offsets,
280 std=1,
281 generator=generator,
277 ) 282 )
278 283
279 if low_freq_noise != 0:
280 low_freq_factor = low_freq_noise * torch.randn(
281 latents.shape[0], 1, 1, 1,
282 dtype=latents.dtype,
283 layout=latents.layout,
284 device=latents.device,
285 generator=generator
286 )
287 noise = noise * (1 - low_freq_factor) + low_freq_factor * torch.randn(
288 latents.shape[0], latents.shape[1], 1, 1,
289 dtype=latents.dtype,
290 layout=latents.layout,
291 device=latents.device,
292 generator=generator
293 )
294
295 # Sample a random timestep for each image 284 # Sample a random timestep for each image
296 timesteps = torch.randint( 285 timesteps = torch.randint(
297 0, 286 0,
@@ -576,7 +565,7 @@ def train(
576 global_step_offset: int = 0, 565 global_step_offset: int = 0,
577 with_prior_preservation: bool = False, 566 with_prior_preservation: bool = False,
578 prior_loss_weight: float = 1.0, 567 prior_loss_weight: float = 1.0,
579 low_freq_noise: float = 0.1, 568 noise_offset: float = 0.2,
580 **kwargs, 569 **kwargs,
581): 570):
582 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 571 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -611,7 +600,7 @@ def train(
611 text_encoder, 600 text_encoder,
612 with_prior_preservation, 601 with_prior_preservation,
613 prior_loss_weight, 602 prior_loss_weight,
614 low_freq_noise, 603 noise_offset,
615 seed, 604 seed,
616 ) 605 )
617 606
diff --git a/training/util.py b/training/util.py
index c8524de..8bd8a83 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,6 +1,5 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy
4from typing import Iterable, Any 3from typing import Iterable, Any
5from contextlib import contextmanager 4from contextlib import contextmanager
6 5