diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:18 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:18 +0100 |
| commit | f4f90c487cbc247952689e906519d8e2eb21da99 (patch) | |
| tree | fc308cdcf02c36437e8017fab5961294f86930fe | |
| parent | Log EMA decay (diff) | |
| download | textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.gz textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.bz2 textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.zip | |
Add contextmanager to EMAModel to apply weights temporarily
| -rw-r--r-- | train_ti.py | 57 | ||||
| -rw-r--r-- | training/util.py | 12 |
2 files changed, 36 insertions, 33 deletions
diff --git a/train_ti.py b/train_ti.py index 2f13128..aa2bf02 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -5,6 +5,7 @@ import logging | |||
| 5 | import copy | 5 | import copy |
| 6 | from pathlib import Path | 6 | from pathlib import Path |
| 7 | from functools import partial | 7 | from functools import partial |
| 8 | from contextlib import nullcontext | ||
| 8 | 9 | ||
| 9 | import torch | 10 | import torch |
| 10 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
| @@ -509,20 +510,15 @@ class Checkpointer(CheckpointerBase): | |||
| 509 | 510 | ||
| 510 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 511 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 511 | 512 | ||
| 512 | if self.ema_embeddings is not None: | 513 | ema_context = self.ema_embeddings.apply_temporary( |
| 513 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | 514 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() |
| 514 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
| 515 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
| 516 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
| 517 | 515 | ||
| 518 | for (token, ids) in zip(self.placeholder_token, self.new_ids): | 516 | with ema_context: |
| 519 | text_encoder.text_model.embeddings.save_embed( | 517 | for (token, ids) in zip(self.placeholder_token, self.new_ids): |
| 520 | ids, | 518 | text_encoder.text_model.embeddings.save_embed( |
| 521 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 519 | ids, |
| 522 | ) | 520 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
| 523 | 521 | ) | |
| 524 | if self.ema_embeddings is not None: | ||
| 525 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
| 526 | 522 | ||
| 527 | del text_encoder | 523 | del text_encoder |
| 528 | 524 | ||
| @@ -530,30 +526,25 @@ class Checkpointer(CheckpointerBase): | |||
| 530 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 526 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 531 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 527 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 532 | 528 | ||
| 533 | if self.ema_embeddings is not None: | 529 | ema_context = self.ema_embeddings.apply_temporary( |
| 534 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | 530 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() |
| 535 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
| 536 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
| 537 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
| 538 | |||
| 539 | orig_dtype = text_encoder.dtype | ||
| 540 | text_encoder.to(dtype=self.weight_dtype) | ||
| 541 | 531 | ||
| 542 | pipeline = VlpnStableDiffusion( | 532 | with ema_context: |
| 543 | text_encoder=text_encoder, | 533 | orig_dtype = text_encoder.dtype |
| 544 | vae=self.vae, | 534 | text_encoder.to(dtype=self.weight_dtype) |
| 545 | unet=self.unet, | ||
| 546 | tokenizer=self.tokenizer, | ||
| 547 | scheduler=self.scheduler, | ||
| 548 | ).to(self.accelerator.device) | ||
| 549 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 550 | 535 | ||
| 551 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 536 | pipeline = VlpnStableDiffusion( |
| 537 | text_encoder=text_encoder, | ||
| 538 | vae=self.vae, | ||
| 539 | unet=self.unet, | ||
| 540 | tokenizer=self.tokenizer, | ||
| 541 | scheduler=self.scheduler, | ||
| 542 | ).to(self.accelerator.device) | ||
| 543 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 552 | 544 | ||
| 553 | text_encoder.to(dtype=orig_dtype) | 545 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
| 554 | 546 | ||
| 555 | if self.ema_embeddings is not None: | 547 | text_encoder.to(dtype=orig_dtype) |
| 556 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
| 557 | 548 | ||
| 558 | del text_encoder | 549 | del text_encoder |
| 559 | del pipeline | 550 | del pipeline |
diff --git a/training/util.py b/training/util.py index 93b6248..6f1e85a 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -2,6 +2,7 @@ from pathlib import Path | |||
| 2 | import json | 2 | import json |
| 3 | import copy | 3 | import copy |
| 4 | from typing import Iterable | 4 | from typing import Iterable |
| 5 | from contextlib import contextmanager | ||
| 5 | 6 | ||
| 6 | import torch | 7 | import torch |
| 7 | from PIL import Image | 8 | from PIL import Image |
| @@ -259,3 +260,14 @@ class EMAModel: | |||
| 259 | raise ValueError("collected_params must all be Tensors") | 260 | raise ValueError("collected_params must all be Tensors") |
| 260 | if len(self.collected_params) != len(self.shadow_params): | 261 | if len(self.collected_params) != len(self.shadow_params): |
| 261 | raise ValueError("collected_params and shadow_params must have the same length") | 262 | raise ValueError("collected_params and shadow_params must have the same length") |
| 263 | |||
| 264 | @contextmanager | ||
| 265 | def apply_temporary(self, parameters): | ||
| 266 | try: | ||
| 267 | parameters = list(parameters) | ||
| 268 | original_params = [p.clone() for p in parameters] | ||
| 269 | self.copy_to(parameters) | ||
| 270 | yield | ||
| 271 | finally: | ||
| 272 | for s_param, param in zip(original_params, parameters): | ||
| 273 | param.data.copy_(s_param.data) | ||
