From 8e9d62225db11913bf7ef67221fc3508d7fe1149 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 17 Jan 2023 16:39:33 +0100 Subject: Update --- train_dreambooth.py | 14 ++++++-------- training/functional.py | 12 ++++++++---- training/lr.py | 2 +- training/strategy/dreambooth.py | 5 ++--- training/strategy/ti.py | 14 ++++++++------ 5 files changed, 25 insertions(+), 22 deletions(-) diff --git a/train_dreambooth.py b/train_dreambooth.py index 48bdcf8..9c1e41c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -1,6 +1,7 @@ import argparse import datetime import logging +import itertools from pathlib import Path from functools import partial @@ -578,14 +579,11 @@ def main(): datamodule.setup() optimizer = optimizer_class( - [ - { - 'params': unet.parameters(), - }, - { - 'params': text_encoder.parameters(), - } - ], + itertools.chain( + unet.parameters(), + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + ), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, diff --git a/training/functional.py b/training/functional.py index 7a3e821..a450ef6 100644 --- a/training/functional.py +++ b/training/functional.py @@ -1,7 +1,7 @@ from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext -from typing import Callable, Any, Tuple, Union, Optional, Type +from typing import Callable, Any, Tuple, Union, Optional, Protocol from functools import partial from pathlib import Path import itertools @@ -37,7 +37,7 @@ class TrainingCallbacks(): on_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) - on_before_optimize: Callable[[int], None] = const() + on_before_optimize: Callable[[float, int], None] = const() on_after_optimize: Callable[[float], None] = const() on_after_epoch: Callable[[float], None] = const() on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) @@ -331,13 +331,17 @@ def loss_step( return loss, acc, bsz +class LossCallable(Protocol): + def __call__(self, step: int, batch: dict[str, Any], eval: bool = False) -> Tuple[Any, Any, int]: ... + + def train_loop( accelerator: Accelerator, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], - loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], + loss_step: LossCallable, sample_frequency: int = 10, checkpoint_frequency: int = 50, global_step_offset: int = 0, @@ -406,7 +410,7 @@ def train_loop( accelerator.backward(loss) - on_before_optimize(epoch) + on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) optimizer.step() lr_scheduler.step() diff --git a/training/lr.py b/training/lr.py index 902c4eb..9690738 100644 --- a/training/lr.py +++ b/training/lr.py @@ -101,7 +101,7 @@ class LRFinder(): self.accelerator.backward(loss) - on_before_optimize(epoch) + on_before_optimize(lr_scheduler.get_last_lr()[0], epoch) self.optimizer.step() lr_scheduler.step() diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index d813b49..f57e736 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -99,8 +99,7 @@ def dreambooth_strategy_callbacks( def on_prepare(): unet.requires_grad_(True) text_encoder.requires_grad_(True) - text_encoder.text_model.embeddings.persist() - text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) + text_encoder.text_model.embeddings.requires_grad_(False) if ema_unet is not None: ema_unet.to(accelerator.device) @@ -125,7 +124,7 @@ def dreambooth_strategy_callbacks( with ema_context(): yield - def on_before_optimize(epoch: int): + def on_before_optimize(lr: float, epoch: int): if accelerator.sync_gradients: params_to_clip = [unet.parameters()] if epoch < train_text_encoder_epochs: diff --git a/training/strategy/ti.py b/training/strategy/ti.py index ba78b98..e922954 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -117,14 +117,15 @@ def textual_inversion_strategy_callbacks( with ema_context(): yield - def on_after_optimize(lr: float): + @torch.no_grad() + def on_before_optimize(lr: float, epoch: int): if use_emb_decay: - with torch.no_grad(): - text_encoder.text_model.embeddings.normalize( - emb_decay_target, - min(1.0, emb_decay * lr) - ) + text_encoder.text_model.embeddings.normalize( + emb_decay_target, + min(1.0, emb_decay * lr) + ) + def on_after_optimize(lr: float): if ema_embeddings is not None: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) @@ -154,6 +155,7 @@ def textual_inversion_strategy_callbacks( on_model=on_model, on_train=on_train, on_eval=on_eval, + on_before_optimize=on_before_optimize, on_after_optimize=on_after_optimize, on_log=on_log, on_checkpoint=on_checkpoint, -- cgit v1.2.3-70-g09d2