diff options
author | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-16 19:03:25 +0200 |
commit | 71f4a40bb48be4f2759ba2d83faff39691cb2955 (patch) | |
tree | 29c704ca549a4c4323403b6cbb0e62f54040ae22 /training/functional.py | |
parent | Added option to use constant LR on cycles > 1 (diff) | |
download | textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.gz textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.bz2 textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.zip |
Improved automation caps
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py index 2da0f69..ebc40de 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -42,7 +42,7 @@ class TrainingCallbacks(): | |||
42 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() | 42 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() |
43 | on_after_epoch: Callable[[], None] = const() | 43 | on_after_epoch: Callable[[], None] = const() |
44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
45 | on_sample: Callable[[int], None] = const() | 45 | on_sample: Callable[[int, int], None] = const() |
46 | on_checkpoint: Callable[[int, str], None] = const() | 46 | on_checkpoint: Callable[[int, str], None] = const() |
47 | 47 | ||
48 | 48 | ||
@@ -96,6 +96,7 @@ def save_samples( | |||
96 | output_dir: Path, | 96 | output_dir: Path, |
97 | seed: int, | 97 | seed: int, |
98 | step: int, | 98 | step: int, |
99 | cycle: int = 1, | ||
99 | batch_size: int = 1, | 100 | batch_size: int = 1, |
100 | num_batches: int = 1, | 101 | num_batches: int = 1, |
101 | num_steps: int = 20, | 102 | num_steps: int = 20, |
@@ -125,7 +126,7 @@ def save_samples( | |||
125 | 126 | ||
126 | for pool, data, gen in datasets: | 127 | for pool, data, gen in datasets: |
127 | all_samples = [] | 128 | all_samples = [] |
128 | file_path = output_dir / pool / f"step_{step}.jpg" | 129 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" |
129 | file_path.parent.mkdir(parents=True, exist_ok=True) | 130 | file_path.parent.mkdir(parents=True, exist_ok=True) |
130 | 131 | ||
131 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 132 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) |
@@ -455,7 +456,7 @@ def train_loop( | |||
455 | sample_frequency: int = 10, | 456 | sample_frequency: int = 10, |
456 | checkpoint_frequency: int = 50, | 457 | checkpoint_frequency: int = 50, |
457 | milestone_checkpoints: bool = True, | 458 | milestone_checkpoints: bool = True, |
458 | initial_samples: bool = True, | 459 | cycle: int = 1, |
459 | global_step_offset: int = 0, | 460 | global_step_offset: int = 0, |
460 | num_epochs: int = 100, | 461 | num_epochs: int = 100, |
461 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, |
@@ -518,12 +519,12 @@ def train_loop( | |||
518 | try: | 519 | try: |
519 | for epoch in range(num_epochs): | 520 | for epoch in range(num_epochs): |
520 | if accelerator.is_main_process: | 521 | if accelerator.is_main_process: |
521 | if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): | 522 | if epoch % sample_frequency == 0 and (cycle == 1 or epoch != 0): |
522 | local_progress_bar.clear() | 523 | local_progress_bar.clear() |
523 | global_progress_bar.clear() | 524 | global_progress_bar.clear() |
524 | 525 | ||
525 | with on_eval(): | 526 | with on_eval(): |
526 | on_sample(global_step) | 527 | on_sample(cycle, global_step) |
527 | 528 | ||
528 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 529 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
529 | local_progress_bar.clear() | 530 | local_progress_bar.clear() |
@@ -648,7 +649,7 @@ def train_loop( | |||
648 | if accelerator.is_main_process: | 649 | if accelerator.is_main_process: |
649 | print("Finished!") | 650 | print("Finished!") |
650 | with on_eval(): | 651 | with on_eval(): |
651 | on_sample(global_step) | 652 | on_sample(cycle, global_step) |
652 | on_checkpoint(global_step, "end") | 653 | on_checkpoint(global_step, "end") |
653 | 654 | ||
654 | except KeyboardInterrupt: | 655 | except KeyboardInterrupt: |
@@ -680,7 +681,7 @@ def train( | |||
680 | sample_frequency: int = 20, | 681 | sample_frequency: int = 20, |
681 | checkpoint_frequency: int = 50, | 682 | checkpoint_frequency: int = 50, |
682 | milestone_checkpoints: bool = True, | 683 | milestone_checkpoints: bool = True, |
683 | initial_samples: bool = True, | 684 | cycle: int = 1, |
684 | global_step_offset: int = 0, | 685 | global_step_offset: int = 0, |
685 | guidance_scale: float = 0.0, | 686 | guidance_scale: float = 0.0, |
686 | prior_loss_weight: float = 1.0, | 687 | prior_loss_weight: float = 1.0, |
@@ -731,7 +732,7 @@ def train( | |||
731 | sample_frequency=sample_frequency, | 732 | sample_frequency=sample_frequency, |
732 | checkpoint_frequency=checkpoint_frequency, | 733 | checkpoint_frequency=checkpoint_frequency, |
733 | milestone_checkpoints=milestone_checkpoints, | 734 | milestone_checkpoints=milestone_checkpoints, |
734 | initial_samples=initial_samples, | 735 | cycle=cycle, |
735 | global_step_offset=global_step_offset, | 736 | global_step_offset=global_step_offset, |
736 | num_epochs=num_train_epochs, | 737 | num_epochs=num_train_epochs, |
737 | gradient_accumulation_steps=gradient_accumulation_steps, | 738 | gradient_accumulation_steps=gradient_accumulation_steps, |