diff options
-rw-r--r-- | train_lora.py | 53 | ||||
-rw-r--r-- | train_ti.py | 53 | ||||
-rw-r--r-- | training/functional.py | 17 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 4 | ||||
-rw-r--r-- | training/strategy/lora.py | 4 | ||||
-rw-r--r-- | training/strategy/ti.py | 23 |
6 files changed, 100 insertions, 54 deletions
diff --git a/train_lora.py b/train_lora.py index 4d4c16a..ba5aee1 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -84,9 +84,9 @@ def parse_args(): | |||
84 | ) | 84 | ) |
85 | parser.add_argument( | 85 | parser.add_argument( |
86 | "--auto_cycles", | 86 | "--auto_cycles", |
87 | type=int, | 87 | type=str, |
88 | default=1, | 88 | default="o", |
89 | help="How many cycles to run automatically." | 89 | help="Cycles to run automatically." |
90 | ) | 90 | ) |
91 | parser.add_argument( | 91 | parser.add_argument( |
92 | "--cycle_decay", | 92 | "--cycle_decay", |
@@ -95,11 +95,6 @@ def parse_args(): | |||
95 | help="Learning rate decay per cycle." | 95 | help="Learning rate decay per cycle." |
96 | ) | 96 | ) |
97 | parser.add_argument( | 97 | parser.add_argument( |
98 | "--cycle_constant", | ||
99 | action="store_true", | ||
100 | help="Use constant LR on cycles > 1." | ||
101 | ) | ||
102 | parser.add_argument( | ||
103 | "--placeholder_tokens", | 98 | "--placeholder_tokens", |
104 | type=str, | 99 | type=str, |
105 | nargs='*', | 100 | nargs='*', |
@@ -920,7 +915,6 @@ def main(): | |||
920 | annealing_func=args.lr_annealing_func, | 915 | annealing_func=args.lr_annealing_func, |
921 | warmup_exp=args.lr_warmup_exp, | 916 | warmup_exp=args.lr_warmup_exp, |
922 | annealing_exp=args.lr_annealing_exp, | 917 | annealing_exp=args.lr_annealing_exp, |
923 | cycles=args.lr_cycles, | ||
924 | end_lr=1e2, | 918 | end_lr=1e2, |
925 | mid_point=args.lr_mid_point, | 919 | mid_point=args.lr_mid_point, |
926 | ) | 920 | ) |
@@ -964,20 +958,38 @@ def main(): | |||
964 | 958 | ||
965 | lora_sample_output_dir = output_dir / lora_project / "samples" | 959 | lora_sample_output_dir = output_dir / lora_project / "samples" |
966 | 960 | ||
961 | auto_cycles = list(args.auto_cycles) | ||
962 | lr_scheduler = args.lr_scheduler | ||
963 | lr_warmup_epochs = args.lr_warmup_epochs | ||
964 | lr_cycles = args.lr_cycles | ||
965 | |||
967 | while True: | 966 | while True: |
968 | if training_iter >= args.auto_cycles: | 967 | if len(auto_cycles) != 0: |
969 | response = input("Run another cycle? [y/n] ") | 968 | response = auto_cycles.pop(0) |
970 | if response.lower().strip() == "n": | 969 | else: |
971 | break | 970 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
971 | |||
972 | if response.lower().strip() == "o": | ||
973 | lr_scheduler = "one_cycle" | ||
974 | lr_warmup_epochs = args.lr_warmup_epochs | ||
975 | lr_cycles = args.lr_cycles | ||
976 | if response.lower().strip() == "w": | ||
977 | lr_scheduler = "constant" | ||
978 | lr_warmup_epochs = num_train_epochs | ||
979 | if response.lower().strip() == "c": | ||
980 | lr_scheduler = "constant" | ||
981 | lr_warmup_epochs = 0 | ||
982 | if response.lower().strip() == "d": | ||
983 | lr_scheduler = "cosine" | ||
984 | lr_warmup_epochs = 0 | ||
985 | lr_cycles = 1 | ||
986 | elif response.lower().strip() == "s": | ||
987 | break | ||
972 | 988 | ||
973 | print("") | 989 | print("") |
974 | print(f"============ LoRA cycle {training_iter + 1} ============") | 990 | print(f"============ LoRA cycle {training_iter + 1} ============") |
975 | print("") | 991 | print("") |
976 | 992 | ||
977 | if args.cycle_constant and training_iter == 1: | ||
978 | args.lr_scheduler = "constant" | ||
979 | args.lr_warmup_epochs = 0 | ||
980 | |||
981 | params_to_optimize = [] | 993 | params_to_optimize = [] |
982 | 994 | ||
983 | if len(args.placeholder_tokens) != 0: | 995 | if len(args.placeholder_tokens) != 0: |
@@ -1012,12 +1024,13 @@ def main(): | |||
1012 | lora_optimizer = create_optimizer(params_to_optimize) | 1024 | lora_optimizer = create_optimizer(params_to_optimize) |
1013 | 1025 | ||
1014 | lora_lr_scheduler = create_lr_scheduler( | 1026 | lora_lr_scheduler = create_lr_scheduler( |
1015 | args.lr_scheduler, | 1027 | lr_scheduler, |
1016 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1028 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1017 | optimizer=lora_optimizer, | 1029 | optimizer=lora_optimizer, |
1018 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), | 1030 | num_training_steps_per_epoch=len(lora_datamodule.train_dataloader), |
1019 | train_epochs=num_train_epochs, | 1031 | train_epochs=num_train_epochs, |
1020 | warmup_epochs=args.lr_warmup_epochs, | 1032 | cycles=lr_cycles, |
1033 | warmup_epochs=lr_warmup_epochs, | ||
1021 | ) | 1034 | ) |
1022 | 1035 | ||
1023 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" | 1036 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" |
@@ -1031,7 +1044,7 @@ def main(): | |||
1031 | num_train_epochs=num_train_epochs, | 1044 | num_train_epochs=num_train_epochs, |
1032 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 1045 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
1033 | global_step_offset=training_iter * num_train_steps, | 1046 | global_step_offset=training_iter * num_train_steps, |
1034 | initial_samples=training_iter == 0, | 1047 | cycle=training_iter, |
1035 | # -- | 1048 | # -- |
1036 | group_labels=group_labels, | 1049 | group_labels=group_labels, |
1037 | sample_output_dir=lora_sample_output_dir, | 1050 | sample_output_dir=lora_sample_output_dir, |
diff --git a/train_ti.py b/train_ti.py index c452269..880320f 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -68,9 +68,9 @@ def parse_args(): | |||
68 | ) | 68 | ) |
69 | parser.add_argument( | 69 | parser.add_argument( |
70 | "--auto_cycles", | 70 | "--auto_cycles", |
71 | type=int, | 71 | type=str, |
72 | default=1, | 72 | default="o", |
73 | help="How many cycles to run automatically." | 73 | help="Cycles to run automatically." |
74 | ) | 74 | ) |
75 | parser.add_argument( | 75 | parser.add_argument( |
76 | "--cycle_decay", | 76 | "--cycle_decay", |
@@ -79,11 +79,6 @@ def parse_args(): | |||
79 | help="Learning rate decay per cycle." | 79 | help="Learning rate decay per cycle." |
80 | ) | 80 | ) |
81 | parser.add_argument( | 81 | parser.add_argument( |
82 | "--cycle_constant", | ||
83 | action="store_true", | ||
84 | help="Use constant LR on cycles > 1." | ||
85 | ) | ||
86 | parser.add_argument( | ||
87 | "--placeholder_tokens", | 82 | "--placeholder_tokens", |
88 | type=str, | 83 | type=str, |
89 | nargs='*', | 84 | nargs='*', |
@@ -921,27 +916,45 @@ def main(): | |||
921 | 916 | ||
922 | sample_output_dir = output_dir / project / "samples" | 917 | sample_output_dir = output_dir / project / "samples" |
923 | 918 | ||
919 | auto_cycles = list(args.auto_cycles) | ||
920 | lr_scheduler = args.lr_scheduler | ||
921 | lr_warmup_epochs = args.lr_warmup_epochs | ||
922 | lr_cycles = args.lr_cycles | ||
923 | |||
924 | while True: | 924 | while True: |
925 | if training_iter >= args.auto_cycles: | 925 | if len(auto_cycles) != 0: |
926 | response = input("Run another cycle? [y/n] ") | 926 | response = auto_cycles.pop(0) |
927 | if response.lower().strip() == "n": | 927 | else: |
928 | break | 928 | response = input("Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") |
929 | |||
930 | if response.lower().strip() == "o": | ||
931 | lr_scheduler = "one_cycle" | ||
932 | lr_warmup_epochs = args.lr_warmup_epochs | ||
933 | lr_cycles = args.lr_cycles | ||
934 | if response.lower().strip() == "w": | ||
935 | lr_scheduler = "constant" | ||
936 | lr_warmup_epochs = num_train_epochs | ||
937 | if response.lower().strip() == "c": | ||
938 | lr_scheduler = "constant" | ||
939 | lr_warmup_epochs = 0 | ||
940 | if response.lower().strip() == "d": | ||
941 | lr_scheduler = "cosine" | ||
942 | lr_warmup_epochs = 0 | ||
943 | lr_cycles = 1 | ||
944 | elif response.lower().strip() == "s": | ||
945 | break | ||
929 | 946 | ||
930 | print("") | 947 | print("") |
931 | print(f"------------ TI cycle {training_iter + 1} ------------") | 948 | print(f"------------ TI cycle {training_iter + 1} ------------") |
932 | print("") | 949 | print("") |
933 | 950 | ||
934 | if args.cycle_constant and training_iter == 1: | ||
935 | args.lr_scheduler = "constant" | ||
936 | args.lr_warmup_epochs = 0 | ||
937 | |||
938 | optimizer = create_optimizer( | 951 | optimizer = create_optimizer( |
939 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 952 | text_encoder.text_model.embeddings.token_embedding.parameters(), |
940 | lr=learning_rate, | 953 | lr=learning_rate, |
941 | ) | 954 | ) |
942 | 955 | ||
943 | lr_scheduler = get_scheduler( | 956 | lr_scheduler = get_scheduler( |
944 | args.lr_scheduler, | 957 | lr_scheduler, |
945 | optimizer=optimizer, | 958 | optimizer=optimizer, |
946 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | 959 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
947 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 960 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
@@ -950,10 +963,10 @@ def main(): | |||
950 | annealing_func=args.lr_annealing_func, | 963 | annealing_func=args.lr_annealing_func, |
951 | warmup_exp=args.lr_warmup_exp, | 964 | warmup_exp=args.lr_warmup_exp, |
952 | annealing_exp=args.lr_annealing_exp, | 965 | annealing_exp=args.lr_annealing_exp, |
953 | cycles=args.lr_cycles, | 966 | cycles=lr_cycles, |
954 | end_lr=1e3, | 967 | end_lr=1e3, |
955 | train_epochs=num_train_epochs, | 968 | train_epochs=num_train_epochs, |
956 | warmup_epochs=args.lr_warmup_epochs, | 969 | warmup_epochs=lr_warmup_epochs, |
957 | mid_point=args.lr_mid_point, | 970 | mid_point=args.lr_mid_point, |
958 | ) | 971 | ) |
959 | 972 | ||
@@ -966,7 +979,7 @@ def main(): | |||
966 | lr_scheduler=lr_scheduler, | 979 | lr_scheduler=lr_scheduler, |
967 | num_train_epochs=num_train_epochs, | 980 | num_train_epochs=num_train_epochs, |
968 | global_step_offset=training_iter * num_train_steps, | 981 | global_step_offset=training_iter * num_train_steps, |
969 | initial_samples=training_iter == 0, | 982 | cycle=training_iter, |
970 | # -- | 983 | # -- |
971 | group_labels=["emb"], | 984 | group_labels=["emb"], |
972 | checkpoint_output_dir=checkpoint_output_dir, | 985 | checkpoint_output_dir=checkpoint_output_dir, |
diff --git a/training/functional.py b/training/functional.py index 2da0f69..ebc40de 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -42,7 +42,7 @@ class TrainingCallbacks(): | |||
42 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() | 42 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() |
43 | on_after_epoch: Callable[[], None] = const() | 43 | on_after_epoch: Callable[[], None] = const() |
44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
45 | on_sample: Callable[[int], None] = const() | 45 | on_sample: Callable[[int, int], None] = const() |
46 | on_checkpoint: Callable[[int, str], None] = const() | 46 | on_checkpoint: Callable[[int, str], None] = const() |
47 | 47 | ||
48 | 48 | ||
@@ -96,6 +96,7 @@ def save_samples( | |||
96 | output_dir: Path, | 96 | output_dir: Path, |
97 | seed: int, | 97 | seed: int, |
98 | step: int, | 98 | step: int, |
99 | cycle: int = 1, | ||
99 | batch_size: int = 1, | 100 | batch_size: int = 1, |
100 | num_batches: int = 1, | 101 | num_batches: int = 1, |
101 | num_steps: int = 20, | 102 | num_steps: int = 20, |
@@ -125,7 +126,7 @@ def save_samples( | |||
125 | 126 | ||
126 | for pool, data, gen in datasets: | 127 | for pool, data, gen in datasets: |
127 | all_samples = [] | 128 | all_samples = [] |
128 | file_path = output_dir / pool / f"step_{step}.jpg" | 129 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" |
129 | file_path.parent.mkdir(parents=True, exist_ok=True) | 130 | file_path.parent.mkdir(parents=True, exist_ok=True) |
130 | 131 | ||
131 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 132 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) |
@@ -455,7 +456,7 @@ def train_loop( | |||
455 | sample_frequency: int = 10, | 456 | sample_frequency: int = 10, |
456 | checkpoint_frequency: int = 50, | 457 | checkpoint_frequency: int = 50, |
457 | milestone_checkpoints: bool = True, | 458 | milestone_checkpoints: bool = True, |
458 | initial_samples: bool = True, | 459 | cycle: int = 1, |
459 | global_step_offset: int = 0, | 460 | global_step_offset: int = 0, |
460 | num_epochs: int = 100, | 461 | num_epochs: int = 100, |
461 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, |
@@ -518,12 +519,12 @@ def train_loop( | |||
518 | try: | 519 | try: |
519 | for epoch in range(num_epochs): | 520 | for epoch in range(num_epochs): |
520 | if accelerator.is_main_process: | 521 | if accelerator.is_main_process: |
521 | if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): | 522 | if epoch % sample_frequency == 0 and (cycle == 1 or epoch != 0): |
522 | local_progress_bar.clear() | 523 | local_progress_bar.clear() |
523 | global_progress_bar.clear() | 524 | global_progress_bar.clear() |
524 | 525 | ||
525 | with on_eval(): | 526 | with on_eval(): |
526 | on_sample(global_step) | 527 | on_sample(cycle, global_step) |
527 | 528 | ||
528 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 529 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
529 | local_progress_bar.clear() | 530 | local_progress_bar.clear() |
@@ -648,7 +649,7 @@ def train_loop( | |||
648 | if accelerator.is_main_process: | 649 | if accelerator.is_main_process: |
649 | print("Finished!") | 650 | print("Finished!") |
650 | with on_eval(): | 651 | with on_eval(): |
651 | on_sample(global_step) | 652 | on_sample(cycle, global_step) |
652 | on_checkpoint(global_step, "end") | 653 | on_checkpoint(global_step, "end") |
653 | 654 | ||
654 | except KeyboardInterrupt: | 655 | except KeyboardInterrupt: |
@@ -680,7 +681,7 @@ def train( | |||
680 | sample_frequency: int = 20, | 681 | sample_frequency: int = 20, |
681 | checkpoint_frequency: int = 50, | 682 | checkpoint_frequency: int = 50, |
682 | milestone_checkpoints: bool = True, | 683 | milestone_checkpoints: bool = True, |
683 | initial_samples: bool = True, | 684 | cycle: int = 1, |
684 | global_step_offset: int = 0, | 685 | global_step_offset: int = 0, |
685 | guidance_scale: float = 0.0, | 686 | guidance_scale: float = 0.0, |
686 | prior_loss_weight: float = 1.0, | 687 | prior_loss_weight: float = 1.0, |
@@ -731,7 +732,7 @@ def train( | |||
731 | sample_frequency=sample_frequency, | 732 | sample_frequency=sample_frequency, |
732 | checkpoint_frequency=checkpoint_frequency, | 733 | checkpoint_frequency=checkpoint_frequency, |
733 | milestone_checkpoints=milestone_checkpoints, | 734 | milestone_checkpoints=milestone_checkpoints, |
734 | initial_samples=initial_samples, | 735 | cycle=cycle, |
735 | global_step_offset=global_step_offset, | 736 | global_step_offset=global_step_offset, |
736 | num_epochs=num_train_epochs, | 737 | num_epochs=num_train_epochs, |
737 | gradient_accumulation_steps=gradient_accumulation_steps, | 738 | gradient_accumulation_steps=gradient_accumulation_steps, |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 4ae28b7..e6fcc89 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -148,7 +148,7 @@ def dreambooth_strategy_callbacks( | |||
148 | torch.cuda.empty_cache() | 148 | torch.cuda.empty_cache() |
149 | 149 | ||
150 | @torch.no_grad() | 150 | @torch.no_grad() |
151 | def on_sample(step): | 151 | def on_sample(cycle, step): |
152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
154 | 154 | ||
@@ -158,7 +158,7 @@ def dreambooth_strategy_callbacks( | |||
158 | unet_.to(dtype=weight_dtype) | 158 | unet_.to(dtype=weight_dtype) |
159 | text_encoder_.to(dtype=weight_dtype) | 159 | text_encoder_.to(dtype=weight_dtype) |
160 | 160 | ||
161 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 161 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
162 | 162 | ||
163 | unet_.to(dtype=orig_unet_dtype) | 163 | unet_.to(dtype=orig_unet_dtype) |
164 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 164 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 48236fb..5c3012e 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -146,11 +146,11 @@ def lora_strategy_callbacks( | |||
146 | torch.cuda.empty_cache() | 146 | torch.cuda.empty_cache() |
147 | 147 | ||
148 | @torch.no_grad() | 148 | @torch.no_grad() |
149 | def on_sample(step): | 149 | def on_sample(cycle, step): |
150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
152 | 152 | ||
153 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 153 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
154 | 154 | ||
155 | del unet_, text_encoder_ | 155 | del unet_, text_encoder_ |
156 | 156 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f0b84b5..6bbff64 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -104,10 +104,28 @@ def textual_inversion_strategy_callbacks( | |||
104 | yield | 104 | yield |
105 | 105 | ||
106 | @torch.no_grad() | 106 | @torch.no_grad() |
107 | def on_before_optimize(epoch: int): | ||
108 | if use_emb_decay: | ||
109 | params = [ | ||
110 | p | ||
111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
112 | if p.grad is not None | ||
113 | ] | ||
114 | return torch.stack(params) if len(params) != 0 else None | ||
115 | |||
116 | @torch.no_grad() | ||
107 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
108 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
109 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) |
110 | 120 | ||
121 | if use_emb_decay and w is not None: | ||
122 | lr = lrs["emb"] or lrs["0"] | ||
123 | lambda_ = emb_decay * lr | ||
124 | |||
125 | if lambda_ != 0: | ||
126 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
128 | |||
111 | def on_log(): | 129 | def on_log(): |
112 | if ema_embeddings is not None: | 130 | if ema_embeddings is not None: |
113 | return {"ema_decay": ema_embeddings.decay} | 131 | return {"ema_decay": ema_embeddings.decay} |
@@ -125,7 +143,7 @@ def textual_inversion_strategy_callbacks( | |||
125 | ) | 143 | ) |
126 | 144 | ||
127 | @torch.no_grad() | 145 | @torch.no_grad() |
128 | def on_sample(step): | 146 | def on_sample(cycle, step): |
129 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 147 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
130 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 148 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
131 | 149 | ||
@@ -135,7 +153,7 @@ def textual_inversion_strategy_callbacks( | |||
135 | unet_.to(dtype=weight_dtype) | 153 | unet_.to(dtype=weight_dtype) |
136 | text_encoder_.to(dtype=weight_dtype) | 154 | text_encoder_.to(dtype=weight_dtype) |
137 | 155 | ||
138 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 156 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) |
139 | 157 | ||
140 | unet_.to(dtype=orig_unet_dtype) | 158 | unet_.to(dtype=orig_unet_dtype) |
141 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 159 | text_encoder_.to(dtype=orig_text_encoder_dtype) |
@@ -148,6 +166,7 @@ def textual_inversion_strategy_callbacks( | |||
148 | return TrainingCallbacks( | 166 | return TrainingCallbacks( |
149 | on_train=on_train, | 167 | on_train=on_train, |
150 | on_eval=on_eval, | 168 | on_eval=on_eval, |
169 | on_before_optimize=on_before_optimize, | ||
151 | on_after_optimize=on_after_optimize, | 170 | on_after_optimize=on_after_optimize, |
152 | on_log=on_log, | 171 | on_log=on_log, |
153 | on_checkpoint=on_checkpoint, | 172 | on_checkpoint=on_checkpoint, |