diff options
| -rw-r--r-- | train_ti.py | 27 | ||||
| -rw-r--r-- | training/common.py | 37 | ||||
| -rw-r--r-- | training/util.py | 16 |
3 files changed, 41 insertions, 39 deletions
diff --git a/train_ti.py b/train_ti.py index 928b721..8631892 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -397,7 +397,7 @@ def parse_args(): | |||
| 397 | ) | 397 | ) |
| 398 | parser.add_argument( | 398 | parser.add_argument( |
| 399 | "--emb_decay_factor", | 399 | "--emb_decay_factor", |
| 400 | default=1, | 400 | default=0, |
| 401 | type=float, | 401 | type=float, |
| 402 | help="Embedding decay factor." | 402 | help="Embedding decay factor." |
| 403 | ) | 403 | ) |
| @@ -532,13 +532,17 @@ class Checkpointer(CheckpointerBase): | |||
| 532 | 532 | ||
| 533 | @torch.no_grad() | 533 | @torch.no_grad() |
| 534 | def save_samples(self, step): | 534 | def save_samples(self, step): |
| 535 | unet = self.accelerator.unwrap_model(self.unet) | ||
| 535 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 536 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 536 | 537 | ||
| 537 | ema_context = self.ema_embeddings.apply_temporary( | 538 | ema_context = self.ema_embeddings.apply_temporary( |
| 538 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() | 539 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() |
| 539 | 540 | ||
| 540 | with ema_context: | 541 | with ema_context: |
| 541 | orig_dtype = text_encoder.dtype | 542 | orig_unet_dtype = unet.dtype |
| 543 | orig_text_encoder_dtype = text_encoder.dtype | ||
| 544 | |||
| 545 | unet.to(dtype=self.weight_dtype) | ||
| 542 | text_encoder.to(dtype=self.weight_dtype) | 546 | text_encoder.to(dtype=self.weight_dtype) |
| 543 | 547 | ||
| 544 | pipeline = VlpnStableDiffusion( | 548 | pipeline = VlpnStableDiffusion( |
| @@ -552,7 +556,8 @@ class Checkpointer(CheckpointerBase): | |||
| 552 | 556 | ||
| 553 | super().save_samples(pipeline, step) | 557 | super().save_samples(pipeline, step) |
| 554 | 558 | ||
| 555 | text_encoder.to(dtype=orig_dtype) | 559 | unet.to(dtype=orig_unet_dtype) |
| 560 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
| 556 | 561 | ||
| 557 | del text_encoder | 562 | del text_encoder |
| 558 | del pipeline | 563 | del pipeline |
| @@ -742,20 +747,17 @@ def main(): | |||
| 742 | warmup_epochs=args.lr_warmup_epochs, | 747 | warmup_epochs=args.lr_warmup_epochs, |
| 743 | ) | 748 | ) |
| 744 | 749 | ||
| 745 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 750 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 746 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 751 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 747 | ) | 752 | ) |
| 748 | 753 | ||
| 749 | vae.to(accelerator.device, dtype=weight_dtype) | 754 | vae.to(accelerator.device, dtype=weight_dtype) |
| 750 | unet.to(accelerator.device, dtype=weight_dtype) | ||
| 751 | 755 | ||
| 752 | if args.use_ema: | 756 | if args.use_ema: |
| 753 | ema_embeddings.to(accelerator.device) | 757 | ema_embeddings.to(accelerator.device) |
| 754 | 758 | ||
| 755 | if args.gradient_checkpointing: | 759 | if args.gradient_checkpointing: |
| 756 | unet.train() | 760 | unet.train() |
| 757 | else: | ||
| 758 | unet.eval() | ||
| 759 | 761 | ||
| 760 | @contextmanager | 762 | @contextmanager |
| 761 | def on_train(epoch: int): | 763 | def on_train(epoch: int): |
| @@ -780,10 +782,11 @@ def main(): | |||
| 780 | 782 | ||
| 781 | @torch.no_grad() | 783 | @torch.no_grad() |
| 782 | def on_after_optimize(lr: float): | 784 | def on_after_optimize(lr: float): |
| 783 | text_encoder.text_model.embeddings.normalize( | 785 | if args.emb_decay_factor != 0: |
| 784 | args.emb_decay_target, | 786 | text_encoder.text_model.embeddings.normalize( |
| 785 | min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) | 787 | args.emb_decay_target, |
| 786 | ) | 788 | min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start)))) |
| 789 | ) | ||
| 787 | 790 | ||
| 788 | if args.use_ema: | 791 | if args.use_ema: |
| 789 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 792 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
diff --git a/training/common.py b/training/common.py index 8083137..5d1e3f9 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -316,30 +316,29 @@ def train_loop( | |||
| 316 | cur_loss_val = AverageMeter() | 316 | cur_loss_val = AverageMeter() |
| 317 | cur_acc_val = AverageMeter() | 317 | cur_acc_val = AverageMeter() |
| 318 | 318 | ||
| 319 | with torch.inference_mode(): | 319 | with torch.inference_mode(), on_eval(): |
| 320 | with on_eval(): | 320 | for step, batch in enumerate(val_dataloader): |
| 321 | for step, batch in enumerate(val_dataloader): | 321 | loss, acc, bsz = loss_step(step, batch, True) |
| 322 | loss, acc, bsz = loss_step(step, batch, True) | ||
| 323 | 322 | ||
| 324 | loss = loss.detach_() | 323 | loss = loss.detach_() |
| 325 | acc = acc.detach_() | 324 | acc = acc.detach_() |
| 326 | 325 | ||
| 327 | cur_loss_val.update(loss, bsz) | 326 | cur_loss_val.update(loss, bsz) |
| 328 | cur_acc_val.update(acc, bsz) | 327 | cur_acc_val.update(acc, bsz) |
| 329 | 328 | ||
| 330 | avg_loss_val.update(loss, bsz) | 329 | avg_loss_val.update(loss, bsz) |
| 331 | avg_acc_val.update(acc, bsz) | 330 | avg_acc_val.update(acc, bsz) |
| 332 | 331 | ||
| 333 | local_progress_bar.update(1) | 332 | local_progress_bar.update(1) |
| 334 | global_progress_bar.update(1) | 333 | global_progress_bar.update(1) |
| 335 | 334 | ||
| 336 | logs = { | 335 | logs = { |
| 337 | "val/loss": avg_loss_val.avg.item(), | 336 | "val/loss": avg_loss_val.avg.item(), |
| 338 | "val/acc": avg_acc_val.avg.item(), | 337 | "val/acc": avg_acc_val.avg.item(), |
| 339 | "val/cur_loss": loss.item(), | 338 | "val/cur_loss": loss.item(), |
| 340 | "val/cur_acc": acc.item(), | 339 | "val/cur_acc": acc.item(), |
| 341 | } | 340 | } |
| 342 | local_progress_bar.set_postfix(**logs) | 341 | local_progress_bar.set_postfix(**logs) |
| 343 | 342 | ||
| 344 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 343 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
| 345 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 344 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
diff --git a/training/util.py b/training/util.py index 1008021..781cf04 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -134,11 +134,11 @@ class EMAModel: | |||
| 134 | def __init__( | 134 | def __init__( |
| 135 | self, | 135 | self, |
| 136 | parameters: Iterable[torch.nn.Parameter], | 136 | parameters: Iterable[torch.nn.Parameter], |
| 137 | update_after_step=0, | 137 | update_after_step: int = 0, |
| 138 | inv_gamma=1.0, | 138 | inv_gamma: float = 1.0, |
| 139 | power=2 / 3, | 139 | power: float = 2 / 3, |
| 140 | min_value=0.0, | 140 | min_value: float = 0.0, |
| 141 | max_value=0.9999, | 141 | max_value: float = 0.9999, |
| 142 | ): | 142 | ): |
| 143 | """ | 143 | """ |
| 144 | @crowsonkb's notes on EMA Warmup: | 144 | @crowsonkb's notes on EMA Warmup: |
| @@ -165,7 +165,7 @@ class EMAModel: | |||
| 165 | self.decay = 0.0 | 165 | self.decay = 0.0 |
| 166 | self.optimization_step = 0 | 166 | self.optimization_step = 0 |
| 167 | 167 | ||
| 168 | def get_decay(self, optimization_step): | 168 | def get_decay(self, optimization_step: int): |
| 169 | """ | 169 | """ |
| 170 | Compute the decay factor for the exponential moving average. | 170 | Compute the decay factor for the exponential moving average. |
| 171 | """ | 171 | """ |
| @@ -276,5 +276,5 @@ class EMAModel: | |||
| 276 | self.copy_to(parameters) | 276 | self.copy_to(parameters) |
| 277 | yield | 277 | yield |
| 278 | finally: | 278 | finally: |
| 279 | for s_param, param in zip(original_params, parameters): | 279 | for o_param, param in zip(original_params, parameters): |
| 280 | param.data.copy_(s_param.data) | 280 | param.data.copy_(o_param.data) |
