diff options
author | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-04-07 14:14:00 +0200 |
commit | 21d70916f66e74a87c631a06b70774954b085b48 (patch) | |
tree | d1b443b9270f45ae6936f3acb565f767c7c65b1f | |
parent | Run PTI only if placeholder tokens arg isn't empty (diff) | |
download | textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.gz textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.tar.bz2 textual-inversion-diff-21d70916f66e74a87c631a06b70774954b085b48.zip |
Fix
-rw-r--r-- | train_lora.py | 13 | ||||
-rw-r--r-- | training/functional.py | 6 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 6 | ||||
-rw-r--r-- | training/strategy/lora.py | 9 | ||||
-rw-r--r-- | training/strategy/ti.py | 6 |
5 files changed, 16 insertions, 24 deletions
diff --git a/train_lora.py b/train_lora.py index daf1f6c..476efcf 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -548,15 +548,18 @@ def parse_args(): | |||
548 | if args.project is None: | 548 | if args.project is None: |
549 | raise ValueError("You must specify --project") | 549 | raise ValueError("You must specify --project") |
550 | 550 | ||
551 | if args.initializer_tokens is None: | ||
552 | args.initializer_tokens = [] | ||
553 | |||
554 | if args.placeholder_tokens is None: | ||
555 | args.placeholder_tokens = [] | ||
556 | |||
551 | if isinstance(args.placeholder_tokens, str): | 557 | if isinstance(args.placeholder_tokens, str): |
552 | args.placeholder_tokens = [args.placeholder_tokens] | 558 | args.placeholder_tokens = [args.placeholder_tokens] |
553 | 559 | ||
554 | if isinstance(args.initializer_tokens, str): | 560 | if isinstance(args.initializer_tokens, str): |
555 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 561 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) |
556 | 562 | ||
557 | if len(args.initializer_tokens) == 0: | ||
558 | raise ValueError("You must specify --initializer_tokens") | ||
559 | |||
560 | if len(args.placeholder_tokens) == 0: | 563 | if len(args.placeholder_tokens) == 0: |
561 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 564 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] |
562 | 565 | ||
@@ -884,7 +887,7 @@ def main(): | |||
884 | num_pti_epochs = math.ceil( | 887 | num_pti_epochs = math.ceil( |
885 | args.num_pti_steps / len(pti_datamodule.train_dataset) | 888 | args.num_pti_steps / len(pti_datamodule.train_dataset) |
886 | ) * args.pti_gradient_accumulation_steps | 889 | ) * args.pti_gradient_accumulation_steps |
887 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) | 890 | pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) |
888 | 891 | ||
889 | pti_optimizer = create_optimizer( | 892 | pti_optimizer = create_optimizer( |
890 | [ | 893 | [ |
@@ -915,7 +918,7 @@ def main(): | |||
915 | # -- | 918 | # -- |
916 | sample_output_dir=pti_sample_output_dir, | 919 | sample_output_dir=pti_sample_output_dir, |
917 | checkpoint_output_dir=pti_checkpoint_output_dir, | 920 | checkpoint_output_dir=pti_checkpoint_output_dir, |
918 | sample_frequency=pti_sample_frequency, | 921 | sample_frequency=math.inf, |
919 | placeholder_tokens=args.placeholder_tokens, | 922 | placeholder_tokens=args.placeholder_tokens, |
920 | placeholder_token_ids=placeholder_token_ids, | 923 | placeholder_token_ids=placeholder_token_ids, |
921 | use_emb_decay=args.use_emb_decay, | 924 | use_emb_decay=args.use_emb_decay, |
diff --git a/training/functional.py b/training/functional.py index c30d1c0..4d83df1 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -34,7 +34,6 @@ def const(result=None): | |||
34 | 34 | ||
35 | @dataclass | 35 | @dataclass |
36 | class TrainingCallbacks(): | 36 | class TrainingCallbacks(): |
37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) | ||
38 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
40 | on_before_optimize: Callable[[float, int], Any] = const() | 39 | on_before_optimize: Callable[[float, int], Any] = const() |
@@ -461,7 +460,6 @@ def train_loop( | |||
461 | ) | 460 | ) |
462 | global_progress_bar.set_description("Total progress") | 461 | global_progress_bar.set_description("Total progress") |
463 | 462 | ||
464 | model = callbacks.on_accum_model() | ||
465 | on_log = callbacks.on_log | 463 | on_log = callbacks.on_log |
466 | on_train = callbacks.on_train | 464 | on_train = callbacks.on_train |
467 | on_before_optimize = callbacks.on_before_optimize | 465 | on_before_optimize = callbacks.on_before_optimize |
@@ -498,8 +496,6 @@ def train_loop( | |||
498 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 496 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
499 | local_progress_bar.reset() | 497 | local_progress_bar.reset() |
500 | 498 | ||
501 | model.train() | ||
502 | |||
503 | with on_train(epoch): | 499 | with on_train(epoch): |
504 | for step, batch in enumerate(train_dataloader): | 500 | for step, batch in enumerate(train_dataloader): |
505 | loss, acc, bsz = loss_step(step, batch, cache) | 501 | loss, acc, bsz = loss_step(step, batch, cache) |
@@ -560,8 +556,6 @@ def train_loop( | |||
560 | on_after_epoch() | 556 | on_after_epoch() |
561 | 557 | ||
562 | if val_dataloader is not None: | 558 | if val_dataloader is not None: |
563 | model.eval() | ||
564 | |||
565 | cur_loss_val = AverageMeter() | 559 | cur_loss_val = AverageMeter() |
566 | cur_acc_val = AverageMeter() | 560 | cur_acc_val = AverageMeter() |
567 | 561 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 9808027..0286673 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -84,11 +84,9 @@ def dreambooth_strategy_callbacks( | |||
84 | else: | 84 | else: |
85 | return nullcontext() | 85 | return nullcontext() |
86 | 86 | ||
87 | def on_accum_model(): | ||
88 | return unet | ||
89 | |||
90 | @contextmanager | 87 | @contextmanager |
91 | def on_train(epoch: int): | 88 | def on_train(epoch: int): |
89 | unet.train() | ||
92 | tokenizer.train() | 90 | tokenizer.train() |
93 | 91 | ||
94 | if epoch < train_text_encoder_epochs: | 92 | if epoch < train_text_encoder_epochs: |
@@ -101,6 +99,7 @@ def dreambooth_strategy_callbacks( | |||
101 | 99 | ||
102 | @contextmanager | 100 | @contextmanager |
103 | def on_eval(): | 101 | def on_eval(): |
102 | unet.eval() | ||
104 | tokenizer.eval() | 103 | tokenizer.eval() |
105 | text_encoder.eval() | 104 | text_encoder.eval() |
106 | 105 | ||
@@ -174,7 +173,6 @@ def dreambooth_strategy_callbacks( | |||
174 | torch.cuda.empty_cache() | 173 | torch.cuda.empty_cache() |
175 | 174 | ||
176 | return TrainingCallbacks( | 175 | return TrainingCallbacks( |
177 | on_accum_model=on_accum_model, | ||
178 | on_train=on_train, | 176 | on_train=on_train, |
179 | on_eval=on_eval, | 177 | on_eval=on_eval, |
180 | on_before_optimize=on_before_optimize, | 178 | on_before_optimize=on_before_optimize, |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 6730dc9..80ffa9c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -64,17 +64,17 @@ def lora_strategy_callbacks( | |||
64 | image_size=sample_image_size, | 64 | image_size=sample_image_size, |
65 | ) | 65 | ) |
66 | 66 | ||
67 | def on_accum_model(): | ||
68 | return unet | ||
69 | |||
70 | @contextmanager | 67 | @contextmanager |
71 | def on_train(epoch: int): | 68 | def on_train(epoch: int): |
72 | tokenizer.train() | 69 | unet.train() |
73 | text_encoder.train() | 70 | text_encoder.train() |
71 | tokenizer.train() | ||
74 | yield | 72 | yield |
75 | 73 | ||
76 | @contextmanager | 74 | @contextmanager |
77 | def on_eval(): | 75 | def on_eval(): |
76 | unet.eval() | ||
77 | text_encoder.eval() | ||
78 | tokenizer.eval() | 78 | tokenizer.eval() |
79 | yield | 79 | yield |
80 | 80 | ||
@@ -152,7 +152,6 @@ def lora_strategy_callbacks( | |||
152 | torch.cuda.empty_cache() | 152 | torch.cuda.empty_cache() |
153 | 153 | ||
154 | return TrainingCallbacks( | 154 | return TrainingCallbacks( |
155 | on_accum_model=on_accum_model, | ||
156 | on_train=on_train, | 155 | on_train=on_train, |
157 | on_eval=on_eval, | 156 | on_eval=on_eval, |
158 | on_before_optimize=on_before_optimize, | 157 | on_before_optimize=on_before_optimize, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 55e9934..6a637c3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -89,16 +89,15 @@ def textual_inversion_strategy_callbacks( | |||
89 | else: | 89 | else: |
90 | return nullcontext() | 90 | return nullcontext() |
91 | 91 | ||
92 | def on_accum_model(): | ||
93 | return text_encoder.text_model.embeddings.token_override_embedding.params | ||
94 | |||
95 | @contextmanager | 92 | @contextmanager |
96 | def on_train(epoch: int): | 93 | def on_train(epoch: int): |
94 | text_encoder.text_model.embeddings.token_override_embedding.params.train() | ||
97 | tokenizer.train() | 95 | tokenizer.train() |
98 | yield | 96 | yield |
99 | 97 | ||
100 | @contextmanager | 98 | @contextmanager |
101 | def on_eval(): | 99 | def on_eval(): |
100 | text_encoder.text_model.embeddings.token_override_embedding.params.eval() | ||
102 | tokenizer.eval() | 101 | tokenizer.eval() |
103 | 102 | ||
104 | with ema_context(): | 103 | with ema_context(): |
@@ -166,7 +165,6 @@ def textual_inversion_strategy_callbacks( | |||
166 | torch.cuda.empty_cache() | 165 | torch.cuda.empty_cache() |
167 | 166 | ||
168 | return TrainingCallbacks( | 167 | return TrainingCallbacks( |
169 | on_accum_model=on_accum_model, | ||
170 | on_train=on_train, | 168 | on_train=on_train, |
171 | on_eval=on_eval, | 169 | on_eval=on_eval, |
172 | on_before_optimize=on_before_optimize, | 170 | on_before_optimize=on_before_optimize, |