From a2c240c8c55dfe930657f66372975d6f26feb168 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 10:02:30 +0100 Subject: TI: Prepare UNet with Accelerate as well --- train_ti.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index 928b721..8631892 100644 --- a/train_ti.py +++ b/train_ti.py @@ -397,7 +397,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay_factor", - default=1, + default=0, type=float, help="Embedding decay factor." ) @@ -532,13 +532,17 @@ class Checkpointer(CheckpointerBase): @torch.no_grad() def save_samples(self, step): + unet = self.accelerator.unwrap_model(self.unet) text_encoder = self.accelerator.unwrap_model(self.text_encoder) ema_context = self.ema_embeddings.apply_temporary( text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() with ema_context: - orig_dtype = text_encoder.dtype + orig_unet_dtype = unet.dtype + orig_text_encoder_dtype = text_encoder.dtype + + unet.to(dtype=self.weight_dtype) text_encoder.to(dtype=self.weight_dtype) pipeline = VlpnStableDiffusion( @@ -552,7 +556,8 @@ class Checkpointer(CheckpointerBase): super().save_samples(pipeline, step) - text_encoder.to(dtype=orig_dtype) + unet.to(dtype=orig_unet_dtype) + text_encoder.to(dtype=orig_text_encoder_dtype) del text_encoder del pipeline @@ -742,20 +747,17 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) vae.to(accelerator.device, dtype=weight_dtype) - unet.to(accelerator.device, dtype=weight_dtype) if args.use_ema: ema_embeddings.to(accelerator.device) if args.gradient_checkpointing: unet.train() - else: - unet.eval() @contextmanager def on_train(epoch: int): @@ -780,10 +782,11 @@ def main(): @torch.no_grad() def on_after_optimize(lr: float): - text_encoder.text_model.embeddings.normalize( - args.emb_decay_target, - min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) - ) + if args.emb_decay_factor != 0: + text_encoder.text_model.embeddings.normalize( + args.emb_decay_target, + min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) + ) if args.use_ema: ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) -- cgit v1.2.3-54-g00ecf