summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-16 19:03:25 +0200
committerVolpeon <git@volpeon.ink>2023-04-16 19:03:25 +0200
commit71f4a40bb48be4f2759ba2d83faff39691cb2955 (patch)
tree29c704ca549a4c4323403b6cbb0e62f54040ae22 /training/functional.py
parentAdded option to use constant LR on cycles > 1 (diff)
downloadtextual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.gz
textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.tar.bz2
textual-inversion-diff-71f4a40bb48be4f2759ba2d83faff39691cb2955.zip
Improved automation caps
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py
index 2da0f69..ebc40de 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -42,7 +42,7 @@ class TrainingCallbacks():
42 on_after_optimize: Callable[[Any, dict[str, float]], None] = const() 42 on_after_optimize: Callable[[Any, dict[str, float]], None] = const()
43 on_after_epoch: Callable[[], None] = const() 43 on_after_epoch: Callable[[], None] = const()
44 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 44 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
45 on_sample: Callable[[int], None] = const() 45 on_sample: Callable[[int, int], None] = const()
46 on_checkpoint: Callable[[int, str], None] = const() 46 on_checkpoint: Callable[[int, str], None] = const()
47 47
48 48
@@ -96,6 +96,7 @@ def save_samples(
96 output_dir: Path, 96 output_dir: Path,
97 seed: int, 97 seed: int,
98 step: int, 98 step: int,
99 cycle: int = 1,
99 batch_size: int = 1, 100 batch_size: int = 1,
100 num_batches: int = 1, 101 num_batches: int = 1,
101 num_steps: int = 20, 102 num_steps: int = 20,
@@ -125,7 +126,7 @@ def save_samples(
125 126
126 for pool, data, gen in datasets: 127 for pool, data, gen in datasets:
127 all_samples = [] 128 all_samples = []
128 file_path = output_dir / pool / f"step_{step}.jpg" 129 file_path = output_dir / pool / f"step_{cycle}_{step}.jpg"
129 file_path.parent.mkdir(parents=True, exist_ok=True) 130 file_path.parent.mkdir(parents=True, exist_ok=True)
130 131
131 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) 132 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches))
@@ -455,7 +456,7 @@ def train_loop(
455 sample_frequency: int = 10, 456 sample_frequency: int = 10,
456 checkpoint_frequency: int = 50, 457 checkpoint_frequency: int = 50,
457 milestone_checkpoints: bool = True, 458 milestone_checkpoints: bool = True,
458 initial_samples: bool = True, 459 cycle: int = 1,
459 global_step_offset: int = 0, 460 global_step_offset: int = 0,
460 num_epochs: int = 100, 461 num_epochs: int = 100,
461 gradient_accumulation_steps: int = 1, 462 gradient_accumulation_steps: int = 1,
@@ -518,12 +519,12 @@ def train_loop(
518 try: 519 try:
519 for epoch in range(num_epochs): 520 for epoch in range(num_epochs):
520 if accelerator.is_main_process: 521 if accelerator.is_main_process:
521 if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): 522 if epoch % sample_frequency == 0 and (cycle == 1 or epoch != 0):
522 local_progress_bar.clear() 523 local_progress_bar.clear()
523 global_progress_bar.clear() 524 global_progress_bar.clear()
524 525
525 with on_eval(): 526 with on_eval():
526 on_sample(global_step) 527 on_sample(cycle, global_step)
527 528
528 if epoch % checkpoint_frequency == 0 and epoch != 0: 529 if epoch % checkpoint_frequency == 0 and epoch != 0:
529 local_progress_bar.clear() 530 local_progress_bar.clear()
@@ -648,7 +649,7 @@ def train_loop(
648 if accelerator.is_main_process: 649 if accelerator.is_main_process:
649 print("Finished!") 650 print("Finished!")
650 with on_eval(): 651 with on_eval():
651 on_sample(global_step) 652 on_sample(cycle, global_step)
652 on_checkpoint(global_step, "end") 653 on_checkpoint(global_step, "end")
653 654
654 except KeyboardInterrupt: 655 except KeyboardInterrupt:
@@ -680,7 +681,7 @@ def train(
680 sample_frequency: int = 20, 681 sample_frequency: int = 20,
681 checkpoint_frequency: int = 50, 682 checkpoint_frequency: int = 50,
682 milestone_checkpoints: bool = True, 683 milestone_checkpoints: bool = True,
683 initial_samples: bool = True, 684 cycle: int = 1,
684 global_step_offset: int = 0, 685 global_step_offset: int = 0,
685 guidance_scale: float = 0.0, 686 guidance_scale: float = 0.0,
686 prior_loss_weight: float = 1.0, 687 prior_loss_weight: float = 1.0,
@@ -731,7 +732,7 @@ def train(
731 sample_frequency=sample_frequency, 732 sample_frequency=sample_frequency,
732 checkpoint_frequency=checkpoint_frequency, 733 checkpoint_frequency=checkpoint_frequency,
733 milestone_checkpoints=milestone_checkpoints, 734 milestone_checkpoints=milestone_checkpoints,
734 initial_samples=initial_samples, 735 cycle=cycle,
735 global_step_offset=global_step_offset, 736 global_step_offset=global_step_offset,
736 num_epochs=num_train_epochs, 737 num_epochs=num_train_epochs,
737 gradient_accumulation_steps=gradient_accumulation_steps, 738 gradient_accumulation_steps=gradient_accumulation_steps,