diff options
author | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:18 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-06 09:07:18 +0100 |
commit | f4f90c487cbc247952689e906519d8e2eb21da99 (patch) | |
tree | fc308cdcf02c36437e8017fab5961294f86930fe | |
parent | Log EMA decay (diff) | |
download | textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.gz textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.tar.bz2 textual-inversion-diff-f4f90c487cbc247952689e906519d8e2eb21da99.zip |
Add contextmanager to EMAModel to apply weights temporarily
-rw-r--r-- | train_ti.py | 57 | ||||
-rw-r--r-- | training/util.py | 12 |
2 files changed, 36 insertions, 33 deletions
diff --git a/train_ti.py b/train_ti.py index 2f13128..aa2bf02 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -5,6 +5,7 @@ import logging | |||
5 | import copy | 5 | import copy |
6 | from pathlib import Path | 6 | from pathlib import Path |
7 | from functools import partial | 7 | from functools import partial |
8 | from contextlib import nullcontext | ||
8 | 9 | ||
9 | import torch | 10 | import torch |
10 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
@@ -509,20 +510,15 @@ class Checkpointer(CheckpointerBase): | |||
509 | 510 | ||
510 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 511 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
511 | 512 | ||
512 | if self.ema_embeddings is not None: | 513 | ema_context = self.ema_embeddings.apply_temporary( |
513 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | 514 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() |
514 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
515 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
516 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
517 | 515 | ||
518 | for (token, ids) in zip(self.placeholder_token, self.new_ids): | 516 | with ema_context: |
519 | text_encoder.text_model.embeddings.save_embed( | 517 | for (token, ids) in zip(self.placeholder_token, self.new_ids): |
520 | ids, | 518 | text_encoder.text_model.embeddings.save_embed( |
521 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 519 | ids, |
522 | ) | 520 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
523 | 521 | ) | |
524 | if self.ema_embeddings is not None: | ||
525 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
526 | 522 | ||
527 | del text_encoder | 523 | del text_encoder |
528 | 524 | ||
@@ -530,30 +526,25 @@ class Checkpointer(CheckpointerBase): | |||
530 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 526 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
531 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 527 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
532 | 528 | ||
533 | if self.ema_embeddings is not None: | 529 | ema_context = self.ema_embeddings.apply_temporary( |
534 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | 530 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() |
535 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
536 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
537 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
538 | |||
539 | orig_dtype = text_encoder.dtype | ||
540 | text_encoder.to(dtype=self.weight_dtype) | ||
541 | 531 | ||
542 | pipeline = VlpnStableDiffusion( | 532 | with ema_context: |
543 | text_encoder=text_encoder, | 533 | orig_dtype = text_encoder.dtype |
544 | vae=self.vae, | 534 | text_encoder.to(dtype=self.weight_dtype) |
545 | unet=self.unet, | ||
546 | tokenizer=self.tokenizer, | ||
547 | scheduler=self.scheduler, | ||
548 | ).to(self.accelerator.device) | ||
549 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
550 | 535 | ||
551 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 536 | pipeline = VlpnStableDiffusion( |
537 | text_encoder=text_encoder, | ||
538 | vae=self.vae, | ||
539 | unet=self.unet, | ||
540 | tokenizer=self.tokenizer, | ||
541 | scheduler=self.scheduler, | ||
542 | ).to(self.accelerator.device) | ||
543 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
552 | 544 | ||
553 | text_encoder.to(dtype=orig_dtype) | 545 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
554 | 546 | ||
555 | if self.ema_embeddings is not None: | 547 | text_encoder.to(dtype=orig_dtype) |
556 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
557 | 548 | ||
558 | del text_encoder | 549 | del text_encoder |
559 | del pipeline | 550 | del pipeline |
diff --git a/training/util.py b/training/util.py index 93b6248..6f1e85a 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -2,6 +2,7 @@ from pathlib import Path | |||
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | from typing import Iterable | 4 | from typing import Iterable |
5 | from contextlib import contextmanager | ||
5 | 6 | ||
6 | import torch | 7 | import torch |
7 | from PIL import Image | 8 | from PIL import Image |
@@ -259,3 +260,14 @@ class EMAModel: | |||
259 | raise ValueError("collected_params must all be Tensors") | 260 | raise ValueError("collected_params must all be Tensors") |
260 | if len(self.collected_params) != len(self.shadow_params): | 261 | if len(self.collected_params) != len(self.shadow_params): |
261 | raise ValueError("collected_params and shadow_params must have the same length") | 262 | raise ValueError("collected_params and shadow_params must have the same length") |
263 | |||
264 | @contextmanager | ||
265 | def apply_temporary(self, parameters): | ||
266 | try: | ||
267 | parameters = list(parameters) | ||
268 | original_params = [p.clone() for p in parameters] | ||
269 | self.copy_to(parameters) | ||
270 | yield | ||
271 | finally: | ||
272 | for s_param, param in zip(original_params, parameters): | ||
273 | param.data.copy_(s_param.data) | ||