diff options
author | Volpeon <git@volpeon.ink> | 2023-01-05 22:05:25 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-05 22:05:25 +0100 |
commit | 5c115a212e40ff177c734351601f9babe29419ce (patch) | |
tree | a66c8c67d2811e126b52ac4d4cd30a1c3ea2c2b9 /train_ti.py | |
parent | Fix LR finder (diff) | |
download | textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.gz textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.bz2 textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.zip |
Added EMA to TI
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 59 |
1 files changed, 58 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index 98385dd..dc36e42 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -2,6 +2,7 @@ import argparse | |||
2 | import math | 2 | import math |
3 | import datetime | 3 | import datetime |
4 | import logging | 4 | import logging |
5 | import copy | ||
5 | from pathlib import Path | 6 | from pathlib import Path |
6 | from functools import partial | 7 | from functools import partial |
7 | 8 | ||
@@ -24,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem | |||
24 | from training.common import run_model | 25 | from training.common import run_model |
25 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
26 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
27 | from training.util import AverageMeter, CheckpointerBase, save_args | 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
28 | from models.clip.embeddings import patch_managed_embeddings | 29 | from models.clip.embeddings import patch_managed_embeddings |
29 | from models.clip.prompt import PromptProcessor | 30 | from models.clip.prompt import PromptProcessor |
30 | from models.clip.tokenizer import MultiCLIPTokenizer | 31 | from models.clip.tokenizer import MultiCLIPTokenizer |
@@ -267,6 +268,27 @@ def parse_args(): | |||
267 | help="Minimum learning rate in the lr scheduler." | 268 | help="Minimum learning rate in the lr scheduler." |
268 | ) | 269 | ) |
269 | parser.add_argument( | 270 | parser.add_argument( |
271 | "--use_ema", | ||
272 | action="store_true", | ||
273 | default=True, | ||
274 | help="Whether to use EMA model." | ||
275 | ) | ||
276 | parser.add_argument( | ||
277 | "--ema_inv_gamma", | ||
278 | type=float, | ||
279 | default=1.0 | ||
280 | ) | ||
281 | parser.add_argument( | ||
282 | "--ema_power", | ||
283 | type=float, | ||
284 | default=6/7 | ||
285 | ) | ||
286 | parser.add_argument( | ||
287 | "--ema_max_decay", | ||
288 | type=float, | ||
289 | default=0.9999 | ||
290 | ) | ||
291 | parser.add_argument( | ||
270 | "--use_8bit_adam", | 292 | "--use_8bit_adam", |
271 | action="store_true", | 293 | action="store_true", |
272 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 294 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
@@ -449,6 +471,7 @@ class Checkpointer(CheckpointerBase): | |||
449 | unet, | 471 | unet, |
450 | tokenizer, | 472 | tokenizer, |
451 | text_encoder, | 473 | text_encoder, |
474 | ema_embeddings, | ||
452 | scheduler, | 475 | scheduler, |
453 | placeholder_token, | 476 | placeholder_token, |
454 | new_ids, | 477 | new_ids, |
@@ -473,6 +496,7 @@ class Checkpointer(CheckpointerBase): | |||
473 | self.unet = unet | 496 | self.unet = unet |
474 | self.tokenizer = tokenizer | 497 | self.tokenizer = tokenizer |
475 | self.text_encoder = text_encoder | 498 | self.text_encoder = text_encoder |
499 | self.ema_embeddings = ema_embeddings | ||
476 | self.scheduler = scheduler | 500 | self.scheduler = scheduler |
477 | self.placeholder_token = placeholder_token | 501 | self.placeholder_token = placeholder_token |
478 | self.new_ids = new_ids | 502 | self.new_ids = new_ids |
@@ -486,17 +510,33 @@ class Checkpointer(CheckpointerBase): | |||
486 | 510 | ||
487 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 511 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
488 | 512 | ||
513 | if self.ema_embeddings is not None: | ||
514 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | ||
515 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
516 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
517 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
518 | |||
489 | for (token, ids) in zip(self.placeholder_token, self.new_ids): | 519 | for (token, ids) in zip(self.placeholder_token, self.new_ids): |
490 | text_encoder.text_model.embeddings.save_embed( | 520 | text_encoder.text_model.embeddings.save_embed( |
491 | ids, | 521 | ids, |
492 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 522 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
493 | ) | 523 | ) |
494 | 524 | ||
525 | if self.ema_embeddings is not None: | ||
526 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
527 | |||
495 | del text_encoder | 528 | del text_encoder |
496 | 529 | ||
497 | @torch.no_grad() | 530 | @torch.no_grad() |
498 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 531 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
499 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 532 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
533 | |||
534 | if self.ema_embeddings is not None: | ||
535 | orig_weights = text_encoder.text_model.embeddings.temp_token_embedding | ||
536 | ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding) | ||
537 | self.ema_embeddings.copy_to(ema_weights.parameters()) | ||
538 | text_encoder.text_model.embeddings.temp_token_embedding = ema_weights | ||
539 | |||
500 | orig_dtype = text_encoder.dtype | 540 | orig_dtype = text_encoder.dtype |
501 | text_encoder.to(dtype=self.weight_dtype) | 541 | text_encoder.to(dtype=self.weight_dtype) |
502 | 542 | ||
@@ -513,6 +553,9 @@ class Checkpointer(CheckpointerBase): | |||
513 | 553 | ||
514 | text_encoder.to(dtype=orig_dtype) | 554 | text_encoder.to(dtype=orig_dtype) |
515 | 555 | ||
556 | if self.ema_embeddings is not None: | ||
557 | text_encoder.text_model.embeddings.temp_token_embedding = orig_weights | ||
558 | |||
516 | del text_encoder | 559 | del text_encoder |
517 | del pipeline | 560 | del pipeline |
518 | 561 | ||
@@ -567,6 +610,7 @@ def main(): | |||
567 | text_encoder.gradient_checkpointing_enable() | 610 | text_encoder.gradient_checkpointing_enable() |
568 | 611 | ||
569 | embeddings = patch_managed_embeddings(text_encoder) | 612 | embeddings = patch_managed_embeddings(text_encoder) |
613 | ema_embeddings = None | ||
570 | 614 | ||
571 | if args.embeddings_dir is not None: | 615 | if args.embeddings_dir is not None: |
572 | embeddings_dir = Path(args.embeddings_dir) | 616 | embeddings_dir = Path(args.embeddings_dir) |
@@ -592,6 +636,14 @@ def main(): | |||
592 | 636 | ||
593 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") | 637 | print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") |
594 | 638 | ||
639 | if args.use_ema: | ||
640 | ema_embeddings = EMAModel( | ||
641 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
642 | inv_gamma=args.ema_inv_gamma, | ||
643 | power=args.ema_power, | ||
644 | max_value=args.ema_max_decay, | ||
645 | ) | ||
646 | |||
595 | vae.requires_grad_(False) | 647 | vae.requires_grad_(False) |
596 | unet.requires_grad_(False) | 648 | unet.requires_grad_(False) |
597 | 649 | ||
@@ -788,6 +840,7 @@ def main(): | |||
788 | # Move vae and unet to device | 840 | # Move vae and unet to device |
789 | vae.to(accelerator.device, dtype=weight_dtype) | 841 | vae.to(accelerator.device, dtype=weight_dtype) |
790 | unet.to(accelerator.device, dtype=weight_dtype) | 842 | unet.to(accelerator.device, dtype=weight_dtype) |
843 | ema_embeddings.to(accelerator.device) | ||
791 | 844 | ||
792 | # Keep vae and unet in eval mode as we don't train these | 845 | # Keep vae and unet in eval mode as we don't train these |
793 | vae.eval() | 846 | vae.eval() |
@@ -883,6 +936,7 @@ def main(): | |||
883 | unet=unet, | 936 | unet=unet, |
884 | tokenizer=tokenizer, | 937 | tokenizer=tokenizer, |
885 | text_encoder=text_encoder, | 938 | text_encoder=text_encoder, |
939 | ema_embeddings=ema_embeddings, | ||
886 | scheduler=checkpoint_scheduler, | 940 | scheduler=checkpoint_scheduler, |
887 | placeholder_token=args.placeholder_token, | 941 | placeholder_token=args.placeholder_token, |
888 | new_ids=new_ids, | 942 | new_ids=new_ids, |
@@ -935,6 +989,9 @@ def main(): | |||
935 | 989 | ||
936 | # Checks if the accelerator has performed an optimization step behind the scenes | 990 | # Checks if the accelerator has performed an optimization step behind the scenes |
937 | if accelerator.sync_gradients: | 991 | if accelerator.sync_gradients: |
992 | if args.use_ema: | ||
993 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
994 | |||
938 | local_progress_bar.update(1) | 995 | local_progress_bar.update(1) |
939 | global_progress_bar.update(1) | 996 | global_progress_bar.update(1) |
940 | 997 | ||