summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 21:53:07 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 21:53:07 +0100
commit83808fe00ac891ad2f625388d144c318b2cb5bfe (patch)
treeb7ca19d27f90be6f02b14f4a39c62fc7250041a2
parentTI: Prepare UNet with Accelerate as well (diff)
downloadtextual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.gz
textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.tar.bz2
textual-inversion-diff-83808fe00ac891ad2f625388d144c318b2cb5bfe.zip
WIP: Modularization ("free(): invalid pointer" my ass)
-rw-r--r--infer.py19
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py1
-rw-r--r--train.py672
-rw-r--r--train_dreambooth.py3
-rw-r--r--train_ti.py74
-rw-r--r--trainer/base.py544
-rw-r--r--trainer/dreambooth.py0
-rw-r--r--trainer/ti.py164
-rw-r--r--training/functional.py (renamed from training/common.py)29
-rw-r--r--training/lora.py107
-rw-r--r--training/util.py214
11 files changed, 1541 insertions, 286 deletions
diff --git a/infer.py b/infer.py
index 36b5a2c..2b07b21 100644
--- a/infer.py
+++ b/infer.py
@@ -214,10 +214,21 @@ def load_embeddings(pipeline, embeddings_dir):
214def create_pipeline(model, dtype): 214def create_pipeline(model, dtype):
215 print("Loading Stable Diffusion pipeline...") 215 print("Loading Stable Diffusion pipeline...")
216 216
217 pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype) 217 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
218 218 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype)
219 patch_managed_embeddings(pipeline.text_encoder) 219 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype)
220 220 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
221 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
222
223 patch_managed_embeddings(text_encoder)
224
225 pipeline = VlpnStableDiffusion(
226 text_encoder=text_encoder,
227 vae=vae,
228 unet=unet,
229 tokenizer=tokenizer,
230 scheduler=scheduler,
231 )
221 pipeline.enable_xformers_memory_efficient_attention() 232 pipeline.enable_xformers_memory_efficient_attention()
222 pipeline.enable_vae_slicing() 233 pipeline.enable_vae_slicing()
223 pipeline.to("cuda") 234 pipeline.to("cuda")
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index a5cfc60..43141bd 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -52,7 +52,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
52 EulerAncestralDiscreteScheduler, 52 EulerAncestralDiscreteScheduler,
53 DPMSolverMultistepScheduler, 53 DPMSolverMultistepScheduler,
54 ], 54 ],
55 **kwargs,
56 ): 55 ):
57 super().__init__() 56 super().__init__()
58 57
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..d8644c4
--- /dev/null
+++ b/train.py
@@ -0,0 +1,672 @@
1import argparse
2import datetime
3import logging
4from pathlib import Path
5
6import torch
7import torch.utils.checkpoint
8
9from accelerate import Accelerator
10from accelerate.logging import get_logger
11from accelerate.utils import LoggerType, set_seed
12from slugify import slugify
13
14from data.csv import VlpnDataModule, VlpnDataItem
15from util import load_config, load_embeddings_from_dir
16
17from trainer.ti import TextualInversionTrainingStrategy
18from trainer.base import Trainer
19from training.optimization import get_scheduler
20from training.util import save_args, generate_class_images, add_placeholder_tokens, get_models
21
22logger = get_logger(__name__)
23
24
25torch.backends.cuda.matmul.allow_tf32 = True
26torch.backends.cudnn.benchmark = True
27
28
29def parse_args():
30 parser = argparse.ArgumentParser(
31 description="Simple example of a training script."
32 )
33 parser.add_argument(
34 "--pretrained_model_name_or_path",
35 type=str,
36 default=None,
37 help="Path to pretrained model or model identifier from huggingface.co/models.",
38 )
39 parser.add_argument(
40 "--tokenizer_name",
41 type=str,
42 default=None,
43 help="Pretrained tokenizer name or path if not the same as model_name",
44 )
45 parser.add_argument(
46 "--train_data_file",
47 type=str,
48 default=None,
49 help="A CSV file containing the training data."
50 )
51 parser.add_argument(
52 "--train_data_template",
53 type=str,
54 default="template",
55 )
56 parser.add_argument(
57 "--project",
58 type=str,
59 default=None,
60 help="The name of the current project.",
61 )
62 parser.add_argument(
63 "--placeholder_tokens",
64 type=str,
65 nargs='*',
66 help="A token to use as a placeholder for the concept.",
67 )
68 parser.add_argument(
69 "--initializer_tokens",
70 type=str,
71 nargs='*',
72 help="A token to use as initializer word."
73 )
74 parser.add_argument(
75 "--num_vectors",
76 type=int,
77 nargs='*',
78 help="Number of vectors per embedding."
79 )
80 parser.add_argument(
81 "--num_class_images",
82 type=int,
83 default=1,
84 help="How many class images to generate."
85 )
86 parser.add_argument(
87 "--class_image_dir",
88 type=str,
89 default="cls",
90 help="The directory where class images will be saved.",
91 )
92 parser.add_argument(
93 "--exclude_collections",
94 type=str,
95 nargs='*',
96 help="Exclude all items with a listed collection.",
97 )
98 parser.add_argument(
99 "--output_dir",
100 type=str,
101 default="output/text-inversion",
102 help="The output directory where the model predictions and checkpoints will be written.",
103 )
104 parser.add_argument(
105 "--embeddings_dir",
106 type=str,
107 default=None,
108 help="The embeddings directory where Textual Inversion embeddings are stored.",
109 )
110 parser.add_argument(
111 "--collection",
112 type=str,
113 nargs='*',
114 help="A collection to filter the dataset.",
115 )
116 parser.add_argument(
117 "--seed",
118 type=int,
119 default=None,
120 help="A seed for reproducible training."
121 )
122 parser.add_argument(
123 "--resolution",
124 type=int,
125 default=768,
126 help=(
127 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
128 " resolution"
129 ),
130 )
131 parser.add_argument(
132 "--num_buckets",
133 type=int,
134 default=0,
135 help="Number of aspect ratio buckets in either direction.",
136 )
137 parser.add_argument(
138 "--progressive_buckets",
139 action="store_true",
140 help="Include images in smaller buckets as well.",
141 )
142 parser.add_argument(
143 "--bucket_step_size",
144 type=int,
145 default=64,
146 help="Step size between buckets.",
147 )
148 parser.add_argument(
149 "--bucket_max_pixels",
150 type=int,
151 default=None,
152 help="Maximum pixels per bucket.",
153 )
154 parser.add_argument(
155 "--tag_dropout",
156 type=float,
157 default=0,
158 help="Tag dropout probability.",
159 )
160 parser.add_argument(
161 "--no_tag_shuffle",
162 action="store_true",
163 help="Shuffle tags.",
164 )
165 parser.add_argument(
166 "--vector_dropout",
167 type=int,
168 default=0,
169 help="Vector dropout probability.",
170 )
171 parser.add_argument(
172 "--vector_shuffle",
173 type=str,
174 default="auto",
175 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
176 )
177 parser.add_argument(
178 "--num_train_epochs",
179 type=int,
180 default=100
181 )
182 parser.add_argument(
183 "--gradient_accumulation_steps",
184 type=int,
185 default=1,
186 help="Number of updates steps to accumulate before performing a backward/update pass.",
187 )
188 parser.add_argument(
189 "--gradient_checkpointing",
190 action="store_true",
191 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
192 )
193 parser.add_argument(
194 "--find_lr",
195 action="store_true",
196 help="Automatically find a learning rate (no training).",
197 )
198 parser.add_argument(
199 "--learning_rate",
200 type=float,
201 default=1e-4,
202 help="Initial learning rate (after the potential warmup period) to use.",
203 )
204 parser.add_argument(
205 "--scale_lr",
206 action="store_true",
207 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
208 )
209 parser.add_argument(
210 "--lr_scheduler",
211 type=str,
212 default="one_cycle",
213 help=(
214 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
215 ' "constant", "constant_with_warmup", "one_cycle"]'
216 ),
217 )
218 parser.add_argument(
219 "--lr_warmup_epochs",
220 type=int,
221 default=10,
222 help="Number of steps for the warmup in the lr scheduler."
223 )
224 parser.add_argument(
225 "--lr_cycles",
226 type=int,
227 default=None,
228 help="Number of restart cycles in the lr scheduler."
229 )
230 parser.add_argument(
231 "--lr_warmup_func",
232 type=str,
233 default="cos",
234 help='Choose between ["linear", "cos"]'
235 )
236 parser.add_argument(
237 "--lr_warmup_exp",
238 type=int,
239 default=1,
240 help='If lr_warmup_func is "cos", exponent to modify the function'
241 )
242 parser.add_argument(
243 "--lr_annealing_func",
244 type=str,
245 default="cos",
246 help='Choose between ["linear", "half_cos", "cos"]'
247 )
248 parser.add_argument(
249 "--lr_annealing_exp",
250 type=int,
251 default=1,
252 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
253 )
254 parser.add_argument(
255 "--lr_min_lr",
256 type=float,
257 default=0.04,
258 help="Minimum learning rate in the lr scheduler."
259 )
260 parser.add_argument(
261 "--use_ema",
262 action="store_true",
263 help="Whether to use EMA model."
264 )
265 parser.add_argument(
266 "--ema_inv_gamma",
267 type=float,
268 default=1.0
269 )
270 parser.add_argument(
271 "--ema_power",
272 type=float,
273 default=1
274 )
275 parser.add_argument(
276 "--ema_max_decay",
277 type=float,
278 default=0.9999
279 )
280 parser.add_argument(
281 "--use_8bit_adam",
282 action="store_true",
283 help="Whether or not to use 8-bit Adam from bitsandbytes."
284 )
285 parser.add_argument(
286 "--adam_beta1",
287 type=float,
288 default=0.9,
289 help="The beta1 parameter for the Adam optimizer."
290 )
291 parser.add_argument(
292 "--adam_beta2",
293 type=float,
294 default=0.999,
295 help="The beta2 parameter for the Adam optimizer."
296 )
297 parser.add_argument(
298 "--adam_weight_decay",
299 type=float,
300 default=0,
301 help="Weight decay to use."
302 )
303 parser.add_argument(
304 "--adam_epsilon",
305 type=float,
306 default=1e-08,
307 help="Epsilon value for the Adam optimizer"
308 )
309 parser.add_argument(
310 "--adam_amsgrad",
311 type=bool,
312 default=False,
313 help="Amsgrad value for the Adam optimizer"
314 )
315 parser.add_argument(
316 "--mixed_precision",
317 type=str,
318 default="no",
319 choices=["no", "fp16", "bf16"],
320 help=(
321 "Whether to use mixed precision. Choose"
322 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
323 "and an Nvidia Ampere GPU."
324 ),
325 )
326 parser.add_argument(
327 "--checkpoint_frequency",
328 type=int,
329 default=5,
330 help="How often to save a checkpoint and sample image (in epochs)",
331 )
332 parser.add_argument(
333 "--sample_frequency",
334 type=int,
335 default=1,
336 help="How often to save a checkpoint and sample image (in epochs)",
337 )
338 parser.add_argument(
339 "--sample_image_size",
340 type=int,
341 default=768,
342 help="Size of sample images",
343 )
344 parser.add_argument(
345 "--sample_batches",
346 type=int,
347 default=1,
348 help="Number of sample batches to generate per checkpoint",
349 )
350 parser.add_argument(
351 "--sample_batch_size",
352 type=int,
353 default=1,
354 help="Number of samples to generate per batch",
355 )
356 parser.add_argument(
357 "--valid_set_size",
358 type=int,
359 default=None,
360 help="Number of images in the validation dataset."
361 )
362 parser.add_argument(
363 "--valid_set_repeat",
364 type=int,
365 default=1,
366 help="Times the images in the validation dataset are repeated."
367 )
368 parser.add_argument(
369 "--train_batch_size",
370 type=int,
371 default=1,
372 help="Batch size (per device) for the training dataloader."
373 )
374 parser.add_argument(
375 "--sample_steps",
376 type=int,
377 default=20,
378 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
379 )
380 parser.add_argument(
381 "--prior_loss_weight",
382 type=float,
383 default=1.0,
384 help="The weight of prior preservation loss."
385 )
386 parser.add_argument(
387 "--emb_decay_target",
388 default=0.4,
389 type=float,
390 help="Embedding decay target."
391 )
392 parser.add_argument(
393 "--emb_decay_factor",
394 default=0,
395 type=float,
396 help="Embedding decay factor."
397 )
398 parser.add_argument(
399 "--emb_decay_start",
400 default=1e-4,
401 type=float,
402 help="Embedding decay start offset."
403 )
404 parser.add_argument(
405 "--noise_timesteps",
406 type=int,
407 default=1000,
408 )
409 parser.add_argument(
410 "--resume_from",
411 type=str,
412 default=None,
413 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
414 )
415 parser.add_argument(
416 "--global_step",
417 type=int,
418 default=0,
419 )
420 parser.add_argument(
421 "--config",
422 type=str,
423 default=None,
424 help="Path to a JSON configuration file containing arguments for invoking this script."
425 )
426
427 args = parser.parse_args()
428 if args.config is not None:
429 args = load_config(args.config)
430 args = parser.parse_args(namespace=argparse.Namespace(**args))
431
432 if args.train_data_file is None:
433 raise ValueError("You must specify --train_data_file")
434
435 if args.pretrained_model_name_or_path is None:
436 raise ValueError("You must specify --pretrained_model_name_or_path")
437
438 if args.project is None:
439 raise ValueError("You must specify --project")
440
441 if isinstance(args.placeholder_tokens, str):
442 args.placeholder_tokens = [args.placeholder_tokens]
443
444 if len(args.placeholder_tokens) == 0:
445 args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)]
446
447 if isinstance(args.initializer_tokens, str):
448 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
449
450 if len(args.initializer_tokens) == 0:
451 raise ValueError("You must specify --initializer_tokens")
452
453 if len(args.placeholder_tokens) != len(args.initializer_tokens):
454 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
455
456 if args.num_vectors is None:
457 args.num_vectors = 1
458
459 if isinstance(args.num_vectors, int):
460 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens)
461
462 if len(args.placeholder_tokens) != len(args.num_vectors):
463 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
464
465 if isinstance(args.collection, str):
466 args.collection = [args.collection]
467
468 if isinstance(args.exclude_collections, str):
469 args.exclude_collections = [args.exclude_collections]
470
471 if args.output_dir is None:
472 raise ValueError("You must specify --output_dir")
473
474 return args
475
476
477def main():
478 args = parse_args()
479
480 global_step_offset = args.global_step
481 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
482 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
483 output_dir.mkdir(parents=True, exist_ok=True)
484
485 accelerator = Accelerator(
486 log_with=LoggerType.TENSORBOARD,
487 logging_dir=f"{output_dir}",
488 gradient_accumulation_steps=args.gradient_accumulation_steps,
489 mixed_precision=args.mixed_precision
490 )
491
492 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
493
494 if args.seed is None:
495 args.seed = torch.random.seed() >> 32
496
497 set_seed(args.seed)
498
499 save_args(output_dir, args)
500
501 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
502 args.pretrained_model_name_or_path)
503
504 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
505 tokenizer.set_dropout(args.vector_dropout)
506
507 vae.enable_slicing()
508 vae.set_use_memory_efficient_attention_xformers(True)
509 unet.set_use_memory_efficient_attention_xformers(True)
510
511 if args.gradient_checkpointing:
512 unet.enable_gradient_checkpointing()
513 text_encoder.gradient_checkpointing_enable()
514
515 if args.embeddings_dir is not None:
516 embeddings_dir = Path(args.embeddings_dir)
517 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
518 raise ValueError("--embeddings_dir must point to an existing directory")
519
520 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
521 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
522
523 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
524 tokenizer=tokenizer,
525 embeddings=embeddings,
526 placeholder_tokens=args.placeholder_tokens,
527 initializer_tokens=args.initializer_tokens,
528 num_vectors=args.num_vectors
529 )
530
531 if len(placeholder_token_ids) != 0:
532 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
533 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
534 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
535
536 if args.scale_lr:
537 args.learning_rate = (
538 args.learning_rate * args.gradient_accumulation_steps *
539 args.train_batch_size * accelerator.num_processes
540 )
541
542 if args.find_lr:
543 args.learning_rate = 1e-5
544
545 if args.use_8bit_adam:
546 try:
547 import bitsandbytes as bnb
548 except ImportError:
549 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
550
551 optimizer_class = bnb.optim.AdamW8bit
552 else:
553 optimizer_class = torch.optim.AdamW
554
555 optimizer = optimizer_class(
556 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
557 lr=args.learning_rate,
558 betas=(args.adam_beta1, args.adam_beta2),
559 weight_decay=args.adam_weight_decay,
560 eps=args.adam_epsilon,
561 amsgrad=args.adam_amsgrad,
562 )
563
564 weight_dtype = torch.float32
565 if args.mixed_precision == "fp16":
566 weight_dtype = torch.float16
567 elif args.mixed_precision == "bf16":
568 weight_dtype = torch.bfloat16
569
570 def keyword_filter(item: VlpnDataItem):
571 cond1 = any(
572 keyword in part
573 for keyword in args.placeholder_tokens
574 for part in item.prompt
575 )
576 cond3 = args.collection is None or args.collection in item.collection
577 cond4 = args.exclude_collections is None or not any(
578 collection in item.collection
579 for collection in args.exclude_collections
580 )
581 return cond1 and cond3 and cond4
582
583 datamodule = VlpnDataModule(
584 data_file=args.train_data_file,
585 batch_size=args.train_batch_size,
586 tokenizer=tokenizer,
587 class_subdir=args.class_image_dir,
588 num_class_images=args.num_class_images,
589 size=args.resolution,
590 num_buckets=args.num_buckets,
591 progressive_buckets=args.progressive_buckets,
592 bucket_step_size=args.bucket_step_size,
593 bucket_max_pixels=args.bucket_max_pixels,
594 dropout=args.tag_dropout,
595 shuffle=not args.no_tag_shuffle,
596 template_key=args.train_data_template,
597 valid_set_size=args.valid_set_size,
598 valid_set_repeat=args.valid_set_repeat,
599 seed=args.seed,
600 filter=keyword_filter,
601 dtype=weight_dtype
602 )
603 datamodule.setup()
604
605 train_dataloader = datamodule.train_dataloader
606 val_dataloader = datamodule.val_dataloader
607
608 if args.num_class_images != 0:
609 generate_class_images(
610 accelerator,
611 text_encoder,
612 vae,
613 unet,
614 tokenizer,
615 sample_scheduler,
616 datamodule.data_train,
617 args.sample_batch_size,
618 args.sample_image_size,
619 args.sample_steps
620 )
621
622 lr_scheduler = get_scheduler(
623 args.lr_scheduler,
624 optimizer=optimizer,
625 num_training_steps_per_epoch=len(train_dataloader),
626 gradient_accumulation_steps=args.gradient_accumulation_steps,
627 min_lr=args.lr_min_lr,
628 warmup_func=args.lr_warmup_func,
629 annealing_func=args.lr_annealing_func,
630 warmup_exp=args.lr_warmup_exp,
631 annealing_exp=args.lr_annealing_exp,
632 cycles=args.lr_cycles,
633 train_epochs=args.num_train_epochs,
634 warmup_epochs=args.lr_warmup_epochs,
635 )
636
637 trainer = Trainer(
638 accelerator=accelerator,
639 unet=unet,
640 text_encoder=text_encoder,
641 tokenizer=tokenizer,
642 vae=vae,
643 noise_scheduler=noise_scheduler,
644 sample_scheduler=sample_scheduler,
645 train_dataloader=train_dataloader,
646 val_dataloader=val_dataloader,
647 dtype=weight_dtype,
648 )
649
650 trainer(
651 strategy_class=TextualInversionTrainingStrategy,
652 optimizer=optimizer,
653 lr_scheduler=lr_scheduler,
654 num_train_epochs=args.num_train_epochs,
655 sample_frequency=args.sample_frequency,
656 checkpoint_frequency=args.checkpoint_frequency,
657 global_step_offset=global_step_offset,
658 prior_loss_weight=args.prior_loss_weight,
659 output_dir=output_dir,
660 placeholder_tokens=args.placeholder_tokens,
661 placeholder_token_ids=placeholder_token_ids,
662 learning_rate=args.learning_rate,
663 sample_steps=args.sample_steps,
664 sample_image_size=args.sample_image_size,
665 sample_batch_size=args.sample_batch_size,
666 sample_batches=args.sample_batches,
667 seed=args.seed,
668 )
669
670
671if __name__ == "__main__":
672 main()
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 53776ba..71bad7e 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -20,10 +20,9 @@ from slugify import slugify
20from util import load_config, load_embeddings_from_dir 20from util import load_config, load_embeddings_from_dir
21from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 21from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
22from data.csv import VlpnDataModule, VlpnDataItem 22from data.csv import VlpnDataModule, VlpnDataItem
23from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
24from training.optimization import get_scheduler 23from training.optimization import get_scheduler
25from training.lr import LRFinder 24from training.lr import LRFinder
26from training.util import CheckpointerBase, EMAModel, save_args 25from training.util import CheckpointerBase, EMAModel, save_args, generate_class_images, add_placeholder_tokens, get_models
27from models.clip.tokenizer import MultiCLIPTokenizer 26from models.clip.tokenizer import MultiCLIPTokenizer
28 27
29logger = get_logger(__name__) 28logger = get_logger(__name__)
diff --git a/train_ti.py b/train_ti.py
index 8631892..deed84c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -19,10 +19,11 @@ from slugify import slugify
19from util import load_config, load_embeddings_from_dir 19from util import load_config, load_embeddings_from_dir
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from data.csv import VlpnDataModule, VlpnDataItem 21from data.csv import VlpnDataModule, VlpnDataItem
22from training.common import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models 22from trainer.base import Checkpointer
23from training.functional import loss_step, train_loop, generate_class_images, add_placeholder_tokens, get_models
23from training.optimization import get_scheduler 24from training.optimization import get_scheduler
24from training.lr import LRFinder 25from training.lr import LRFinder
25from training.util import CheckpointerBase, EMAModel, save_args 26from training.util import EMAModel, save_args
26from models.clip.tokenizer import MultiCLIPTokenizer 27from models.clip.tokenizer import MultiCLIPTokenizer
27 28
28logger = get_logger(__name__) 29logger = get_logger(__name__)
@@ -480,38 +481,20 @@ def parse_args():
480 return args 481 return args
481 482
482 483
483class Checkpointer(CheckpointerBase): 484class TextualInversionCheckpointer(Checkpointer):
484 def __init__( 485 def __init__(
485 self, 486 self,
486 weight_dtype: torch.dtype,
487 accelerator: Accelerator,
488 vae: AutoencoderKL,
489 unet: UNet2DConditionModel,
490 tokenizer: MultiCLIPTokenizer,
491 text_encoder: CLIPTextModel,
492 ema_embeddings: EMAModel, 487 ema_embeddings: EMAModel,
493 scheduler,
494 placeholder_tokens,
495 placeholder_token_ids,
496 *args, 488 *args,
497 **kwargs 489 **kwargs,
498 ): 490 ):
499 super().__init__(*args, **kwargs) 491 super().__init__(*args, **kwargs)
500 492
501 self.weight_dtype = weight_dtype
502 self.accelerator = accelerator
503 self.vae = vae
504 self.unet = unet
505 self.tokenizer = tokenizer
506 self.text_encoder = text_encoder
507 self.ema_embeddings = ema_embeddings 493 self.ema_embeddings = ema_embeddings
508 self.scheduler = scheduler
509 self.placeholder_tokens = placeholder_tokens
510 self.placeholder_token_ids = placeholder_token_ids
511 494
512 @torch.no_grad() 495 @torch.no_grad()
513 def checkpoint(self, step, postfix): 496 def checkpoint(self, step, postfix):
514 print("Saving checkpoint for step %d..." % step) 497 print(f"Saving checkpoint for step {step}...")
515 498
516 checkpoints_path = self.output_dir.joinpath("checkpoints") 499 checkpoints_path = self.output_dir.joinpath("checkpoints")
517 checkpoints_path.mkdir(parents=True, exist_ok=True) 500 checkpoints_path.mkdir(parents=True, exist_ok=True)
@@ -519,7 +502,8 @@ class Checkpointer(CheckpointerBase):
519 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 502 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
520 503
521 ema_context = self.ema_embeddings.apply_temporary( 504 ema_context = self.ema_embeddings.apply_temporary(
522 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() 505 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
506 ) if self.ema_embeddings is not None else nullcontext()
523 507
524 with ema_context: 508 with ema_context:
525 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): 509 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
@@ -528,42 +512,14 @@ class Checkpointer(CheckpointerBase):
528 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 512 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
529 ) 513 )
530 514
531 del text_encoder 515 @torch.inference_mode()
532
533 @torch.no_grad()
534 def save_samples(self, step): 516 def save_samples(self, step):
535 unet = self.accelerator.unwrap_model(self.unet)
536 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
537
538 ema_context = self.ema_embeddings.apply_temporary( 517 ema_context = self.ema_embeddings.apply_temporary(
539 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if self.ema_embeddings is not None else nullcontext() 518 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
519 ) if self.ema_embeddings is not None else nullcontext()
540 520
541 with ema_context: 521 with ema_context:
542 orig_unet_dtype = unet.dtype 522 super().save_samples(step)
543 orig_text_encoder_dtype = text_encoder.dtype
544
545 unet.to(dtype=self.weight_dtype)
546 text_encoder.to(dtype=self.weight_dtype)
547
548 pipeline = VlpnStableDiffusion(
549 text_encoder=text_encoder,
550 vae=self.vae,
551 unet=self.unet,
552 tokenizer=self.tokenizer,
553 scheduler=self.scheduler,
554 ).to(self.accelerator.device)
555 pipeline.set_progress_bar_config(dynamic_ncols=True)
556
557 super().save_samples(pipeline, step)
558
559 unet.to(dtype=orig_unet_dtype)
560 text_encoder.to(dtype=orig_text_encoder_dtype)
561
562 del text_encoder
563 del pipeline
564
565 if torch.cuda.is_available():
566 torch.cuda.empty_cache()
567 523
568 524
569def main(): 525def main():
@@ -806,8 +762,8 @@ def main():
806 args.seed, 762 args.seed,
807 ) 763 )
808 764
809 checkpointer = Checkpointer( 765 checkpointer = TextualInversionCheckpointer(
810 weight_dtype=weight_dtype, 766 dtype=weight_dtype,
811 train_dataloader=train_dataloader, 767 train_dataloader=train_dataloader,
812 val_dataloader=val_dataloader, 768 val_dataloader=val_dataloader,
813 accelerator=accelerator, 769 accelerator=accelerator,
@@ -816,7 +772,7 @@ def main():
816 tokenizer=tokenizer, 772 tokenizer=tokenizer,
817 text_encoder=text_encoder, 773 text_encoder=text_encoder,
818 ema_embeddings=ema_embeddings, 774 ema_embeddings=ema_embeddings,
819 scheduler=sample_scheduler, 775 sample_scheduler=sample_scheduler,
820 placeholder_tokens=args.placeholder_tokens, 776 placeholder_tokens=args.placeholder_tokens,
821 placeholder_token_ids=placeholder_token_ids, 777 placeholder_token_ids=placeholder_token_ids,
822 output_dir=output_dir, 778 output_dir=output_dir,
diff --git a/trainer/base.py b/trainer/base.py
new file mode 100644
index 0000000..e700dd6
--- /dev/null
+++ b/trainer/base.py
@@ -0,0 +1,544 @@
1from pathlib import Path
2import math
3from contextlib import contextmanager
4from typing import Type, Optional
5import itertools
6from functools import partial
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.utils.data import DataLoader
12
13from accelerate import Accelerator
14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
16
17from tqdm.auto import tqdm
18from PIL import Image
19
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from models.clip.tokenizer import MultiCLIPTokenizer
22from models.clip.util import get_extended_embeddings
23from training.util import AverageMeter
24
25
26def make_grid(images, rows, cols):
27 w, h = images[0].size
28 grid = Image.new('RGB', size=(cols*w, rows*h))
29 for i, image in enumerate(images):
30 grid.paste(image, box=(i % cols*w, i//cols*h))
31 return grid
32
33
34class Checkpointer():
35 def __init__(
36 self,
37 accelerator: Accelerator,
38 vae: AutoencoderKL,
39 unet: UNet2DConditionModel,
40 text_encoder: CLIPTextModel,
41 tokenizer: MultiCLIPTokenizer,
42 sample_scheduler,
43 dtype,
44 train_dataloader: DataLoader,
45 val_dataloader: DataLoader,
46 output_dir: Path,
47 sample_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
52 seed: Optional[int] = None,
53 *args,
54 **kwargs,
55 ):
56 self.accelerator = accelerator
57 self.vae = vae
58 self.unet = unet
59 self.text_encoder = text_encoder
60 self.tokenizer = tokenizer
61 self.sample_scheduler = sample_scheduler
62 self.dtype = dtype
63 self.train_dataloader = train_dataloader
64 self.val_dataloader = val_dataloader
65 self.output_dir = output_dir
66 self.sample_steps = sample_steps
67 self.sample_guidance_scale = sample_guidance_scale
68 self.sample_image_size = sample_image_size
69 self.sample_batches = sample_batches
70 self.sample_batch_size = sample_batch_size
71 self.seed = seed if seed is not None else torch.random.seed()
72
73 @torch.no_grad()
74 def checkpoint(self, step: int, postfix: str):
75 pass
76
77 @torch.inference_mode()
78 def save_samples(self, step: int):
79 print(f"Saving samples for step {step}...")
80
81 samples_path = self.output_dir.joinpath("samples")
82
83 grid_cols = min(self.sample_batch_size, 4)
84 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
85
86 unet = self.accelerator.unwrap_model(self.unet)
87 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
88
89 orig_unet_dtype = unet.dtype
90 orig_text_encoder_dtype = text_encoder.dtype
91
92 unet.to(dtype=self.dtype)
93 text_encoder.to(dtype=self.dtype)
94
95 pipeline = VlpnStableDiffusion(
96 text_encoder=text_encoder,
97 vae=self.vae,
98 unet=self.unet,
99 tokenizer=self.tokenizer,
100 scheduler=self.sample_scheduler,
101 ).to(self.accelerator.device)
102 pipeline.set_progress_bar_config(dynamic_ncols=True)
103
104 generator = torch.Generator(device=self.accelerator.device).manual_seed(self.seed)
105
106 for pool, data, gen in [
107 ("stable", self.val_dataloader, generator),
108 ("val", self.val_dataloader, None),
109 ("train", self.train_dataloader, None)
110 ]:
111 all_samples = []
112 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
113 file_path.parent.mkdir(parents=True, exist_ok=True)
114
115 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
116 prompt_ids = [
117 prompt
118 for batch in batches
119 for prompt in batch["prompt_ids"]
120 ]
121 nprompt_ids = [
122 prompt
123 for batch in batches
124 for prompt in batch["nprompt_ids"]
125 ]
126
127 for i in range(self.sample_batches):
128 start = i * self.sample_batch_size
129 end = (i + 1) * self.sample_batch_size
130 prompt = prompt_ids[start:end]
131 nprompt = nprompt_ids[start:end]
132
133 samples = pipeline(
134 prompt=prompt,
135 negative_prompt=nprompt,
136 height=self.sample_image_size,
137 width=self.sample_image_size,
138 generator=gen,
139 guidance_scale=self.sample_guidance_scale,
140 num_inference_steps=self.sample_steps,
141 output_type='pil'
142 ).images
143
144 all_samples += samples
145
146 image_grid = make_grid(all_samples, grid_rows, grid_cols)
147 image_grid.save(file_path, quality=85)
148
149 unet.to(dtype=orig_unet_dtype)
150 text_encoder.to(dtype=orig_text_encoder_dtype)
151
152 del unet
153 del text_encoder
154 del generator
155 del pipeline
156
157 if torch.cuda.is_available():
158 torch.cuda.empty_cache()
159
160
161class TrainingStrategy():
162 def __init__(
163 self,
164 tokenizer: MultiCLIPTokenizer,
165 *args,
166 **kwargs,
167 ):
168 self.tokenizer = tokenizer
169 self.checkpointer = Checkpointer(tokenizer=tokenizer, *args, **kwargs)
170
171 @property
172 def main_model(self) -> nn.Module:
173 ...
174
175 @contextmanager
176 def on_train(self, epoch: int):
177 try:
178 self.tokenizer.train()
179 yield
180 finally:
181 pass
182
183 @contextmanager
184 def on_eval(self):
185 try:
186 self.tokenizer.eval()
187 yield
188 finally:
189 pass
190
191 def on_before_optimize(self, epoch: int):
192 ...
193
194 def on_after_optimize(self, lr: float):
195 ...
196
197 def on_log():
198 return {}
199
200
201def loss_step(
202 vae: AutoencoderKL,
203 unet: UNet2DConditionModel,
204 text_encoder: CLIPTextModel,
205 seed: int,
206 noise_scheduler,
207 prior_loss_weight: float,
208 step: int,
209 batch: dict,
210 eval: bool = False
211):
212 # Convert images to latent space
213 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
214 latents = latents * 0.18215
215
216 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
217
218 # Sample noise that we'll add to the latents
219 noise = torch.randn(
220 latents.shape,
221 dtype=latents.dtype,
222 layout=latents.layout,
223 device=latents.device,
224 generator=generator
225 )
226 bsz = latents.shape[0]
227 # Sample a random timestep for each image
228 timesteps = torch.randint(
229 0,
230 noise_scheduler.config.num_train_timesteps,
231 (bsz,),
232 generator=generator,
233 device=latents.device,
234 )
235 timesteps = timesteps.long()
236
237 # Add noise to the latents according to the noise magnitude at each timestep
238 # (this is the forward diffusion process)
239 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
240 noisy_latents = noisy_latents.to(dtype=unet.dtype)
241
242 # Get the text embedding for conditioning
243 encoder_hidden_states = get_extended_embeddings(
244 text_encoder,
245 batch["input_ids"],
246 batch["attention_mask"]
247 )
248 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
249
250 # Predict the noise residual
251 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
252
253 # Get the target for loss depending on the prediction type
254 if noise_scheduler.config.prediction_type == "epsilon":
255 target = noise
256 elif noise_scheduler.config.prediction_type == "v_prediction":
257 target = noise_scheduler.get_velocity(latents, noise, timesteps)
258 else:
259 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
260
261 if batch["with_prior"].all():
262 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
263 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
264 target, target_prior = torch.chunk(target, 2, dim=0)
265
266 # Compute instance loss
267 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
268
269 # Compute prior loss
270 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
271
272 # Add the prior loss to the instance loss.
273 loss = loss + prior_loss_weight * prior_loss
274 else:
275 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
276
277 acc = (model_pred == target).float().mean()
278
279 return loss, acc, bsz
280
281
282def train_loop(
283 strategy: TrainingStrategy,
284 accelerator: Accelerator,
285 vae: AutoencoderKL,
286 unet: UNet2DConditionModel,
287 text_encoder: CLIPTextModel,
288 train_dataloader: DataLoader,
289 val_dataloader: DataLoader,
290 seed: int,
291 optimizer: torch.optim.Optimizer,
292 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
293 noise_scheduler,
294 prior_loss_weight: float = 1.0,
295 sample_frequency: int = 10,
296 checkpoint_frequency: int = 50,
297 global_step_offset: int = 0,
298 num_epochs: int = 100,
299):
300 num_training_steps_per_epoch = math.ceil(
301 len(train_dataloader) / accelerator.gradient_accumulation_steps
302 )
303 num_val_steps_per_epoch = len(val_dataloader)
304
305 num_training_steps = num_training_steps_per_epoch * num_epochs
306 num_val_steps = num_val_steps_per_epoch * num_epochs
307
308 global_step = 0
309
310 avg_loss = AverageMeter()
311 avg_acc = AverageMeter()
312
313 avg_loss_val = AverageMeter()
314 avg_acc_val = AverageMeter()
315
316 max_acc_val = 0.0
317
318 local_progress_bar = tqdm(
319 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
320 disable=not accelerator.is_local_main_process,
321 dynamic_ncols=True
322 )
323 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
324
325 global_progress_bar = tqdm(
326 range(num_training_steps + num_val_steps),
327 disable=not accelerator.is_local_main_process,
328 dynamic_ncols=True
329 )
330 global_progress_bar.set_description("Total progress")
331
332 loss_step_ = partial(
333 loss_step,
334 vae,
335 unet,
336 text_encoder,
337 seed,
338 noise_scheduler,
339 prior_loss_weight
340 )
341
342 try:
343 for epoch in range(num_epochs):
344 if accelerator.is_main_process:
345 if epoch % sample_frequency == 0 and epoch != 0:
346 strategy.checkpointer.save_samples(global_step + global_step_offset)
347
348 if epoch % checkpoint_frequency == 0 and epoch != 0:
349 strategy.checkpointer.checkpoint(global_step + global_step_offset, "training")
350
351 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
352 local_progress_bar.reset()
353
354 strategy.main_model.train()
355
356 with strategy.on_train(epoch):
357 for step, batch in enumerate(train_dataloader):
358 with accelerator.accumulate(strategy.main_model):
359 loss, acc, bsz = loss_step_(step, batch)
360
361 accelerator.backward(loss)
362
363 strategy.on_before_optimize(epoch)
364
365 optimizer.step()
366 lr_scheduler.step()
367 optimizer.zero_grad(set_to_none=True)
368
369 avg_loss.update(loss.detach_(), bsz)
370 avg_acc.update(acc.detach_(), bsz)
371
372 # Checks if the accelerator has performed an optimization step behind the scenes
373 if accelerator.sync_gradients:
374 strategy.on_after_optimize(lr_scheduler.get_last_lr()[0])
375
376 local_progress_bar.update(1)
377 global_progress_bar.update(1)
378
379 global_step += 1
380
381 logs = {
382 "train/loss": avg_loss.avg.item(),
383 "train/acc": avg_acc.avg.item(),
384 "train/cur_loss": loss.item(),
385 "train/cur_acc": acc.item(),
386 "lr": lr_scheduler.get_last_lr()[0],
387 }
388 logs.update(strategy.on_log())
389
390 accelerator.log(logs, step=global_step)
391
392 local_progress_bar.set_postfix(**logs)
393
394 if global_step >= num_training_steps:
395 break
396
397 accelerator.wait_for_everyone()
398
399 strategy.main_model.eval()
400
401 cur_loss_val = AverageMeter()
402 cur_acc_val = AverageMeter()
403
404 with torch.inference_mode(), strategy.on_eval():
405 for step, batch in enumerate(val_dataloader):
406 loss, acc, bsz = loss_step_(step, batch, True)
407
408 loss = loss.detach_()
409 acc = acc.detach_()
410
411 cur_loss_val.update(loss, bsz)
412 cur_acc_val.update(acc, bsz)
413
414 avg_loss_val.update(loss, bsz)
415 avg_acc_val.update(acc, bsz)
416
417 local_progress_bar.update(1)
418 global_progress_bar.update(1)
419
420 logs = {
421 "val/loss": avg_loss_val.avg.item(),
422 "val/acc": avg_acc_val.avg.item(),
423 "val/cur_loss": loss.item(),
424 "val/cur_acc": acc.item(),
425 }
426 local_progress_bar.set_postfix(**logs)
427
428 logs["val/cur_loss"] = cur_loss_val.avg.item()
429 logs["val/cur_acc"] = cur_acc_val.avg.item()
430
431 accelerator.log(logs, step=global_step)
432
433 local_progress_bar.clear()
434 global_progress_bar.clear()
435
436 if accelerator.is_main_process:
437 if avg_acc_val.avg.item() > max_acc_val:
438 accelerator.print(
439 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
440 strategy.checkpointer.checkpoint(global_step + global_step_offset, "milestone")
441 max_acc_val = avg_acc_val.avg.item()
442
443 # Create the pipeline using using the trained modules and save it.
444 if accelerator.is_main_process:
445 print("Finished!")
446 strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
447 strategy.checkpointer.save_samples(global_step + global_step_offset)
448 accelerator.end_training()
449
450 except KeyboardInterrupt:
451 if accelerator.is_main_process:
452 print("Interrupted")
453 strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
454 accelerator.end_training()
455
456
457class Trainer():
458 def __init__(
459 self,
460 accelerator: Accelerator,
461 unet: UNet2DConditionModel,
462 text_encoder: CLIPTextModel,
463 tokenizer: MultiCLIPTokenizer,
464 vae: AutoencoderKL,
465 noise_scheduler: DDPMScheduler,
466 sample_scheduler: DPMSolverMultistepScheduler,
467 train_dataloader: DataLoader,
468 val_dataloader: DataLoader,
469 dtype: torch.dtype,
470 ):
471 self.accelerator = accelerator
472 self.unet = unet
473 self.text_encoder = text_encoder
474 self.tokenizer = tokenizer
475 self.vae = vae
476 self.noise_scheduler = noise_scheduler
477 self.sample_scheduler = sample_scheduler
478 self.train_dataloader = train_dataloader
479 self.val_dataloader = val_dataloader
480 self.dtype = dtype
481
482 def __call__(
483 self,
484 strategy_class: Type[TrainingStrategy],
485 optimizer,
486 lr_scheduler,
487 num_train_epochs: int = 100,
488 sample_frequency: int = 20,
489 checkpoint_frequency: int = 50,
490 global_step_offset: int = 0,
491 prior_loss_weight: float = 0,
492 seed: Optional[int] = None,
493 **kwargs,
494 ):
495 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = self.accelerator.prepare(
496 self.unet, self.text_encoder, optimizer, self.train_dataloader, self.val_dataloader, lr_scheduler
497 )
498
499 self.vae.to(self.accelerator.device, dtype=self.dtype)
500
501 for model in (unet, text_encoder, self.vae):
502 model.requires_grad_(False)
503 model.eval()
504
505 if seed is None:
506 seed = torch.random.seed()
507
508 strategy = strategy_class(
509 accelerator=self.accelerator,
510 vae=self.vae,
511 unet=unet,
512 text_encoder=text_encoder,
513 tokenizer=self.tokenizer,
514 sample_scheduler=self.sample_scheduler,
515 train_dataloader=train_dataloader,
516 val_dataloader=val_dataloader,
517 dtype=self.dtype,
518 seed=seed,
519 **kwargs
520 )
521
522 if self.accelerator.is_main_process:
523 self.accelerator.init_trackers("textual_inversion")
524
525 train_loop(
526 strategy=strategy,
527 accelerator=self.accelerator,
528 vae=self.vae,
529 unet=unet,
530 text_encoder=text_encoder,
531 train_dataloader=train_dataloader,
532 val_dataloader=val_dataloader,
533 seed=seed,
534 optimizer=optimizer,
535 lr_scheduler=lr_scheduler,
536 noise_scheduler=self.noise_scheduler,
537 prior_loss_weight=prior_loss_weight,
538 sample_frequency=sample_frequency,
539 checkpoint_frequency=checkpoint_frequency,
540 global_step_offset=global_step_offset,
541 num_epochs=num_train_epochs,
542 )
543
544 self.accelerator.free_memory()
diff --git a/trainer/dreambooth.py b/trainer/dreambooth.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/trainer/dreambooth.py
diff --git a/trainer/ti.py b/trainer/ti.py
new file mode 100644
index 0000000..15cf747
--- /dev/null
+++ b/trainer/ti.py
@@ -0,0 +1,164 @@
1from contextlib import contextmanager, nullcontext
2
3import torch
4
5from slugify import slugify
6
7from diffusers import UNet2DConditionModel
8from transformers import CLIPTextModel
9
10from trainer.base import TrainingStrategy, Checkpointer
11from training.util import EMAModel
12
13
14class TextualInversionCheckpointer(Checkpointer):
15 def __init__(
16 self,
17 ema_embeddings: EMAModel,
18 *args,
19 **kwargs,
20 ):
21 super().__init__(*args, **kwargs)
22
23 self.ema_embeddings = ema_embeddings
24
25 @torch.no_grad()
26 def checkpoint(self, step, postfix):
27 print(f"Saving checkpoint for step {step}...")
28
29 checkpoints_path = self.output_dir.joinpath("checkpoints")
30 checkpoints_path.mkdir(parents=True, exist_ok=True)
31
32 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
33
34 ema_context = self.ema_embeddings.apply_temporary(
35 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
36 ) if self.ema_embeddings is not None else nullcontext()
37
38 with ema_context:
39 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
40 text_encoder.text_model.embeddings.save_embed(
41 ids,
42 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
43 )
44
45 @torch.inference_mode()
46 def save_samples(self, step):
47 ema_context = self.ema_embeddings.apply_temporary(
48 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
49 ) if self.ema_embeddings is not None else nullcontext()
50
51 with ema_context:
52 super().save_samples(step)
53
54
55class TextualInversionTrainingStrategy(TrainingStrategy):
56 def __init__(
57 self,
58 unet: UNet2DConditionModel,
59 text_encoder: CLIPTextModel,
60 placeholder_tokens: list[str],
61 placeholder_token_ids: list[list[int]],
62 learning_rate: float,
63 gradient_checkpointing: bool = False,
64 use_emb_decay: bool = False,
65 emb_decay_target: float = 0.4,
66 emb_decay_factor: float = 1,
67 emb_decay_start: float = 1e-4,
68 use_ema: bool = False,
69 ema_inv_gamma: float = 1.0,
70 ema_power: int = 1,
71 ema_max_decay: float = 0.9999,
72 *args,
73 **kwargs,
74 ):
75 super().__init__(
76 unet=unet,
77 text_encoder=text_encoder,
78 *args,
79 **kwargs
80 )
81
82 self.text_encoder = text_encoder
83 self.unet = unet
84
85 self.placeholder_tokens = placeholder_tokens
86 self.placeholder_token_ids = placeholder_token_ids
87
88 self.gradient_checkpointing = gradient_checkpointing
89
90 self.learning_rate = learning_rate
91 self.use_emb_decay = use_emb_decay
92 self.emb_decay_target = emb_decay_target
93 self.emb_decay_factor = emb_decay_factor
94 self.emb_decay_start = emb_decay_start
95
96 self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
97
98 self.ema_embeddings = None
99
100 if use_ema:
101 self.ema_embeddings = EMAModel(
102 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
103 inv_gamma=ema_inv_gamma,
104 power=ema_power,
105 max_value=ema_max_decay,
106 )
107
108 self.checkpointer = TextualInversionCheckpointer(
109 unet=unet,
110 text_encoder=text_encoder,
111 ema_embeddings=self.ema_embeddings,
112 *args,
113 **kwargs
114 )
115
116 @property
117 def main_model(self):
118 return self.text_encoder
119
120 @contextmanager
121 def on_train(self, epoch: int):
122 try:
123 if self.gradient_checkpointing:
124 self.unet.train()
125
126 with super().on_eval():
127 yield
128 finally:
129 pass
130
131 @contextmanager
132 def on_eval(self):
133 try:
134 if self.gradient_checkpointing:
135 self.unet.eval()
136
137 ema_context = self.ema_embeddings.apply_temporary(
138 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
139 ) if self.ema_embeddings is not None else nullcontext()
140
141 with ema_context, super().on_eval():
142 yield
143 finally:
144 pass
145
146 @torch.no_grad()
147 def on_after_optimize(self, lr: float):
148 if self.use_emb_decay:
149 self.text_encoder.text_model.embeddings.normalize(
150 self.emb_decay_target,
151 min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start))))
152 )
153
154 if self.ema_embeddings is not None:
155 self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters())
156
157 def on_log(self):
158 log = super().on_log()
159 added = {}
160
161 if self.ema_embeddings is not None:
162 added = {"ema_decay": self.ema_embeddings.decay}
163
164 return log.update(added)
diff --git a/training/common.py b/training/functional.py
index 5d1e3f9..2d81eca 100644
--- a/training/common.py
+++ b/training/functional.py
@@ -16,19 +16,14 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 16from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
17from models.clip.util import get_extended_embeddings 17from models.clip.util import get_extended_embeddings
18from models.clip.tokenizer import MultiCLIPTokenizer 18from models.clip.tokenizer import MultiCLIPTokenizer
19from training.util import AverageMeter, CheckpointerBase 19from training.util import AverageMeter
20from trainer.base import Checkpointer
20 21
21 22
22def noop(*args, **kwards): 23def const(result=None):
23 pass 24 def fn(*args, **kwargs):
24 25 return result
25 26 return fn
26def noop_ctx(*args, **kwards):
27 return nullcontext()
28
29
30def noop_on_log():
31 return {}
32 27
33 28
34def generate_class_images( 29def generate_class_images(
@@ -210,7 +205,7 @@ def train_loop(
210 optimizer: torch.optim.Optimizer, 205 optimizer: torch.optim.Optimizer,
211 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 206 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
212 model: torch.nn.Module, 207 model: torch.nn.Module,
213 checkpointer: CheckpointerBase, 208 checkpointer: Checkpointer,
214 train_dataloader: DataLoader, 209 train_dataloader: DataLoader,
215 val_dataloader: DataLoader, 210 val_dataloader: DataLoader,
216 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 211 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
@@ -218,11 +213,11 @@ def train_loop(
218 checkpoint_frequency: int = 50, 213 checkpoint_frequency: int = 50,
219 global_step_offset: int = 0, 214 global_step_offset: int = 0,
220 num_epochs: int = 100, 215 num_epochs: int = 100,
221 on_log: Callable[[], dict[str, Any]] = noop_on_log, 216 on_log: Callable[[], dict[str, Any]] = const({}),
222 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, 217 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
223 on_before_optimize: Callable[[int], None] = noop, 218 on_before_optimize: Callable[[int], None] = const(),
224 on_after_optimize: Callable[[float], None] = noop, 219 on_after_optimize: Callable[[float], None] = const(),
225 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx 220 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
226): 221):
227 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 222 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
228 num_val_steps_per_epoch = len(val_dataloader) 223 num_val_steps_per_epoch = len(val_dataloader)
diff --git a/training/lora.py b/training/lora.py
deleted file mode 100644
index 3857d78..0000000
--- a/training/lora.py
+++ /dev/null
@@ -1,107 +0,0 @@
1import torch
2import torch.nn as nn
3
4from diffusers import ModelMixin, ConfigMixin
5from diffusers.configuration_utils import register_to_config
6from diffusers.models.cross_attention import CrossAttention
7from diffusers.utils.import_utils import is_xformers_available
8
9
10if is_xformers_available():
11 import xformers
12 import xformers.ops
13else:
14 xformers = None
15
16
17class LoRALinearLayer(nn.Module):
18 def __init__(self, in_features, out_features, rank=4):
19 super().__init__()
20
21 if rank > min(in_features, out_features):
22 raise ValueError(
23 f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
24 )
25
26 self.lora_down = nn.Linear(in_features, rank, bias=False)
27 self.lora_up = nn.Linear(rank, out_features, bias=False)
28 self.scale = 1.0
29
30 nn.init.normal_(self.lora_down.weight, std=1 / rank)
31 nn.init.zeros_(self.lora_up.weight)
32
33 def forward(self, hidden_states):
34 down_hidden_states = self.lora_down(hidden_states)
35 up_hidden_states = self.lora_up(down_hidden_states)
36
37 return up_hidden_states
38
39
40class LoRACrossAttnProcessor(nn.Module):
41 def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
42 super().__init__()
43
44 self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
45 self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
46 self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
47 self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
48
49 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
50 batch_size, sequence_length, _ = hidden_states.shape
51 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
52
53 query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
54 query = attn.head_to_batch_dim(query)
55
56 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
57
58 key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
59 value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
60
61 key = attn.head_to_batch_dim(key)
62 value = attn.head_to_batch_dim(value)
63
64 attention_probs = attn.get_attention_scores(query, key, attention_mask)
65 hidden_states = torch.bmm(attention_probs, value)
66 hidden_states = attn.batch_to_head_dim(hidden_states)
67
68 # linear proj
69 hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
70 # dropout
71 hidden_states = attn.to_out[1](hidden_states)
72
73 return hidden_states
74
75
76class LoRAXFormersCrossAttnProcessor(nn.Module):
77 def __init__(self, hidden_size, cross_attention_dim, rank=4):
78 super().__init__()
79
80 self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
81 self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
82 self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
83 self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
84
85 def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
86 batch_size, sequence_length, _ = hidden_states.shape
87 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
88
89 query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
90 query = attn.head_to_batch_dim(query).contiguous()
91
92 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
93
94 key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
95 value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
96
97 key = attn.head_to_batch_dim(key).contiguous()
98 value = attn.head_to_batch_dim(value).contiguous()
99
100 hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
101
102 # linear proj
103 hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
104 # dropout
105 hidden_states = attn.to_out[1](hidden_states)
106
107 return hidden_states
diff --git a/training/util.py b/training/util.py
index 781cf04..a292edd 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,12 +1,40 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4import itertools 4from typing import Iterable, Union
5from typing import Iterable, Optional
6from contextlib import contextmanager 5from contextlib import contextmanager
7 6
8import torch 7import torch
9from PIL import Image 8
9from transformers import CLIPTextModel
10from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
11
12from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
13from models.clip.tokenizer import MultiCLIPTokenizer
14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
15
16
17class TrainingStrategy():
18 @property
19 def main_model(self) -> torch.nn.Module:
20 ...
21
22 @contextmanager
23 def on_train(self, epoch: int):
24 yield
25
26 @contextmanager
27 def on_eval(self):
28 yield
29
30 def on_before_optimize(self, epoch: int):
31 ...
32
33 def on_after_optimize(self, lr: float):
34 ...
35
36 def on_log():
37 return {}
10 38
11 39
12def save_args(basepath: Path, args, extra={}): 40def save_args(basepath: Path, args, extra={}):
@@ -16,12 +44,93 @@ def save_args(basepath: Path, args, extra={}):
16 json.dump(info, f, indent=4) 44 json.dump(info, f, indent=4)
17 45
18 46
19def make_grid(images, rows, cols): 47def generate_class_images(
20 w, h = images[0].size 48 accelerator,
21 grid = Image.new('RGB', size=(cols*w, rows*h)) 49 text_encoder,
22 for i, image in enumerate(images): 50 vae,
23 grid.paste(image, box=(i % cols*w, i//cols*h)) 51 unet,
24 return grid 52 tokenizer,
53 scheduler,
54 data_train,
55 sample_batch_size,
56 sample_image_size,
57 sample_steps
58):
59 missing_data = [item for item in data_train if not item.class_image_path.exists()]
60
61 if len(missing_data) == 0:
62 return
63
64 batched_data = [
65 missing_data[i:i+sample_batch_size]
66 for i in range(0, len(missing_data), sample_batch_size)
67 ]
68
69 pipeline = VlpnStableDiffusion(
70 text_encoder=text_encoder,
71 vae=vae,
72 unet=unet,
73 tokenizer=tokenizer,
74 scheduler=scheduler,
75 ).to(accelerator.device)
76 pipeline.set_progress_bar_config(dynamic_ncols=True)
77
78 with torch.inference_mode():
79 for batch in batched_data:
80 image_name = [item.class_image_path for item in batch]
81 prompt = [item.cprompt for item in batch]
82 nprompt = [item.nprompt for item in batch]
83
84 images = pipeline(
85 prompt=prompt,
86 negative_prompt=nprompt,
87 height=sample_image_size,
88 width=sample_image_size,
89 num_inference_steps=sample_steps
90 ).images
91
92 for i, image in enumerate(images):
93 image.save(image_name[i])
94
95 del pipeline
96
97 if torch.cuda.is_available():
98 torch.cuda.empty_cache()
99
100
101def get_models(pretrained_model_name_or_path: str):
102 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
103 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
104 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
105 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
106 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
107 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
108 pretrained_model_name_or_path, subfolder='scheduler')
109
110 embeddings = patch_managed_embeddings(text_encoder)
111
112 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
113
114
115def add_placeholder_tokens(
116 tokenizer: MultiCLIPTokenizer,
117 embeddings: ManagedCLIPTextEmbeddings,
118 placeholder_tokens: list[str],
119 initializer_tokens: list[str],
120 num_vectors: Union[list[int], int]
121):
122 initializer_token_ids = [
123 tokenizer.encode(token, add_special_tokens=False)
124 for token in initializer_tokens
125 ]
126 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
127
128 embeddings.resize(len(tokenizer))
129
130 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
131 embeddings.add_embed(placeholder_token_id, initializer_token_id)
132
133 return placeholder_token_ids, initializer_token_ids
25 134
26 135
27class AverageMeter: 136class AverageMeter:
@@ -38,93 +147,6 @@ class AverageMeter:
38 self.avg = self.sum / self.count 147 self.avg = self.sum / self.count
39 148
40 149
41class CheckpointerBase:
42 def __init__(
43 self,
44 train_dataloader,
45 val_dataloader,
46 output_dir: Path,
47 sample_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
52 seed: Optional[int] = None
53 ):
54 self.train_dataloader = train_dataloader
55 self.val_dataloader = val_dataloader
56 self.output_dir = output_dir
57 self.sample_image_size = sample_image_size
58 self.sample_steps = sample_steps
59 self.sample_guidance_scale = sample_guidance_scale
60 self.sample_batches = sample_batches
61 self.sample_batch_size = sample_batch_size
62 self.seed = seed if seed is not None else torch.random.seed()
63
64 @torch.no_grad()
65 def checkpoint(self, step: int, postfix: str):
66 pass
67
68 @torch.inference_mode()
69 def save_samples(self, pipeline, step: int):
70 samples_path = Path(self.output_dir).joinpath("samples")
71
72 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
73
74 grid_cols = min(self.sample_batch_size, 4)
75 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
76
77 for pool, data, gen in [
78 ("stable", self.val_dataloader, generator),
79 ("val", self.val_dataloader, None),
80 ("train", self.train_dataloader, None)
81 ]:
82 all_samples = []
83 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
84 file_path.parent.mkdir(parents=True, exist_ok=True)
85
86 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
87 prompt_ids = [
88 prompt
89 for batch in batches
90 for prompt in batch["prompt_ids"]
91 ]
92 nprompt_ids = [
93 prompt
94 for batch in batches
95 for prompt in batch["nprompt_ids"]
96 ]
97
98 for i in range(self.sample_batches):
99 start = i * self.sample_batch_size
100 end = (i + 1) * self.sample_batch_size
101 prompt = prompt_ids[start:end]
102 nprompt = nprompt_ids[start:end]
103
104 samples = pipeline(
105 prompt=prompt,
106 negative_prompt=nprompt,
107 height=self.sample_image_size,
108 width=self.sample_image_size,
109 generator=gen,
110 guidance_scale=self.sample_guidance_scale,
111 num_inference_steps=self.sample_steps,
112 output_type='pil'
113 ).images
114
115 all_samples += samples
116
117 del samples
118
119 image_grid = make_grid(all_samples, grid_rows, grid_cols)
120 image_grid.save(file_path, quality=85)
121
122 del all_samples
123 del image_grid
124
125 del generator
126
127
128# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 150# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
129class EMAModel: 151class EMAModel:
130 """ 152 """