From a2c240c8c55dfe930657f66372975d6f26feb168 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 10:02:30 +0100 Subject: TI: Prepare UNet with Accelerate as well --- train_ti.py | 27 +++++++++++++++------------ training/common.py | 37 ++++++++++++++++++------------------- training/util.py | 16 ++++++++-------- 3 files changed, 41 insertions(+), 39 deletions(-) diff --git a/train_ti.py b/train_ti.py index 928b721..8631892 100644 --- a/train_ti.py +++ b/train_ti.py @@ -397,7 +397,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay_factor", - default=1, + default=0, type=float, help="Embedding decay factor." ) @@ -532,13 +532,17 @@ class Checkpointer(CheckpointerBase): @torch.no_grad() def save_samples(self, step): + unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() with ema_context: - orig_dtype = text_encoder.dtype + orig_unet_dtype = unet.dtype + orig_text_encoder_dtype = text_encoder.dtype + + unet.to(dtype=self.weight_dtype) text_encoder.to(dtype=self.weight_dtype) pipeline = VlpnStableDiffusion( @@ -552,7 +556,8 @@ class Checkpointer(CheckpointerBase): super().save_samples(pipeline, step) - text_encoder.to(dtype=orig_dtype) + unet.to(dtype=orig_unet_dtype) + text_encoder.to(dtype=orig_text_encoder_dtype) del text_encoder del pipeline @@ -742,20 +747,17 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) vae.to(accelerator.device, dtype=weight_dtype) - unet.to(accelerator.device, dtype=weight_dtype) if args.use_ema: ema_embeddings.to(accelerator.device) if args.gradient_checkpointing: unet.train() - else: - unet.eval() @contextmanager def on_train(epoch: int): @@ -780,10 +782,11 @@ def main(): @torch.no_grad() def on_after_optimize(lr: float): - text_encoder.text_model.embeddings.normalize( - args.emb_decay_target, - min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) - ) + if args.emb_decay_factor != 0: + text_encoder.text_model.embeddings.normalize( + args.emb_decay_target, + min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) + ) if args.use_ema: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) diff --git a/training/common.py b/training/common.py index 8083137..5d1e3f9 100644 --- a/training/common.py +++ b/training/common.py @@ -316,30 +316,29 @@ def train_loop( cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() - with torch.inference_mode(): - with on_eval(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loss_step(step, batch, True) + with torch.inference_mode(), on_eval(): + for step, batch in enumerate(val_dataloader): + loss, acc, bsz = loss_step(step, batch, True) - loss = loss.detach_() - acc = acc.detach_() + loss = loss.detach_() + acc = acc.detach_() - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) - local_progress_bar.update(1) - global_progress_bar.update(1) + local_progress_bar.update(1) + global_progress_bar.update(1) - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } + local_progress_bar.set_postfix(**logs) logs["val/cur_loss"] = cur_loss_val.avg.item() logs["val/cur_acc"] = cur_acc_val.avg.item() diff --git a/training/util.py b/training/util.py index 1008021..781cf04 100644 --- a/training/util.py +++ b/training/util.py @@ -134,11 +134,11 @@ class EMAModel: def __init__( self, parameters: Iterable[torch.nn.Parameter], - update_after_step=0, - inv_gamma=1.0, - power=2 / 3, - min_value=0.0, - max_value=0.9999, + update_after_step: int = 0, + inv_gamma: float = 1.0, + power: float = 2 / 3, + min_value: float = 0.0, + max_value: float = 0.9999, ): """ @crowsonkb's notes on EMA Warmup: @@ -165,7 +165,7 @@ class EMAModel: self.decay = 0.0 self.optimization_step = 0 - def get_decay(self, optimization_step): + def get_decay(self, optimization_step: int): """ Compute the decay factor for the exponential moving average. """ @@ -276,5 +276,5 @@ class EMAModel: self.copy_to(parameters) yield finally: - for s_param, param in zip(original_params, parameters): - param.data.copy_(s_param.data) + for o_param, param in zip(original_params, parameters): + param.data.copy_(o_param.data) -- cgit v1.2.3-54-g00ecf