summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train.py672
-rw-r--r--train_ti.py49
-rw-r--r--trainer_old/base.py538
-rw-r--r--trainer_old/dreambooth.py0
-rw-r--r--trainer_old/ti.py168
-rw-r--r--training/functional.py63
-rw-r--r--training/strategy/ti.py20
7 files changed, 40 insertions, 1470 deletions
diff --git a/train.py b/train.py
deleted file mode 100644
index d8644c4..0000000
--- a/train.py
+++ /dev/null
@@ -1,672 +0,0 @@
1import argparse
2import datetime
3import logging
4from pathlib import Path
5
6import torch
7import torch.utils.checkpoint
8
9from accelerate import Accelerator
10from accelerate.logging import get_logger
11from accelerate.utils import LoggerType, set_seed
12from slugify import slugify
13
14from data.csv import VlpnDataModule, VlpnDataItem
15from util import load_config, load_embeddings_from_dir
16
17from trainer.ti import TextualInversionTrainingStrategy
18from trainer.base import Trainer
19from training.optimization import get_scheduler
20from training.util import save_args, generate_class_images, add_placeholder_tokens, get_models
21
22logger = get_logger(__name__)
23
24
25torch.backends.cuda.matmul.allow_tf32 = True
26torch.backends.cudnn.benchmark = True
27
28
29def parse_args():
30 parser = argparse.ArgumentParser(
31 description="Simple example of a training script."
32 )
33 parser.add_argument(
34 "--pretrained_model_name_or_path",
35 type=str,
36 default=None,
37 help="Path to pretrained model or model identifier from huggingface.co/models.",
38 )
39 parser.add_argument(
40 "--tokenizer_name",
41 type=str,
42 default=None,
43 help="Pretrained tokenizer name or path if not the same as model_name",
44 )
45 parser.add_argument(
46 "--train_data_file",
47 type=str,
48 default=None,
49 help="A CSV file containing the training data."
50 )
51 parser.add_argument(
52 "--train_data_template",
53 type=str,
54 default="template",
55 )
56 parser.add_argument(
57 "--project",
58 type=str,
59 default=None,
60 help="The name of the current project.",
61 )
62 parser.add_argument(
63 "--placeholder_tokens",
64 type=str,
65 nargs='*',
66 help="A token to use as a placeholder for the concept.",
67 )
68 parser.add_argument(
69 "--initializer_tokens",
70 type=str,
71 nargs='*',
72 help="A token to use as initializer word."
73 )
74 parser.add_argument(
75 "--num_vectors",
76 type=int,
77 nargs='*',
78 help="Number of vectors per embedding."
79 )
80 parser.add_argument(
81 "--num_class_images",
82 type=int,
83 default=1,
84 help="How many class images to generate."
85 )
86 parser.add_argument(
87 "--class_image_dir",
88 type=str,
89 default="cls",
90 help="The directory where class images will be saved.",
91 )
92 parser.add_argument(
93 "--exclude_collections",
94 type=str,
95 nargs='*',
96 help="Exclude all items with a listed collection.",
97 )
98 parser.add_argument(
99 "--output_dir",
100 type=str,
101 default="output/text-inversion",
102 help="The output directory where the model predictions and checkpoints will be written.",
103 )
104 parser.add_argument(
105 "--embeddings_dir",
106 type=str,
107 default=None,
108 help="The embeddings directory where Textual Inversion embeddings are stored.",
109 )
110 parser.add_argument(
111 "--collection",
112 type=str,
113 nargs='*',
114 help="A collection to filter the dataset.",
115 )
116 parser.add_argument(
117 "--seed",
118 type=int,
119 default=None,
120 help="A seed for reproducible training."
121 )
122 parser.add_argument(
123 "--resolution",
124 type=int,
125 default=768,
126 help=(
127 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
128 " resolution"
129 ),
130 )
131 parser.add_argument(
132 "--num_buckets",
133 type=int,
134 default=0,
135 help="Number of aspect ratio buckets in either direction.",
136 )
137 parser.add_argument(
138 "--progressive_buckets",
139 action="store_true",
140 help="Include images in smaller buckets as well.",
141 )
142 parser.add_argument(
143 "--bucket_step_size",
144 type=int,
145 default=64,
146 help="Step size between buckets.",
147 )
148 parser.add_argument(
149 "--bucket_max_pixels",
150 type=int,
151 default=None,
152 help="Maximum pixels per bucket.",
153 )
154 parser.add_argument(
155 "--tag_dropout",
156 type=float,
157 default=0,
158 help="Tag dropout probability.",
159 )
160 parser.add_argument(
161 "--no_tag_shuffle",
162 action="store_true",
163 help="Shuffle tags.",
164 )
165 parser.add_argument(
166 "--vector_dropout",
167 type=int,
168 default=0,
169 help="Vector dropout probability.",
170 )
171 parser.add_argument(
172 "--vector_shuffle",
173 type=str,
174 default="auto",
175 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
176 )
177 parser.add_argument(
178 "--num_train_epochs",
179 type=int,
180 default=100
181 )
182 parser.add_argument(
183 "--gradient_accumulation_steps",
184 type=int,
185 default=1,
186 help="Number of updates steps to accumulate before performing a backward/update pass.",
187 )
188 parser.add_argument(
189 "--gradient_checkpointing",
190 action="store_true",
191 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
192 )
193 parser.add_argument(
194 "--find_lr",
195 action="store_true",
196 help="Automatically find a learning rate (no training).",
197 )
198 parser.add_argument(
199 "--learning_rate",
200 type=float,
201 default=1e-4,
202 help="Initial learning rate (after the potential warmup period) to use.",
203 )
204 parser.add_argument(
205 "--scale_lr",
206 action="store_true",
207 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
208 )
209 parser.add_argument(
210 "--lr_scheduler",
211 type=str,
212 default="one_cycle",
213 help=(
214 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
215 ' "constant", "constant_with_warmup", "one_cycle"]'
216 ),
217 )
218 parser.add_argument(
219 "--lr_warmup_epochs",
220 type=int,
221 default=10,
222 help="Number of steps for the warmup in the lr scheduler."
223 )
224 parser.add_argument(
225 "--lr_cycles",
226 type=int,
227 default=None,
228 help="Number of restart cycles in the lr scheduler."
229 )
230 parser.add_argument(
231 "--lr_warmup_func",
232 type=str,
233 default="cos",
234 help='Choose between ["linear", "cos"]'
235 )
236 parser.add_argument(
237 "--lr_warmup_exp",
238 type=int,
239 default=1,
240 help='If lr_warmup_func is "cos", exponent to modify the function'
241 )
242 parser.add_argument(
243 "--lr_annealing_func",
244 type=str,
245 default="cos",
246 help='Choose between ["linear", "half_cos", "cos"]'
247 )
248 parser.add_argument(
249 "--lr_annealing_exp",
250 type=int,
251 default=1,
252 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
253 )
254 parser.add_argument(
255 "--lr_min_lr",
256 type=float,
257 default=0.04,
258 help="Minimum learning rate in the lr scheduler."
259 )
260 parser.add_argument(
261 "--use_ema",
262 action="store_true",
263 help="Whether to use EMA model."
264 )
265 parser.add_argument(
266 "--ema_inv_gamma",
267 type=float,
268 default=1.0
269 )
270 parser.add_argument(
271 "--ema_power",
272 type=float,
273 default=1
274 )
275 parser.add_argument(
276 "--ema_max_decay",
277 type=float,
278 default=0.9999
279 )
280 parser.add_argument(
281 "--use_8bit_adam",
282 action="store_true",
283 help="Whether or not to use 8-bit Adam from bitsandbytes."
284 )
285 parser.add_argument(
286 "--adam_beta1",
287 type=float,
288 default=0.9,
289 help="The beta1 parameter for the Adam optimizer."
290 )
291 parser.add_argument(
292 "--adam_beta2",
293 type=float,
294 default=0.999,
295 help="The beta2 parameter for the Adam optimizer."
296 )
297 parser.add_argument(
298 "--adam_weight_decay",
299 type=float,
300 default=0,
301 help="Weight decay to use."
302 )
303 parser.add_argument(
304 "--adam_epsilon",
305 type=float,
306 default=1e-08,
307 help="Epsilon value for the Adam optimizer"
308 )
309 parser.add_argument(
310 "--adam_amsgrad",
311 type=bool,
312 default=False,
313 help="Amsgrad value for the Adam optimizer"
314 )
315 parser.add_argument(
316 "--mixed_precision",
317 type=str,
318 default="no",
319 choices=["no", "fp16", "bf16"],
320 help=(
321 "Whether to use mixed precision. Choose"
322 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
323 "and an Nvidia Ampere GPU."
324 ),
325 )
326 parser.add_argument(
327 "--checkpoint_frequency",
328 type=int,
329 default=5,
330 help="How often to save a checkpoint and sample image (in epochs)",
331 )
332 parser.add_argument(
333 "--sample_frequency",
334 type=int,
335 default=1,
336 help="How often to save a checkpoint and sample image (in epochs)",
337 )
338 parser.add_argument(
339 "--sample_image_size",
340 type=int,
341 default=768,
342 help="Size of sample images",
343 )
344 parser.add_argument(
345 "--sample_batches",
346 type=int,
347 default=1,
348 help="Number of sample batches to generate per checkpoint",
349 )
350 parser.add_argument(
351 "--sample_batch_size",
352 type=int,
353 default=1,
354 help="Number of samples to generate per batch",
355 )
356 parser.add_argument(
357 "--valid_set_size",
358 type=int,
359 default=None,
360 help="Number of images in the validation dataset."
361 )
362 parser.add_argument(
363 "--valid_set_repeat",
364 type=int,
365 default=1,
366 help="Times the images in the validation dataset are repeated."
367 )
368 parser.add_argument(
369 "--train_batch_size",
370 type=int,
371 default=1,
372 help="Batch size (per device) for the training dataloader."
373 )
374 parser.add_argument(
375 "--sample_steps",
376 type=int,
377 default=20,
378 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
379 )
380 parser.add_argument(
381 "--prior_loss_weight",
382 type=float,
383 default=1.0,
384 help="The weight of prior preservation loss."
385 )
386 parser.add_argument(
387 "--emb_decay_target",
388 default=0.4,
389 type=float,
390 help="Embedding decay target."
391 )
392 parser.add_argument(
393 "--emb_decay_factor",
394 default=0,
395 type=float,
396 help="Embedding decay factor."
397 )
398 parser.add_argument(
399 "--emb_decay_start",
400 default=1e-4,
401 type=float,
402 help="Embedding decay start offset."
403 )
404 parser.add_argument(
405 "--noise_timesteps",
406 type=int,
407 default=1000,
408 )
409 parser.add_argument(
410 "--resume_from",
411 type=str,
412 default=None,
413 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
414 )
415 parser.add_argument(
416 "--global_step",
417 type=int,
418 default=0,
419 )
420 parser.add_argument(
421 "--config",
422 type=str,
423 default=None,
424 help="Path to a JSON configuration file containing arguments for invoking this script."
425 )
426
427 args = parser.parse_args()
428 if args.config is not None:
429 args = load_config(args.config)
430 args = parser.parse_args(namespace=argparse.Namespace(**args))
431
432 if args.train_data_file is None:
433 raise ValueError("You must specify --train_data_file")
434
435 if args.pretrained_model_name_or_path is None:
436 raise ValueError("You must specify --pretrained_model_name_or_path")
437
438 if args.project is None:
439 raise ValueError("You must specify --project")
440
441 if isinstance(args.placeholder_tokens, str):
442 args.placeholder_tokens = [args.placeholder_tokens]
443
444 if len(args.placeholder_tokens) == 0:
445 args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)]
446
447 if isinstance(args.initializer_tokens, str):
448 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
449
450 if len(args.initializer_tokens) == 0:
451 raise ValueError("You must specify --initializer_tokens")
452
453 if len(args.placeholder_tokens) != len(args.initializer_tokens):
454 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
455
456 if args.num_vectors is None:
457 args.num_vectors = 1
458
459 if isinstance(args.num_vectors, int):
460 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens)
461
462 if len(args.placeholder_tokens) != len(args.num_vectors):
463 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
464
465 if isinstance(args.collection, str):
466 args.collection = [args.collection]
467
468 if isinstance(args.exclude_collections, str):
469 args.exclude_collections = [args.exclude_collections]
470
471 if args.output_dir is None:
472 raise ValueError("You must specify --output_dir")
473
474 return args
475
476
477def main():
478 args = parse_args()
479
480 global_step_offset = args.global_step
481 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
482 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
483 output_dir.mkdir(parents=True, exist_ok=True)
484
485 accelerator = Accelerator(
486 log_with=LoggerType.TENSORBOARD,
487 logging_dir=f"{output_dir}",
488 gradient_accumulation_steps=args.gradient_accumulation_steps,
489 mixed_precision=args.mixed_precision
490 )
491
492 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
493
494 if args.seed is None:
495 args.seed = torch.random.seed() >> 32
496
497 set_seed(args.seed)
498
499 save_args(output_dir, args)
500
501 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
502 args.pretrained_model_name_or_path)
503
504 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
505 tokenizer.set_dropout(args.vector_dropout)
506
507 vae.enable_slicing()
508 vae.set_use_memory_efficient_attention_xformers(True)
509 unet.set_use_memory_efficient_attention_xformers(True)
510
511 if args.gradient_checkpointing:
512 unet.enable_gradient_checkpointing()
513 text_encoder.gradient_checkpointing_enable()
514
515 if args.embeddings_dir is not None:
516 embeddings_dir = Path(args.embeddings_dir)
517 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
518 raise ValueError("--embeddings_dir must point to an existing directory")
519
520 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
521 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
522
523 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
524 tokenizer=tokenizer,
525 embeddings=embeddings,
526 placeholder_tokens=args.placeholder_tokens,
527 initializer_tokens=args.initializer_tokens,
528 num_vectors=args.num_vectors
529 )
530
531 if len(placeholder_token_ids) != 0:
532 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
533 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
534 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
535
536 if args.scale_lr:
537 args.learning_rate = (
538 args.learning_rate * args.gradient_accumulation_steps *
539 args.train_batch_size * accelerator.num_processes
540 )
541
542 if args.find_lr:
543 args.learning_rate = 1e-5
544
545 if args.use_8bit_adam:
546 try:
547 import bitsandbytes as bnb
548 except ImportError:
549 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
550
551 optimizer_class = bnb.optim.AdamW8bit
552 else:
553 optimizer_class = torch.optim.AdamW
554
555 optimizer = optimizer_class(
556 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
557 lr=args.learning_rate,
558 betas=(args.adam_beta1, args.adam_beta2),
559 weight_decay=args.adam_weight_decay,
560 eps=args.adam_epsilon,
561 amsgrad=args.adam_amsgrad,
562 )
563
564 weight_dtype = torch.float32
565 if args.mixed_precision == "fp16":
566 weight_dtype = torch.float16
567 elif args.mixed_precision == "bf16":
568 weight_dtype = torch.bfloat16
569
570 def keyword_filter(item: VlpnDataItem):
571 cond1 = any(
572 keyword in part
573 for keyword in args.placeholder_tokens
574 for part in item.prompt
575 )
576 cond3 = args.collection is None or args.collection in item.collection
577 cond4 = args.exclude_collections is None or not any(
578 collection in item.collection
579 for collection in args.exclude_collections
580 )
581 return cond1 and cond3 and cond4
582
583 datamodule = VlpnDataModule(
584 data_file=args.train_data_file,
585 batch_size=args.train_batch_size,
586 tokenizer=tokenizer,
587 class_subdir=args.class_image_dir,
588 num_class_images=args.num_class_images,
589 size=args.resolution,
590 num_buckets=args.num_buckets,
591 progressive_buckets=args.progressive_buckets,
592 bucket_step_size=args.bucket_step_size,
593 bucket_max_pixels=args.bucket_max_pixels,
594 dropout=args.tag_dropout,
595 shuffle=not args.no_tag_shuffle,
596 template_key=args.train_data_template,
597 valid_set_size=args.valid_set_size,
598 valid_set_repeat=args.valid_set_repeat,
599 seed=args.seed,
600 filter=keyword_filter,
601 dtype=weight_dtype
602 )
603 datamodule.setup()
604
605 train_dataloader = datamodule.train_dataloader
606 val_dataloader = datamodule.val_dataloader
607
608 if args.num_class_images != 0:
609 generate_class_images(
610 accelerator,
611 text_encoder,
612 vae,
613 unet,
614 tokenizer,
615 sample_scheduler,
616 datamodule.data_train,
617 args.sample_batch_size,
618 args.sample_image_size,
619 args.sample_steps
620 )
621
622 lr_scheduler = get_scheduler(
623 args.lr_scheduler,
624 optimizer=optimizer,
625 num_training_steps_per_epoch=len(train_dataloader),
626 gradient_accumulation_steps=args.gradient_accumulation_steps,
627 min_lr=args.lr_min_lr,
628 warmup_func=args.lr_warmup_func,
629 annealing_func=args.lr_annealing_func,
630 warmup_exp=args.lr_warmup_exp,
631 annealing_exp=args.lr_annealing_exp,
632 cycles=args.lr_cycles,
633 train_epochs=args.num_train_epochs,
634 warmup_epochs=args.lr_warmup_epochs,
635 )
636
637 trainer = Trainer(
638 accelerator=accelerator,
639 unet=unet,
640 text_encoder=text_encoder,
641 tokenizer=tokenizer,
642 vae=vae,
643 noise_scheduler=noise_scheduler,
644 sample_scheduler=sample_scheduler,
645 train_dataloader=train_dataloader,
646 val_dataloader=val_dataloader,
647 dtype=weight_dtype,
648 )
649
650 trainer(
651 strategy_class=TextualInversionTrainingStrategy,
652 optimizer=optimizer,
653 lr_scheduler=lr_scheduler,
654 num_train_epochs=args.num_train_epochs,
655 sample_frequency=args.sample_frequency,
656 checkpoint_frequency=args.checkpoint_frequency,
657 global_step_offset=global_step_offset,
658 prior_loss_weight=args.prior_loss_weight,
659 output_dir=output_dir,
660 placeholder_tokens=args.placeholder_tokens,
661 placeholder_token_ids=placeholder_token_ids,
662 learning_rate=args.learning_rate,
663 sample_steps=args.sample_steps,
664 sample_image_size=args.sample_image_size,
665 sample_batch_size=args.sample_batch_size,
666 sample_batches=args.sample_batches,
667 seed=args.seed,
668 )
669
670
671if __name__ == "__main__":
672 main()
diff --git a/train_ti.py b/train_ti.py
index 2fd325b..3c9810f 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -3,7 +3,6 @@ import datetime
3import logging 3import logging
4from functools import partial 4from functools import partial
5from pathlib import Path 5from pathlib import Path
6from contextlib import contextmanager, nullcontext
7 6
8import torch 7import torch
9import torch.utils.checkpoint 8import torch.utils.checkpoint
@@ -16,7 +15,6 @@ from slugify import slugify
16 15
17from util import load_config, load_embeddings_from_dir 16from util import load_config, load_embeddings_from_dir
18from data.csv import VlpnDataModule, VlpnDataItem 17from data.csv import VlpnDataModule, VlpnDataItem
19from trainer_old.base import Checkpointer
20from training.functional import train, generate_class_images, add_placeholder_tokens, get_models 18from training.functional import train, generate_class_images, add_placeholder_tokens, get_models
21from training.strategy.ti import textual_inversion_strategy 19from training.strategy.ti import textual_inversion_strategy
22from training.optimization import get_scheduler 20from training.optimization import get_scheduler
@@ -483,51 +481,6 @@ def parse_args():
483 return args 481 return args
484 482
485 483
486class TextualInversionCheckpointer(Checkpointer):
487 def __init__(
488 self,
489 ema_embeddings: EMAModel,
490 placeholder_tokens: list[str],
491 placeholder_token_ids: list[list[int]],
492 *args,
493 **kwargs,
494 ):
495 super().__init__(*args, **kwargs)
496
497 self.ema_embeddings = ema_embeddings
498 self.placeholder_tokens = placeholder_tokens
499 self.placeholder_token_ids = placeholder_token_ids
500
501 @torch.no_grad()
502 def checkpoint(self, step, postfix):
503 print(f"Saving checkpoint for step {step}...")
504
505 checkpoints_path = self.output_dir.joinpath("checkpoints")
506 checkpoints_path.mkdir(parents=True, exist_ok=True)
507
508 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
509
510 ema_context = self.ema_embeddings.apply_temporary(
511 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
512 ) if self.ema_embeddings is not None else nullcontext()
513
514 with ema_context:
515 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
516 text_encoder.text_model.embeddings.save_embed(
517 ids,
518 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
519 )
520
521 @torch.no_grad()
522 def save_samples(self, step):
523 ema_context = self.ema_embeddings.apply_temporary(
524 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
525 ) if self.ema_embeddings is not None else nullcontext()
526
527 with ema_context:
528 super().save_samples(step)
529
530
531def main(): 484def main():
532 args = parse_args() 485 args = parse_args()
533 486
@@ -769,7 +722,7 @@ def main():
769 checkpoint_frequency=args.checkpoint_frequency, 722 checkpoint_frequency=args.checkpoint_frequency,
770 global_step_offset=global_step_offset, 723 global_step_offset=global_step_offset,
771 prior_loss_weight=args.prior_loss_weight, 724 prior_loss_weight=args.prior_loss_weight,
772 **strategy, 725 callbacks=strategy,
773 ) 726 )
774 727
775 728
diff --git a/trainer_old/base.py b/trainer_old/base.py
deleted file mode 100644
index 5903d96..0000000
--- a/trainer_old/base.py
+++ /dev/null
@@ -1,538 +0,0 @@
1from pathlib import Path
2import math
3from contextlib import contextmanager
4from typing import Type, Optional
5import itertools
6from functools import partial
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.utils.data import DataLoader
12
13from accelerate import Accelerator
14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
16
17from tqdm.auto import tqdm
18from PIL import Image
19
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from models.clip.tokenizer import MultiCLIPTokenizer
22from models.clip.util import get_extended_embeddings
23from training.util import AverageMeter
24
25
26def make_grid(images, rows, cols):
27 w, h = images[0].size
28 grid = Image.new('RGB', size=(cols*w, rows*h))
29 for i, image in enumerate(images):
30 grid.paste(image, box=(i % cols*w, i//cols*h))
31 return grid
32
33
34class Checkpointer():
35 def __init__(
36 self,
37 accelerator: Accelerator,
38 vae: AutoencoderKL,
39 unet: UNet2DConditionModel,
40 text_encoder: CLIPTextModel,
41 tokenizer: MultiCLIPTokenizer,
42 sample_scheduler,
43 dtype,
44 train_dataloader: DataLoader,
45 val_dataloader: DataLoader,
46 output_dir: Path,
47 sample_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
52 seed: Optional[int] = None,
53 *args,
54 **kwargs,
55 ):
56 self.accelerator = accelerator
57 self.vae = vae
58 self.unet = unet
59 self.text_encoder = text_encoder
60 self.tokenizer = tokenizer
61 self.sample_scheduler = sample_scheduler
62 self.dtype = dtype
63 self.train_dataloader = train_dataloader
64 self.val_dataloader = val_dataloader
65 self.output_dir = output_dir
66 self.sample_steps = sample_steps
67 self.sample_guidance_scale = sample_guidance_scale
68 self.sample_image_size = sample_image_size
69 self.sample_batches = sample_batches
70 self.sample_batch_size = sample_batch_size
71 self.seed = seed if seed is not None else torch.random.seed()
72
73 @torch.no_grad()
74 def checkpoint(self, step: int, postfix: str):
75 pass
76
77 @torch.no_grad()
78 def save_samples(self, step: int):
79 print(f"Saving samples for step {step}...")
80
81 samples_path = self.output_dir.joinpath("samples")
82
83 grid_cols = min(self.sample_batch_size, 4)
84 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
85
86 unet = self.accelerator.unwrap_model(self.unet)
87 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
88
89 orig_unet_dtype = unet.dtype
90 orig_text_encoder_dtype = text_encoder.dtype
91
92 unet.to(dtype=self.dtype)
93 text_encoder.to(dtype=self.dtype)
94
95 pipeline = VlpnStableDiffusion(
96 text_encoder=text_encoder,
97 vae=self.vae,
98 unet=self.unet,
99 tokenizer=self.tokenizer,
100 scheduler=self.sample_scheduler,
101 ).to(self.accelerator.device)
102 pipeline.set_progress_bar_config(dynamic_ncols=True)
103
104 generator = torch.Generator(device=self.accelerator.device).manual_seed(self.seed)
105
106 for pool, data, gen in [
107 ("stable", self.val_dataloader, generator),
108 ("val", self.val_dataloader, None),
109 ("train", self.train_dataloader, None)
110 ]:
111 all_samples = []
112 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
113 file_path.parent.mkdir(parents=True, exist_ok=True)
114
115 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
116 prompt_ids = [
117 prompt
118 for batch in batches
119 for prompt in batch["prompt_ids"]
120 ]
121 nprompt_ids = [
122 prompt
123 for batch in batches
124 for prompt in batch["nprompt_ids"]
125 ]
126
127 for i in range(self.sample_batches):
128 start = i * self.sample_batch_size
129 end = (i + 1) * self.sample_batch_size
130 prompt = prompt_ids[start:end]
131 nprompt = nprompt_ids[start:end]
132
133 samples = pipeline(
134 prompt=prompt,
135 negative_prompt=nprompt,
136 height=self.sample_image_size,
137 width=self.sample_image_size,
138 generator=gen,
139 guidance_scale=self.sample_guidance_scale,
140 num_inference_steps=self.sample_steps,
141 output_type='pil'
142 ).images
143
144 all_samples += samples
145
146 image_grid = make_grid(all_samples, grid_rows, grid_cols)
147 image_grid.save(file_path, quality=85)
148
149 unet.to(dtype=orig_unet_dtype)
150 text_encoder.to(dtype=orig_text_encoder_dtype)
151
152 del unet
153 del text_encoder
154 del generator
155 del pipeline
156
157 if torch.cuda.is_available():
158 torch.cuda.empty_cache()
159
160
161class TrainingStrategy():
162 def __init__(
163 self,
164 tokenizer: MultiCLIPTokenizer,
165 *args,
166 **kwargs,
167 ):
168 self.tokenizer = tokenizer
169 self.checkpointer = Checkpointer(tokenizer=tokenizer, *args, **kwargs)
170
171 @property
172 def main_model(self) -> nn.Module:
173 ...
174
175 @contextmanager
176 def on_train(self, epoch: int):
177 self.tokenizer.train()
178 yield
179
180 @contextmanager
181 def on_eval(self):
182 self.tokenizer.eval()
183 yield
184
185 def on_before_optimize(self, epoch: int):
186 ...
187
188 def on_after_optimize(self, lr: float):
189 ...
190
191 def on_log():
192 return {}
193
194
195def loss_step(
196 vae: AutoencoderKL,
197 unet: UNet2DConditionModel,
198 text_encoder: CLIPTextModel,
199 seed: int,
200 noise_scheduler,
201 prior_loss_weight: float,
202 step: int,
203 batch: dict,
204 eval: bool = False
205):
206 # Convert images to latent space
207 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
208 latents = latents * 0.18215
209
210 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
211
212 # Sample noise that we'll add to the latents
213 noise = torch.randn(
214 latents.shape,
215 dtype=latents.dtype,
216 layout=latents.layout,
217 device=latents.device,
218 generator=generator
219 )
220 bsz = latents.shape[0]
221 # Sample a random timestep for each image
222 timesteps = torch.randint(
223 0,
224 noise_scheduler.config.num_train_timesteps,
225 (bsz,),
226 generator=generator,
227 device=latents.device,
228 )
229 timesteps = timesteps.long()
230
231 # Add noise to the latents according to the noise magnitude at each timestep
232 # (this is the forward diffusion process)
233 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
234 noisy_latents = noisy_latents.to(dtype=unet.dtype)
235
236 # Get the text embedding for conditioning
237 encoder_hidden_states = get_extended_embeddings(
238 text_encoder,
239 batch["input_ids"],
240 batch["attention_mask"]
241 )
242 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
243
244 # Predict the noise residual
245 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
246
247 # Get the target for loss depending on the prediction type
248 if noise_scheduler.config.prediction_type == "epsilon":
249 target = noise
250 elif noise_scheduler.config.prediction_type == "v_prediction":
251 target = noise_scheduler.get_velocity(latents, noise, timesteps)
252 else:
253 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
254
255 if batch["with_prior"].all():
256 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
257 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
258 target, target_prior = torch.chunk(target, 2, dim=0)
259
260 # Compute instance loss
261 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
262
263 # Compute prior loss
264 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
265
266 # Add the prior loss to the instance loss.
267 loss = loss + prior_loss_weight * prior_loss
268 else:
269 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
270
271 acc = (model_pred == target).float().mean()
272
273 return loss, acc, bsz
274
275
276def train_loop(
277 strategy: TrainingStrategy,
278 accelerator: Accelerator,
279 vae: AutoencoderKL,
280 unet: UNet2DConditionModel,
281 text_encoder: CLIPTextModel,
282 train_dataloader: DataLoader,
283 val_dataloader: DataLoader,
284 seed: int,
285 optimizer: torch.optim.Optimizer,
286 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
287 noise_scheduler,
288 prior_loss_weight: float = 1.0,
289 sample_frequency: int = 10,
290 checkpoint_frequency: int = 50,
291 global_step_offset: int = 0,
292 num_epochs: int = 100,
293):
294 num_training_steps_per_epoch = math.ceil(
295 len(train_dataloader) / accelerator.gradient_accumulation_steps
296 )
297 num_val_steps_per_epoch = len(val_dataloader)
298
299 num_training_steps = num_training_steps_per_epoch * num_epochs
300 num_val_steps = num_val_steps_per_epoch * num_epochs
301
302 global_step = 0
303
304 avg_loss = AverageMeter()
305 avg_acc = AverageMeter()
306
307 avg_loss_val = AverageMeter()
308 avg_acc_val = AverageMeter()
309
310 max_acc_val = 0.0
311
312 local_progress_bar = tqdm(
313 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
314 disable=not accelerator.is_local_main_process,
315 dynamic_ncols=True
316 )
317 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
318
319 global_progress_bar = tqdm(
320 range(num_training_steps + num_val_steps),
321 disable=not accelerator.is_local_main_process,
322 dynamic_ncols=True
323 )
324 global_progress_bar.set_description("Total progress")
325
326 loss_step_ = partial(
327 loss_step,
328 vae,
329 unet,
330 text_encoder,
331 seed,
332 noise_scheduler,
333 prior_loss_weight
334 )
335
336 try:
337 for epoch in range(num_epochs):
338 if accelerator.is_main_process:
339 if epoch % sample_frequency == 0 and epoch != 0:
340 strategy.checkpointer.save_samples(global_step + global_step_offset)
341
342 if epoch % checkpoint_frequency == 0 and epoch != 0:
343 strategy.checkpointer.checkpoint(global_step + global_step_offset, "training")
344
345 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
346 local_progress_bar.reset()
347
348 strategy.main_model.train()
349
350 with strategy.on_train(epoch):
351 for step, batch in enumerate(train_dataloader):
352 with accelerator.accumulate(strategy.main_model):
353 loss, acc, bsz = loss_step_(step, batch)
354
355 accelerator.backward(loss)
356
357 strategy.on_before_optimize(epoch)
358
359 optimizer.step()
360 lr_scheduler.step()
361 optimizer.zero_grad(set_to_none=True)
362
363 avg_loss.update(loss.detach_(), bsz)
364 avg_acc.update(acc.detach_(), bsz)
365
366 # Checks if the accelerator has performed an optimization step behind the scenes
367 if accelerator.sync_gradients:
368 strategy.on_after_optimize(lr_scheduler.get_last_lr()[0])
369
370 local_progress_bar.update(1)
371 global_progress_bar.update(1)
372
373 global_step += 1
374
375 logs = {
376 "train/loss": avg_loss.avg.item(),
377 "train/acc": avg_acc.avg.item(),
378 "train/cur_loss": loss.item(),
379 "train/cur_acc": acc.item(),
380 "lr": lr_scheduler.get_last_lr()[0],
381 }
382 logs.update(strategy.on_log())
383
384 accelerator.log(logs, step=global_step)
385
386 local_progress_bar.set_postfix(**logs)
387
388 if global_step >= num_training_steps:
389 break
390
391 accelerator.wait_for_everyone()
392
393 strategy.main_model.eval()
394
395 cur_loss_val = AverageMeter()
396 cur_acc_val = AverageMeter()
397
398 with torch.inference_mode(), strategy.on_eval():
399 for step, batch in enumerate(val_dataloader):
400 loss, acc, bsz = loss_step_(step, batch, True)
401
402 loss = loss.detach_()
403 acc = acc.detach_()
404
405 cur_loss_val.update(loss, bsz)
406 cur_acc_val.update(acc, bsz)
407
408 avg_loss_val.update(loss, bsz)
409 avg_acc_val.update(acc, bsz)
410
411 local_progress_bar.update(1)
412 global_progress_bar.update(1)
413
414 logs = {
415 "val/loss": avg_loss_val.avg.item(),
416 "val/acc": avg_acc_val.avg.item(),
417 "val/cur_loss": loss.item(),
418 "val/cur_acc": acc.item(),
419 }
420 local_progress_bar.set_postfix(**logs)
421
422 logs["val/cur_loss"] = cur_loss_val.avg.item()
423 logs["val/cur_acc"] = cur_acc_val.avg.item()
424
425 accelerator.log(logs, step=global_step)
426
427 local_progress_bar.clear()
428 global_progress_bar.clear()
429
430 if accelerator.is_main_process:
431 if avg_acc_val.avg.item() > max_acc_val:
432 accelerator.print(
433 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
434 strategy.checkpointer.checkpoint(global_step + global_step_offset, "milestone")
435 max_acc_val = avg_acc_val.avg.item()
436
437 # Create the pipeline using using the trained modules and save it.
438 if accelerator.is_main_process:
439 print("Finished!")
440 strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
441 strategy.checkpointer.save_samples(global_step + global_step_offset)
442 accelerator.end_training()
443
444 except KeyboardInterrupt:
445 if accelerator.is_main_process:
446 print("Interrupted")
447 strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
448 accelerator.end_training()
449
450
451class Trainer():
452 def __init__(
453 self,
454 accelerator: Accelerator,
455 unet: UNet2DConditionModel,
456 text_encoder: CLIPTextModel,
457 tokenizer: MultiCLIPTokenizer,
458 vae: AutoencoderKL,
459 noise_scheduler: DDPMScheduler,
460 sample_scheduler: DPMSolverMultistepScheduler,
461 train_dataloader: DataLoader,
462 val_dataloader: DataLoader,
463 dtype: torch.dtype,
464 ):
465 self.accelerator = accelerator
466 self.unet = unet
467 self.text_encoder = text_encoder
468 self.tokenizer = tokenizer
469 self.vae = vae
470 self.noise_scheduler = noise_scheduler
471 self.sample_scheduler = sample_scheduler
472 self.train_dataloader = train_dataloader
473 self.val_dataloader = val_dataloader
474 self.dtype = dtype
475
476 def __call__(
477 self,
478 strategy_class: Type[TrainingStrategy],
479 optimizer,
480 lr_scheduler,
481 num_train_epochs: int = 100,
482 sample_frequency: int = 20,
483 checkpoint_frequency: int = 50,
484 global_step_offset: int = 0,
485 prior_loss_weight: float = 0,
486 seed: Optional[int] = None,
487 **kwargs,
488 ):
489 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = self.accelerator.prepare(
490 self.unet, self.text_encoder, optimizer, self.train_dataloader, self.val_dataloader, lr_scheduler
491 )
492
493 self.vae.to(self.accelerator.device, dtype=self.dtype)
494
495 for model in (unet, text_encoder, self.vae):
496 model.requires_grad_(False)
497 model.eval()
498
499 if seed is None:
500 seed = torch.random.seed()
501
502 strategy = strategy_class(
503 accelerator=self.accelerator,
504 vae=self.vae,
505 unet=unet,
506 text_encoder=text_encoder,
507 tokenizer=self.tokenizer,
508 sample_scheduler=self.sample_scheduler,
509 train_dataloader=train_dataloader,
510 val_dataloader=val_dataloader,
511 dtype=self.dtype,
512 seed=seed,
513 **kwargs
514 )
515
516 if self.accelerator.is_main_process:
517 self.accelerator.init_trackers("textual_inversion")
518
519 train_loop(
520 strategy=strategy,
521 accelerator=self.accelerator,
522 vae=self.vae,
523 unet=unet,
524 text_encoder=text_encoder,
525 train_dataloader=train_dataloader,
526 val_dataloader=val_dataloader,
527 seed=seed,
528 optimizer=optimizer,
529 lr_scheduler=lr_scheduler,
530 noise_scheduler=self.noise_scheduler,
531 prior_loss_weight=prior_loss_weight,
532 sample_frequency=sample_frequency,
533 checkpoint_frequency=checkpoint_frequency,
534 global_step_offset=global_step_offset,
535 num_epochs=num_train_epochs,
536 )
537
538 self.accelerator.free_memory()
diff --git a/trainer_old/dreambooth.py b/trainer_old/dreambooth.py
deleted file mode 100644
index e69de29..0000000
--- a/trainer_old/dreambooth.py
+++ /dev/null
diff --git a/trainer_old/ti.py b/trainer_old/ti.py
deleted file mode 100644
index 66393af..0000000
--- a/trainer_old/ti.py
+++ /dev/null
@@ -1,168 +0,0 @@
1from contextlib import contextmanager, nullcontext
2
3import torch
4
5from slugify import slugify
6
7from diffusers import UNet2DConditionModel
8from transformers import CLIPTextModel
9
10from trainer.base import TrainingStrategy, Checkpointer
11from training.util import EMAModel
12
13
14class TextualInversionCheckpointer(Checkpointer):
15 def __init__(
16 self,
17 ema_embeddings: EMAModel,
18 placeholder_tokens: list[str],
19 placeholder_token_ids: list[list[int]],
20 *args,
21 **kwargs,
22 ):
23 super().__init__(*args, **kwargs)
24
25 self.ema_embeddings = ema_embeddings
26 self.placeholder_tokens = placeholder_tokens
27 self.placeholder_token_ids = placeholder_token_ids
28
29 @torch.no_grad()
30 def checkpoint(self, step, postfix):
31 print(f"Saving checkpoint for step {step}...")
32
33 checkpoints_path = self.output_dir.joinpath("checkpoints")
34 checkpoints_path.mkdir(parents=True, exist_ok=True)
35
36 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
37
38 ema_context = self.ema_embeddings.apply_temporary(
39 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
40 ) if self.ema_embeddings is not None else nullcontext()
41
42 with ema_context:
43 for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
44 text_encoder.text_model.embeddings.save_embed(
45 ids,
46 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
47 )
48
49 @torch.no_grad()
50 def save_samples(self, step):
51 ema_context = self.ema_embeddings.apply_temporary(
52 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
53 ) if self.ema_embeddings is not None else nullcontext()
54
55 with ema_context:
56 super().save_samples(step)
57
58
59class TextualInversionTrainingStrategy(TrainingStrategy):
60 def __init__(
61 self,
62 unet: UNet2DConditionModel,
63 text_encoder: CLIPTextModel,
64 placeholder_tokens: list[str],
65 placeholder_token_ids: list[list[int]],
66 learning_rate: float,
67 gradient_checkpointing: bool = False,
68 use_emb_decay: bool = False,
69 emb_decay_target: float = 0.4,
70 emb_decay_factor: float = 1,
71 emb_decay_start: float = 1e-4,
72 use_ema: bool = False,
73 ema_inv_gamma: float = 1.0,
74 ema_power: int = 1,
75 ema_max_decay: float = 0.9999,
76 *args,
77 **kwargs,
78 ):
79 super().__init__(
80 unet=unet,
81 text_encoder=text_encoder,
82 *args,
83 **kwargs
84 )
85
86 self.text_encoder = text_encoder
87 self.unet = unet
88
89 self.placeholder_tokens = placeholder_tokens
90 self.placeholder_token_ids = placeholder_token_ids
91
92 self.gradient_checkpointing = gradient_checkpointing
93
94 self.learning_rate = learning_rate
95 self.use_emb_decay = use_emb_decay
96 self.emb_decay_target = emb_decay_target
97 self.emb_decay_factor = emb_decay_factor
98 self.emb_decay_start = emb_decay_start
99
100 self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
101
102 self.ema_embeddings = None
103
104 if use_ema:
105 self.ema_embeddings = EMAModel(
106 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
107 inv_gamma=ema_inv_gamma,
108 power=ema_power,
109 max_value=ema_max_decay,
110 )
111
112 self.checkpointer = TextualInversionCheckpointer(
113 unet=unet,
114 text_encoder=text_encoder,
115 ema_embeddings=self.ema_embeddings,
116 *args,
117 **kwargs
118 )
119
120 @property
121 def main_model(self):
122 return self.text_encoder
123
124 @contextmanager
125 def on_train(self, epoch: int):
126 try:
127 if self.gradient_checkpointing:
128 self.unet.train()
129
130 with super().on_eval():
131 yield
132 finally:
133 pass
134
135 @contextmanager
136 def on_eval(self):
137 try:
138 if self.gradient_checkpointing:
139 self.unet.eval()
140
141 ema_context = self.ema_embeddings.apply_temporary(
142 self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
143 ) if self.ema_embeddings is not None else nullcontext()
144
145 with ema_context, super().on_eval():
146 yield
147 finally:
148 pass
149
150 @torch.no_grad()
151 def on_after_optimize(self, lr: float):
152 if self.use_emb_decay:
153 self.text_encoder.text_model.embeddings.normalize(
154 self.emb_decay_target,
155 min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start))))
156 )
157
158 if self.ema_embeddings is not None:
159 self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters())
160
161 def on_log(self):
162 log = super().on_log()
163 added = {}
164
165 if self.ema_embeddings is not None:
166 added = {"ema_decay": self.ema_embeddings.decay}
167
168 return log.update(added)
diff --git a/training/functional.py b/training/functional.py
index e54c9c8..4ca7470 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -1,3 +1,4 @@
1from dataclasses import dataclass
1import math 2import math
2from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union, Optional 4from typing import Callable, Any, Tuple, Union, Optional
@@ -14,6 +15,7 @@ from transformers import CLIPTextModel
14from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler 15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
15 16
16from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image
17 19
18from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
19from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 21from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
@@ -28,6 +30,18 @@ def const(result=None):
28 return fn 30 return fn
29 31
30 32
33@dataclass
34class TrainingCallbacks():
35 on_prepare: Callable[[float], None] = const()
36 on_log: Callable[[], dict[str, Any]] = const({})
37 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
38 on_before_optimize: Callable[[int], None] = const()
39 on_after_optimize: Callable[[float], None] = const()
40 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
41 on_sample: Callable[[int], None] = const()
42 on_checkpoint: Callable[[int, str], None] = const()
43
44
31def make_grid(images, rows, cols): 45def make_grid(images, rows, cols):
32 w, h = images[0].size 46 w, h = images[0].size
33 grid = Image.new('RGB', size=(cols*w, rows*h)) 47 grid = Image.new('RGB', size=(cols*w, rows*h))
@@ -341,13 +355,7 @@ def train_loop(
341 checkpoint_frequency: int = 50, 355 checkpoint_frequency: int = 50,
342 global_step_offset: int = 0, 356 global_step_offset: int = 0,
343 num_epochs: int = 100, 357 num_epochs: int = 100,
344 on_log: Callable[[], dict[str, Any]] = const({}), 358 callbacks: TrainingCallbacks = TrainingCallbacks(),
345 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
346 on_before_optimize: Callable[[int], None] = const(),
347 on_after_optimize: Callable[[float], None] = const(),
348 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()),
349 on_sample: Callable[[int], None] = const(),
350 on_checkpoint: Callable[[int, str], None] = const(),
351): 359):
352 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 360 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
353 num_val_steps_per_epoch = len(val_dataloader) 361 num_val_steps_per_epoch = len(val_dataloader)
@@ -383,24 +391,24 @@ def train_loop(
383 for epoch in range(num_epochs): 391 for epoch in range(num_epochs):
384 if accelerator.is_main_process: 392 if accelerator.is_main_process:
385 if epoch % sample_frequency == 0: 393 if epoch % sample_frequency == 0:
386 on_sample(global_step + global_step_offset) 394 callbacks.on_sample(global_step + global_step_offset)
387 395
388 if epoch % checkpoint_frequency == 0 and epoch != 0: 396 if epoch % checkpoint_frequency == 0 and epoch != 0:
389 on_checkpoint(global_step + global_step_offset, "training") 397 callbacks.on_checkpoint(global_step + global_step_offset, "training")
390 398
391 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 399 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
392 local_progress_bar.reset() 400 local_progress_bar.reset()
393 401
394 model.train() 402 model.train()
395 403
396 with on_train(epoch): 404 with callbacks.on_train(epoch):
397 for step, batch in enumerate(train_dataloader): 405 for step, batch in enumerate(train_dataloader):
398 with accelerator.accumulate(model): 406 with accelerator.accumulate(model):
399 loss, acc, bsz = loss_step(step, batch) 407 loss, acc, bsz = loss_step(step, batch)
400 408
401 accelerator.backward(loss) 409 accelerator.backward(loss)
402 410
403 on_before_optimize(epoch) 411 callbacks.on_before_optimize(epoch)
404 412
405 optimizer.step() 413 optimizer.step()
406 lr_scheduler.step() 414 lr_scheduler.step()
@@ -411,7 +419,7 @@ def train_loop(
411 419
412 # Checks if the accelerator has performed an optimization step behind the scenes 420 # Checks if the accelerator has performed an optimization step behind the scenes
413 if accelerator.sync_gradients: 421 if accelerator.sync_gradients:
414 on_after_optimize(lr_scheduler.get_last_lr()[0]) 422 callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0])
415 423
416 local_progress_bar.update(1) 424 local_progress_bar.update(1)
417 global_progress_bar.update(1) 425 global_progress_bar.update(1)
@@ -425,7 +433,7 @@ def train_loop(
425 "train/cur_acc": acc.item(), 433 "train/cur_acc": acc.item(),
426 "lr": lr_scheduler.get_last_lr()[0], 434 "lr": lr_scheduler.get_last_lr()[0],
427 } 435 }
428 logs.update(on_log()) 436 logs.update(callbacks.on_log())
429 437
430 accelerator.log(logs, step=global_step) 438 accelerator.log(logs, step=global_step)
431 439
@@ -441,7 +449,7 @@ def train_loop(
441 cur_loss_val = AverageMeter() 449 cur_loss_val = AverageMeter()
442 cur_acc_val = AverageMeter() 450 cur_acc_val = AverageMeter()
443 451
444 with torch.inference_mode(), on_eval(): 452 with torch.inference_mode(), callbacks.on_eval():
445 for step, batch in enumerate(val_dataloader): 453 for step, batch in enumerate(val_dataloader):
446 loss, acc, bsz = loss_step(step, batch, True) 454 loss, acc, bsz = loss_step(step, batch, True)
447 455
@@ -477,20 +485,20 @@ def train_loop(
477 if avg_acc_val.avg.item() > max_acc_val: 485 if avg_acc_val.avg.item() > max_acc_val:
478 accelerator.print( 486 accelerator.print(
479 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 487 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
480 on_checkpoint(global_step + global_step_offset, "milestone") 488 callbacks.on_checkpoint(global_step + global_step_offset, "milestone")
481 max_acc_val = avg_acc_val.avg.item() 489 max_acc_val = avg_acc_val.avg.item()
482 490
483 # Create the pipeline using using the trained modules and save it. 491 # Create the pipeline using using the trained modules and save it.
484 if accelerator.is_main_process: 492 if accelerator.is_main_process:
485 print("Finished!") 493 print("Finished!")
486 on_checkpoint(global_step + global_step_offset, "end") 494 callbacks.on_checkpoint(global_step + global_step_offset, "end")
487 on_sample(global_step + global_step_offset) 495 callbacks.on_sample(global_step + global_step_offset)
488 accelerator.end_training() 496 accelerator.end_training()
489 497
490 except KeyboardInterrupt: 498 except KeyboardInterrupt:
491 if accelerator.is_main_process: 499 if accelerator.is_main_process:
492 print("Interrupted") 500 print("Interrupted")
493 on_checkpoint(global_step + global_step_offset, "end") 501 callbacks.on_checkpoint(global_step + global_step_offset, "end")
494 accelerator.end_training() 502 accelerator.end_training()
495 503
496 504
@@ -511,14 +519,7 @@ def train(
511 checkpoint_frequency: int = 50, 519 checkpoint_frequency: int = 50,
512 global_step_offset: int = 0, 520 global_step_offset: int = 0,
513 prior_loss_weight: float = 0, 521 prior_loss_weight: float = 0,
514 on_prepare: Callable[[], dict[str, Any]] = const({}), 522 callbacks: TrainingCallbacks = TrainingCallbacks(),
515 on_log: Callable[[], dict[str, Any]] = const({}),
516 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()),
517 on_before_optimize: Callable[[int], None] = const(),
518 on_after_optimize: Callable[[float], None] = const(),
519 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()),
520 on_sample: Callable[[int], None] = const(),
521 on_checkpoint: Callable[[int, str], None] = const(),
522): 523):
523 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 524 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
524 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 525 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
@@ -530,7 +531,7 @@ def train(
530 model.requires_grad_(False) 531 model.requires_grad_(False)
531 model.eval() 532 model.eval()
532 533
533 on_prepare() 534 callbacks.on_prepare()
534 535
535 loss_step_ = partial( 536 loss_step_ = partial(
536 loss_step, 537 loss_step,
@@ -557,13 +558,7 @@ def train(
557 checkpoint_frequency=checkpoint_frequency, 558 checkpoint_frequency=checkpoint_frequency,
558 global_step_offset=global_step_offset, 559 global_step_offset=global_step_offset,
559 num_epochs=num_train_epochs, 560 num_epochs=num_train_epochs,
560 on_log=on_log, 561 callbacks=callbacks,
561 on_train=on_train,
562 on_before_optimize=on_before_optimize,
563 on_after_optimize=on_after_optimize,
564 on_eval=on_eval,
565 on_sample=on_sample,
566 on_checkpoint=on_checkpoint,
567 ) 562 )
568 563
569 accelerator.free_memory() 564 accelerator.free_memory()
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 83dc566..6f8384f 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -15,7 +15,7 @@ from slugify import slugify
15 15
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import save_samples 18from training.functional import TrainingCallbacks, save_samples
19 19
20 20
21def textual_inversion_strategy( 21def textual_inversion_strategy(
@@ -153,12 +153,12 @@ def textual_inversion_strategy(
153 with ema_context: 153 with ema_context:
154 save_samples_(step=step) 154 save_samples_(step=step)
155 155
156 return { 156 return TrainingCallbacks(
157 "on_prepare": on_prepare, 157 on_prepare=on_prepare,
158 "on_train": on_train, 158 on_train=on_train,
159 "on_eval": on_eval, 159 on_eval=on_eval,
160 "on_after_optimize": on_after_optimize, 160 on_after_optimize=on_after_optimize,
161 "on_log": on_log, 161 on_log=on_log,
162 "on_checkpoint": on_checkpoint, 162 on_checkpoint=on_checkpoint,
163 "on_sample": on_sample, 163 on_sample=on_sample,
164 } 164 )