summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-13 10:45:53 +0100
committerVolpeon <git@volpeon.ink>2022-12-13 10:45:53 +0100
commit26ece1a796c7ef87ed96f5b38fab80d0ae958b9a (patch)
treed733df7000d5259c1a3ad19abd6dec982d65974d /dreambooth.py
parentAdd support for resume in Textual Inversion (diff)
downloadtextual-inversion-diff-26ece1a796c7ef87ed96f5b38fab80d0ae958b9a.tar.gz
textual-inversion-diff-26ece1a796c7ef87ed96f5b38fab80d0ae958b9a.tar.bz2
textual-inversion-diff-26ece1a796c7ef87ed96f5b38fab80d0ae958b9a.zip
Fixed sample/checkpoint frequency
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 9a6f70a..31416e9 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -625,8 +625,6 @@ def main():
625 vae.requires_grad_(False) 625 vae.requires_grad_(False)
626 626
627 if len(args.placeholder_token) != 0: 627 if len(args.placeholder_token) != 0:
628 print(f"Adding text embeddings: {args.placeholder_token}")
629
630 # Convert the initializer_token, placeholder_token to ids 628 # Convert the initializer_token, placeholder_token to ids
631 initializer_token_ids = torch.stack([ 629 initializer_token_ids = torch.stack([
632 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 630 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
@@ -1114,7 +1112,7 @@ def main():
1114 max_acc_val = avg_acc_val.avg.item() 1112 max_acc_val = avg_acc_val.avg.item()
1115 1113
1116 if accelerator.is_main_process: 1114 if accelerator.is_main_process:
1117 if epoch % args.sample_frequency == 0: 1115 if (epoch + 1) % args.sample_frequency == 0:
1118 checkpointer.save_samples(global_step, args.sample_steps) 1116 checkpointer.save_samples(global_step, args.sample_steps)
1119 1117
1120 # Create the pipeline using using the trained modules and save it. 1118 # Create the pipeline using using the trained modules and save it.