From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- training/strategy/dreambooth.py | 29 +++++++++++++++++------------ training/strategy/lora.py | 41 +++++++++++++++++++++++++++++------------ training/strategy/ti.py | 27 +++++++++++++++++++-------- 3 files changed, 65 insertions(+), 32 deletions(-) (limited to 'training/strategy') diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, - train_text_encoder_epochs: int, + train_text_encoder_cycles: int, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, @@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( return nullcontext() @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): unet.train() tokenizer.train() - if epoch < train_text_encoder_epochs: + if cycle < train_text_encoder_cycles: text_encoder.train() - elif epoch == train_text_encoder_epochs: - text_encoder.requires_grad_(False) - text_encoder.eval() + tokenizer.train() yield @@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( with ema_context(): yield - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): params_to_clip = [unet.parameters()] - if epoch < train_text_encoder_epochs: + if cycle < train_text_encoder_cycles: params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) @@ -189,8 +187,16 @@ def dreambooth_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) text_encoder.text_model.embeddings.requires_grad_(False) @@ -198,6 +204,5 @@ def dreambooth_prepare( dreambooth_strategy = TrainingStrategy( - callbacks=dreambooth_strategy_callbacks, - prepare=dreambooth_prepare + callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare ) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -81,7 +81,7 @@ def lora_strategy_callbacks( tokenizer.eval() yield - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): if not pti_mode: accelerator.clip_grad_norm_( itertools.chain( @@ -89,7 +89,7 @@ def lora_strategy_callbacks( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), ), - max_grad_norm + max_grad_norm, ) if len(placeholder_tokens) != 0 and use_emb_decay: @@ -108,7 +108,9 @@ def lora_strategy_callbacks( if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) @torch.no_grad() def on_checkpoint(step, postfix): @@ -128,25 +130,32 @@ def lora_strategy_callbacks( if not pti_mode: lora_config = {} - state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) + state_dict = get_peft_model_state_dict( + unet_, state_dict=accelerator.get_state_dict(unet_) + ) lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) ) - text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} + text_encoder_state_dict = { + f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() + } state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) + lora_config[ + "text_encoder_peft_config" + ] = text_encoder_.get_peft_config_as_dict(inference=True) if len(placeholder_tokens) != 0: ti_state_dict = { f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) - for (token, ids) - in zip(placeholder_tokens, placeholder_token_ids) + for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) } state_dict.update(ti_state_dict) - save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") + save_file( + state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" + ) with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) @@ -185,10 +194,18 @@ def lora_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - **kwargs + **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( yield @torch.no_grad() - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): if use_emb_decay: params = [ p @@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) + ema_embeddings.step( + text_encoder.text_model.embeddings.token_embedding.parameters() + ) if use_emb_decay and w is not None: lr = lrs["emb"] if "emb" in lrs else lrs["0"] @@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) def on_log(): if ema_embeddings is not None: @@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( print(f"Saving checkpoint for step {step}...") with ema_context(): - for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + for token, ids in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, - checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" + checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", ) @torch.no_grad() @@ -183,7 +187,7 @@ def textual_inversion_prepare( val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, gradient_checkpointing: bool = False, - **kwargs + **kwargs, ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": @@ -191,8 +195,15 @@ def textual_inversion_prepare( elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) unet.to(accelerator.device, dtype=weight_dtype) unet.requires_grad_(False) -- cgit v1.2.3-70-g09d2