summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-19 21:10:58 +0100
committerVolpeon <git@volpeon.ink>2022-12-19 21:10:58 +0100
commit9b808b6ca102cfec0c273626a0bcadf897b7c942 (patch)
tree446311b3c6dca74ac9f9f4e055e2eba5f9cae9e5 /train_dreambooth.py
parentAvoid increased VRAM usage on validation (diff)
downloadtextual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.tar.gz
textual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.tar.bz2
textual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.zip
Improved dataset prompt handling, fixed
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py1133
1 files changed, 1133 insertions, 0 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
new file mode 100644
index 0000000..3eecf9c
--- /dev/null
+++ b/train_dreambooth.py
@@ -0,0 +1,1133 @@
1import argparse
2import itertools
3import math
4import datetime
5import logging
6import json
7from pathlib import Path
8
9import torch
10import torch.nn.functional as F
11import torch.utils.checkpoint
12
13from accelerate import Accelerator
14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18from diffusers.training_utils import EMAModel
19from PIL import Image
20from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer
22from slugify import slugify
23
24from common import load_text_embeddings
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from pipelines.util import set_use_memory_efficient_attention_xformers
27from data.csv import CSVDataModule
28from training.optimization import get_one_cycle_schedule
29from models.clip.prompt import PromptProcessor
30
31logger = get_logger(__name__)
32
33
34torch.backends.cuda.matmul.allow_tf32 = True
35torch.backends.cudnn.benchmark = True
36
37
38def parse_args():
39 parser = argparse.ArgumentParser(
40 description="Simple example of a training script."
41 )
42 parser.add_argument(
43 "--pretrained_model_name_or_path",
44 type=str,
45 default=None,
46 help="Path to pretrained model or model identifier from huggingface.co/models.",
47 )
48 parser.add_argument(
49 "--tokenizer_name",
50 type=str,
51 default=None,
52 help="Pretrained tokenizer name or path if not the same as model_name",
53 )
54 parser.add_argument(
55 "--train_data_file",
56 type=str,
57 default=None,
58 help="A folder containing the training data."
59 )
60 parser.add_argument(
61 "--train_data_template",
62 type=str,
63 default="template",
64 )
65 parser.add_argument(
66 "--instance_identifier",
67 type=str,
68 default=None,
69 help="A token to use as a placeholder for the concept.",
70 )
71 parser.add_argument(
72 "--class_identifier",
73 type=str,
74 default=None,
75 help="A token to use as a placeholder for the concept.",
76 )
77 parser.add_argument(
78 "--placeholder_token",
79 type=str,
80 nargs='*',
81 default=[],
82 help="A token to use as a placeholder for the concept.",
83 )
84 parser.add_argument(
85 "--initializer_token",
86 type=str,
87 nargs='*',
88 default=[],
89 help="A token to use as initializer word."
90 )
91 parser.add_argument(
92 "--train_text_encoder",
93 action="store_true",
94 default=True,
95 help="Whether to train the whole text encoder."
96 )
97 parser.add_argument(
98 "--train_text_encoder_epochs",
99 default=999999,
100 help="Number of epochs the text encoder will be trained."
101 )
102 parser.add_argument(
103 "--tag_dropout",
104 type=float,
105 default=0.1,
106 help="Tag dropout probability.",
107 )
108 parser.add_argument(
109 "--num_class_images",
110 type=int,
111 default=400,
112 help="How many class images to generate."
113 )
114 parser.add_argument(
115 "--repeats",
116 type=int,
117 default=1,
118 help="How many times to repeat the training data."
119 )
120 parser.add_argument(
121 "--output_dir",
122 type=str,
123 default="output/dreambooth",
124 help="The output directory where the model predictions and checkpoints will be written.",
125 )
126 parser.add_argument(
127 "--embeddings_dir",
128 type=str,
129 default=None,
130 help="The embeddings directory where Textual Inversion embeddings are stored.",
131 )
132 parser.add_argument(
133 "--mode",
134 type=str,
135 default=None,
136 help="A mode to filter the dataset.",
137 )
138 parser.add_argument(
139 "--seed",
140 type=int,
141 default=None,
142 help="A seed for reproducible training."
143 )
144 parser.add_argument(
145 "--resolution",
146 type=int,
147 default=768,
148 help=(
149 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
150 " resolution"
151 ),
152 )
153 parser.add_argument(
154 "--center_crop",
155 action="store_true",
156 help="Whether to center crop images before resizing to resolution"
157 )
158 parser.add_argument(
159 "--dataloader_num_workers",
160 type=int,
161 default=0,
162 help=(
163 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
164 " process."
165 ),
166 )
167 parser.add_argument(
168 "--num_train_epochs",
169 type=int,
170 default=100
171 )
172 parser.add_argument(
173 "--max_train_steps",
174 type=int,
175 default=None,
176 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
177 )
178 parser.add_argument(
179 "--gradient_accumulation_steps",
180 type=int,
181 default=1,
182 help="Number of updates steps to accumulate before performing a backward/update pass.",
183 )
184 parser.add_argument(
185 "--gradient_checkpointing",
186 action="store_true",
187 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
188 )
189 parser.add_argument(
190 "--learning_rate_unet",
191 type=float,
192 default=2e-6,
193 help="Initial learning rate (after the potential warmup period) to use.",
194 )
195 parser.add_argument(
196 "--learning_rate_text",
197 type=float,
198 default=2e-6,
199 help="Initial learning rate (after the potential warmup period) to use.",
200 )
201 parser.add_argument(
202 "--scale_lr",
203 action="store_true",
204 default=True,
205 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
206 )
207 parser.add_argument(
208 "--lr_scheduler",
209 type=str,
210 default="one_cycle",
211 help=(
212 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
213 ' "constant", "constant_with_warmup", "one_cycle"]'
214 ),
215 )
216 parser.add_argument(
217 "--lr_warmup_epochs",
218 type=int,
219 default=10,
220 help="Number of steps for the warmup in the lr scheduler."
221 )
222 parser.add_argument(
223 "--lr_cycles",
224 type=int,
225 default=None,
226 help="Number of restart cycles in the lr scheduler (if supported)."
227 )
228 parser.add_argument(
229 "--use_ema",
230 action="store_true",
231 default=True,
232 help="Whether to use EMA model."
233 )
234 parser.add_argument(
235 "--ema_inv_gamma",
236 type=float,
237 default=1.0
238 )
239 parser.add_argument(
240 "--ema_power",
241 type=float,
242 default=6/7
243 )
244 parser.add_argument(
245 "--ema_max_decay",
246 type=float,
247 default=0.9999
248 )
249 parser.add_argument(
250 "--use_8bit_adam",
251 action="store_true",
252 default=True,
253 help="Whether or not to use 8-bit Adam from bitsandbytes."
254 )
255 parser.add_argument(
256 "--adam_beta1",
257 type=float,
258 default=0.9,
259 help="The beta1 parameter for the Adam optimizer."
260 )
261 parser.add_argument(
262 "--adam_beta2",
263 type=float,
264 default=0.999,
265 help="The beta2 parameter for the Adam optimizer."
266 )
267 parser.add_argument(
268 "--adam_weight_decay",
269 type=float,
270 default=1e-2,
271 help="Weight decay to use."
272 )
273 parser.add_argument(
274 "--adam_epsilon",
275 type=float,
276 default=1e-08,
277 help="Epsilon value for the Adam optimizer"
278 )
279 parser.add_argument(
280 "--mixed_precision",
281 type=str,
282 default="no",
283 choices=["no", "fp16", "bf16"],
284 help=(
285 "Whether to use mixed precision. Choose"
286 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
287 "and an Nvidia Ampere GPU."
288 ),
289 )
290 parser.add_argument(
291 "--sample_frequency",
292 type=int,
293 default=1,
294 help="How often to save a checkpoint and sample image",
295 )
296 parser.add_argument(
297 "--sample_image_size",
298 type=int,
299 default=768,
300 help="Size of sample images",
301 )
302 parser.add_argument(
303 "--sample_batches",
304 type=int,
305 default=1,
306 help="Number of sample batches to generate per checkpoint",
307 )
308 parser.add_argument(
309 "--sample_batch_size",
310 type=int,
311 default=1,
312 help="Number of samples to generate per batch",
313 )
314 parser.add_argument(
315 "--valid_set_size",
316 type=int,
317 default=None,
318 help="Number of images in the validation dataset."
319 )
320 parser.add_argument(
321 "--train_batch_size",
322 type=int,
323 default=1,
324 help="Batch size (per device) for the training dataloader."
325 )
326 parser.add_argument(
327 "--sample_steps",
328 type=int,
329 default=15,
330 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
331 )
332 parser.add_argument(
333 "--prior_loss_weight",
334 type=float,
335 default=1.0,
336 help="The weight of prior preservation loss."
337 )
338 parser.add_argument(
339 "--max_grad_norm",
340 default=1.0,
341 type=float,
342 help="Max gradient norm."
343 )
344 parser.add_argument(
345 "--noise_timesteps",
346 type=int,
347 default=1000,
348 )
349 parser.add_argument(
350 "--config",
351 type=str,
352 default=None,
353 help="Path to a JSON configuration file containing arguments for invoking this script."
354 )
355
356 args = parser.parse_args()
357 if args.config is not None:
358 with open(args.config, 'rt') as f:
359 args = parser.parse_args(
360 namespace=argparse.Namespace(**json.load(f)["args"]))
361
362 if args.train_data_file is None:
363 raise ValueError("You must specify --train_data_file")
364
365 if args.pretrained_model_name_or_path is None:
366 raise ValueError("You must specify --pretrained_model_name_or_path")
367
368 if args.instance_identifier is None:
369 raise ValueError("You must specify --instance_identifier")
370
371 if isinstance(args.initializer_token, str):
372 args.initializer_token = [args.initializer_token]
373
374 if isinstance(args.placeholder_token, str):
375 args.placeholder_token = [args.placeholder_token]
376
377 if len(args.placeholder_token) == 0:
378 args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))]
379
380 if len(args.placeholder_token) != len(args.initializer_token):
381 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
382
383 if args.output_dir is None:
384 raise ValueError("You must specify --output_dir")
385
386 return args
387
388
389def save_args(basepath: Path, args, extra={}):
390 info = {"args": vars(args)}
391 info["args"].update(extra)
392 with open(basepath.joinpath("args.json"), "w") as f:
393 json.dump(info, f, indent=4)
394
395
396def freeze_params(params):
397 for param in params:
398 param.requires_grad = False
399
400
401def make_grid(images, rows, cols):
402 w, h = images[0].size
403 grid = Image.new('RGB', size=(cols*w, rows*h))
404 for i, image in enumerate(images):
405 grid.paste(image, box=(i % cols*w, i//cols*h))
406 return grid
407
408
409class AverageMeter:
410 def __init__(self, name=None):
411 self.name = name
412 self.reset()
413
414 def reset(self):
415 self.sum = self.count = self.avg = 0
416
417 def update(self, val, n=1):
418 self.sum += val * n
419 self.count += n
420 self.avg = self.sum / self.count
421
422
423class Checkpointer:
424 def __init__(
425 self,
426 datamodule,
427 accelerator,
428 vae,
429 unet,
430 ema_unet,
431 tokenizer,
432 text_encoder,
433 scheduler,
434 output_dir: Path,
435 instance_identifier,
436 placeholder_token,
437 placeholder_token_id,
438 sample_image_size,
439 sample_batches,
440 sample_batch_size,
441 seed
442 ):
443 self.datamodule = datamodule
444 self.accelerator = accelerator
445 self.vae = vae
446 self.unet = unet
447 self.ema_unet = ema_unet
448 self.tokenizer = tokenizer
449 self.text_encoder = text_encoder
450 self.scheduler = scheduler
451 self.output_dir = output_dir
452 self.instance_identifier = instance_identifier
453 self.placeholder_token = placeholder_token
454 self.placeholder_token_id = placeholder_token_id
455 self.sample_image_size = sample_image_size
456 self.seed = seed or torch.random.seed()
457 self.sample_batches = sample_batches
458 self.sample_batch_size = sample_batch_size
459
460 @torch.no_grad()
461 def save_model(self):
462 print("Saving model...")
463
464 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
465 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
466
467 pipeline = VlpnStableDiffusion(
468 text_encoder=text_encoder,
469 vae=self.vae,
470 unet=unet,
471 tokenizer=self.tokenizer,
472 scheduler=self.scheduler,
473 )
474 pipeline.save_pretrained(self.output_dir.joinpath("model"))
475
476 del unet
477 del text_encoder
478 del pipeline
479
480 if torch.cuda.is_available():
481 torch.cuda.empty_cache()
482
483 @torch.no_grad()
484 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
485 samples_path = Path(self.output_dir).joinpath("samples")
486
487 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
488 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
489
490 pipeline = VlpnStableDiffusion(
491 text_encoder=text_encoder,
492 vae=self.vae,
493 unet=unet,
494 tokenizer=self.tokenizer,
495 scheduler=self.scheduler,
496 ).to(self.accelerator.device)
497 pipeline.set_progress_bar_config(dynamic_ncols=True)
498
499 train_data = self.datamodule.train_dataloader()
500 val_data = self.datamodule.val_dataloader()
501
502 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
503 stable_latents = torch.randn(
504 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
505 device=pipeline.device,
506 generator=generator,
507 )
508
509 with torch.autocast("cuda"), torch.inference_mode():
510 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
511 all_samples = []
512 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
513 file_path.parent.mkdir(parents=True, exist_ok=True)
514
515 data_enum = enumerate(data)
516
517 batches = [
518 batch
519 for j, batch in data_enum
520 if j * data.batch_size < self.sample_batch_size * self.sample_batches
521 ]
522 prompts = [
523 prompt.format(identifier=self.instance_identifier)
524 for batch in batches
525 for prompt in batch["prompts"]
526 ]
527 nprompts = [
528 prompt
529 for batch in batches
530 for prompt in batch["nprompts"]
531 ]
532
533 for i in range(self.sample_batches):
534 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
535 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
536
537 samples = pipeline(
538 prompt=prompt,
539 negative_prompt=nprompt,
540 height=self.sample_image_size,
541 width=self.sample_image_size,
542 image=latents[:len(prompt)] if latents is not None else None,
543 generator=generator if latents is not None else None,
544 guidance_scale=guidance_scale,
545 eta=eta,
546 num_inference_steps=num_inference_steps,
547 output_type='pil'
548 ).images
549
550 all_samples += samples
551
552 del samples
553
554 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
555 image_grid.save(file_path, quality=85)
556
557 del all_samples
558 del image_grid
559
560 del unet
561 del text_encoder
562 del pipeline
563 del generator
564 del stable_latents
565
566 if torch.cuda.is_available():
567 torch.cuda.empty_cache()
568
569
570def main():
571 args = parse_args()
572
573 if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
574 raise ValueError(
575 "Gradient accumulation is not supported when training the text encoder in distributed training. "
576 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
577 )
578
579 instance_identifier = args.instance_identifier
580
581 if len(args.placeholder_token) != 0:
582 instance_identifier = instance_identifier.format(args.placeholder_token[0])
583
584 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
585 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
586 basepath.mkdir(parents=True, exist_ok=True)
587
588 accelerator = Accelerator(
589 log_with=LoggerType.TENSORBOARD,
590 logging_dir=f"{basepath}",
591 gradient_accumulation_steps=args.gradient_accumulation_steps,
592 mixed_precision=args.mixed_precision
593 )
594
595 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
596
597 args.seed = args.seed or (torch.random.seed() >> 32)
598 set_seed(args.seed)
599
600 save_args(basepath, args)
601
602 # Load the tokenizer and add the placeholder token as a additional special token
603 if args.tokenizer_name:
604 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
605 elif args.pretrained_model_name_or_path:
606 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
607
608 # Load models and create wrapper for stable diffusion
609 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
610 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
611 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
612 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
613 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
614 args.pretrained_model_name_or_path, subfolder='scheduler')
615
616 vae.enable_slicing()
617 set_use_memory_efficient_attention_xformers(unet, True)
618 set_use_memory_efficient_attention_xformers(vae, True)
619
620 if args.gradient_checkpointing:
621 unet.enable_gradient_checkpointing()
622 text_encoder.gradient_checkpointing_enable()
623
624 ema_unet = None
625 if args.use_ema:
626 ema_unet = EMAModel(
627 unet,
628 inv_gamma=args.ema_inv_gamma,
629 power=args.ema_power,
630 max_value=args.ema_max_decay,
631 device=accelerator.device
632 )
633
634 # Freeze text_encoder and vae
635 vae.requires_grad_(False)
636
637 if args.embeddings_dir is not None:
638 embeddings_dir = Path(args.embeddings_dir)
639 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
640 raise ValueError("--embeddings_dir must point to an existing directory")
641 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
642 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
643
644 if len(args.placeholder_token) != 0:
645 # Convert the initializer_token, placeholder_token to ids
646 initializer_token_ids = torch.stack([
647 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
648 for token in args.initializer_token
649 ])
650
651 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
652 print(f"Added {num_added_tokens} new tokens.")
653
654 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
655
656 # Resize the token embeddings as we are adding new special tokens to the tokenizer
657 text_encoder.resize_token_embeddings(len(tokenizer))
658
659 token_embeds = text_encoder.get_input_embeddings().weight.data
660 original_token_embeds = token_embeds.clone().to(accelerator.device)
661 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
662
663 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
664 token_embeds[token_id] = embeddings
665 else:
666 placeholder_token_id = []
667
668 if args.train_text_encoder:
669 print(f"Training entire text encoder.")
670 else:
671 print(f"Training added text embeddings")
672
673 freeze_params(itertools.chain(
674 text_encoder.text_model.encoder.parameters(),
675 text_encoder.text_model.final_layer_norm.parameters(),
676 text_encoder.text_model.embeddings.position_embedding.parameters(),
677 ))
678
679 index_fixed_tokens = torch.arange(len(tokenizer))
680 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
681
682 prompt_processor = PromptProcessor(tokenizer, text_encoder)
683
684 if args.scale_lr:
685 args.learning_rate_unet = (
686 args.learning_rate_unet * args.gradient_accumulation_steps *
687 args.train_batch_size * accelerator.num_processes
688 )
689 args.learning_rate_text = (
690 args.learning_rate_text * args.gradient_accumulation_steps *
691 args.train_batch_size * accelerator.num_processes
692 )
693
694 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
695 if args.use_8bit_adam:
696 try:
697 import bitsandbytes as bnb
698 except ImportError:
699 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
700
701 optimizer_class = bnb.optim.AdamW8bit
702 else:
703 optimizer_class = torch.optim.AdamW
704
705 if args.train_text_encoder:
706 text_encoder_params_to_optimize = text_encoder.parameters()
707 else:
708 text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters()
709
710 # Initialize the optimizer
711 optimizer = optimizer_class(
712 [
713 {
714 'params': unet.parameters(),
715 'lr': args.learning_rate_unet,
716 },
717 {
718 'params': text_encoder_params_to_optimize,
719 'lr': args.learning_rate_text,
720 }
721 ],
722 betas=(args.adam_beta1, args.adam_beta2),
723 weight_decay=args.adam_weight_decay,
724 eps=args.adam_epsilon,
725 )
726
727 weight_dtype = torch.float32
728 if args.mixed_precision == "fp16":
729 weight_dtype = torch.float16
730 elif args.mixed_precision == "bf16":
731 weight_dtype = torch.bfloat16
732
733 def collate_fn(examples):
734 prompts = [example["prompts"] for example in examples]
735 nprompts = [example["nprompts"] for example in examples]
736 input_ids = [example["instance_prompt_ids"] for example in examples]
737 pixel_values = [example["instance_images"] for example in examples]
738
739 # concat class and instance examples for prior preservation
740 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
741 input_ids += [example["class_prompt_ids"] for example in examples]
742 pixel_values += [example["class_images"] for example in examples]
743
744 pixel_values = torch.stack(pixel_values)
745 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
746
747 inputs = prompt_processor.unify_input_ids(input_ids)
748
749 batch = {
750 "prompts": prompts,
751 "nprompts": nprompts,
752 "input_ids": inputs.input_ids,
753 "pixel_values": pixel_values,
754 "attention_mask": inputs.attention_mask,
755 }
756 return batch
757
758 datamodule = CSVDataModule(
759 data_file=args.train_data_file,
760 batch_size=args.train_batch_size,
761 prompt_processor=prompt_processor,
762 instance_identifier=instance_identifier,
763 class_identifier=args.class_identifier,
764 class_subdir="cls",
765 num_class_images=args.num_class_images,
766 size=args.resolution,
767 repeats=args.repeats,
768 mode=args.mode,
769 dropout=args.tag_dropout,
770 center_crop=args.center_crop,
771 template_key=args.train_data_template,
772 valid_set_size=args.valid_set_size,
773 num_workers=args.dataloader_num_workers,
774 collate_fn=collate_fn
775 )
776
777 datamodule.prepare_data()
778 datamodule.setup()
779
780 if args.num_class_images != 0:
781 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
782
783 if len(missing_data) != 0:
784 batched_data = [
785 missing_data[i:i+args.sample_batch_size]
786 for i in range(0, len(missing_data), args.sample_batch_size)
787 ]
788
789 pipeline = VlpnStableDiffusion(
790 text_encoder=text_encoder,
791 vae=vae,
792 unet=unet,
793 tokenizer=tokenizer,
794 scheduler=checkpoint_scheduler,
795 ).to(accelerator.device)
796 pipeline.set_progress_bar_config(dynamic_ncols=True)
797
798 with torch.autocast("cuda"), torch.inference_mode():
799 for batch in batched_data:
800 image_name = [item.class_image_path for item in batch]
801 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
802 nprompt = [item.nprompt for item in batch]
803
804 images = pipeline(
805 prompt=prompt,
806 negative_prompt=nprompt,
807 num_inference_steps=args.sample_steps
808 ).images
809
810 for i, image in enumerate(images):
811 image.save(image_name[i])
812
813 del pipeline
814
815 if torch.cuda.is_available():
816 torch.cuda.empty_cache()
817
818 train_dataloader = datamodule.train_dataloader()
819 val_dataloader = datamodule.val_dataloader()
820
821 # Scheduler and math around the number of training steps.
822 overrode_max_train_steps = False
823 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
824 if args.max_train_steps is None:
825 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
826 overrode_max_train_steps = True
827 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
828
829 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
830
831 if args.lr_scheduler == "one_cycle":
832 lr_scheduler = get_one_cycle_schedule(
833 optimizer=optimizer,
834 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
835 )
836 elif args.lr_scheduler == "cosine_with_restarts":
837 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
838 optimizer=optimizer,
839 num_warmup_steps=warmup_steps,
840 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
841 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
842 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
843 )
844 else:
845 lr_scheduler = get_scheduler(
846 args.lr_scheduler,
847 optimizer=optimizer,
848 num_warmup_steps=warmup_steps,
849 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
850 )
851
852 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
853 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
854 )
855
856 # Move text_encoder and vae to device
857 vae.to(accelerator.device, dtype=weight_dtype)
858
859 # Keep text_encoder and vae in eval mode as we don't train these
860 vae.eval()
861
862 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
863 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
864 if overrode_max_train_steps:
865 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
866
867 num_val_steps_per_epoch = len(val_dataloader)
868 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
869 val_steps = num_val_steps_per_epoch * num_epochs
870
871 # We need to initialize the trackers we use, and also store our configuration.
872 # The trackers initializes automatically on the main process.
873 if accelerator.is_main_process:
874 config = vars(args).copy()
875 config["initializer_token"] = " ".join(config["initializer_token"])
876 config["placeholder_token"] = " ".join(config["placeholder_token"])
877 accelerator.init_trackers("dreambooth", config=config)
878
879 # Train!
880 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
881
882 logger.info("***** Running training *****")
883 logger.info(f" Num Epochs = {num_epochs}")
884 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
885 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
886 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
887 logger.info(f" Total optimization steps = {args.max_train_steps}")
888 # Only show the progress bar once on each machine.
889
890 global_step = 0
891
892 avg_loss = AverageMeter()
893 avg_acc = AverageMeter()
894
895 avg_loss_val = AverageMeter()
896 avg_acc_val = AverageMeter()
897
898 max_acc_val = 0.0
899
900 checkpointer = Checkpointer(
901 datamodule=datamodule,
902 accelerator=accelerator,
903 vae=vae,
904 unet=unet,
905 ema_unet=ema_unet,
906 tokenizer=tokenizer,
907 text_encoder=text_encoder,
908 scheduler=checkpoint_scheduler,
909 output_dir=basepath,
910 instance_identifier=instance_identifier,
911 placeholder_token=args.placeholder_token,
912 placeholder_token_id=placeholder_token_id,
913 sample_image_size=args.sample_image_size,
914 sample_batch_size=args.sample_batch_size,
915 sample_batches=args.sample_batches,
916 seed=args.seed
917 )
918
919 if accelerator.is_main_process:
920 checkpointer.save_samples(0, args.sample_steps)
921
922 local_progress_bar = tqdm(
923 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
924 disable=not accelerator.is_local_main_process,
925 dynamic_ncols=True
926 )
927 local_progress_bar.set_description("Epoch X / Y")
928
929 global_progress_bar = tqdm(
930 range(args.max_train_steps + val_steps),
931 disable=not accelerator.is_local_main_process,
932 dynamic_ncols=True
933 )
934 global_progress_bar.set_description("Total progress")
935
936 try:
937 for epoch in range(num_epochs):
938 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
939 local_progress_bar.reset()
940
941 unet.train()
942
943 if epoch < args.train_text_encoder_epochs:
944 text_encoder.train()
945 elif epoch == args.train_text_encoder_epochs:
946 freeze_params(text_encoder.parameters())
947
948 sample_checkpoint = False
949
950 for step, batch in enumerate(train_dataloader):
951 with accelerator.accumulate(unet):
952 # Convert images to latent space
953 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
954 latents = latents * 0.18215
955
956 # Sample noise that we'll add to the latents
957 noise = torch.randn_like(latents)
958 bsz = latents.shape[0]
959 # Sample a random timestep for each image
960 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
961 (bsz,), device=latents.device)
962 timesteps = timesteps.long()
963
964 # Add noise to the latents according to the noise magnitude at each timestep
965 # (this is the forward diffusion process)
966 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
967
968 # Get the text embedding for conditioning
969 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
970
971 # Predict the noise residual
972 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
973
974 # Get the target for loss depending on the prediction type
975 if noise_scheduler.config.prediction_type == "epsilon":
976 target = noise
977 elif noise_scheduler.config.prediction_type == "v_prediction":
978 target = noise_scheduler.get_velocity(latents, noise, timesteps)
979 else:
980 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
981
982 if args.num_class_images != 0:
983 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
984 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
985 target, target_prior = torch.chunk(target, 2, dim=0)
986
987 # Compute instance loss
988 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
989
990 # Compute prior loss
991 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
992
993 # Add the prior loss to the instance loss.
994 loss = loss + args.prior_loss_weight * prior_loss
995 else:
996 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
997
998 acc = (model_pred == latents).float().mean()
999
1000 accelerator.backward(loss)
1001
1002 if accelerator.sync_gradients:
1003 params_to_clip = (
1004 itertools.chain(unet.parameters(), text_encoder.parameters())
1005 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
1006 else unet.parameters()
1007 )
1008 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1009
1010 optimizer.step()
1011 if not accelerator.optimizer_step_was_skipped:
1012 lr_scheduler.step()
1013 if args.use_ema:
1014 ema_unet.step(unet)
1015 optimizer.zero_grad(set_to_none=True)
1016
1017 if not args.train_text_encoder:
1018 # Let's make sure we don't update any embedding weights besides the newly added token
1019 with torch.no_grad():
1020 text_encoder.get_input_embeddings(
1021 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
1022
1023 avg_loss.update(loss.detach_(), bsz)
1024 avg_acc.update(acc.detach_(), bsz)
1025
1026 # Checks if the accelerator has performed an optimization step behind the scenes
1027 if accelerator.sync_gradients:
1028 local_progress_bar.update(1)
1029 global_progress_bar.update(1)
1030
1031 global_step += 1
1032
1033 logs = {
1034 "train/loss": avg_loss.avg.item(),
1035 "train/acc": avg_acc.avg.item(),
1036 "train/cur_loss": loss.item(),
1037 "train/cur_acc": acc.item(),
1038 "lr/unet": lr_scheduler.get_last_lr()[0],
1039 "lr/text": lr_scheduler.get_last_lr()[1]
1040 }
1041 if args.use_ema:
1042 logs["ema_decay"] = 1 - ema_unet.decay
1043
1044 accelerator.log(logs, step=global_step)
1045
1046 local_progress_bar.set_postfix(**logs)
1047
1048 if global_step >= args.max_train_steps:
1049 break
1050
1051 accelerator.wait_for_everyone()
1052
1053 unet.eval()
1054 text_encoder.eval()
1055
1056 with torch.inference_mode():
1057 for step, batch in enumerate(val_dataloader):
1058 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
1059 latents = latents * 0.18215
1060
1061 noise = torch.randn_like(latents)
1062 bsz = latents.shape[0]
1063 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
1064 (bsz,), device=latents.device)
1065 timesteps = timesteps.long()
1066
1067 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1068
1069 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
1070
1071 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1072
1073 # Get the target for loss depending on the prediction type
1074 if noise_scheduler.config.prediction_type == "epsilon":
1075 target = noise
1076 elif noise_scheduler.config.prediction_type == "v_prediction":
1077 target = noise_scheduler.get_velocity(latents, noise, timesteps)
1078 else:
1079 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1080
1081 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1082
1083 acc = (model_pred == latents).float().mean()
1084
1085 avg_loss_val.update(loss.detach_(), bsz)
1086 avg_acc_val.update(acc.detach_(), bsz)
1087
1088 if accelerator.sync_gradients:
1089 local_progress_bar.update(1)
1090 global_progress_bar.update(1)
1091
1092 logs = {
1093 "val/loss": avg_loss_val.avg.item(),
1094 "val/acc": avg_acc_val.avg.item(),
1095 "val/cur_loss": loss.item(),
1096 "val/cur_acc": acc.item(),
1097 }
1098 local_progress_bar.set_postfix(**logs)
1099
1100 accelerator.log({
1101 "val/loss": avg_loss_val.avg.item(),
1102 "val/acc": avg_acc_val.avg.item(),
1103 }, step=global_step)
1104
1105 local_progress_bar.clear()
1106 global_progress_bar.clear()
1107
1108 if avg_acc_val.avg.item() > max_acc_val:
1109 accelerator.print(
1110 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1111 max_acc_val = avg_acc_val.avg.item()
1112
1113 if accelerator.is_main_process:
1114 if (epoch + 1) % args.sample_frequency == 0:
1115 checkpointer.save_samples(global_step, args.sample_steps)
1116
1117 # Create the pipeline using using the trained modules and save it.
1118 if accelerator.is_main_process:
1119 print("Finished! Saving final checkpoint and resume state.")
1120 checkpointer.save_model()
1121
1122 accelerator.end_training()
1123
1124 except KeyboardInterrupt:
1125 if accelerator.is_main_process:
1126 print("Interrupted, saving checkpoint and resume state...")
1127 checkpointer.save_model()
1128 accelerator.end_training()
1129 quit()
1130
1131
1132if __name__ == "__main__":
1133 main()