summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-21 11:43:50 +0200
committerVolpeon <git@volpeon.ink>2023-04-21 11:43:50 +0200
commit7da4f0485032bb8b8acfc678546ffcea3a23a44b (patch)
tree1e7880189df21132861114b5dbf4c614405c9855 /train_lora.py
parentFix PTI (diff)
downloadtextual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.gz
textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.tar.bz2
textual-inversion-diff-7da4f0485032bb8b8acfc678546ffcea3a23a44b.zip
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py
index 0d8b8cb..1d1485d 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -873,7 +873,6 @@ def main():
873 seed=args.seed, 873 seed=args.seed,
874 guidance_scale=args.guidance_scale, 874 guidance_scale=args.guidance_scale,
875 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 875 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
876 offset_noise_strength=args.offset_noise_strength,
877 sample_scheduler=sample_scheduler, 876 sample_scheduler=sample_scheduler,
878 sample_batch_size=args.sample_batch_size, 877 sample_batch_size=args.sample_batch_size,
879 sample_num_batches=args.sample_batches, 878 sample_num_batches=args.sample_batches,
@@ -984,13 +983,14 @@ def main():
984 lr_scheduler=pti_lr_scheduler, 983 lr_scheduler=pti_lr_scheduler,
985 num_train_epochs=num_train_epochs, 984 num_train_epochs=num_train_epochs,
986 gradient_accumulation_steps=args.gradient_accumulation_steps, 985 gradient_accumulation_steps=args.gradient_accumulation_steps,
987 cycle=1, 986 cycle=0,
988 pti_mode=True, 987 pti_mode=True,
989 # -- 988 # --
990 group_labels=["emb"], 989 group_labels=["emb"],
991 sample_output_dir=pti_sample_output_dir, 990 sample_output_dir=pti_sample_output_dir,
992 checkpoint_output_dir=pti_checkpoint_output_dir, 991 checkpoint_output_dir=pti_checkpoint_output_dir,
993 sample_frequency=pti_sample_frequency, 992 sample_frequency=pti_sample_frequency,
993 offset_noise_strength=0,
994 no_val=True, 994 no_val=True,
995 ) 995 )
996 996
@@ -1132,11 +1132,13 @@ def main():
1132 gradient_accumulation_steps=args.gradient_accumulation_steps, 1132 gradient_accumulation_steps=args.gradient_accumulation_steps,
1133 global_step_offset=training_iter * num_train_steps, 1133 global_step_offset=training_iter * num_train_steps,
1134 cycle=training_iter, 1134 cycle=training_iter,
1135 train_text_encoder_cycles=args.train_text_encoder_cycles,
1135 # -- 1136 # --
1136 group_labels=group_labels, 1137 group_labels=group_labels,
1137 sample_output_dir=lora_sample_output_dir, 1138 sample_output_dir=lora_sample_output_dir,
1138 checkpoint_output_dir=lora_checkpoint_output_dir, 1139 checkpoint_output_dir=lora_checkpoint_output_dir,
1139 sample_frequency=lora_sample_frequency, 1140 sample_frequency=lora_sample_frequency,
1141 offset_noise_strength=args.offset_noise_strength,
1140 no_val=args.valid_set_size == 0, 1142 no_val=args.valid_set_size == 0,
1141 ) 1143 )
1142 1144