diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/train_lora.py b/train_lora.py index 0d8b8cb..1d1485d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -873,7 +873,6 @@ def main(): | |||
| 873 | seed=args.seed, | 873 | seed=args.seed, |
| 874 | guidance_scale=args.guidance_scale, | 874 | guidance_scale=args.guidance_scale, |
| 875 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 875 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 876 | offset_noise_strength=args.offset_noise_strength, | ||
| 877 | sample_scheduler=sample_scheduler, | 876 | sample_scheduler=sample_scheduler, |
| 878 | sample_batch_size=args.sample_batch_size, | 877 | sample_batch_size=args.sample_batch_size, |
| 879 | sample_num_batches=args.sample_batches, | 878 | sample_num_batches=args.sample_batches, |
| @@ -984,13 +983,14 @@ def main(): | |||
| 984 | lr_scheduler=pti_lr_scheduler, | 983 | lr_scheduler=pti_lr_scheduler, |
| 985 | num_train_epochs=num_train_epochs, | 984 | num_train_epochs=num_train_epochs, |
| 986 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 985 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 987 | cycle=1, | 986 | cycle=0, |
| 988 | pti_mode=True, | 987 | pti_mode=True, |
| 989 | # -- | 988 | # -- |
| 990 | group_labels=["emb"], | 989 | group_labels=["emb"], |
| 991 | sample_output_dir=pti_sample_output_dir, | 990 | sample_output_dir=pti_sample_output_dir, |
| 992 | checkpoint_output_dir=pti_checkpoint_output_dir, | 991 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 993 | sample_frequency=pti_sample_frequency, | 992 | sample_frequency=pti_sample_frequency, |
| 993 | offset_noise_strength=0, | ||
| 994 | no_val=True, | 994 | no_val=True, |
| 995 | ) | 995 | ) |
| 996 | 996 | ||
| @@ -1132,11 +1132,13 @@ def main(): | |||
| 1132 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1132 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1133 | global_step_offset=training_iter * num_train_steps, | 1133 | global_step_offset=training_iter * num_train_steps, |
| 1134 | cycle=training_iter, | 1134 | cycle=training_iter, |
| 1135 | train_text_encoder_cycles=args.train_text_encoder_cycles, | ||
| 1135 | # -- | 1136 | # -- |
| 1136 | group_labels=group_labels, | 1137 | group_labels=group_labels, |
| 1137 | sample_output_dir=lora_sample_output_dir, | 1138 | sample_output_dir=lora_sample_output_dir, |
| 1138 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1139 | checkpoint_output_dir=lora_checkpoint_output_dir, |
| 1139 | sample_frequency=lora_sample_frequency, | 1140 | sample_frequency=lora_sample_frequency, |
| 1141 | offset_noise_strength=args.offset_noise_strength, | ||
| 1140 | no_val=args.valid_set_size == 0, | 1142 | no_val=args.valid_set_size == 0, |
| 1141 | ) | 1143 | ) |
| 1142 | 1144 | ||
