diff options
| author | Volpeon <git@volpeon.ink> | 2023-04-13 09:13:35 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-04-13 09:13:35 +0200 |
| commit | 24f7a14defdbc050907ad282ebdc7ea6f6591363 (patch) | |
| tree | 996eaa43214f7c1ce7634605a8c08c9204898944 /training | |
| parent | Update (diff) | |
| download | textual-inversion-diff-24f7a14defdbc050907ad282ebdc7ea6f6591363.tar.gz textual-inversion-diff-24f7a14defdbc050907ad282ebdc7ea6f6591363.tar.bz2 textual-inversion-diff-24f7a14defdbc050907ad282ebdc7ea6f6591363.zip | |
Added cycle LR decay
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/training/functional.py b/training/functional.py index ed8ae3a..54bbe78 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -451,6 +451,7 @@ def train_loop( | |||
| 451 | sample_frequency: int = 10, | 451 | sample_frequency: int = 10, |
| 452 | checkpoint_frequency: int = 50, | 452 | checkpoint_frequency: int = 50, |
| 453 | milestone_checkpoints: bool = True, | 453 | milestone_checkpoints: bool = True, |
| 454 | initial_samples: bool = True, | ||
| 454 | global_step_offset: int = 0, | 455 | global_step_offset: int = 0, |
| 455 | num_epochs: int = 100, | 456 | num_epochs: int = 100, |
| 456 | gradient_accumulation_steps: int = 1, | 457 | gradient_accumulation_steps: int = 1, |
| @@ -513,7 +514,7 @@ def train_loop( | |||
| 513 | try: | 514 | try: |
| 514 | for epoch in range(num_epochs): | 515 | for epoch in range(num_epochs): |
| 515 | if accelerator.is_main_process: | 516 | if accelerator.is_main_process: |
| 516 | if epoch % sample_frequency == 0: | 517 | if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): |
| 517 | local_progress_bar.clear() | 518 | local_progress_bar.clear() |
| 518 | global_progress_bar.clear() | 519 | global_progress_bar.clear() |
| 519 | 520 | ||
| @@ -673,6 +674,7 @@ def train( | |||
| 673 | sample_frequency: int = 20, | 674 | sample_frequency: int = 20, |
| 674 | checkpoint_frequency: int = 50, | 675 | checkpoint_frequency: int = 50, |
| 675 | milestone_checkpoints: bool = True, | 676 | milestone_checkpoints: bool = True, |
| 677 | initial_samples: bool = True, | ||
| 676 | global_step_offset: int = 0, | 678 | global_step_offset: int = 0, |
| 677 | guidance_scale: float = 0.0, | 679 | guidance_scale: float = 0.0, |
| 678 | prior_loss_weight: float = 1.0, | 680 | prior_loss_weight: float = 1.0, |
| @@ -723,6 +725,7 @@ def train( | |||
| 723 | sample_frequency=sample_frequency, | 725 | sample_frequency=sample_frequency, |
| 724 | checkpoint_frequency=checkpoint_frequency, | 726 | checkpoint_frequency=checkpoint_frequency, |
| 725 | milestone_checkpoints=milestone_checkpoints, | 727 | milestone_checkpoints=milestone_checkpoints, |
| 728 | initial_samples=initial_samples, | ||
| 726 | global_step_offset=global_step_offset, | 729 | global_step_offset=global_step_offset, |
| 727 | num_epochs=num_train_epochs, | 730 | num_epochs=num_train_epochs, |
| 728 | gradient_accumulation_steps=gradient_accumulation_steps, | 731 | gradient_accumulation_steps=gradient_accumulation_steps, |
