summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py108
-rw-r--r--training/functional.py118
-rw-r--r--training/strategy/ti.py164
3 files changed, 312 insertions, 78 deletions
diff --git a/train_ti.py b/train_ti.py
index 97e4e72..2fd325b 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -17,7 +17,8 @@ from slugify import slugify
17from util import load_config, load_embeddings_from_dir 17from util import load_config, load_embeddings_from_dir
18from data.csv import VlpnDataModule, VlpnDataItem 18from data.csv import VlpnDataModule, VlpnDataItem
19from trainer_old.base import Checkpointer 19from trainer_old.base import Checkpointer
20from training.functional import train, loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models 20from training.functional import train, generate_class_images, add_placeholder_tokens, get_models
21from training.strategy.ti import textual_inversion_strategy
21from training.optimization import get_scheduler 22from training.optimization import get_scheduler
22from training.lr import LRFinder 23from training.lr import LRFinder
23from training.util import EMAModel, save_args 24from training.util import EMAModel, save_args
@@ -387,6 +388,11 @@ def parse_args():
387 help="The weight of prior preservation loss." 388 help="The weight of prior preservation loss."
388 ) 389 )
389 parser.add_argument( 390 parser.add_argument(
391 "--use_emb_decay",
392 action="store_true",
393 help="Whether to use embedding decay."
394 )
395 parser.add_argument(
390 "--emb_decay_target", 396 "--emb_decay_target",
391 default=0.4, 397 default=0.4,
392 type=float, 398 type=float,
@@ -591,14 +597,6 @@ def main():
591 else: 597 else:
592 ema_embeddings = None 598 ema_embeddings = None
593 599
594 vae.requires_grad_(False)
595 unet.requires_grad_(False)
596
597 text_encoder.text_model.encoder.requires_grad_(False)
598 text_encoder.text_model.final_layer_norm.requires_grad_(False)
599 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
600 text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
601
602 if args.scale_lr: 600 if args.scale_lr:
603 args.learning_rate = ( 601 args.learning_rate = (
604 args.learning_rate * args.gradient_accumulation_steps * 602 args.learning_rate * args.gradient_accumulation_steps *
@@ -719,73 +717,36 @@ def main():
719 seed=args.seed, 717 seed=args.seed,
720 ) 718 )
721 719
722 def on_prepare(): 720 strategy = textual_inversion_strategy(
723 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
724
725 if args.gradient_checkpointing:
726 unet.train()
727
728 @contextmanager
729 def on_train(epoch: int):
730 try:
731 tokenizer.train()
732 yield
733 finally:
734 pass
735
736 @contextmanager
737 def on_eval():
738 try:
739 tokenizer.eval()
740
741 ema_context = ema_embeddings.apply_temporary(
742 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext()
743
744 with ema_context:
745 yield
746 finally:
747 pass
748
749 @torch.no_grad()
750 def on_after_optimize(lr: float):
751 if args.emb_decay_factor != 0:
752 text_encoder.text_model.embeddings.normalize(
753 args.emb_decay_target,
754 min(1.0, max(0.0, args.emb_decay_factor * ((lr - args.emb_decay_start) / (args.learning_rate - args.emb_decay_start))))
755 )
756
757 if args.use_ema:
758 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
759
760 def on_log():
761 if args.use_ema:
762 return {"ema_decay": ema_embeddings.decay}
763 return {}
764
765 checkpointer = TextualInversionCheckpointer(
766 dtype=weight_dtype,
767 train_dataloader=train_dataloader,
768 val_dataloader=val_dataloader,
769 accelerator=accelerator, 721 accelerator=accelerator,
770 vae=vae,
771 unet=unet, 722 unet=unet,
772 tokenizer=tokenizer,
773 text_encoder=text_encoder, 723 text_encoder=text_encoder,
774 ema_embeddings=ema_embeddings, 724 tokenizer=tokenizer,
725 vae=vae,
775 sample_scheduler=sample_scheduler, 726 sample_scheduler=sample_scheduler,
727 train_dataloader=train_dataloader,
728 val_dataloader=val_dataloader,
729 dtype=weight_dtype,
730 output_dir=output_dir,
731 seed=args.seed,
776 placeholder_tokens=args.placeholder_tokens, 732 placeholder_tokens=args.placeholder_tokens,
777 placeholder_token_ids=placeholder_token_ids, 733 placeholder_token_ids=placeholder_token_ids,
778 output_dir=output_dir, 734 learning_rate=args.learning_rate,
779 sample_steps=args.sample_steps, 735 gradient_checkpointing=args.gradient_checkpointing,
780 sample_image_size=args.sample_image_size, 736 use_emb_decay=args.use_emb_decay,
737 emb_decay_target=args.emb_decay_target,
738 emb_decay_factor=args.emb_decay_factor,
739 emb_decay_start=args.emb_decay_start,
740 use_ema=args.use_ema,
741 ema_inv_gamma=args.ema_inv_gamma,
742 ema_power=args.ema_power,
743 ema_max_decay=args.ema_max_decay,
781 sample_batch_size=args.sample_batch_size, 744 sample_batch_size=args.sample_batch_size,
782 sample_batches=args.sample_batches, 745 sample_num_batches=args.sample_batches,
783 seed=args.seed 746 sample_num_steps=args.sample_steps,
747 sample_image_size=args.sample_image_size,
784 ) 748 )
785 749
786 if accelerator.is_main_process:
787 accelerator.init_trackers("textual_inversion")
788
789 if args.find_lr: 750 if args.find_lr:
790 lr_finder = LRFinder( 751 lr_finder = LRFinder(
791 accelerator=accelerator, 752 accelerator=accelerator,
@@ -793,10 +754,7 @@ def main():
793 model=text_encoder, 754 model=text_encoder,
794 train_dataloader=train_dataloader, 755 train_dataloader=train_dataloader,
795 val_dataloader=val_dataloader, 756 val_dataloader=val_dataloader,
796 loss_step=loss_step_, 757 **strategy,
797 on_train=on_train,
798 on_eval=on_eval,
799 on_after_optimize=on_after_optimize,
800 ) 758 )
801 lr_finder.run(num_epochs=100, end_lr=1e3) 759 lr_finder.run(num_epochs=100, end_lr=1e3)
802 760
@@ -811,13 +769,7 @@ def main():
811 checkpoint_frequency=args.checkpoint_frequency, 769 checkpoint_frequency=args.checkpoint_frequency,
812 global_step_offset=global_step_offset, 770 global_step_offset=global_step_offset,
813 prior_loss_weight=args.prior_loss_weight, 771 prior_loss_weight=args.prior_loss_weight,
814 on_prepare=on_prepare, 772 **strategy,
815 on_log=on_log,
816 on_train=on_train,
817 on_after_optimize=on_after_optimize,
818 on_eval=on_eval,
819 on_sample=checkpointer.save_samples,
820 on_checkpoint=checkpointer.checkpoint,
821 ) 773 )
822 774
823 775
diff --git a/training/functional.py b/training/functional.py
index 1f2ca6d..e54c9c8 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -2,6 +2,8 @@ import math
2from contextlib import _GeneratorContextManager, nullcontext 2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union, Optional 3from typing import Callable, Any, Tuple, Union, Optional
4from functools import partial 4from functools import partial
5from pathlib import Path
6import itertools
5 7
6import torch 8import torch
7import torch.nn.functional as F 9import torch.nn.functional as F
@@ -26,6 +28,14 @@ def const(result=None):
26 return fn 28 return fn
27 29
28 30
31def make_grid(images, rows, cols):
32 w, h = images[0].size
33 grid = Image.new('RGB', size=(cols*w, rows*h))
34 for i, image in enumerate(images):
35 grid.paste(image, box=(i % cols*w, i//cols*h))
36 return grid
37
38
29def get_models(pretrained_model_name_or_path: str): 39def get_models(pretrained_model_name_or_path: str):
30 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 40 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
31 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 41 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
@@ -40,6 +50,107 @@ def get_models(pretrained_model_name_or_path: str):
40 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 50 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
41 51
42 52
53def save_samples(
54 accelerator: Accelerator,
55 unet: UNet2DConditionModel,
56 text_encoder: CLIPTextModel,
57 tokenizer: MultiCLIPTokenizer,
58 vae: AutoencoderKL,
59 sample_scheduler: DPMSolverMultistepScheduler,
60 train_dataloader: DataLoader,
61 val_dataloader: DataLoader,
62 dtype: torch.dtype,
63 output_dir: Path,
64 seed: int,
65 step: int,
66 batch_size: int = 1,
67 num_batches: int = 1,
68 num_steps: int = 20,
69 guidance_scale: float = 7.5,
70 image_size: Optional[int] = None,
71):
72 print(f"Saving samples for step {step}...")
73
74 samples_path = output_dir.joinpath("samples")
75
76 grid_cols = min(batch_size, 4)
77 grid_rows = (num_batches * batch_size) // grid_cols
78
79 unet = accelerator.unwrap_model(unet)
80 text_encoder = accelerator.unwrap_model(text_encoder)
81
82 orig_unet_dtype = unet.dtype
83 orig_text_encoder_dtype = text_encoder.dtype
84
85 unet.to(dtype=dtype)
86 text_encoder.to(dtype=dtype)
87
88 pipeline = VlpnStableDiffusion(
89 text_encoder=text_encoder,
90 vae=vae,
91 unet=unet,
92 tokenizer=tokenizer,
93 scheduler=sample_scheduler,
94 ).to(accelerator.device)
95 pipeline.set_progress_bar_config(dynamic_ncols=True)
96
97 generator = torch.Generator(device=accelerator.device).manual_seed(seed)
98
99 for pool, data, gen in [
100 ("stable", val_dataloader, generator),
101 ("val", val_dataloader, None),
102 ("train", train_dataloader, None)
103 ]:
104 all_samples = []
105 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
106 file_path.parent.mkdir(parents=True, exist_ok=True)
107
108 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches))
109 prompt_ids = [
110 prompt
111 for batch in batches
112 for prompt in batch["prompt_ids"]
113 ]
114 nprompt_ids = [
115 prompt
116 for batch in batches
117 for prompt in batch["nprompt_ids"]
118 ]
119
120 for i in range(num_batches):
121 start = i * batch_size
122 end = (i + 1) * batch_size
123 prompt = prompt_ids[start:end]
124 nprompt = nprompt_ids[start:end]
125
126 samples = pipeline(
127 prompt=prompt,
128 negative_prompt=nprompt,
129 height=image_size,
130 width=image_size,
131 generator=gen,
132 guidance_scale=guidance_scale,
133 num_inference_steps=num_steps,
134 output_type='pil'
135 ).images
136
137 all_samples += samples
138
139 image_grid = make_grid(all_samples, grid_rows, grid_cols)
140 image_grid.save(file_path, quality=85)
141
142 unet.to(dtype=orig_unet_dtype)
143 text_encoder.to(dtype=orig_text_encoder_dtype)
144
145 del unet
146 del text_encoder
147 del generator
148 del pipeline
149
150 if torch.cuda.is_available():
151 torch.cuda.empty_cache()
152
153
43def generate_class_images( 154def generate_class_images(
44 accelerator: Accelerator, 155 accelerator: Accelerator,
45 text_encoder: CLIPTextModel, 156 text_encoder: CLIPTextModel,
@@ -109,6 +220,10 @@ def get_models(pretrained_model_name_or_path: str):
109 220
110 embeddings = patch_managed_embeddings(text_encoder) 221 embeddings = patch_managed_embeddings(text_encoder)
111 222
223 vae.requires_grad_(False)
224 unet.requires_grad_(False)
225 text_encoder.requires_grad_(False)
226
112 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 227 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
113 228
114 229
@@ -427,6 +542,9 @@ def train(
427 seed, 542 seed,
428 ) 543 )
429 544
545 if accelerator.is_main_process:
546 accelerator.init_trackers("textual_inversion")
547
430 train_loop( 548 train_loop(
431 accelerator=accelerator, 549 accelerator=accelerator,
432 optimizer=optimizer, 550 optimizer=optimizer,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
new file mode 100644
index 0000000..83dc566
--- /dev/null
+++ b/training/strategy/ti.py
@@ -0,0 +1,164 @@
1from contextlib import nullcontext
2from typing import Optional
3from functools import partial
4from contextlib import contextmanager, nullcontext
5from pathlib import Path
6
7import torch
8from torch.utils.data import DataLoader
9
10from accelerate import Accelerator
11from transformers import CLIPTextModel
12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
13
14from slugify import slugify
15
16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel
18from training.functional import save_samples
19
20
21def textual_inversion_strategy(
22 accelerator: Accelerator,
23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel,
25 tokenizer: MultiCLIPTokenizer,
26 vae: AutoencoderKL,
27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader,
29 val_dataloader: DataLoader,
30 dtype: torch.dtype,
31 output_dir: Path,
32 seed: int,
33 placeholder_tokens: list[str],
34 placeholder_token_ids: list[list[int]],
35 learning_rate: float,
36 gradient_checkpointing: bool = False,
37 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4,
39 emb_decay_factor: float = 1,
40 emb_decay_start: float = 1e-4,
41 use_ema: bool = False,
42 ema_inv_gamma: float = 1.0,
43 ema_power: int = 1,
44 ema_max_decay: float = 0.9999,
45 sample_batch_size: int = 1,
46 sample_num_batches: int = 1,
47 sample_num_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: Optional[int] = None,
50):
51 save_samples_ = partial(
52 save_samples,
53 accelerator=accelerator,
54 unet=unet,
55 text_encoder=text_encoder,
56 tokenizer=tokenizer,
57 vae=vae,
58 sample_scheduler=sample_scheduler,
59 train_dataloader=train_dataloader,
60 val_dataloader=val_dataloader,
61 dtype=dtype,
62 output_dir=output_dir,
63 seed=seed,
64 batch_size=sample_batch_size,
65 num_batches=sample_num_batches,
66 num_steps=sample_num_steps,
67 guidance_scale=sample_guidance_scale,
68 image_size=sample_image_size,
69 )
70
71 if use_ema:
72 ema_embeddings = EMAModel(
73 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
74 inv_gamma=ema_inv_gamma,
75 power=ema_power,
76 max_value=ema_max_decay,
77 )
78 else:
79 ema_embeddings = None
80
81 def on_prepare():
82 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
83
84 if use_ema:
85 ema_embeddings.to(accelerator.device)
86
87 if gradient_checkpointing:
88 unet.train()
89
90 @contextmanager
91 def on_train(epoch: int):
92 try:
93 tokenizer.train()
94 yield
95 finally:
96 pass
97
98 @contextmanager
99 def on_eval():
100 try:
101 tokenizer.eval()
102
103 ema_context = ema_embeddings.apply_temporary(
104 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext()
105
106 with ema_context:
107 yield
108 finally:
109 pass
110
111 @torch.no_grad()
112 def on_after_optimize(lr: float):
113 if use_emb_decay:
114 text_encoder.text_model.embeddings.normalize(
115 emb_decay_target,
116 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start))))
117 )
118
119 if use_ema:
120 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
121
122 def on_log():
123 if use_ema:
124 return {"ema_decay": ema_embeddings.decay}
125 return {}
126
127 @torch.no_grad()
128 def on_checkpoint(step, postfix):
129 print(f"Saving checkpoint for step {step}...")
130
131 checkpoints_path = output_dir.joinpath("checkpoints")
132 checkpoints_path.mkdir(parents=True, exist_ok=True)
133
134 text_encoder = accelerator.unwrap_model(text_encoder)
135
136 ema_context = ema_embeddings.apply_temporary(
137 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
138 ) if ema_embeddings is not None else nullcontext()
139
140 with ema_context:
141 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
142 text_encoder.text_model.embeddings.save_embed(
143 ids,
144 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
145 )
146
147 @torch.no_grad()
148 def on_sample(step):
149 ema_context = ema_embeddings.apply_temporary(
150 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
151 ) if ema_embeddings is not None else nullcontext()
152
153 with ema_context:
154 save_samples_(step=step)
155
156 return {
157 "on_prepare": on_prepare,
158 "on_train": on_train,
159 "on_eval": on_eval,
160 "on_after_optimize": on_after_optimize,
161 "on_log": on_log,
162 "on_checkpoint": on_checkpoint,
163 "on_sample": on_sample,
164 }