diff options
author | Volpeon <git@volpeon.ink> | 2023-03-24 17:23:09 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-24 17:23:09 +0100 |
commit | 9bfb4a078f63a7ce6e35e89093f17febd9ff4b51 (patch) | |
tree | 41b83780c79803531c7208a72bff9206ffa908da | |
parent | Fixed Lora training perf issue (diff) | |
download | textual-inversion-diff-9bfb4a078f63a7ce6e35e89093f17febd9ff4b51.tar.gz textual-inversion-diff-9bfb4a078f63a7ce6e35e89093f17febd9ff4b51.tar.bz2 textual-inversion-diff-9bfb4a078f63a7ce6e35e89093f17febd9ff4b51.zip |
Update
-rw-r--r-- | train_dreambooth.py | 7 | ||||
-rw-r--r-- | train_lora.py | 7 | ||||
-rw-r--r-- | train_ti.py | 6 | ||||
-rw-r--r-- | training/functional.py | 11 |
4 files changed, 22 insertions, 9 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index a85ae4c..1b8a3d2 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -176,6 +176,12 @@ def parse_args(): | |||
176 | ), | 176 | ), |
177 | ) | 177 | ) |
178 | parser.add_argument( | 178 | parser.add_argument( |
179 | "--offset_noise_strength", | ||
180 | type=float, | ||
181 | default=0.15, | ||
182 | help="Perlin offset noise strength.", | ||
183 | ) | ||
184 | parser.add_argument( | ||
179 | "--num_train_epochs", | 185 | "--num_train_epochs", |
180 | type=int, | 186 | type=int, |
181 | default=100 | 187 | default=100 |
@@ -621,6 +627,7 @@ def main(): | |||
621 | num_train_epochs=args.num_train_epochs, | 627 | num_train_epochs=args.num_train_epochs, |
622 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 628 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
623 | sample_frequency=args.sample_frequency, | 629 | sample_frequency=args.sample_frequency, |
630 | offset_noise_strength=args.offset_noise_strength, | ||
624 | # -- | 631 | # -- |
625 | tokenizer=tokenizer, | 632 | tokenizer=tokenizer, |
626 | sample_scheduler=sample_scheduler, | 633 | sample_scheduler=sample_scheduler, |
diff --git a/train_lora.py b/train_lora.py index fa24cee..b16a99b 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -165,6 +165,12 @@ def parse_args(): | |||
165 | ), | 165 | ), |
166 | ) | 166 | ) |
167 | parser.add_argument( | 167 | parser.add_argument( |
168 | "--offset_noise_strength", | ||
169 | type=float, | ||
170 | default=0.15, | ||
171 | help="Perlin offset noise strength.", | ||
172 | ) | ||
173 | parser.add_argument( | ||
168 | "--num_train_epochs", | 174 | "--num_train_epochs", |
169 | type=int, | 175 | type=int, |
170 | default=100 | 176 | default=100 |
@@ -649,6 +655,7 @@ def main(): | |||
649 | num_train_epochs=args.num_train_epochs, | 655 | num_train_epochs=args.num_train_epochs, |
650 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 656 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
651 | sample_frequency=args.sample_frequency, | 657 | sample_frequency=args.sample_frequency, |
658 | offset_noise_strength=args.offset_noise_strength, | ||
652 | # -- | 659 | # -- |
653 | tokenizer=tokenizer, | 660 | tokenizer=tokenizer, |
654 | sample_scheduler=sample_scheduler, | 661 | sample_scheduler=sample_scheduler, |
diff --git a/train_ti.py b/train_ti.py index ef71f6f..bbc5524 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -188,9 +188,9 @@ def parse_args(): | |||
188 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 188 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
189 | ) | 189 | ) |
190 | parser.add_argument( | 190 | parser.add_argument( |
191 | "--perlin_strength", | 191 | "--offset_noise_strength", |
192 | type=float, | 192 | type=float, |
193 | default=0.1, | 193 | default=0.15, |
194 | help="Perlin offset noise strength.", | 194 | help="Perlin offset noise strength.", |
195 | ) | 195 | ) |
196 | parser.add_argument( | 196 | parser.add_argument( |
@@ -661,7 +661,7 @@ def main(): | |||
661 | checkpoint_frequency=args.checkpoint_frequency, | 661 | checkpoint_frequency=args.checkpoint_frequency, |
662 | milestone_checkpoints=not args.no_milestone_checkpoints, | 662 | milestone_checkpoints=not args.no_milestone_checkpoints, |
663 | global_step_offset=global_step_offset, | 663 | global_step_offset=global_step_offset, |
664 | perlin_strength=args.perlin_strength, | 664 | offset_noise_strength=args.offset_noise_strength, |
665 | # -- | 665 | # -- |
666 | tokenizer=tokenizer, | 666 | tokenizer=tokenizer, |
667 | sample_scheduler=sample_scheduler, | 667 | sample_scheduler=sample_scheduler, |
diff --git a/training/functional.py b/training/functional.py index ee73ab2..87bb339 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -277,7 +277,7 @@ def loss_step( | |||
277 | with_prior_preservation: bool, | 277 | with_prior_preservation: bool, |
278 | prior_loss_weight: float, | 278 | prior_loss_weight: float, |
279 | seed: int, | 279 | seed: int, |
280 | perlin_strength: float, | 280 | offset_noise_strength: float, |
281 | step: int, | 281 | step: int, |
282 | batch: dict[str, Any], | 282 | batch: dict[str, Any], |
283 | eval: bool = False, | 283 | eval: bool = False, |
@@ -300,11 +300,10 @@ def loss_step( | |||
300 | generator=generator | 300 | generator=generator |
301 | ) | 301 | ) |
302 | 302 | ||
303 | if perlin_strength != 0: | 303 | if offset_noise_strength != 0: |
304 | noise += perlin_strength * perlin_noise( | 304 | noise += offset_noise_strength * perlin_noise( |
305 | latents.shape, | 305 | latents.shape, |
306 | res=1, | 306 | res=1, |
307 | octaves=4, | ||
308 | dtype=latents.dtype, | 307 | dtype=latents.dtype, |
309 | device=latents.device, | 308 | device=latents.device, |
310 | generator=generator | 309 | generator=generator |
@@ -610,7 +609,7 @@ def train( | |||
610 | global_step_offset: int = 0, | 609 | global_step_offset: int = 0, |
611 | with_prior_preservation: bool = False, | 610 | with_prior_preservation: bool = False, |
612 | prior_loss_weight: float = 1.0, | 611 | prior_loss_weight: float = 1.0, |
613 | perlin_strength: float = 0.1, | 612 | offset_noise_strength: float = 0.1, |
614 | **kwargs, | 613 | **kwargs, |
615 | ): | 614 | ): |
616 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 615 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -642,7 +641,7 @@ def train( | |||
642 | with_prior_preservation, | 641 | with_prior_preservation, |
643 | prior_loss_weight, | 642 | prior_loss_weight, |
644 | seed, | 643 | seed, |
645 | perlin_strength, | 644 | offset_noise_strength, |
646 | ) | 645 | ) |
647 | 646 | ||
648 | if accelerator.is_main_process: | 647 | if accelerator.is_main_process: |