summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py11
-rw-r--r--train_ti.py74
-rw-r--r--training/functional.py100
-rw-r--r--training/lr.py29
-rw-r--r--training/strategy/ti.py54
5 files changed, 106 insertions, 162 deletions
diff --git a/data/csv.py b/data/csv.py
index b058a3e..5de3ac7 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -100,28 +100,25 @@ def generate_buckets(
100 return buckets, bucket_items, bucket_assignments 100 return buckets, bucket_items, bucket_assignments
101 101
102 102
103def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): 103def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples):
104 with_prior = all("class_prompt_ids" in example for example in examples)
105
106 prompt_ids = [example["prompt_ids"] for example in examples] 104 prompt_ids = [example["prompt_ids"] for example in examples]
107 nprompt_ids = [example["nprompt_ids"] for example in examples] 105 nprompt_ids = [example["nprompt_ids"] for example in examples]
108 106
109 input_ids = [example["instance_prompt_ids"] for example in examples] 107 input_ids = [example["instance_prompt_ids"] for example in examples]
110 pixel_values = [example["instance_images"] for example in examples] 108 pixel_values = [example["instance_images"] for example in examples]
111 109
112 if with_prior: 110 if with_prior_preservation:
113 input_ids += [example["class_prompt_ids"] for example in examples] 111 input_ids += [example["class_prompt_ids"] for example in examples]
114 pixel_values += [example["class_images"] for example in examples] 112 pixel_values += [example["class_images"] for example in examples]
115 113
116 pixel_values = torch.stack(pixel_values) 114 pixel_values = torch.stack(pixel_values)
117 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 115 pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format)
118 116
119 prompts = unify_input_ids(tokenizer, prompt_ids) 117 prompts = unify_input_ids(tokenizer, prompt_ids)
120 nprompts = unify_input_ids(tokenizer, nprompt_ids) 118 nprompts = unify_input_ids(tokenizer, nprompt_ids)
121 inputs = unify_input_ids(tokenizer, input_ids) 119 inputs = unify_input_ids(tokenizer, input_ids)
122 120
123 batch = { 121 batch = {
124 "with_prior": torch.tensor([with_prior] * len(examples)),
125 "prompt_ids": prompts.input_ids, 122 "prompt_ids": prompts.input_ids,
126 "nprompt_ids": nprompts.input_ids, 123 "nprompt_ids": nprompts.input_ids,
127 "input_ids": inputs.input_ids, 124 "input_ids": inputs.input_ids,
@@ -285,7 +282,7 @@ class VlpnDataModule():
285 size=self.size, interpolation=self.interpolation, 282 size=self.size, interpolation=self.interpolation,
286 ) 283 )
287 284
288 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) 285 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
289 286
290 self.train_dataloader = DataLoader( 287 self.train_dataloader = DataLoader(
291 train_dataset, 288 train_dataset,
diff --git a/train_ti.py b/train_ti.py
index 3c9810f..4bac736 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -15,11 +15,11 @@ from slugify import slugify
15 15
16from util import load_config, load_embeddings_from_dir 16from util import load_config, load_embeddings_from_dir
17from data.csv import VlpnDataModule, VlpnDataItem 17from data.csv import VlpnDataModule, VlpnDataItem
18from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 18from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models
19from training.strategy.ti import textual_inversion_strategy 19from training.strategy.ti import textual_inversion_strategy
20from training.optimization import get_scheduler 20from training.optimization import get_scheduler
21from training.lr import LRFinder 21from training.lr import LRFinder
22from training.util import EMAModel, save_args 22from training.util import save_args
23 23
24logger = get_logger(__name__) 24logger = get_logger(__name__)
25 25
@@ -82,7 +82,7 @@ def parse_args():
82 parser.add_argument( 82 parser.add_argument(
83 "--num_class_images", 83 "--num_class_images",
84 type=int, 84 type=int,
85 default=1, 85 default=0,
86 help="How many class images to generate." 86 help="How many class images to generate."
87 ) 87 )
88 parser.add_argument( 88 parser.add_argument(
@@ -398,7 +398,7 @@ def parse_args():
398 ) 398 )
399 parser.add_argument( 399 parser.add_argument(
400 "--emb_decay_factor", 400 "--emb_decay_factor",
401 default=0, 401 default=1,
402 type=float, 402 type=float,
403 help="Embedding decay factor." 403 help="Embedding decay factor."
404 ) 404 )
@@ -540,16 +540,6 @@ def main():
540 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) 540 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
541 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") 541 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
542 542
543 if args.use_ema:
544 ema_embeddings = EMAModel(
545 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
546 inv_gamma=args.ema_inv_gamma,
547 power=args.ema_power,
548 max_value=args.ema_max_decay,
549 )
550 else:
551 ema_embeddings = None
552
553 if args.scale_lr: 543 if args.scale_lr:
554 args.learning_rate = ( 544 args.learning_rate = (
555 args.learning_rate * args.gradient_accumulation_steps * 545 args.learning_rate * args.gradient_accumulation_steps *
@@ -654,23 +644,13 @@ def main():
654 warmup_epochs=args.lr_warmup_epochs, 644 warmup_epochs=args.lr_warmup_epochs,
655 ) 645 )
656 646
657 if args.use_ema: 647 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
658 ema_embeddings.to(accelerator.device) 648 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
659
660 trainer = partial(
661 train,
662 accelerator=accelerator,
663 vae=vae,
664 unet=unet,
665 text_encoder=text_encoder,
666 noise_scheduler=noise_scheduler,
667 train_dataloader=train_dataloader,
668 val_dataloader=val_dataloader,
669 dtype=weight_dtype,
670 seed=args.seed,
671 ) 649 )
672 650
673 strategy = textual_inversion_strategy( 651 vae.to(accelerator.device, dtype=weight_dtype)
652
653 callbacks = textual_inversion_strategy(
674 accelerator=accelerator, 654 accelerator=accelerator,
675 unet=unet, 655 unet=unet,
676 text_encoder=text_encoder, 656 text_encoder=text_encoder,
@@ -679,7 +659,6 @@ def main():
679 sample_scheduler=sample_scheduler, 659 sample_scheduler=sample_scheduler,
680 train_dataloader=train_dataloader, 660 train_dataloader=train_dataloader,
681 val_dataloader=val_dataloader, 661 val_dataloader=val_dataloader,
682 dtype=weight_dtype,
683 output_dir=output_dir, 662 output_dir=output_dir,
684 seed=args.seed, 663 seed=args.seed,
685 placeholder_tokens=args.placeholder_tokens, 664 placeholder_tokens=args.placeholder_tokens,
@@ -700,31 +679,54 @@ def main():
700 sample_image_size=args.sample_image_size, 679 sample_image_size=args.sample_image_size,
701 ) 680 )
702 681
682 for model in (unet, text_encoder, vae):
683 model.requires_grad_(False)
684 model.eval()
685
686 callbacks.on_prepare()
687
688 loss_step_ = partial(
689 loss_step,
690 vae,
691 noise_scheduler,
692 unet,
693 text_encoder,
694 args.num_class_images != 0,
695 args.prior_loss_weight,
696 args.seed,
697 )
698
703 if args.find_lr: 699 if args.find_lr:
704 lr_finder = LRFinder( 700 lr_finder = LRFinder(
705 accelerator=accelerator, 701 accelerator=accelerator,
706 optimizer=optimizer, 702 optimizer=optimizer,
707 model=text_encoder,
708 train_dataloader=train_dataloader, 703 train_dataloader=train_dataloader,
709 val_dataloader=val_dataloader, 704 val_dataloader=val_dataloader,
710 **strategy, 705 callbacks=callbacks,
711 ) 706 )
712 lr_finder.run(num_epochs=100, end_lr=1e3) 707 lr_finder.run(num_epochs=100, end_lr=1e3)
713 708
714 plt.savefig(output_dir.joinpath("lr.png"), dpi=300) 709 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
715 plt.close() 710 plt.close()
716 else: 711 else:
717 trainer( 712 if accelerator.is_main_process:
713 accelerator.init_trackers("textual_inversion")
714
715 train_loop(
716 accelerator=accelerator,
718 optimizer=optimizer, 717 optimizer=optimizer,
719 lr_scheduler=lr_scheduler, 718 lr_scheduler=lr_scheduler,
720 num_train_epochs=args.num_train_epochs, 719 train_dataloader=train_dataloader,
720 val_dataloader=val_dataloader,
721 loss_step=loss_step_,
721 sample_frequency=args.sample_frequency, 722 sample_frequency=args.sample_frequency,
722 checkpoint_frequency=args.checkpoint_frequency, 723 checkpoint_frequency=args.checkpoint_frequency,
723 global_step_offset=global_step_offset, 724 global_step_offset=global_step_offset,
724 prior_loss_weight=args.prior_loss_weight, 725 callbacks=callbacks,
725 callbacks=strategy,
726 ) 726 )
727 727
728 accelerator.end_training()
729
728 730
729if __name__ == "__main__": 731if __name__ == "__main__":
730 main() 732 main()
diff --git a/training/functional.py b/training/functional.py
index 4ca7470..c01595a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -33,6 +33,7 @@ def const(result=None):
33@dataclass 33@dataclass
34class TrainingCallbacks(): 34class TrainingCallbacks():
35 on_prepare: Callable[[float], None] = const() 35 on_prepare: Callable[[float], None] = const()
36 on_model: Callable[[], torch.nn.Module] = const(None)
36 on_log: Callable[[], dict[str, Any]] = const({}) 37 on_log: Callable[[], dict[str, Any]] = const({})
37 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
38 on_before_optimize: Callable[[int], None] = const() 39 on_before_optimize: Callable[[int], None] = const()
@@ -267,6 +268,7 @@ def loss_step(
267 noise_scheduler: DDPMScheduler, 268 noise_scheduler: DDPMScheduler,
268 unet: UNet2DConditionModel, 269 unet: UNet2DConditionModel,
269 text_encoder: CLIPTextModel, 270 text_encoder: CLIPTextModel,
271 with_prior_preservation: bool,
270 prior_loss_weight: float, 272 prior_loss_weight: float,
271 seed: int, 273 seed: int,
272 step: int, 274 step: int,
@@ -322,7 +324,7 @@ def loss_step(
322 else: 324 else:
323 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 325 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
324 326
325 if batch["with_prior"].all(): 327 if with_prior_preservation:
326 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 328 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
327 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 329 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
328 target, target_prior = torch.chunk(target, 2, dim=0) 330 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -347,7 +349,6 @@ def train_loop(
347 accelerator: Accelerator, 349 accelerator: Accelerator,
348 optimizer: torch.optim.Optimizer, 350 optimizer: torch.optim.Optimizer,
349 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 351 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
350 model: torch.nn.Module,
351 train_dataloader: DataLoader, 352 train_dataloader: DataLoader,
352 val_dataloader: DataLoader, 353 val_dataloader: DataLoader,
353 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 354 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
@@ -387,28 +388,37 @@ def train_loop(
387 ) 388 )
388 global_progress_bar.set_description("Total progress") 389 global_progress_bar.set_description("Total progress")
389 390
391 model = callbacks.on_model()
392 on_log = callbacks.on_log
393 on_train = callbacks.on_train
394 on_before_optimize = callbacks.on_before_optimize
395 on_after_optimize = callbacks.on_after_optimize
396 on_eval = callbacks.on_eval
397 on_sample = callbacks.on_sample
398 on_checkpoint = callbacks.on_checkpoint
399
390 try: 400 try:
391 for epoch in range(num_epochs): 401 for epoch in range(num_epochs):
392 if accelerator.is_main_process: 402 if accelerator.is_main_process:
393 if epoch % sample_frequency == 0: 403 if epoch % sample_frequency == 0:
394 callbacks.on_sample(global_step + global_step_offset) 404 on_sample(global_step + global_step_offset)
395 405
396 if epoch % checkpoint_frequency == 0 and epoch != 0: 406 if epoch % checkpoint_frequency == 0 and epoch != 0:
397 callbacks.on_checkpoint(global_step + global_step_offset, "training") 407 on_checkpoint(global_step + global_step_offset, "training")
398 408
399 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 409 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
400 local_progress_bar.reset() 410 local_progress_bar.reset()
401 411
402 model.train() 412 model.train()
403 413
404 with callbacks.on_train(epoch): 414 with on_train(epoch):
405 for step, batch in enumerate(train_dataloader): 415 for step, batch in enumerate(train_dataloader):
406 with accelerator.accumulate(model): 416 with accelerator.accumulate(model):
407 loss, acc, bsz = loss_step(step, batch) 417 loss, acc, bsz = loss_step(step, batch)
408 418
409 accelerator.backward(loss) 419 accelerator.backward(loss)
410 420
411 callbacks.on_before_optimize(epoch) 421 on_before_optimize(epoch)
412 422
413 optimizer.step() 423 optimizer.step()
414 lr_scheduler.step() 424 lr_scheduler.step()
@@ -419,7 +429,7 @@ def train_loop(
419 429
420 # Checks if the accelerator has performed an optimization step behind the scenes 430 # Checks if the accelerator has performed an optimization step behind the scenes
421 if accelerator.sync_gradients: 431 if accelerator.sync_gradients:
422 callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0]) 432 on_after_optimize(lr_scheduler.get_last_lr()[0])
423 433
424 local_progress_bar.update(1) 434 local_progress_bar.update(1)
425 global_progress_bar.update(1) 435 global_progress_bar.update(1)
@@ -433,7 +443,7 @@ def train_loop(
433 "train/cur_acc": acc.item(), 443 "train/cur_acc": acc.item(),
434 "lr": lr_scheduler.get_last_lr()[0], 444 "lr": lr_scheduler.get_last_lr()[0],
435 } 445 }
436 logs.update(callbacks.on_log()) 446 logs.update(on_log())
437 447
438 accelerator.log(logs, step=global_step) 448 accelerator.log(logs, step=global_step)
439 449
@@ -449,7 +459,7 @@ def train_loop(
449 cur_loss_val = AverageMeter() 459 cur_loss_val = AverageMeter()
450 cur_acc_val = AverageMeter() 460 cur_acc_val = AverageMeter()
451 461
452 with torch.inference_mode(), callbacks.on_eval(): 462 with torch.inference_mode(), on_eval():
453 for step, batch in enumerate(val_dataloader): 463 for step, batch in enumerate(val_dataloader):
454 loss, acc, bsz = loss_step(step, batch, True) 464 loss, acc, bsz = loss_step(step, batch, True)
455 465
@@ -485,80 +495,16 @@ def train_loop(
485 if avg_acc_val.avg.item() > max_acc_val: 495 if avg_acc_val.avg.item() > max_acc_val:
486 accelerator.print( 496 accelerator.print(
487 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 497 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
488 callbacks.on_checkpoint(global_step + global_step_offset, "milestone") 498 on_checkpoint(global_step + global_step_offset, "milestone")
489 max_acc_val = avg_acc_val.avg.item() 499 max_acc_val = avg_acc_val.avg.item()
490 500
491 # Create the pipeline using using the trained modules and save it. 501 # Create the pipeline using using the trained modules and save it.
492 if accelerator.is_main_process: 502 if accelerator.is_main_process:
493 print("Finished!") 503 print("Finished!")
494 callbacks.on_checkpoint(global_step + global_step_offset, "end") 504 on_checkpoint(global_step + global_step_offset, "end")
495 callbacks.on_sample(global_step + global_step_offset) 505 on_sample(global_step + global_step_offset)
496 accelerator.end_training()
497 506
498 except KeyboardInterrupt: 507 except KeyboardInterrupt:
499 if accelerator.is_main_process: 508 if accelerator.is_main_process:
500 print("Interrupted") 509 print("Interrupted")
501 callbacks.on_checkpoint(global_step + global_step_offset, "end") 510 on_checkpoint(global_step + global_step_offset, "end")
502 accelerator.end_training()
503
504
505def train(
506 accelerator: Accelerator,
507 unet: UNet2DConditionModel,
508 text_encoder: CLIPTextModel,
509 vae: AutoencoderKL,
510 noise_scheduler: DDPMScheduler,
511 train_dataloader: DataLoader,
512 val_dataloader: DataLoader,
513 dtype: torch.dtype,
514 seed: int,
515 optimizer: torch.optim.Optimizer,
516 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
517 num_train_epochs: int = 100,
518 sample_frequency: int = 20,
519 checkpoint_frequency: int = 50,
520 global_step_offset: int = 0,
521 prior_loss_weight: float = 0,
522 callbacks: TrainingCallbacks = TrainingCallbacks(),
523):
524 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
525 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
526 )
527
528 vae.to(accelerator.device, dtype=dtype)
529
530 for model in (unet, text_encoder, vae):
531 model.requires_grad_(False)
532 model.eval()
533
534 callbacks.on_prepare()
535
536 loss_step_ = partial(
537 loss_step,
538 vae,
539 noise_scheduler,
540 unet,
541 text_encoder,
542 prior_loss_weight,
543 seed,
544 )
545
546 if accelerator.is_main_process:
547 accelerator.init_trackers("textual_inversion")
548
549 train_loop(
550 accelerator=accelerator,
551 optimizer=optimizer,
552 lr_scheduler=lr_scheduler,
553 model=text_encoder,
554 train_dataloader=train_dataloader,
555 val_dataloader=val_dataloader,
556 loss_step=loss_step_,
557 sample_frequency=sample_frequency,
558 checkpoint_frequency=checkpoint_frequency,
559 global_step_offset=global_step_offset,
560 num_epochs=num_train_epochs,
561 callbacks=callbacks,
562 )
563
564 accelerator.free_memory()
diff --git a/training/lr.py b/training/lr.py
index 7584ba2..902c4eb 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -9,6 +9,7 @@ import torch
9from torch.optim.lr_scheduler import LambdaLR 9from torch.optim.lr_scheduler import LambdaLR
10from tqdm.auto import tqdm 10from tqdm.auto import tqdm
11 11
12from training.functional import TrainingCallbacks
12from training.util import AverageMeter 13from training.util import AverageMeter
13 14
14 15
@@ -24,26 +25,19 @@ class LRFinder():
24 def __init__( 25 def __init__(
25 self, 26 self,
26 accelerator, 27 accelerator,
27 model,
28 optimizer, 28 optimizer,
29 train_dataloader, 29 train_dataloader,
30 val_dataloader, 30 val_dataloader,
31 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 31 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
32 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, 32 callbacks: TrainingCallbacks = TrainingCallbacks()
33 on_before_optimize: Callable[[int], None] = noop,
34 on_after_optimize: Callable[[float], None] = noop,
35 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx
36 ): 33 ):
37 self.accelerator = accelerator 34 self.accelerator = accelerator
38 self.model = model 35 self.model = callbacks.on_model()
39 self.optimizer = optimizer 36 self.optimizer = optimizer
40 self.train_dataloader = train_dataloader 37 self.train_dataloader = train_dataloader
41 self.val_dataloader = val_dataloader 38 self.val_dataloader = val_dataloader
42 self.loss_fn = loss_fn 39 self.loss_fn = loss_fn
43 self.on_train = on_train 40 self.callbacks = callbacks
44 self.on_before_optimize = on_before_optimize
45 self.on_after_optimize = on_after_optimize
46 self.on_eval = on_eval
47 41
48 # self.model_state = copy.deepcopy(model.state_dict()) 42 # self.model_state = copy.deepcopy(model.state_dict())
49 # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) 43 # self.optimizer_state = copy.deepcopy(optimizer.state_dict())
@@ -82,6 +76,13 @@ class LRFinder():
82 ) 76 )
83 progress_bar.set_description("Epoch X / Y") 77 progress_bar.set_description("Epoch X / Y")
84 78
79 self.callbacks.on_prepare()
80
81 on_train = self.callbacks.on_train
82 on_before_optimize = self.callbacks.on_before_optimize
83 on_after_optimize = self.callbacks.on_after_optimize
84 on_eval = self.callbacks.on_eval
85
85 for epoch in range(num_epochs): 86 for epoch in range(num_epochs):
86 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 87 progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
87 88
@@ -90,7 +91,7 @@ class LRFinder():
90 91
91 self.model.train() 92 self.model.train()
92 93
93 with self.on_train(epoch): 94 with on_train(epoch):
94 for step, batch in enumerate(self.train_dataloader): 95 for step, batch in enumerate(self.train_dataloader):
95 if step >= num_train_batches: 96 if step >= num_train_batches:
96 break 97 break
@@ -100,21 +101,21 @@ class LRFinder():
100 101
101 self.accelerator.backward(loss) 102 self.accelerator.backward(loss)
102 103
103 self.on_before_optimize(epoch) 104 on_before_optimize(epoch)
104 105
105 self.optimizer.step() 106 self.optimizer.step()
106 lr_scheduler.step() 107 lr_scheduler.step()
107 self.optimizer.zero_grad(set_to_none=True) 108 self.optimizer.zero_grad(set_to_none=True)
108 109
109 if self.accelerator.sync_gradients: 110 if self.accelerator.sync_gradients:
110 self.on_after_optimize(lr_scheduler.get_last_lr()[0]) 111 on_after_optimize(lr_scheduler.get_last_lr()[0])
111 112
112 progress_bar.update(1) 113 progress_bar.update(1)
113 114
114 self.model.eval() 115 self.model.eval()
115 116
116 with torch.inference_mode(): 117 with torch.inference_mode():
117 with self.on_eval(): 118 with on_eval():
118 for step, batch in enumerate(self.val_dataloader): 119 for step, batch in enumerate(self.val_dataloader):
119 if step >= num_val_batches: 120 if step >= num_val_batches:
120 break 121 break
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6f8384f..753dce0 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -27,7 +27,6 @@ def textual_inversion_strategy(
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: DataLoader, 29 val_dataloader: DataLoader,
30 dtype: torch.dtype,
31 output_dir: Path, 30 output_dir: Path,
32 seed: int, 31 seed: int,
33 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
@@ -48,6 +47,12 @@ def textual_inversion_strategy(
48 sample_guidance_scale: float = 7.5, 47 sample_guidance_scale: float = 7.5,
49 sample_image_size: Optional[int] = None, 48 sample_image_size: Optional[int] = None,
50): 49):
50 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16
53 elif accelerator.state.mixed_precision == "bf16":
54 weight_dtype = torch.bfloat16
55
51 save_samples_ = partial( 56 save_samples_ = partial(
52 save_samples, 57 save_samples,
53 accelerator=accelerator, 58 accelerator=accelerator,
@@ -58,7 +63,7 @@ def textual_inversion_strategy(
58 sample_scheduler=sample_scheduler, 63 sample_scheduler=sample_scheduler,
59 train_dataloader=train_dataloader, 64 train_dataloader=train_dataloader,
60 val_dataloader=val_dataloader, 65 val_dataloader=val_dataloader,
61 dtype=dtype, 66 dtype=weight_dtype,
62 output_dir=output_dir, 67 output_dir=output_dir,
63 seed=seed, 68 seed=seed,
64 batch_size=sample_batch_size, 69 batch_size=sample_batch_size,
@@ -78,6 +83,17 @@ def textual_inversion_strategy(
78 else: 83 else:
79 ema_embeddings = None 84 ema_embeddings = None
80 85
86 def ema_context():
87 if use_ema:
88 return ema_embeddings.apply_temporary(
89 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
90 )
91 else:
92 return nullcontext()
93
94 def on_model():
95 return text_encoder
96
81 def on_prepare(): 97 def on_prepare():
82 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) 98 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
83 99
@@ -89,24 +105,15 @@ def textual_inversion_strategy(
89 105
90 @contextmanager 106 @contextmanager
91 def on_train(epoch: int): 107 def on_train(epoch: int):
92 try: 108 tokenizer.train()
93 tokenizer.train() 109 yield
94 yield
95 finally:
96 pass
97 110
98 @contextmanager 111 @contextmanager
99 def on_eval(): 112 def on_eval():
100 try: 113 tokenizer.eval()
101 tokenizer.eval()
102 114
103 ema_context = ema_embeddings.apply_temporary( 115 with ema_context():
104 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext() 116 yield
105
106 with ema_context:
107 yield
108 finally:
109 pass
110 117
111 @torch.no_grad() 118 @torch.no_grad()
112 def on_after_optimize(lr: float): 119 def on_after_optimize(lr: float):
@@ -131,13 +138,7 @@ def textual_inversion_strategy(
131 checkpoints_path = output_dir.joinpath("checkpoints") 138 checkpoints_path = output_dir.joinpath("checkpoints")
132 checkpoints_path.mkdir(parents=True, exist_ok=True) 139 checkpoints_path.mkdir(parents=True, exist_ok=True)
133 140
134 text_encoder = accelerator.unwrap_model(text_encoder) 141 with ema_context():
135
136 ema_context = ema_embeddings.apply_temporary(
137 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
138 ) if ema_embeddings is not None else nullcontext()
139
140 with ema_context:
141 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 142 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
142 text_encoder.text_model.embeddings.save_embed( 143 text_encoder.text_model.embeddings.save_embed(
143 ids, 144 ids,
@@ -146,15 +147,12 @@ def textual_inversion_strategy(
146 147
147 @torch.no_grad() 148 @torch.no_grad()
148 def on_sample(step): 149 def on_sample(step):
149 ema_context = ema_embeddings.apply_temporary( 150 with ema_context():
150 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
151 ) if ema_embeddings is not None else nullcontext()
152
153 with ema_context:
154 save_samples_(step=step) 151 save_samples_(step=step)
155 152
156 return TrainingCallbacks( 153 return TrainingCallbacks(
157 on_prepare=on_prepare, 154 on_prepare=on_prepare,
155 on_model=on_model,
158 on_train=on_train, 156 on_train=on_train,
159 on_eval=on_eval, 157 on_eval=on_eval,
160 on_after_optimize=on_after_optimize, 158 on_after_optimize=on_after_optimize,