From 21d70916f66e74a87c631a06b70774954b085b48 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Apr 2023 14:14:00 +0200 Subject: Fix --- train_lora.py | 13 ++++++++----- training/functional.py | 6 ------ training/strategy/dreambooth.py | 6 ++---- training/strategy/lora.py | 9 ++++----- training/strategy/ti.py | 6 ++---- 5 files changed, 16 insertions(+), 24 deletions(-) diff --git a/train_lora.py b/train_lora.py index daf1f6c..476efcf 100644 --- a/train_lora.py +++ b/train_lora.py @@ -548,15 +548,18 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") + if args.initializer_tokens is None: + args.initializer_tokens = [] + + if args.placeholder_tokens is None: + args.placeholder_tokens = [] + if isinstance(args.placeholder_tokens, str): args.placeholder_tokens = [args.placeholder_tokens] if isinstance(args.initializer_tokens, str): args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) - if len(args.initializer_tokens) == 0: - raise ValueError("You must specify --initializer_tokens") - if len(args.placeholder_tokens) == 0: args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] @@ -884,7 +887,7 @@ def main(): num_pti_epochs = math.ceil( args.num_pti_steps / len(pti_datamodule.train_dataset) ) * args.pti_gradient_accumulation_steps - pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_train_steps)) + pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) pti_optimizer = create_optimizer( [ @@ -915,7 +918,7 @@ def main(): # -- sample_output_dir=pti_sample_output_dir, checkpoint_output_dir=pti_checkpoint_output_dir, - sample_frequency=pti_sample_frequency, + sample_frequency=math.inf, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, use_emb_decay=args.use_emb_decay, diff --git a/training/functional.py b/training/functional.py index c30d1c0..4d83df1 100644 --- a/training/functional.py +++ b/training/functional.py @@ -34,7 +34,6 @@ def const(result=None): @dataclass class TrainingCallbacks(): - on_accum_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[float, int], Any] = const() @@ -461,7 +460,6 @@ def train_loop( ) global_progress_bar.set_description("Total progress") - model = callbacks.on_accum_model() on_log = callbacks.on_log on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize @@ -498,8 +496,6 @@ def train_loop( local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() - model.train() - with on_train(epoch): for step, batch in enumerate(train_dataloader): loss, acc, bsz = loss_step(step, batch, cache) @@ -560,8 +556,6 @@ def train_loop( on_after_epoch() if val_dataloader is not None: - model.eval() - cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 9808027..0286673 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -84,11 +84,9 @@ def dreambooth_strategy_callbacks( else: return nullcontext() - def on_accum_model(): - return unet - @contextmanager def on_train(epoch: int): + unet.train() tokenizer.train() if epoch < train_text_encoder_epochs: @@ -101,6 +99,7 @@ def dreambooth_strategy_callbacks( @contextmanager def on_eval(): + unet.eval() tokenizer.eval() text_encoder.eval() @@ -174,7 +173,6 @@ def dreambooth_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 6730dc9..80ffa9c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -64,17 +64,17 @@ def lora_strategy_callbacks( image_size=sample_image_size, ) - def on_accum_model(): - return unet - @contextmanager def on_train(epoch: int): - tokenizer.train() + unet.train() text_encoder.train() + tokenizer.train() yield @contextmanager def on_eval(): + unet.eval() + text_encoder.eval() tokenizer.eval() yield @@ -152,7 +152,6 @@ def lora_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 55e9934..6a637c3 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -89,16 +89,15 @@ def textual_inversion_strategy_callbacks( else: return nullcontext() - def on_accum_model(): - return text_encoder.text_model.embeddings.token_override_embedding.params - @contextmanager def on_train(epoch: int): + text_encoder.text_model.embeddings.token_override_embedding.params.train() tokenizer.train() yield @contextmanager def on_eval(): + text_encoder.text_model.embeddings.token_override_embedding.params.eval() tokenizer.eval() with ema_context(): @@ -166,7 +165,6 @@ def textual_inversion_strategy_callbacks( torch.cuda.empty_cache() return TrainingCallbacks( - on_accum_model=on_accum_model, on_train=on_train, on_eval=on_eval, on_before_optimize=on_before_optimize, -- cgit v1.2.3-70-g09d2