diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 0d8b8cb..1d1485d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -873,7 +873,6 @@ def main(): | |||
873 | seed=args.seed, | 873 | seed=args.seed, |
874 | guidance_scale=args.guidance_scale, | 874 | guidance_scale=args.guidance_scale, |
875 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 875 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
876 | offset_noise_strength=args.offset_noise_strength, | ||
877 | sample_scheduler=sample_scheduler, | 876 | sample_scheduler=sample_scheduler, |
878 | sample_batch_size=args.sample_batch_size, | 877 | sample_batch_size=args.sample_batch_size, |
879 | sample_num_batches=args.sample_batches, | 878 | sample_num_batches=args.sample_batches, |
@@ -984,13 +983,14 @@ def main(): | |||
984 | lr_scheduler=pti_lr_scheduler, | 983 | lr_scheduler=pti_lr_scheduler, |
985 | num_train_epochs=num_train_epochs, | 984 | num_train_epochs=num_train_epochs, |
986 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 985 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
987 | cycle=1, | 986 | cycle=0, |
988 | pti_mode=True, | 987 | pti_mode=True, |
989 | # -- | 988 | # -- |
990 | group_labels=["emb"], | 989 | group_labels=["emb"], |
991 | sample_output_dir=pti_sample_output_dir, | 990 | sample_output_dir=pti_sample_output_dir, |
992 | checkpoint_output_dir=pti_checkpoint_output_dir, | 991 | checkpoint_output_dir=pti_checkpoint_output_dir, |
993 | sample_frequency=pti_sample_frequency, | 992 | sample_frequency=pti_sample_frequency, |
993 | offset_noise_strength=0, | ||
994 | no_val=True, | 994 | no_val=True, |
995 | ) | 995 | ) |
996 | 996 | ||
@@ -1132,11 +1132,13 @@ def main(): | |||
1132 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1132 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1133 | global_step_offset=training_iter * num_train_steps, | 1133 | global_step_offset=training_iter * num_train_steps, |
1134 | cycle=training_iter, | 1134 | cycle=training_iter, |
1135 | train_text_encoder_cycles=args.train_text_encoder_cycles, | ||
1135 | # -- | 1136 | # -- |
1136 | group_labels=group_labels, | 1137 | group_labels=group_labels, |
1137 | sample_output_dir=lora_sample_output_dir, | 1138 | sample_output_dir=lora_sample_output_dir, |
1138 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1139 | checkpoint_output_dir=lora_checkpoint_output_dir, |
1139 | sample_frequency=lora_sample_frequency, | 1140 | sample_frequency=lora_sample_frequency, |
1141 | offset_noise_strength=args.offset_noise_strength, | ||
1140 | no_val=args.valid_set_size == 0, | 1142 | no_val=args.valid_set_size == 0, |
1141 | ) | 1143 | ) |
1142 | 1144 | ||