summaryrefslogtreecommitdiffstats
path: root/textual_dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_dreambooth.py')
-rw-r--r--textual_dreambooth.py917
1 files changed, 0 insertions, 917 deletions
diff --git a/textual_dreambooth.py b/textual_dreambooth.py
deleted file mode 100644
index c07d98b..0000000
--- a/textual_dreambooth.py
+++ /dev/null
@@ -1,917 +0,0 @@
1import argparse
2import itertools
3import math
4import os
5import datetime
6import logging
7from pathlib import Path
8
9import numpy as np
10import torch
11import torch.nn.functional as F
12import torch.utils.checkpoint
13
14from accelerate import Accelerator
15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler
20from PIL import Image
21from tqdm.auto import tqdm
22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json
26import os
27
28from data.dreambooth.csv import CSVDataModule
29
30logger = get_logger(__name__)
31
32
33torch.backends.cuda.matmul.allow_tf32 = True
34
35
36def parse_args():
37 parser = argparse.ArgumentParser(
38 description="Simple example of a training script."
39 )
40 parser.add_argument(
41 "--pretrained_model_name_or_path",
42 type=str,
43 default=None,
44 help="Path to pretrained model or model identifier from huggingface.co/models.",
45 )
46 parser.add_argument(
47 "--tokenizer_name",
48 type=str,
49 default=None,
50 help="Pretrained tokenizer name or path if not the same as model_name",
51 )
52 parser.add_argument(
53 "--train_data_file",
54 type=str,
55 default=None,
56 help="A CSV file containing the training data."
57 )
58 parser.add_argument(
59 "--placeholder_token",
60 type=str,
61 default=None,
62 help="A token to use as a placeholder for the concept.",
63 )
64 parser.add_argument(
65 "--initializer_token",
66 type=str,
67 default=None,
68 help="A token to use as initializer word."
69 )
70 parser.add_argument(
71 "--use_class_images",
72 action="store_true",
73 default=True,
74 help="Include class images in the loss calculation a la Dreambooth.",
75 )
76 parser.add_argument(
77 "--repeats",
78 type=int,
79 default=100,
80 help="How many times to repeat the training data.")
81 parser.add_argument(
82 "--output_dir",
83 type=str,
84 default="output/text-inversion",
85 help="The output directory where the model predictions and checkpoints will be written.",
86 )
87 parser.add_argument(
88 "--seed",
89 type=int,
90 default=None,
91 help="A seed for reproducible training.")
92 parser.add_argument(
93 "--resolution",
94 type=int,
95 default=512,
96 help=(
97 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
98 " resolution"
99 ),
100 )
101 parser.add_argument(
102 "--center_crop",
103 action="store_true",
104 help="Whether to center crop images before resizing to resolution"
105 )
106 parser.add_argument(
107 "--num_train_epochs",
108 type=int,
109 default=100)
110 parser.add_argument(
111 "--max_train_steps",
112 type=int,
113 default=5000,
114 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
115 )
116 parser.add_argument(
117 "--gradient_accumulation_steps",
118 type=int,
119 default=1,
120 help="Number of updates steps to accumulate before performing a backward/update pass.",
121 )
122 parser.add_argument(
123 "--gradient_checkpointing",
124 action="store_true",
125 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
126 )
127 parser.add_argument(
128 "--learning_rate",
129 type=float,
130 default=1e-4,
131 help="Initial learning rate (after the potential warmup period) to use.",
132 )
133 parser.add_argument(
134 "--scale_lr",
135 action="store_true",
136 default=True,
137 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
138 )
139 parser.add_argument(
140 "--lr_scheduler",
141 type=str,
142 default="constant",
143 help=(
144 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
145 ' "constant", "constant_with_warmup"]'
146 ),
147 )
148 parser.add_argument(
149 "--lr_warmup_steps",
150 type=int,
151 default=500,
152 help="Number of steps for the warmup in the lr scheduler."
153 )
154 parser.add_argument(
155 "--use_8bit_adam",
156 action="store_true",
157 help="Whether or not to use 8-bit Adam from bitsandbytes."
158 )
159 parser.add_argument(
160 "--adam_beta1",
161 type=float,
162 default=0.9,
163 help="The beta1 parameter for the Adam optimizer."
164 )
165 parser.add_argument(
166 "--adam_beta2",
167 type=float,
168 default=0.999,
169 help="The beta2 parameter for the Adam optimizer."
170 )
171 parser.add_argument(
172 "--adam_weight_decay",
173 type=float,
174 default=1e-2,
175 help="Weight decay to use."
176 )
177 parser.add_argument(
178 "--adam_epsilon",
179 type=float,
180 default=1e-08,
181 help="Epsilon value for the Adam optimizer"
182 )
183 parser.add_argument(
184 "--mixed_precision",
185 type=str,
186 default="no",
187 choices=["no", "fp16", "bf16"],
188 help=(
189 "Whether to use mixed precision. Choose"
190 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
191 "and an Nvidia Ampere GPU."
192 ),
193 )
194 parser.add_argument(
195 "--local_rank",
196 type=int,
197 default=-1,
198 help="For distributed training: local_rank"
199 )
200 parser.add_argument(
201 "--checkpoint_frequency",
202 type=int,
203 default=500,
204 help="How often to save a checkpoint and sample image",
205 )
206 parser.add_argument(
207 "--sample_image_size",
208 type=int,
209 default=512,
210 help="Size of sample images",
211 )
212 parser.add_argument(
213 "--sample_batches",
214 type=int,
215 default=1,
216 help="Number of sample batches to generate per checkpoint",
217 )
218 parser.add_argument(
219 "--sample_batch_size",
220 type=int,
221 default=1,
222 help="Number of samples to generate per batch",
223 )
224 parser.add_argument(
225 "--train_batch_size",
226 type=int,
227 default=1,
228 help="Batch size (per device) for the training dataloader."
229 )
230 parser.add_argument(
231 "--sample_steps",
232 type=int,
233 default=30,
234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
235 )
236 parser.add_argument(
237 "--prior_loss_weight",
238 type=float,
239 default=1.0,
240 help="The weight of prior preservation loss."
241 )
242 parser.add_argument(
243 "--resume_from",
244 type=str,
245 default=None,
246 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
247 )
248 parser.add_argument(
249 "--resume_checkpoint",
250 type=str,
251 default=None,
252 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
253 )
254 parser.add_argument(
255 "--config",
256 type=str,
257 default=None,
258 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
259 )
260
261 args = parser.parse_args()
262 if args.resume_from is not None:
263 with open(f"{args.resume_from}/resume.json", 'rt') as f:
264 args = parser.parse_args(
265 namespace=argparse.Namespace(**json.load(f)["args"]))
266 elif args.config is not None:
267 with open(args.config, 'rt') as f:
268 args = parser.parse_args(
269 namespace=argparse.Namespace(**json.load(f)["args"]))
270
271 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
272 if env_local_rank != -1 and env_local_rank != args.local_rank:
273 args.local_rank = env_local_rank
274
275 if args.train_data_file is None:
276 raise ValueError("You must specify --train_data_file")
277
278 if args.pretrained_model_name_or_path is None:
279 raise ValueError("You must specify --pretrained_model_name_or_path")
280
281 if args.placeholder_token is None:
282 raise ValueError("You must specify --placeholder_token")
283
284 if args.initializer_token is None:
285 raise ValueError("You must specify --initializer_token")
286
287 if args.output_dir is None:
288 raise ValueError("You must specify --output_dir")
289
290 return args
291
292
293def freeze_params(params):
294 for param in params:
295 param.requires_grad = False
296
297
298def save_resume_file(basepath, args, extra={}):
299 info = {"args": vars(args)}
300 info["args"].update(extra)
301 with open(f"{basepath}/resume.json", "w") as f:
302 json.dump(info, f, indent=4)
303
304
305def make_grid(images, rows, cols):
306 w, h = images[0].size
307 grid = Image.new('RGB', size=(cols*w, rows*h))
308 for i, image in enumerate(images):
309 grid.paste(image, box=(i % cols*w, i//cols*h))
310 return grid
311
312
313class Checkpointer:
314 def __init__(
315 self,
316 datamodule,
317 accelerator,
318 vae,
319 unet,
320 tokenizer,
321 placeholder_token,
322 placeholder_token_id,
323 output_dir,
324 sample_image_size,
325 sample_batches,
326 sample_batch_size,
327 seed
328 ):
329 self.datamodule = datamodule
330 self.accelerator = accelerator
331 self.vae = vae
332 self.unet = unet
333 self.tokenizer = tokenizer
334 self.placeholder_token = placeholder_token
335 self.placeholder_token_id = placeholder_token_id
336 self.output_dir = output_dir
337 self.sample_image_size = sample_image_size
338 self.seed = seed or torch.random.seed()
339 self.sample_batches = sample_batches
340 self.sample_batch_size = sample_batch_size
341
342 @torch.no_grad()
343 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
344 print("Saving checkpoint for step %d..." % step)
345
346 if path is None:
347 checkpoints_path = f"{self.output_dir}/checkpoints"
348 os.makedirs(checkpoints_path, exist_ok=True)
349
350 unwrapped = self.accelerator.unwrap_model(text_encoder)
351
352 # Save a checkpoint
353 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
354 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
355
356 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
357 if path is not None:
358 torch.save(learned_embeds_dict, path)
359 else:
360 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
361 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
362
363 del unwrapped
364 del learned_embeds
365
366 @torch.no_grad()
367 def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
368 samples_path = Path(self.output_dir).joinpath("samples")
369
370 unwrapped = self.accelerator.unwrap_model(text_encoder)
371 scheduler = EulerAScheduler(
372 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
373 )
374
375 # Save a sample image
376 pipeline = VlpnStableDiffusion(
377 text_encoder=unwrapped,
378 vae=self.vae,
379 unet=self.unet,
380 tokenizer=self.tokenizer,
381 scheduler=scheduler,
382 ).to(self.accelerator.device)
383 pipeline.enable_attention_slicing()
384
385 train_data = self.datamodule.train_dataloader()
386 val_data = self.datamodule.val_dataloader()
387
388 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
389 stable_latents = torch.randn(
390 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
391 device=pipeline.device,
392 generator=generator,
393 )
394
395 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
396 all_samples = []
397 file_path = samples_path.joinpath(pool, f"step_{step}.png")
398 file_path.parent.mkdir(parents=True, exist_ok=True)
399
400 data_enum = enumerate(data)
401
402 for i in range(self.sample_batches):
403 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
404 prompt = [prompt.format(self.placeholder_token)
405 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
406 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
407
408 with self.accelerator.autocast():
409 samples = pipeline(
410 prompt=prompt,
411 negative_prompt=nprompt,
412 height=self.sample_image_size,
413 width=self.sample_image_size,
414 latents=latents[:len(prompt)] if latents is not None else None,
415 generator=generator if latents is not None else None,
416 guidance_scale=guidance_scale,
417 eta=eta,
418 num_inference_steps=num_inference_steps,
419 output_type='pil'
420 )["sample"]
421
422 all_samples += samples
423
424 del samples
425
426 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
427 image_grid.save(file_path)
428
429 del all_samples
430 del image_grid
431
432 del unwrapped
433 del scheduler
434 del pipeline
435 del generator
436 del stable_latents
437
438 if torch.cuda.is_available():
439 torch.cuda.empty_cache()
440
441
442def main():
443 args = parse_args()
444
445 global_step_offset = 0
446 if args.resume_from is not None:
447 basepath = Path(args.resume_from)
448 print("Resuming state from %s" % args.resume_from)
449 with open(basepath.joinpath("resume.json"), 'r') as f:
450 state = json.load(f)
451 global_step_offset = state["args"].get("global_step", 0)
452
453 print("We've trained %d steps so far" % global_step_offset)
454 else:
455 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
456 basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now)
457 basepath.mkdir(parents=True, exist_ok=True)
458
459 accelerator = Accelerator(
460 log_with=LoggerType.TENSORBOARD,
461 logging_dir=f"{basepath}",
462 gradient_accumulation_steps=args.gradient_accumulation_steps,
463 mixed_precision=args.mixed_precision
464 )
465
466 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
467
468 # If passed along, set the training seed now.
469 if args.seed is not None:
470 set_seed(args.seed)
471
472 # Load the tokenizer and add the placeholder token as a additional special token
473 if args.tokenizer_name:
474 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
475 elif args.pretrained_model_name_or_path:
476 tokenizer = CLIPTokenizer.from_pretrained(
477 args.pretrained_model_name_or_path + '/tokenizer'
478 )
479
480 # Add the placeholder token in tokenizer
481 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
482 if num_added_tokens == 0:
483 raise ValueError(
484 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
485 " `placeholder_token` that is not already in the tokenizer."
486 )
487
488 # Convert the initializer_token, placeholder_token to ids
489 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
490 # Check if initializer_token is a single token or a sequence of tokens
491 if len(initializer_token_ids) > 1:
492 raise ValueError(
493 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
494
495 initializer_token_ids = torch.tensor(initializer_token_ids)
496 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
497
498 # Load models and create wrapper for stable diffusion
499 text_encoder = CLIPTextModel.from_pretrained(
500 args.pretrained_model_name_or_path + '/text_encoder',
501 )
502 vae = AutoencoderKL.from_pretrained(
503 args.pretrained_model_name_or_path + '/vae',
504 )
505 unet = UNet2DConditionModel.from_pretrained(
506 args.pretrained_model_name_or_path + '/unet',
507 )
508
509 if args.gradient_checkpointing:
510 unet.enable_gradient_checkpointing()
511
512 slice_size = unet.config.attention_head_dim // 2
513 unet.set_attention_slice(slice_size)
514
515 # Resize the token embeddings as we are adding new special tokens to the tokenizer
516 text_encoder.resize_token_embeddings(len(tokenizer))
517
518 # Initialise the newly added placeholder token with the embeddings of the initializer token
519 token_embeds = text_encoder.get_input_embeddings().weight.data
520
521 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
522
523 if args.resume_checkpoint is not None:
524 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
525 args.placeholder_token]
526 else:
527 token_embeds[placeholder_token_id] = initializer_token_embeddings
528
529 # Freeze vae and unet
530 freeze_params(vae.parameters())
531 freeze_params(unet.parameters())
532 # Freeze all parameters except for the token embeddings in text encoder
533 params_to_freeze = itertools.chain(
534 text_encoder.text_model.encoder.parameters(),
535 text_encoder.text_model.final_layer_norm.parameters(),
536 text_encoder.text_model.embeddings.position_embedding.parameters(),
537 )
538 freeze_params(params_to_freeze)
539
540 if args.scale_lr:
541 args.learning_rate = (
542 args.learning_rate * args.gradient_accumulation_steps *
543 args.train_batch_size * accelerator.num_processes
544 )
545
546 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
547 if args.use_8bit_adam:
548 try:
549 import bitsandbytes as bnb
550 except ImportError:
551 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
552
553 optimizer_class = bnb.optim.AdamW8bit
554 else:
555 optimizer_class = torch.optim.AdamW
556
557 # Initialize the optimizer
558 optimizer = optimizer_class(
559 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
560 lr=args.learning_rate,
561 betas=(args.adam_beta1, args.adam_beta2),
562 weight_decay=args.adam_weight_decay,
563 eps=args.adam_epsilon,
564 )
565
566 noise_scheduler = DDPMScheduler(
567 beta_start=0.00085,
568 beta_end=0.012,
569 beta_schedule="scaled_linear",
570 num_train_timesteps=1000
571 )
572
573 def collate_fn(examples):
574 prompts = [example["prompts"] for example in examples]
575 nprompts = [example["nprompts"] for example in examples]
576 input_ids = [example["instance_prompt_ids"] for example in examples]
577 pixel_values = [example["instance_images"] for example in examples]
578
579 # concat class and instance examples for prior preservation
580 if args.use_class_images and "class_prompt_ids" in examples[0]:
581 input_ids += [example["class_prompt_ids"] for example in examples]
582 pixel_values += [example["class_images"] for example in examples]
583
584 pixel_values = torch.stack(pixel_values)
585 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
586
587 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
588
589 batch = {
590 "prompts": prompts,
591 "nprompts": nprompts,
592 "input_ids": input_ids,
593 "pixel_values": pixel_values,
594 }
595 return batch
596
597 datamodule = CSVDataModule(
598 data_file=args.train_data_file,
599 batch_size=args.train_batch_size,
600 tokenizer=tokenizer,
601 instance_identifier=args.placeholder_token,
602 class_identifier=args.initializer_token if args.use_class_images else None,
603 class_subdir="ti_cls",
604 size=args.resolution,
605 repeats=args.repeats,
606 center_crop=args.center_crop,
607 valid_set_size=args.sample_batch_size*args.sample_batches,
608 collate_fn=collate_fn
609 )
610
611 datamodule.prepare_data()
612 datamodule.setup()
613
614 if args.use_class_images:
615 missing_data = [item for item in datamodule.data if not item[1].exists()]
616
617 if len(missing_data) != 0:
618 batched_data = [missing_data[i:i+args.sample_batch_size]
619 for i in range(0, len(missing_data), args.sample_batch_size)]
620
621 scheduler = EulerAScheduler(
622 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
623 )
624
625 pipeline = VlpnStableDiffusion(
626 text_encoder=text_encoder,
627 vae=vae,
628 unet=unet,
629 tokenizer=tokenizer,
630 scheduler=scheduler,
631 ).to(accelerator.device)
632 pipeline.enable_attention_slicing()
633
634 for batch in batched_data:
635 image_name = [p[1] for p in batch]
636 prompt = [p[2].format(args.initializer_token) for p in batch]
637 nprompt = [p[3] for p in batch]
638
639 with accelerator.autocast():
640 images = pipeline(
641 prompt=prompt,
642 negative_prompt=nprompt,
643 num_inference_steps=args.sample_steps
644 ).images
645
646 for i, image in enumerate(images):
647 image.save(image_name[i])
648
649 del pipeline
650
651 if torch.cuda.is_available():
652 torch.cuda.empty_cache()
653
654 train_dataloader = datamodule.train_dataloader()
655 val_dataloader = datamodule.val_dataloader()
656
657 checkpointer = Checkpointer(
658 datamodule=datamodule,
659 accelerator=accelerator,
660 vae=vae,
661 unet=unet,
662 tokenizer=tokenizer,
663 placeholder_token=args.placeholder_token,
664 placeholder_token_id=placeholder_token_id,
665 output_dir=basepath,
666 sample_image_size=args.sample_image_size,
667 sample_batch_size=args.sample_batch_size,
668 sample_batches=args.sample_batches,
669 seed=args.seed
670 )
671
672 # Scheduler and math around the number of training steps.
673 overrode_max_train_steps = False
674 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
675 if args.max_train_steps is None:
676 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
677 overrode_max_train_steps = True
678
679 lr_scheduler = get_scheduler(
680 args.lr_scheduler,
681 optimizer=optimizer,
682 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
683 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
684 )
685
686 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
687 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
688 )
689
690 # Move vae and unet to device
691 vae.to(accelerator.device)
692 unet.to(accelerator.device)
693
694 # Keep vae and unet in eval mode as we don't train these
695 vae.eval()
696 unet.eval()
697
698 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
699 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
700 if overrode_max_train_steps:
701 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
702
703 num_val_steps_per_epoch = len(val_dataloader)
704 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
705 val_steps = num_val_steps_per_epoch * num_epochs
706
707 # We need to initialize the trackers we use, and also store our configuration.
708 # The trackers initializes automatically on the main process.
709 if accelerator.is_main_process:
710 accelerator.init_trackers("textual_inversion", config=vars(args))
711
712 # Train!
713 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
714
715 logger.info("***** Running training *****")
716 logger.info(f" Num Epochs = {num_epochs}")
717 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
718 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
719 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
720 logger.info(f" Total optimization steps = {args.max_train_steps}")
721 # Only show the progress bar once on each machine.
722
723 global_step = 0
724 min_val_loss = np.inf
725
726 if accelerator.is_main_process:
727 checkpointer.save_samples(
728 0,
729 text_encoder,
730 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
731
732 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch),
733 disable=not accelerator.is_local_main_process)
734 local_progress_bar.set_description("Batch X out of Y")
735
736 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process)
737 global_progress_bar.set_description("Total progress")
738
739 try:
740 for epoch in range(num_epochs):
741 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}")
742 local_progress_bar.reset()
743
744 text_encoder.train()
745 train_loss = 0.0
746
747 for step, batch in enumerate(train_dataloader):
748 with accelerator.accumulate(text_encoder):
749 # Convert images to latent space
750 with torch.no_grad():
751 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
752 latents = latents * 0.18215
753
754 # Sample noise that we'll add to the latents
755 noise = torch.randn(latents.shape).to(latents.device)
756 bsz = latents.shape[0]
757 # Sample a random timestep for each image
758 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
759 (bsz,), device=latents.device)
760 timesteps = timesteps.long()
761
762 # Add noise to the latents according to the noise magnitude at each timestep
763 # (this is the forward diffusion process)
764 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
765
766 # Get the text embedding for conditioning
767 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
768
769 # Predict the noise residual
770 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
771
772 if args.use_class_images:
773 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
774 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
775 noise, noise_prior = torch.chunk(noise, 2, dim=0)
776
777 # Compute instance loss
778 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
779
780 # Compute prior loss
781 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
782
783 # Add the prior loss to the instance loss.
784 loss = loss + args.prior_loss_weight * prior_loss
785 else:
786 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
787
788 accelerator.backward(loss)
789
790 # Zero out the gradients for all token embeddings except the newly added
791 # embeddings for the concept, as we only want to optimize the concept embeddings
792 if accelerator.num_processes > 1:
793 grads = text_encoder.module.get_input_embeddings().weight.grad
794 else:
795 grads = text_encoder.get_input_embeddings().weight.grad
796 # Get the index for tokens that we want to zero the grads for
797 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
798 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
799
800 optimizer.step()
801 if not accelerator.optimizer_step_was_skipped:
802 lr_scheduler.step()
803 optimizer.zero_grad(set_to_none=True)
804
805 loss = loss.detach().item()
806 train_loss += loss
807
808 # Checks if the accelerator has performed an optimization step behind the scenes
809 if accelerator.sync_gradients:
810 local_progress_bar.update(1)
811 global_progress_bar.update(1)
812
813 global_step += 1
814
815 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
816 local_progress_bar.clear()
817 global_progress_bar.clear()
818
819 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
820 save_resume_file(basepath, args, {
821 "global_step": global_step + global_step_offset,
822 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
823 })
824
825 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
826 local_progress_bar.set_postfix(**logs)
827
828 if global_step >= args.max_train_steps:
829 break
830
831 train_loss /= len(train_dataloader)
832
833 accelerator.wait_for_everyone()
834
835 text_encoder.eval()
836 val_loss = 0.0
837
838 for step, batch in enumerate(val_dataloader):
839 with torch.no_grad():
840 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
841 latents = latents * 0.18215
842
843 noise = torch.randn(latents.shape).to(latents.device)
844 bsz = latents.shape[0]
845 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
846 (bsz,), device=latents.device)
847 timesteps = timesteps.long()
848
849 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
850
851 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
852
853 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
854
855 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
856
857 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
858
859 loss = loss.detach().item()
860 val_loss += loss
861
862 if accelerator.sync_gradients:
863 local_progress_bar.update(1)
864 global_progress_bar.update(1)
865
866 logs = {"mode": "validation", "loss": loss}
867 local_progress_bar.set_postfix(**logs)
868
869 val_loss /= len(val_dataloader)
870
871 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
872
873 local_progress_bar.clear()
874 global_progress_bar.clear()
875
876 if min_val_loss > val_loss:
877 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
878 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
879 min_val_loss = val_loss
880
881 if accelerator.is_main_process:
882 checkpointer.save_samples(
883 global_step + global_step_offset,
884 text_encoder,
885 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
886
887 # Create the pipeline using using the trained modules and save it.
888 if accelerator.is_main_process:
889 print("Finished! Saving final checkpoint and resume state.")
890 checkpointer.checkpoint(
891 global_step + global_step_offset,
892 "end",
893 text_encoder,
894 path=f"{basepath}/learned_embeds.bin"
895 )
896
897 save_resume_file(basepath, args, {
898 "global_step": global_step + global_step_offset,
899 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
900 })
901
902 accelerator.end_training()
903
904 except KeyboardInterrupt:
905 if accelerator.is_main_process:
906 print("Interrupted, saving checkpoint and resume state...")
907 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder)
908 save_resume_file(basepath, args, {
909 "global_step": global_step + global_step_offset,
910 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
911 })
912 accelerator.end_training()
913 quit()
914
915
916if __name__ == "__main__":
917 main()