summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-05 22:05:25 +0100
committerVolpeon <git@volpeon.ink>2023-01-05 22:05:25 +0100
commit5c115a212e40ff177c734351601f9babe29419ce (patch)
treea66c8c67d2811e126b52ac4d4cd30a1c3ea2c2b9 /train_ti.py
parentFix LR finder (diff)
downloadtextual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.gz
textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.tar.bz2
textual-inversion-diff-5c115a212e40ff177c734351601f9babe29419ce.zip
Added EMA to TI
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py59
1 files changed, 58 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py
index 98385dd..dc36e42 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -2,6 +2,7 @@ import argparse
2import math 2import math
3import datetime 3import datetime
4import logging 4import logging
5import copy
5from pathlib import Path 6from pathlib import Path
6from functools import partial 7from functools import partial
7 8
@@ -24,7 +25,7 @@ from data.csv import CSVDataModule, CSVDataItem
24from training.common import run_model 25from training.common import run_model
25from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
26from training.lr import LRFinder 27from training.lr import LRFinder
27from training.util import AverageMeter, CheckpointerBase, save_args 28from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
28from models.clip.embeddings import patch_managed_embeddings 29from models.clip.embeddings import patch_managed_embeddings
29from models.clip.prompt import PromptProcessor 30from models.clip.prompt import PromptProcessor
30from models.clip.tokenizer import MultiCLIPTokenizer 31from models.clip.tokenizer import MultiCLIPTokenizer
@@ -267,6 +268,27 @@ def parse_args():
267 help="Minimum learning rate in the lr scheduler." 268 help="Minimum learning rate in the lr scheduler."
268 ) 269 )
269 parser.add_argument( 270 parser.add_argument(
271 "--use_ema",
272 action="store_true",
273 default=True,
274 help="Whether to use EMA model."
275 )
276 parser.add_argument(
277 "--ema_inv_gamma",
278 type=float,
279 default=1.0
280 )
281 parser.add_argument(
282 "--ema_power",
283 type=float,
284 default=6/7
285 )
286 parser.add_argument(
287 "--ema_max_decay",
288 type=float,
289 default=0.9999
290 )
291 parser.add_argument(
270 "--use_8bit_adam", 292 "--use_8bit_adam",
271 action="store_true", 293 action="store_true",
272 help="Whether or not to use 8-bit Adam from bitsandbytes." 294 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -449,6 +471,7 @@ class Checkpointer(CheckpointerBase):
449 unet, 471 unet,
450 tokenizer, 472 tokenizer,
451 text_encoder, 473 text_encoder,
474 ema_embeddings,
452 scheduler, 475 scheduler,
453 placeholder_token, 476 placeholder_token,
454 new_ids, 477 new_ids,
@@ -473,6 +496,7 @@ class Checkpointer(CheckpointerBase):
473 self.unet = unet 496 self.unet = unet
474 self.tokenizer = tokenizer 497 self.tokenizer = tokenizer
475 self.text_encoder = text_encoder 498 self.text_encoder = text_encoder
499 self.ema_embeddings = ema_embeddings
476 self.scheduler = scheduler 500 self.scheduler = scheduler
477 self.placeholder_token = placeholder_token 501 self.placeholder_token = placeholder_token
478 self.new_ids = new_ids 502 self.new_ids = new_ids
@@ -486,17 +510,33 @@ class Checkpointer(CheckpointerBase):
486 510
487 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 511 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
488 512
513 if self.ema_embeddings is not None:
514 orig_weights = text_encoder.text_model.embeddings.temp_token_embedding
515 ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding)
516 self.ema_embeddings.copy_to(ema_weights.parameters())
517 text_encoder.text_model.embeddings.temp_token_embedding = ema_weights
518
489 for (token, ids) in zip(self.placeholder_token, self.new_ids): 519 for (token, ids) in zip(self.placeholder_token, self.new_ids):
490 text_encoder.text_model.embeddings.save_embed( 520 text_encoder.text_model.embeddings.save_embed(
491 ids, 521 ids,
492 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 522 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
493 ) 523 )
494 524
525 if self.ema_embeddings is not None:
526 text_encoder.text_model.embeddings.temp_token_embedding = orig_weights
527
495 del text_encoder 528 del text_encoder
496 529
497 @torch.no_grad() 530 @torch.no_grad()
498 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 531 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
499 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 532 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
533
534 if self.ema_embeddings is not None:
535 orig_weights = text_encoder.text_model.embeddings.temp_token_embedding
536 ema_weights = copy.deepcopy(text_encoder.text_model.embeddings.temp_token_embedding)
537 self.ema_embeddings.copy_to(ema_weights.parameters())
538 text_encoder.text_model.embeddings.temp_token_embedding = ema_weights
539
500 orig_dtype = text_encoder.dtype 540 orig_dtype = text_encoder.dtype
501 text_encoder.to(dtype=self.weight_dtype) 541 text_encoder.to(dtype=self.weight_dtype)
502 542
@@ -513,6 +553,9 @@ class Checkpointer(CheckpointerBase):
513 553
514 text_encoder.to(dtype=orig_dtype) 554 text_encoder.to(dtype=orig_dtype)
515 555
556 if self.ema_embeddings is not None:
557 text_encoder.text_model.embeddings.temp_token_embedding = orig_weights
558
516 del text_encoder 559 del text_encoder
517 del pipeline 560 del pipeline
518 561
@@ -567,6 +610,7 @@ def main():
567 text_encoder.gradient_checkpointing_enable() 610 text_encoder.gradient_checkpointing_enable()
568 611
569 embeddings = patch_managed_embeddings(text_encoder) 612 embeddings = patch_managed_embeddings(text_encoder)
613 ema_embeddings = None
570 614
571 if args.embeddings_dir is not None: 615 if args.embeddings_dir is not None:
572 embeddings_dir = Path(args.embeddings_dir) 616 embeddings_dir = Path(args.embeddings_dir)
@@ -592,6 +636,14 @@ def main():
592 636
593 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}") 637 print(f"Added {len(new_ids)} new tokens: {list(zip(args.placeholder_token, new_ids, init_ratios))}")
594 638
639 if args.use_ema:
640 ema_embeddings = EMAModel(
641 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
642 inv_gamma=args.ema_inv_gamma,
643 power=args.ema_power,
644 max_value=args.ema_max_decay,
645 )
646
595 vae.requires_grad_(False) 647 vae.requires_grad_(False)
596 unet.requires_grad_(False) 648 unet.requires_grad_(False)
597 649
@@ -788,6 +840,7 @@ def main():
788 # Move vae and unet to device 840 # Move vae and unet to device
789 vae.to(accelerator.device, dtype=weight_dtype) 841 vae.to(accelerator.device, dtype=weight_dtype)
790 unet.to(accelerator.device, dtype=weight_dtype) 842 unet.to(accelerator.device, dtype=weight_dtype)
843 ema_embeddings.to(accelerator.device)
791 844
792 # Keep vae and unet in eval mode as we don't train these 845 # Keep vae and unet in eval mode as we don't train these
793 vae.eval() 846 vae.eval()
@@ -883,6 +936,7 @@ def main():
883 unet=unet, 936 unet=unet,
884 tokenizer=tokenizer, 937 tokenizer=tokenizer,
885 text_encoder=text_encoder, 938 text_encoder=text_encoder,
939 ema_embeddings=ema_embeddings,
886 scheduler=checkpoint_scheduler, 940 scheduler=checkpoint_scheduler,
887 placeholder_token=args.placeholder_token, 941 placeholder_token=args.placeholder_token,
888 new_ids=new_ids, 942 new_ids=new_ids,
@@ -935,6 +989,9 @@ def main():
935 989
936 # Checks if the accelerator has performed an optimization step behind the scenes 990 # Checks if the accelerator has performed an optimization step behind the scenes
937 if accelerator.sync_gradients: 991 if accelerator.sync_gradients:
992 if args.use_ema:
993 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
994
938 local_progress_bar.update(1) 995 local_progress_bar.update(1)
939 global_progress_bar.update(1) 996 global_progress_bar.update(1)
940 997