summaryrefslogtreecommitdiffstats
path: root/textual_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_dreambooth.py')
-rw-r--r--textual_dreambooth.py948
1 files changed, 948 insertions, 0 deletions
diff --git a/textual_dreambooth.py b/textual_dreambooth.py
new file mode 100644
index 0000000..a46953d
--- /dev/null
+++ b/textual_dreambooth.py
@@ -0,0 +1,948 @@
1import argparse
2import itertools
3import math
4import os
5import datetime
6import logging
7from pathlib import Path
8
9import numpy as np
10import torch
11import torch.nn.functional as F
12import torch.utils.checkpoint
13
14from accelerate import Accelerator
15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler
20from PIL import Image
21from tqdm.auto import tqdm
22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json
26import os
27
28from data.dreambooth.csv import CSVDataModule
29
30logger = get_logger(__name__)
31
32
33torch.backends.cuda.matmul.allow_tf32 = True
34
35
36def parse_args():
37 parser = argparse.ArgumentParser(
38 description="Simple example of a training script."
39 )
40 parser.add_argument(
41 "--pretrained_model_name_or_path",
42 type=str,
43 default=None,
44 help="Path to pretrained model or model identifier from huggingface.co/models.",
45 )
46 parser.add_argument(
47 "--tokenizer_name",
48 type=str,
49 default=None,
50 help="Pretrained tokenizer name or path if not the same as model_name",
51 )
52 parser.add_argument(
53 "--train_data_file",
54 type=str,
55 default=None,
56 help="A CSV file containing the training data."
57 )
58 parser.add_argument(
59 "--placeholder_token",
60 type=str,
61 default=None,
62 help="A token to use as a placeholder for the concept.",
63 )
64 parser.add_argument(
65 "--initializer_token",
66 type=str,
67 default=None,
68 help="A token to use as initializer word."
69 )
70 parser.add_argument(
71 "--num_vec_per_token",
72 type=int,
73 default=1,
74 help=(
75 "The number of vectors used to represent the placeholder token. The higher the number, the better the"
76 " result at the cost of editability. This can be fixed by prompt editing."
77 ),
78 )
79 parser.add_argument(
80 "--initialize_rest_random",
81 action="store_true",
82 help="Initialize rest of the placeholder tokens with random."
83 )
84 parser.add_argument(
85 "--use_class_images",
86 action="store_true",
87 default=True,
88 help="Include class images in the loss calculation a la Dreambooth.",
89 )
90 parser.add_argument(
91 "--repeats",
92 type=int,
93 default=100,
94 help="How many times to repeat the training data.")
95 parser.add_argument(
96 "--output_dir",
97 type=str,
98 default="output/text-inversion",
99 help="The output directory where the model predictions and checkpoints will be written.",
100 )
101 parser.add_argument(
102 "--seed",
103 type=int,
104 default=None,
105 help="A seed for reproducible training.")
106 parser.add_argument(
107 "--resolution",
108 type=int,
109 default=512,
110 help=(
111 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
112 " resolution"
113 ),
114 )
115 parser.add_argument(
116 "--center_crop",
117 action="store_true",
118 help="Whether to center crop images before resizing to resolution"
119 )
120 parser.add_argument(
121 "--num_train_epochs",
122 type=int,
123 default=100)
124 parser.add_argument(
125 "--max_train_steps",
126 type=int,
127 default=5000,
128 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
129 )
130 parser.add_argument(
131 "--gradient_accumulation_steps",
132 type=int,
133 default=1,
134 help="Number of updates steps to accumulate before performing a backward/update pass.",
135 )
136 parser.add_argument(
137 "--gradient_checkpointing",
138 action="store_true",
139 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
140 )
141 parser.add_argument(
142 "--learning_rate",
143 type=float,
144 default=1e-4,
145 help="Initial learning rate (after the potential warmup period) to use.",
146 )
147 parser.add_argument(
148 "--scale_lr",
149 action="store_true",
150 default=True,
151 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
152 )
153 parser.add_argument(
154 "--lr_scheduler",
155 type=str,
156 default="constant",
157 help=(
158 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
159 ' "constant", "constant_with_warmup"]'
160 ),
161 )
162 parser.add_argument(
163 "--lr_warmup_steps",
164 type=int,
165 default=500,
166 help="Number of steps for the warmup in the lr scheduler."
167 )
168 parser.add_argument(
169 "--use_8bit_adam",
170 action="store_true",
171 help="Whether or not to use 8-bit Adam from bitsandbytes."
172 )
173 parser.add_argument(
174 "--adam_beta1",
175 type=float,
176 default=0.9,
177 help="The beta1 parameter for the Adam optimizer."
178 )
179 parser.add_argument(
180 "--adam_beta2",
181 type=float,
182 default=0.999,
183 help="The beta2 parameter for the Adam optimizer."
184 )
185 parser.add_argument(
186 "--adam_weight_decay",
187 type=float,
188 default=1e-2,
189 help="Weight decay to use."
190 )
191 parser.add_argument(
192 "--adam_epsilon",
193 type=float,
194 default=1e-08,
195 help="Epsilon value for the Adam optimizer"
196 )
197 parser.add_argument(
198 "--mixed_precision",
199 type=str,
200 default="no",
201 choices=["no", "fp16", "bf16"],
202 help=(
203 "Whether to use mixed precision. Choose"
204 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
205 "and an Nvidia Ampere GPU."
206 ),
207 )
208 parser.add_argument(
209 "--local_rank",
210 type=int,
211 default=-1,
212 help="For distributed training: local_rank"
213 )
214 parser.add_argument(
215 "--checkpoint_frequency",
216 type=int,
217 default=500,
218 help="How often to save a checkpoint and sample image",
219 )
220 parser.add_argument(
221 "--sample_image_size",
222 type=int,
223 default=512,
224 help="Size of sample images",
225 )
226 parser.add_argument(
227 "--sample_batches",
228 type=int,
229 default=1,
230 help="Number of sample batches to generate per checkpoint",
231 )
232 parser.add_argument(
233 "--sample_batch_size",
234 type=int,
235 default=1,
236 help="Number of samples to generate per batch",
237 )
238 parser.add_argument(
239 "--train_batch_size",
240 type=int,
241 default=1,
242 help="Batch size (per device) for the training dataloader."
243 )
244 parser.add_argument(
245 "--sample_steps",
246 type=int,
247 default=30,
248 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
249 )
250 parser.add_argument(
251 "--prior_loss_weight",
252 type=float,
253 default=1.0,
254 help="The weight of prior preservation loss."
255 )
256 parser.add_argument(
257 "--resume_from",
258 type=str,
259 default=None,
260 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
261 )
262 parser.add_argument(
263 "--resume_checkpoint",
264 type=str,
265 default=None,
266 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
267 )
268 parser.add_argument(
269 "--config",
270 type=str,
271 default=None,
272 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
273 )
274
275 args = parser.parse_args()
276 if args.resume_from is not None:
277 with open(f"{args.resume_from}/resume.json", 'rt') as f:
278 args = parser.parse_args(
279 namespace=argparse.Namespace(**json.load(f)["args"]))
280 elif args.config is not None:
281 with open(args.config, 'rt') as f:
282 args = parser.parse_args(
283 namespace=argparse.Namespace(**json.load(f)["args"]))
284
285 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
286 if env_local_rank != -1 and env_local_rank != args.local_rank:
287 args.local_rank = env_local_rank
288
289 if args.train_data_file is None:
290 raise ValueError("You must specify --train_data_file")
291
292 if args.pretrained_model_name_or_path is None:
293 raise ValueError("You must specify --pretrained_model_name_or_path")
294
295 if args.placeholder_token is None:
296 raise ValueError("You must specify --placeholder_token")
297
298 if args.initializer_token is None:
299 raise ValueError("You must specify --initializer_token")
300
301 if args.output_dir is None:
302 raise ValueError("You must specify --output_dir")
303
304 return args
305
306
307def freeze_params(params):
308 for param in params:
309 param.requires_grad = False
310
311
312def save_resume_file(basepath, args, extra={}):
313 info = {"args": vars(args)}
314 info["args"].update(extra)
315 with open(f"{basepath}/resume.json", "w") as f:
316 json.dump(info, f, indent=4)
317
318
319def make_grid(images, rows, cols):
320 w, h = images[0].size
321 grid = Image.new('RGB', size=(cols*w, rows*h))
322 for i, image in enumerate(images):
323 grid.paste(image, box=(i % cols*w, i//cols*h))
324 return grid
325
326
327def add_tokens_and_get_placeholder_token(args, token_ids, tokenizer, text_encoder):
328 assert args.num_vec_per_token >= len(token_ids)
329 placeholder_tokens = [f"{args.placeholder_token}_{i}" for i in range(args.num_vec_per_token)]
330
331 for placeholder_token in placeholder_tokens:
332 num_added_tokens = tokenizer.add_tokens(placeholder_token)
333 if num_added_tokens == 0:
334 raise ValueError(
335 f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
336 " `placeholder_token` that is not already in the tokenizer."
337 )
338
339 placeholder_token = " ".join(placeholder_tokens)
340 placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
341
342 print(f"The placeholder tokens are {placeholder_token} while the ids are {placeholder_token_ids}")
343
344 text_encoder.resize_token_embeddings(len(tokenizer))
345 token_embeds = text_encoder.get_input_embeddings().weight.data
346
347 if args.initialize_rest_random:
348 # The idea is that the placeholder tokens form adjectives as in x x x white dog.
349 for i, placeholder_token_id in enumerate(placeholder_token_ids):
350 if len(placeholder_token_ids) - i < len(token_ids):
351 token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]]
352 else:
353 token_embeds[placeholder_token_id] = torch.rand_like(token_embeds[placeholder_token_id])
354 else:
355 for i, placeholder_token_id in enumerate(placeholder_token_ids):
356 token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]]
357
358 return placeholder_token, placeholder_token_ids
359
360
361class Checkpointer:
362 def __init__(
363 self,
364 datamodule,
365 accelerator,
366 vae,
367 unet,
368 tokenizer,
369 placeholder_token,
370 placeholder_token_ids,
371 output_dir,
372 sample_image_size,
373 sample_batches,
374 sample_batch_size,
375 seed
376 ):
377 self.datamodule = datamodule
378 self.accelerator = accelerator
379 self.vae = vae
380 self.unet = unet
381 self.tokenizer = tokenizer
382 self.placeholder_token = placeholder_token
383 self.placeholder_token_ids = placeholder_token_ids
384 self.output_dir = output_dir
385 self.sample_image_size = sample_image_size
386 self.seed = seed or torch.random.seed()
387 self.sample_batches = sample_batches
388 self.sample_batch_size = sample_batch_size
389
390 @torch.no_grad()
391 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
392 print("Saving checkpoint for step %d..." % step)
393
394 if path is None:
395 checkpoints_path = f"{self.output_dir}/checkpoints"
396 os.makedirs(checkpoints_path, exist_ok=True)
397
398 unwrapped = self.accelerator.unwrap_model(text_encoder)
399
400 # Save a checkpoint
401 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_ids]
402 learned_embeds_dict = {}
403 for i, placeholder_token in enumerate(self.placeholder_token.split(" ")):
404 learned_embeds_dict[placeholder_token] = learned_embeds[i].detach().cpu()
405
406 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
407 if path is not None:
408 torch.save(learned_embeds_dict, path)
409 else:
410 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
411 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
412
413 del unwrapped
414 del learned_embeds
415
416 @torch.no_grad()
417 def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
418 samples_path = Path(self.output_dir).joinpath("samples")
419
420 unwrapped = self.accelerator.unwrap_model(text_encoder)
421 scheduler = EulerAScheduler(
422 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
423 )
424
425 # Save a sample image
426 pipeline = VlpnStableDiffusion(
427 text_encoder=unwrapped,
428 vae=self.vae,
429 unet=self.unet,
430 tokenizer=self.tokenizer,
431 scheduler=scheduler,
432 ).to(self.accelerator.device)
433 pipeline.enable_attention_slicing()
434
435 train_data = self.datamodule.train_dataloader()
436 val_data = self.datamodule.val_dataloader()
437
438 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
439 stable_latents = torch.randn(
440 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
441 device=pipeline.device,
442 generator=generator,
443 )
444
445 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
446 all_samples = []
447 file_path = samples_path.joinpath(pool, f"step_{step}.png")
448 file_path.parent.mkdir(parents=True, exist_ok=True)
449
450 data_enum = enumerate(data)
451
452 for i in range(self.sample_batches):
453 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
454 prompt = [prompt.format(self.placeholder_token)
455 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
456 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
457
458 with self.accelerator.autocast():
459 samples = pipeline(
460 prompt=prompt,
461 negative_prompt=nprompt,
462 height=self.sample_image_size,
463 width=self.sample_image_size,
464 latents=latents[:len(prompt)] if latents is not None else None,
465 generator=generator if latents is not None else None,
466 guidance_scale=guidance_scale,
467 eta=eta,
468 num_inference_steps=num_inference_steps,
469 output_type='pil'
470 )["sample"]
471
472 all_samples += samples
473
474 del samples
475
476 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
477 image_grid.save(file_path)
478
479 del all_samples
480 del image_grid
481
482 del unwrapped
483 del scheduler
484 del pipeline
485 del generator
486 del stable_latents
487
488 if torch.cuda.is_available():
489 torch.cuda.empty_cache()
490
491
492def main():
493 args = parse_args()
494
495 global_step_offset = 0
496 if args.resume_from is not None:
497 basepath = Path(args.resume_from)
498 print("Resuming state from %s" % args.resume_from)
499 with open(basepath.joinpath("resume.json"), 'r') as f:
500 state = json.load(f)
501 global_step_offset = state["args"].get("global_step", 0)
502
503 print("We've trained %d steps so far" % global_step_offset)
504 else:
505 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
506 basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now)
507 basepath.mkdir(parents=True, exist_ok=True)
508
509 accelerator = Accelerator(
510 log_with=LoggerType.TENSORBOARD,
511 logging_dir=f"{basepath}",
512 gradient_accumulation_steps=args.gradient_accumulation_steps,
513 mixed_precision=args.mixed_precision
514 )
515
516 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
517
518 # If passed along, set the training seed now.
519 if args.seed is not None:
520 set_seed(args.seed)
521
522 # Load the tokenizer and add the placeholder token as a additional special token
523 if args.tokenizer_name:
524 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
525 elif args.pretrained_model_name_or_path:
526 tokenizer = CLIPTokenizer.from_pretrained(
527 args.pretrained_model_name_or_path + '/tokenizer'
528 )
529
530 # Load models and create wrapper for stable diffusion
531 text_encoder = CLIPTextModel.from_pretrained(
532 args.pretrained_model_name_or_path + '/text_encoder',
533 )
534 vae = AutoencoderKL.from_pretrained(
535 args.pretrained_model_name_or_path + '/vae',
536 )
537 unet = UNet2DConditionModel.from_pretrained(
538 args.pretrained_model_name_or_path + '/unet',
539 )
540
541 if args.gradient_checkpointing:
542 unet.enable_gradient_checkpointing()
543
544 slice_size = unet.config.attention_head_dim // 2
545 unet.set_attention_slice(slice_size)
546
547 token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
548 # regardless of whether the number of token_ids is 1 or more, it'll set one and then keep repeating.
549 placeholder_token, placeholder_token_ids = add_tokens_and_get_placeholder_token(
550 args, token_ids, tokenizer, text_encoder)
551
552 # if args.resume_checkpoint is not None:
553 # token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
554 # args.placeholder_token]
555 # else:
556 # token_embeds[placeholder_token_id] = initializer_token_embeddings
557
558 # Freeze vae and unet
559 freeze_params(vae.parameters())
560 freeze_params(unet.parameters())
561 # Freeze all parameters except for the token embeddings in text encoder
562 params_to_freeze = itertools.chain(
563 text_encoder.text_model.encoder.parameters(),
564 text_encoder.text_model.final_layer_norm.parameters(),
565 text_encoder.text_model.embeddings.position_embedding.parameters(),
566 )
567 freeze_params(params_to_freeze)
568
569 if args.scale_lr:
570 args.learning_rate = (
571 args.learning_rate * args.gradient_accumulation_steps *
572 args.train_batch_size * accelerator.num_processes
573 )
574
575 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
576 if args.use_8bit_adam:
577 try:
578 import bitsandbytes as bnb
579 except ImportError:
580 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
581
582 optimizer_class = bnb.optim.AdamW8bit
583 else:
584 optimizer_class = torch.optim.AdamW
585
586 # Initialize the optimizer
587 optimizer = optimizer_class(
588 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
589 lr=args.learning_rate,
590 betas=(args.adam_beta1, args.adam_beta2),
591 weight_decay=args.adam_weight_decay,
592 eps=args.adam_epsilon,
593 )
594
595 noise_scheduler = DDPMScheduler(
596 beta_start=0.00085,
597 beta_end=0.012,
598 beta_schedule="scaled_linear",
599 num_train_timesteps=1000
600 )
601
602 def collate_fn(examples):
603 prompts = [example["prompts"] for example in examples]
604 nprompts = [example["nprompts"] for example in examples]
605 input_ids = [example["instance_prompt_ids"] for example in examples]
606 pixel_values = [example["instance_images"] for example in examples]
607
608 # concat class and instance examples for prior preservation
609 if args.use_class_images and "class_prompt_ids" in examples[0]:
610 input_ids += [example["class_prompt_ids"] for example in examples]
611 pixel_values += [example["class_images"] for example in examples]
612
613 pixel_values = torch.stack(pixel_values)
614 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
615
616 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
617
618 batch = {
619 "prompts": prompts,
620 "nprompts": nprompts,
621 "input_ids": input_ids,
622 "pixel_values": pixel_values,
623 }
624 return batch
625
626 datamodule = CSVDataModule(
627 data_file=args.train_data_file,
628 batch_size=args.train_batch_size,
629 tokenizer=tokenizer,
630 instance_identifier=placeholder_token,
631 class_identifier=args.initializer_token if args.use_class_images else None,
632 class_subdir="ti_cls",
633 size=args.resolution,
634 repeats=args.repeats,
635 center_crop=args.center_crop,
636 valid_set_size=args.sample_batch_size*args.sample_batches,
637 collate_fn=collate_fn
638 )
639
640 datamodule.prepare_data()
641 datamodule.setup()
642
643 if args.use_class_images:
644 missing_data = [item for item in datamodule.data if not item[1].exists()]
645
646 if len(missing_data) != 0:
647 batched_data = [missing_data[i:i+args.sample_batch_size]
648 for i in range(0, len(missing_data), args.sample_batch_size)]
649
650 scheduler = EulerAScheduler(
651 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
652 )
653
654 pipeline = VlpnStableDiffusion(
655 text_encoder=text_encoder,
656 vae=vae,
657 unet=unet,
658 tokenizer=tokenizer,
659 scheduler=scheduler,
660 ).to(accelerator.device)
661 pipeline.enable_attention_slicing()
662
663 for batch in batched_data:
664 image_name = [p[1] for p in batch]
665 prompt = [p[2].format(args.initializer_token) for p in batch]
666 nprompt = [p[3] for p in batch]
667
668 with accelerator.autocast():
669 images = pipeline(
670 prompt=prompt,
671 negative_prompt=nprompt,
672 num_inference_steps=args.sample_steps
673 ).images
674
675 for i, image in enumerate(images):
676 image.save(image_name[i])
677
678 del pipeline
679
680 if torch.cuda.is_available():
681 torch.cuda.empty_cache()
682
683 train_dataloader = datamodule.train_dataloader()
684 val_dataloader = datamodule.val_dataloader()
685
686 checkpointer = Checkpointer(
687 datamodule=datamodule,
688 accelerator=accelerator,
689 vae=vae,
690 unet=unet,
691 tokenizer=tokenizer,
692 placeholder_token=args.placeholder_token,
693 placeholder_token_ids=placeholder_token_ids,
694 output_dir=basepath,
695 sample_image_size=args.sample_image_size,
696 sample_batch_size=args.sample_batch_size,
697 sample_batches=args.sample_batches,
698 seed=args.seed
699 )
700
701 # Scheduler and math around the number of training steps.
702 overrode_max_train_steps = False
703 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
704 if args.max_train_steps is None:
705 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
706 overrode_max_train_steps = True
707
708 lr_scheduler = get_scheduler(
709 args.lr_scheduler,
710 optimizer=optimizer,
711 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
712 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
713 )
714
715 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
716 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
717 )
718
719 # Move vae and unet to device
720 vae.to(accelerator.device)
721 unet.to(accelerator.device)
722
723 # Keep vae and unet in eval mode as we don't train these
724 vae.eval()
725 unet.eval()
726
727 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
728 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
729 if overrode_max_train_steps:
730 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
731
732 num_val_steps_per_epoch = len(val_dataloader)
733 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
734 val_steps = num_val_steps_per_epoch * num_epochs
735
736 # We need to initialize the trackers we use, and also store our configuration.
737 # The trackers initializes automatically on the main process.
738 if accelerator.is_main_process:
739 accelerator.init_trackers("textual_inversion", config=vars(args))
740
741 # Train!
742 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
743
744 logger.info("***** Running training *****")
745 logger.info(f" Num Epochs = {num_epochs}")
746 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
747 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
748 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
749 logger.info(f" Total optimization steps = {args.max_train_steps}")
750 # Only show the progress bar once on each machine.
751
752 global_step = 0
753 min_val_loss = np.inf
754
755 if accelerator.is_main_process:
756 checkpointer.save_samples(
757 0,
758 text_encoder,
759 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
760
761 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch),
762 disable=not accelerator.is_local_main_process)
763 local_progress_bar.set_description("Batch X out of Y")
764
765 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process)
766 global_progress_bar.set_description("Total progress")
767
768 try:
769 for epoch in range(num_epochs):
770 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}")
771 local_progress_bar.reset()
772
773 text_encoder.train()
774 train_loss = 0.0
775
776 for step, batch in enumerate(train_dataloader):
777 with accelerator.accumulate(text_encoder):
778 # Convert images to latent space
779 with torch.no_grad():
780 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
781 latents = latents * 0.18215
782
783 # Sample noise that we'll add to the latents
784 noise = torch.randn(latents.shape).to(latents.device)
785 bsz = latents.shape[0]
786 # Sample a random timestep for each image
787 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
788 (bsz,), device=latents.device)
789 timesteps = timesteps.long()
790
791 # Add noise to the latents according to the noise magnitude at each timestep
792 # (this is the forward diffusion process)
793 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
794
795 # Get the text embedding for conditioning
796 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
797
798 # Predict the noise residual
799 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
800
801 if args.use_class_images:
802 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
803 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
804 noise, noise_prior = torch.chunk(noise, 2, dim=0)
805
806 # Compute instance loss
807 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
808
809 # Compute prior loss
810 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
811
812 # Add the prior loss to the instance loss.
813 loss = loss + args.prior_loss_weight * prior_loss
814 else:
815 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
816
817 accelerator.backward(loss)
818
819 # Zero out the gradients for all token embeddings except the newly added
820 # embeddings for the concept, as we only want to optimize the concept embeddings
821 if accelerator.num_processes > 1:
822 grads = text_encoder.module.get_input_embeddings().weight.grad
823 else:
824 grads = text_encoder.get_input_embeddings().weight.grad
825 # Get the index for tokens that we want to zero the grads for
826 grad_mask = torch.arange(len(tokenizer)) != placeholder_token_ids[0]
827 for i in range(1, len(placeholder_token_ids)):
828 grad_mask = grad_mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i])
829 grads.data[grad_mask, :] = grads.data[grad_mask, :].fill_(0)
830
831 optimizer.step()
832 if not accelerator.optimizer_step_was_skipped:
833 lr_scheduler.step()
834 optimizer.zero_grad(set_to_none=True)
835
836 loss = loss.detach().item()
837 train_loss += loss
838
839 # Checks if the accelerator has performed an optimization step behind the scenes
840 if accelerator.sync_gradients:
841 local_progress_bar.update(1)
842 global_progress_bar.update(1)
843
844 global_step += 1
845
846 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
847 local_progress_bar.clear()
848 global_progress_bar.clear()
849
850 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
851 save_resume_file(basepath, args, {
852 "global_step": global_step + global_step_offset,
853 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
854 })
855
856 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
857 local_progress_bar.set_postfix(**logs)
858
859 if global_step >= args.max_train_steps:
860 break
861
862 train_loss /= len(train_dataloader)
863
864 accelerator.wait_for_everyone()
865
866 text_encoder.eval()
867 val_loss = 0.0
868
869 for step, batch in enumerate(val_dataloader):
870 with torch.no_grad():
871 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
872 latents = latents * 0.18215
873
874 noise = torch.randn(latents.shape).to(latents.device)
875 bsz = latents.shape[0]
876 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
877 (bsz,), device=latents.device)
878 timesteps = timesteps.long()
879
880 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
881
882 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
883
884 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
885
886 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
887
888 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
889
890 loss = loss.detach().item()
891 val_loss += loss
892
893 if accelerator.sync_gradients:
894 local_progress_bar.update(1)
895 global_progress_bar.update(1)
896
897 logs = {"mode": "validation", "loss": loss}
898 local_progress_bar.set_postfix(**logs)
899
900 val_loss /= len(val_dataloader)
901
902 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
903
904 local_progress_bar.clear()
905 global_progress_bar.clear()
906
907 if min_val_loss > val_loss:
908 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
909 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
910 min_val_loss = val_loss
911
912 if accelerator.is_main_process:
913 checkpointer.save_samples(
914 global_step + global_step_offset,
915 text_encoder,
916 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
917
918 # Create the pipeline using using the trained modules and save it.
919 if accelerator.is_main_process:
920 print("Finished! Saving final checkpoint and resume state.")
921 checkpointer.checkpoint(
922 global_step + global_step_offset,
923 "end",
924 text_encoder,
925 path=f"{basepath}/learned_embeds.bin"
926 )
927
928 save_resume_file(basepath, args, {
929 "global_step": global_step + global_step_offset,
930 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
931 })
932
933 accelerator.end_training()
934
935 except KeyboardInterrupt:
936 if accelerator.is_main_process:
937 print("Interrupted, saving checkpoint and resume state...")
938 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder)
939 save_resume_file(basepath, args, {
940 "global_step": global_step + global_step_offset,
941 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
942 })
943 accelerator.end_training()
944 quit()
945
946
947if __name__ == "__main__":
948 main()