summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
commit5821523a524190490a287c5e2aacb6e72cc3a4cf (patch)
treec0eac536c754f078683be6d59893ad23d70baf51
parentTraining update (diff)
downloadtextual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.gz
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.bz2
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.zip
Update
-rw-r--r--train_dreambooth.py5
-rw-r--r--train_ti.py113
-rw-r--r--training/functional.py19
-rw-r--r--training/strategy/dreambooth.py10
-rw-r--r--training/strategy/ti.py19
-rw-r--r--training/util.py11
6 files changed, 104 insertions, 73 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index d722e68..48bdcf8 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -14,8 +14,7 @@ from slugify import slugify
14 14
15from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
16from data.csv import VlpnDataModule, keyword_filter 16from data.csv import VlpnDataModule, keyword_filter
17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 17from training.functional import train, get_models
18from training.strategy.ti import textual_inversion_strategy
19from training.strategy.dreambooth import dreambooth_strategy 18from training.strategy.dreambooth import dreambooth_strategy
20from training.optimization import get_scheduler 19from training.optimization import get_scheduler
21from training.util import save_args 20from training.util import save_args
@@ -610,7 +609,7 @@ def main():
610 ) 609 )
611 610
612 trainer( 611 trainer(
613 callbacks_fn=dreambooth_strategy, 612 strategy=dreambooth_strategy,
614 project="dreambooth", 613 project="dreambooth",
615 train_dataloader=datamodule.train_dataloader, 614 train_dataloader=datamodule.train_dataloader,
616 val_dataloader=datamodule.val_dataloader, 615 val_dataloader=datamodule.val_dataloader,
diff --git a/train_ti.py b/train_ti.py
index e7aeb23..0891c49 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -14,7 +14,7 @@ from slugify import slugify
14 14
15from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
16from data.csv import VlpnDataModule, keyword_filter 16from data.csv import VlpnDataModule, keyword_filter
17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 17from training.functional import train, add_placeholder_tokens, get_models
18from training.strategy.ti import textual_inversion_strategy 18from training.strategy.ti import textual_inversion_strategy
19from training.optimization import get_scheduler 19from training.optimization import get_scheduler
20from training.util import save_args 20from training.util import save_args
@@ -79,6 +79,10 @@ def parse_args():
79 help="Number of vectors per embedding." 79 help="Number of vectors per embedding."
80 ) 80 )
81 parser.add_argument( 81 parser.add_argument(
82 "--simultaneous",
83 action="store_true",
84 )
85 parser.add_argument(
82 "--num_class_images", 86 "--num_class_images",
83 type=int, 87 type=int,
84 default=0, 88 default=0,
@@ -474,11 +478,12 @@ def parse_args():
474 if len(args.placeholder_tokens) != len(args.num_vectors): 478 if len(args.placeholder_tokens) != len(args.num_vectors):
475 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 479 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
476 480
477 if isinstance(args.train_data_template, str): 481 if not args.simultaneous:
478 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 482 if isinstance(args.train_data_template, str):
483 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens)
479 484
480 if len(args.placeholder_tokens) != len(args.train_data_template): 485 if len(args.placeholder_tokens) != len(args.train_data_template):
481 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") 486 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items")
482 487
483 if isinstance(args.collection, str): 488 if isinstance(args.collection, str):
484 args.collection = [args.collection] 489 args.collection = [args.collection]
@@ -560,6 +565,8 @@ def main():
560 elif args.mixed_precision == "bf16": 565 elif args.mixed_precision == "bf16":
561 weight_dtype = torch.bfloat16 566 weight_dtype = torch.bfloat16
562 567
568 checkpoint_output_dir = output_dir.joinpath("checkpoints")
569
563 trainer = partial( 570 trainer = partial(
564 train, 571 train,
565 accelerator=accelerator, 572 accelerator=accelerator,
@@ -569,30 +576,50 @@ def main():
569 noise_scheduler=noise_scheduler, 576 noise_scheduler=noise_scheduler,
570 dtype=weight_dtype, 577 dtype=weight_dtype,
571 seed=args.seed, 578 seed=args.seed,
572 callbacks_fn=textual_inversion_strategy 579 with_prior_preservation=args.num_class_images != 0,
573 ) 580 prior_loss_weight=args.prior_loss_weight,
574 581 strategy=textual_inversion_strategy,
575 checkpoint_output_dir = output_dir.joinpath("checkpoints") 582 num_train_epochs=args.num_train_epochs,
576 583 sample_frequency=args.sample_frequency,
577 for i, placeholder_token, initializer_token, num_vectors, data_template in zip( 584 checkpoint_frequency=args.checkpoint_frequency,
578 range(len(args.placeholder_tokens)), 585 global_step_offset=global_step_offset,
579 args.placeholder_tokens, 586 # --
580 args.initializer_tokens, 587 tokenizer=tokenizer,
581 args.num_vectors, 588 sample_scheduler=sample_scheduler,
582 args.train_data_template 589 checkpoint_output_dir=checkpoint_output_dir,
583 ): 590 learning_rate=args.learning_rate,
584 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") 591 gradient_checkpointing=args.gradient_checkpointing,
592 use_emb_decay=args.use_emb_decay,
593 emb_decay_target=args.emb_decay_target,
594 emb_decay_factor=args.emb_decay_factor,
595 emb_decay_start=args.emb_decay_start,
596 use_ema=args.use_ema,
597 ema_inv_gamma=args.ema_inv_gamma,
598 ema_power=args.ema_power,
599 ema_max_decay=args.ema_max_decay,
600 sample_batch_size=args.sample_batch_size,
601 sample_num_batches=args.sample_batches,
602 sample_num_steps=args.sample_steps,
603 sample_image_size=args.sample_image_size,
604 )
605
606 def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template):
607 if len(placeholder_tokens) == 1:
608 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}")
609 else:
610 sample_output_dir = output_dir.joinpath("samples")
585 611
586 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 612 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
587 tokenizer=tokenizer, 613 tokenizer=tokenizer,
588 embeddings=embeddings, 614 embeddings=embeddings,
589 placeholder_tokens=[placeholder_token], 615 placeholder_tokens=placeholder_tokens,
590 initializer_tokens=[initializer_token], 616 initializer_tokens=initializer_tokens,
591 num_vectors=[num_vectors] 617 num_vectors=num_vectors
592 ) 618 )
593 619
594 print( 620 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids))
595 f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") 621
622 print(f"{i + 1}: {stats})")
596 623
597 datamodule = VlpnDataModule( 624 datamodule = VlpnDataModule(
598 data_file=args.train_data_file, 625 data_file=args.train_data_file,
@@ -612,7 +639,7 @@ def main():
612 train_set_pad=args.train_set_pad, 639 train_set_pad=args.train_set_pad,
613 valid_set_pad=args.valid_set_pad, 640 valid_set_pad=args.valid_set_pad,
614 seed=args.seed, 641 seed=args.seed,
615 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), 642 filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections),
616 dtype=weight_dtype 643 dtype=weight_dtype
617 ) 644 )
618 datamodule.setup() 645 datamodule.setup()
@@ -647,36 +674,24 @@ def main():
647 val_dataloader=datamodule.val_dataloader, 674 val_dataloader=datamodule.val_dataloader,
648 optimizer=optimizer, 675 optimizer=optimizer,
649 lr_scheduler=lr_scheduler, 676 lr_scheduler=lr_scheduler,
650 num_train_epochs=args.num_train_epochs,
651 sample_frequency=args.sample_frequency,
652 checkpoint_frequency=args.checkpoint_frequency,
653 global_step_offset=global_step_offset,
654 with_prior_preservation=args.num_class_images != 0,
655 prior_loss_weight=args.prior_loss_weight,
656 # -- 677 # --
657 tokenizer=tokenizer,
658 sample_scheduler=sample_scheduler,
659 sample_output_dir=sample_output_dir, 678 sample_output_dir=sample_output_dir,
660 checkpoint_output_dir=checkpoint_output_dir, 679 placeholder_tokens=placeholder_tokens,
661 placeholder_tokens=[placeholder_token],
662 placeholder_token_ids=placeholder_token_ids, 680 placeholder_token_ids=placeholder_token_ids,
663 learning_rate=args.learning_rate,
664 gradient_checkpointing=args.gradient_checkpointing,
665 use_emb_decay=args.use_emb_decay,
666 emb_decay_target=args.emb_decay_target,
667 emb_decay_factor=args.emb_decay_factor,
668 emb_decay_start=args.emb_decay_start,
669 use_ema=args.use_ema,
670 ema_inv_gamma=args.ema_inv_gamma,
671 ema_power=args.ema_power,
672 ema_max_decay=args.ema_max_decay,
673 sample_batch_size=args.sample_batch_size,
674 sample_num_batches=args.sample_batches,
675 sample_num_steps=args.sample_steps,
676 sample_image_size=args.sample_image_size,
677 ) 681 )
678 682
679 embeddings.persist() 683 if args.simultaneous:
684 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
685 else:
686 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
687 range(len(args.placeholder_tokens)),
688 args.placeholder_tokens,
689 args.initializer_tokens,
690 args.num_vectors,
691 args.train_data_template
692 ):
693 run(i, [placeholder_token], [initializer_token], [num_vectors], data_template)
694 embeddings.persist()
680 695
681 696
682if __name__ == "__main__": 697if __name__ == "__main__":
diff --git a/training/functional.py b/training/functional.py
index 3d27380..7a3e821 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -39,11 +39,18 @@ class TrainingCallbacks():
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[int], None] = const() 40 on_before_optimize: Callable[[int], None] = const()
41 on_after_optimize: Callable[[float], None] = const() 41 on_after_optimize: Callable[[float], None] = const()
42 on_after_epoch: Callable[[float], None] = const()
42 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 43 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
43 on_sample: Callable[[int], None] = const() 44 on_sample: Callable[[int], None] = const()
44 on_checkpoint: Callable[[int, str], None] = const() 45 on_checkpoint: Callable[[int, str], None] = const()
45 46
46 47
48@dataclass
49class TrainingStrategy():
50 callbacks: Callable[..., TrainingCallbacks]
51 prepare_unet: bool = False
52
53
47def make_grid(images, rows, cols): 54def make_grid(images, rows, cols):
48 w, h = images[0].size 55 w, h = images[0].size
49 grid = Image.new('RGB', size=(cols*w, rows*h)) 56 grid = Image.new('RGB', size=(cols*w, rows*h))
@@ -373,6 +380,7 @@ def train_loop(
373 on_train = callbacks.on_train 380 on_train = callbacks.on_train
374 on_before_optimize = callbacks.on_before_optimize 381 on_before_optimize = callbacks.on_before_optimize
375 on_after_optimize = callbacks.on_after_optimize 382 on_after_optimize = callbacks.on_after_optimize
383 on_after_epoch = callbacks.on_after_epoch
376 on_eval = callbacks.on_eval 384 on_eval = callbacks.on_eval
377 on_sample = callbacks.on_sample 385 on_sample = callbacks.on_sample
378 on_checkpoint = callbacks.on_checkpoint 386 on_checkpoint = callbacks.on_checkpoint
@@ -434,6 +442,8 @@ def train_loop(
434 442
435 accelerator.wait_for_everyone() 443 accelerator.wait_for_everyone()
436 444
445 on_after_epoch(lr_scheduler.get_last_lr()[0])
446
437 if val_dataloader is not None: 447 if val_dataloader is not None:
438 model.eval() 448 model.eval()
439 449
@@ -512,8 +522,7 @@ def train(
512 val_dataloader: Optional[DataLoader], 522 val_dataloader: Optional[DataLoader],
513 optimizer: torch.optim.Optimizer, 523 optimizer: torch.optim.Optimizer,
514 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 524 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
515 callbacks_fn: Callable[..., TrainingCallbacks], 525 strategy: TrainingStrategy,
516 prepare_unet: bool = False,
517 num_train_epochs: int = 100, 526 num_train_epochs: int = 100,
518 sample_frequency: int = 20, 527 sample_frequency: int = 20,
519 checkpoint_frequency: int = 50, 528 checkpoint_frequency: int = 50,
@@ -524,12 +533,12 @@ def train(
524): 533):
525 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] 534 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler]
526 535
527 if prepare_unet: 536 if strategy.prepare_unet:
528 prep.append(unet) 537 prep.append(unet)
529 538
530 prep = accelerator.prepare(*prep) 539 prep = accelerator.prepare(*prep)
531 540
532 if prepare_unet: 541 if strategy.prepare_unet:
533 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep 542 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep
534 else: 543 else:
535 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep 544 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep
@@ -542,7 +551,7 @@ def train(
542 model.requires_grad_(False) 551 model.requires_grad_(False)
543 model.eval() 552 model.eval()
544 553
545 callbacks = callbacks_fn( 554 callbacks = strategy.callbacks(
546 accelerator=accelerator, 555 accelerator=accelerator,
547 unet=unet, 556 unet=unet,
548 text_encoder=text_encoder, 557 text_encoder=text_encoder,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 93c81cb..bc26ee6 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import TrainingCallbacks, save_samples 18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 19
20 20
21def dreambooth_strategy( 21def dreambooth_strategy_callbacks(
22 accelerator: Accelerator, 22 accelerator: Accelerator,
23 unet: UNet2DConditionModel, 23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel, 24 text_encoder: CLIPTextModel,
@@ -185,3 +185,9 @@ def dreambooth_strategy(
185 on_checkpoint=on_checkpoint, 185 on_checkpoint=on_checkpoint,
186 on_sample=on_sample, 186 on_sample=on_sample,
187 ) 187 )
188
189
190dreambooth_strategy = TrainingStrategy(
191 callbacks=dreambooth_strategy_callbacks,
192 prepare_unet=True
193)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 00f3529..597abd0 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -15,10 +15,10 @@ from slugify import slugify
15 15
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import TrainingCallbacks, save_samples 18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 19
20 20
21def textual_inversion_strategy( 21def textual_inversion_strategy_callbacks(
22 accelerator: Accelerator, 22 accelerator: Accelerator,
23 unet: UNet2DConditionModel, 23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel, 24 text_encoder: CLIPTextModel,
@@ -119,17 +119,18 @@ def textual_inversion_strategy(
119 with ema_context(): 119 with ema_context():
120 yield 120 yield
121 121
122 @torch.no_grad()
123 def on_after_optimize(lr: float): 122 def on_after_optimize(lr: float):
123 if use_ema:
124 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
125
126 @torch.no_grad()
127 def on_after_epoch(lr: float):
124 if use_emb_decay: 128 if use_emb_decay:
125 text_encoder.text_model.embeddings.normalize( 129 text_encoder.text_model.embeddings.normalize(
126 emb_decay_target, 130 emb_decay_target,
127 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) 131 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start))))
128 ) 132 )
129 133
130 if use_ema:
131 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
132
133 def on_log(): 134 def on_log():
134 if use_ema: 135 if use_ema:
135 return {"ema_decay": ema_embeddings.decay} 136 return {"ema_decay": ema_embeddings.decay}
@@ -157,7 +158,13 @@ def textual_inversion_strategy(
157 on_train=on_train, 158 on_train=on_train,
158 on_eval=on_eval, 159 on_eval=on_eval,
159 on_after_optimize=on_after_optimize, 160 on_after_optimize=on_after_optimize,
161 on_after_epoch=on_after_epoch,
160 on_log=on_log, 162 on_log=on_log,
161 on_checkpoint=on_checkpoint, 163 on_checkpoint=on_checkpoint,
162 on_sample=on_sample, 164 on_sample=on_sample,
163 ) 165 )
166
167
168textual_inversion_strategy = TrainingStrategy(
169 callbacks=textual_inversion_strategy_callbacks,
170)
diff --git a/training/util.py b/training/util.py
index 557b196..237626f 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,18 +1,11 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4from typing import Iterable, Union 4from typing import Iterable, Any
5from contextlib import contextmanager 5from contextlib import contextmanager
6 6
7import torch 7import torch
8 8
9from transformers import CLIPTextModel
10from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
11
12from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
13from models.clip.tokenizer import MultiCLIPTokenizer
14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
15
16 9
17def save_args(basepath: Path, args, extra={}): 10def save_args(basepath: Path, args, extra={}):
18 info = {"args": vars(args)} 11 info = {"args": vars(args)}
@@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}):
22 15
23 16
24class AverageMeter: 17class AverageMeter:
18 avg: Any
19
25 def __init__(self, name=None): 20 def __init__(self, name=None):
26 self.name = name 21 self.name = name
27 self.reset() 22 self.reset()