diff options
| -rw-r--r-- | train_lora.py | 6 | ||||
| -rw-r--r-- | training/functional.py | 4 | ||||
| -rw-r--r-- | training/strategy/lora.py | 8 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
4 files changed, 12 insertions, 8 deletions
diff --git a/train_lora.py b/train_lora.py index 0d8b8cb..1d1485d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -873,7 +873,6 @@ def main(): | |||
| 873 | seed=args.seed, | 873 | seed=args.seed, |
| 874 | guidance_scale=args.guidance_scale, | 874 | guidance_scale=args.guidance_scale, |
| 875 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 875 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 876 | offset_noise_strength=args.offset_noise_strength, | ||
| 877 | sample_scheduler=sample_scheduler, | 876 | sample_scheduler=sample_scheduler, |
| 878 | sample_batch_size=args.sample_batch_size, | 877 | sample_batch_size=args.sample_batch_size, |
| 879 | sample_num_batches=args.sample_batches, | 878 | sample_num_batches=args.sample_batches, |
| @@ -984,13 +983,14 @@ def main(): | |||
| 984 | lr_scheduler=pti_lr_scheduler, | 983 | lr_scheduler=pti_lr_scheduler, |
| 985 | num_train_epochs=num_train_epochs, | 984 | num_train_epochs=num_train_epochs, |
| 986 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 985 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 987 | cycle=1, | 986 | cycle=0, |
| 988 | pti_mode=True, | 987 | pti_mode=True, |
| 989 | # -- | 988 | # -- |
| 990 | group_labels=["emb"], | 989 | group_labels=["emb"], |
| 991 | sample_output_dir=pti_sample_output_dir, | 990 | sample_output_dir=pti_sample_output_dir, |
| 992 | checkpoint_output_dir=pti_checkpoint_output_dir, | 991 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 993 | sample_frequency=pti_sample_frequency, | 992 | sample_frequency=pti_sample_frequency, |
| 993 | offset_noise_strength=0, | ||
| 994 | no_val=True, | 994 | no_val=True, |
| 995 | ) | 995 | ) |
| 996 | 996 | ||
| @@ -1132,11 +1132,13 @@ def main(): | |||
| 1132 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1132 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 1133 | global_step_offset=training_iter * num_train_steps, | 1133 | global_step_offset=training_iter * num_train_steps, |
| 1134 | cycle=training_iter, | 1134 | cycle=training_iter, |
| 1135 | train_text_encoder_cycles=args.train_text_encoder_cycles, | ||
| 1135 | # -- | 1136 | # -- |
| 1136 | group_labels=group_labels, | 1137 | group_labels=group_labels, |
| 1137 | sample_output_dir=lora_sample_output_dir, | 1138 | sample_output_dir=lora_sample_output_dir, |
| 1138 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1139 | checkpoint_output_dir=lora_checkpoint_output_dir, |
| 1139 | sample_frequency=lora_sample_frequency, | 1140 | sample_frequency=lora_sample_frequency, |
| 1141 | offset_noise_strength=args.offset_noise_strength, | ||
| 1140 | no_val=args.valid_set_size == 0, | 1142 | no_val=args.valid_set_size == 0, |
| 1141 | ) | 1143 | ) |
| 1142 | 1144 | ||
diff --git a/training/functional.py b/training/functional.py index c6ceb20..695a24f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -456,7 +456,7 @@ def train_loop( | |||
| 456 | sample_frequency: int = 10, | 456 | sample_frequency: int = 10, |
| 457 | checkpoint_frequency: int = 50, | 457 | checkpoint_frequency: int = 50, |
| 458 | milestone_checkpoints: bool = True, | 458 | milestone_checkpoints: bool = True, |
| 459 | cycle: int = 1, | 459 | cycle: int = 0, |
| 460 | global_step_offset: int = 0, | 460 | global_step_offset: int = 0, |
| 461 | num_epochs: int = 100, | 461 | num_epochs: int = 100, |
| 462 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, |
| @@ -537,7 +537,7 @@ def train_loop( | |||
| 537 | 537 | ||
| 538 | logs = {} | 538 | logs = {} |
| 539 | 539 | ||
| 540 | with on_train(epoch): | 540 | with on_train(cycle): |
| 541 | for step, batch in enumerate(train_dataloader): | 541 | for step, batch in enumerate(train_dataloader): |
| 542 | loss, acc, bsz = loss_step(step, batch, cache) | 542 | loss, acc, bsz = loss_step(step, batch, cache) |
| 543 | loss /= gradient_accumulation_steps | 543 | loss /= gradient_accumulation_steps |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 5c3012e..1f0a117 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -35,6 +35,7 @@ def lora_strategy_callbacks( | |||
| 35 | placeholder_tokens: list[str], | 35 | placeholder_tokens: list[str], |
| 36 | placeholder_token_ids: list[list[int]], | 36 | placeholder_token_ids: list[list[int]], |
| 37 | pti_mode: bool = False, | 37 | pti_mode: bool = False, |
| 38 | train_text_encoder_cycles: int = 99999, | ||
| 38 | use_emb_decay: bool = False, | 39 | use_emb_decay: bool = False, |
| 39 | emb_decay_target: float = 0.4, | 40 | emb_decay_target: float = 0.4, |
| 40 | emb_decay: float = 1e-2, | 41 | emb_decay: float = 1e-2, |
| @@ -66,10 +67,11 @@ def lora_strategy_callbacks( | |||
| 66 | ) | 67 | ) |
| 67 | 68 | ||
| 68 | @contextmanager | 69 | @contextmanager |
| 69 | def on_train(epoch: int): | 70 | def on_train(cycle: int): |
| 70 | unet.train() | 71 | unet.train() |
| 71 | text_encoder.train() | 72 | if cycle < train_text_encoder_cycles: |
| 72 | tokenizer.train() | 73 | text_encoder.train() |
| 74 | tokenizer.train() | ||
| 73 | yield | 75 | yield |
| 74 | 76 | ||
| 75 | @contextmanager | 77 | @contextmanager |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f330cb7..6bc1d7d 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -90,7 +90,7 @@ def textual_inversion_strategy_callbacks( | |||
| 90 | return nullcontext() | 90 | return nullcontext() |
| 91 | 91 | ||
| 92 | @contextmanager | 92 | @contextmanager |
| 93 | def on_train(epoch: int): | 93 | def on_train(cycle: int): |
| 94 | text_encoder.train() | 94 | text_encoder.train() |
| 95 | tokenizer.train() | 95 | tokenizer.train() |
| 96 | yield | 96 | yield |
