summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-21 09:09:50 +0100
committerVolpeon <git@volpeon.ink>2023-02-21 09:09:50 +0100
commit16b92605a59d59c65789c89b54bb97da51908056 (patch)
treeb0cbf8677897c3f44c736b710fd034eb2c5de6a0 /training/functional.py
parentUpdate (diff)
downloadtextual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.gz
textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.tar.bz2
textual-inversion-diff-16b92605a59d59c65789c89b54bb97da51908056.zip
Embedding normalization: Ignore tensors with grad = 0
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/training/functional.py b/training/functional.py
index 85dd884..739d055 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -362,6 +362,7 @@ def train_loop(
362 loss_step: LossCallable, 362 loss_step: LossCallable,
363 sample_frequency: int = 10, 363 sample_frequency: int = 10,
364 checkpoint_frequency: int = 50, 364 checkpoint_frequency: int = 50,
365 milestone_checkpoints: bool = True,
365 global_step_offset: int = 0, 366 global_step_offset: int = 0,
366 num_epochs: int = 100, 367 num_epochs: int = 100,
367 callbacks: TrainingCallbacks = TrainingCallbacks(), 368 callbacks: TrainingCallbacks = TrainingCallbacks(),
@@ -514,7 +515,7 @@ def train_loop(
514 accelerator.log(logs, step=global_step) 515 accelerator.log(logs, step=global_step)
515 516
516 if accelerator.is_main_process: 517 if accelerator.is_main_process:
517 if avg_acc_val.avg.item() > best_acc_val: 518 if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints:
518 local_progress_bar.clear() 519 local_progress_bar.clear()
519 global_progress_bar.clear() 520 global_progress_bar.clear()
520 521
@@ -527,7 +528,7 @@ def train_loop(
527 accs.append(avg_acc_val.avg.item()) 528 accs.append(avg_acc_val.avg.item())
528 else: 529 else:
529 if accelerator.is_main_process: 530 if accelerator.is_main_process:
530 if avg_acc.avg.item() > best_acc: 531 if avg_acc.avg.item() > best_acc and milestone_checkpoints:
531 local_progress_bar.clear() 532 local_progress_bar.clear()
532 global_progress_bar.clear() 533 global_progress_bar.clear()
533 534
@@ -572,6 +573,7 @@ def train(
572 num_train_epochs: int = 100, 573 num_train_epochs: int = 100,
573 sample_frequency: int = 20, 574 sample_frequency: int = 20,
574 checkpoint_frequency: int = 50, 575 checkpoint_frequency: int = 50,
576 milestone_checkpoints: bool = True,
575 global_step_offset: int = 0, 577 global_step_offset: int = 0,
576 with_prior_preservation: bool = False, 578 with_prior_preservation: bool = False,
577 prior_loss_weight: float = 1.0, 579 prior_loss_weight: float = 1.0,
@@ -626,6 +628,7 @@ def train(
626 loss_step=loss_step_, 628 loss_step=loss_step_,
627 sample_frequency=sample_frequency, 629 sample_frequency=sample_frequency,
628 checkpoint_frequency=checkpoint_frequency, 630 checkpoint_frequency=checkpoint_frequency,
631 milestone_checkpoints=milestone_checkpoints,
629 global_step_offset=global_step_offset, 632 global_step_offset=global_step_offset,
630 num_epochs=num_train_epochs, 633 num_epochs=num_train_epochs,
631 callbacks=callbacks, 634 callbacks=callbacks,