summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_dreambooth.py7
-rw-r--r--train_lora.py7
-rw-r--r--train_ti.py6
-rw-r--r--training/functional.py11
4 files changed, 22 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index a85ae4c..1b8a3d2 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -176,6 +176,12 @@ def parse_args():
176 ), 176 ),
177 ) 177 )
178 parser.add_argument( 178 parser.add_argument(
179 "--offset_noise_strength",
180 type=float,
181 default=0.15,
182 help="Perlin offset noise strength.",
183 )
184 parser.add_argument(
179 "--num_train_epochs", 185 "--num_train_epochs",
180 type=int, 186 type=int,
181 default=100 187 default=100
@@ -621,6 +627,7 @@ def main():
621 num_train_epochs=args.num_train_epochs, 627 num_train_epochs=args.num_train_epochs,
622 gradient_accumulation_steps=args.gradient_accumulation_steps, 628 gradient_accumulation_steps=args.gradient_accumulation_steps,
623 sample_frequency=args.sample_frequency, 629 sample_frequency=args.sample_frequency,
630 offset_noise_strength=args.offset_noise_strength,
624 # -- 631 # --
625 tokenizer=tokenizer, 632 tokenizer=tokenizer,
626 sample_scheduler=sample_scheduler, 633 sample_scheduler=sample_scheduler,
diff --git a/train_lora.py b/train_lora.py
index fa24cee..b16a99b 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -165,6 +165,12 @@ def parse_args():
165 ), 165 ),
166 ) 166 )
167 parser.add_argument( 167 parser.add_argument(
168 "--offset_noise_strength",
169 type=float,
170 default=0.15,
171 help="Perlin offset noise strength.",
172 )
173 parser.add_argument(
168 "--num_train_epochs", 174 "--num_train_epochs",
169 type=int, 175 type=int,
170 default=100 176 default=100
@@ -649,6 +655,7 @@ def main():
649 num_train_epochs=args.num_train_epochs, 655 num_train_epochs=args.num_train_epochs,
650 gradient_accumulation_steps=args.gradient_accumulation_steps, 656 gradient_accumulation_steps=args.gradient_accumulation_steps,
651 sample_frequency=args.sample_frequency, 657 sample_frequency=args.sample_frequency,
658 offset_noise_strength=args.offset_noise_strength,
652 # -- 659 # --
653 tokenizer=tokenizer, 660 tokenizer=tokenizer,
654 sample_scheduler=sample_scheduler, 661 sample_scheduler=sample_scheduler,
diff --git a/train_ti.py b/train_ti.py
index ef71f6f..bbc5524 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -188,9 +188,9 @@ def parse_args():
188 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', 188 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
189 ) 189 )
190 parser.add_argument( 190 parser.add_argument(
191 "--perlin_strength", 191 "--offset_noise_strength",
192 type=float, 192 type=float,
193 default=0.1, 193 default=0.15,
194 help="Perlin offset noise strength.", 194 help="Perlin offset noise strength.",
195 ) 195 )
196 parser.add_argument( 196 parser.add_argument(
@@ -661,7 +661,7 @@ def main():
661 checkpoint_frequency=args.checkpoint_frequency, 661 checkpoint_frequency=args.checkpoint_frequency,
662 milestone_checkpoints=not args.no_milestone_checkpoints, 662 milestone_checkpoints=not args.no_milestone_checkpoints,
663 global_step_offset=global_step_offset, 663 global_step_offset=global_step_offset,
664 perlin_strength=args.perlin_strength, 664 offset_noise_strength=args.offset_noise_strength,
665 # -- 665 # --
666 tokenizer=tokenizer, 666 tokenizer=tokenizer,
667 sample_scheduler=sample_scheduler, 667 sample_scheduler=sample_scheduler,
diff --git a/training/functional.py b/training/functional.py
index ee73ab2..87bb339 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -277,7 +277,7 @@ def loss_step(
277 with_prior_preservation: bool, 277 with_prior_preservation: bool,
278 prior_loss_weight: float, 278 prior_loss_weight: float,
279 seed: int, 279 seed: int,
280 perlin_strength: float, 280 offset_noise_strength: float,
281 step: int, 281 step: int,
282 batch: dict[str, Any], 282 batch: dict[str, Any],
283 eval: bool = False, 283 eval: bool = False,
@@ -300,11 +300,10 @@ def loss_step(
300 generator=generator 300 generator=generator
301 ) 301 )
302 302
303 if perlin_strength != 0: 303 if offset_noise_strength != 0:
304 noise += perlin_strength * perlin_noise( 304 noise += offset_noise_strength * perlin_noise(
305 latents.shape, 305 latents.shape,
306 res=1, 306 res=1,
307 octaves=4,
308 dtype=latents.dtype, 307 dtype=latents.dtype,
309 device=latents.device, 308 device=latents.device,
310 generator=generator 309 generator=generator
@@ -610,7 +609,7 @@ def train(
610 global_step_offset: int = 0, 609 global_step_offset: int = 0,
611 with_prior_preservation: bool = False, 610 with_prior_preservation: bool = False,
612 prior_loss_weight: float = 1.0, 611 prior_loss_weight: float = 1.0,
613 perlin_strength: float = 0.1, 612 offset_noise_strength: float = 0.1,
614 **kwargs, 613 **kwargs,
615): 614):
616 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 615 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
@@ -642,7 +641,7 @@ def train(
642 with_prior_preservation, 641 with_prior_preservation,
643 prior_loss_weight, 642 prior_loss_weight,
644 seed, 643 seed,
645 perlin_strength, 644 offset_noise_strength,
646 ) 645 )
647 646
648 if accelerator.is_main_process: 647 if accelerator.is_main_process: