summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py847
1 files changed, 847 insertions, 0 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
new file mode 100644
index 0000000..aa8e744
--- /dev/null
+++ b/textual_inversion.py
@@ -0,0 +1,847 @@
1import argparse
2import itertools
3import math
4import os
5import datetime
6
7import numpy as np
8import torch
9import torch.nn.functional as F
10import torch.utils.checkpoint
11
12from accelerate import Accelerator
13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed
15from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
16from diffusers.optimization import get_scheduler
17from pipelines.stable_diffusion.no_check import NoCheck
18from PIL import Image
19from tqdm.auto import tqdm
20from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
21from slugify import slugify
22import json
23import os
24
25from data.textual_inversion.csv import CSVDataModule
26
27logger = get_logger(__name__)
28
29
30def parse_args():
31 parser = argparse.ArgumentParser(
32 description="Simple example of a training script."
33 )
34 parser.add_argument(
35 "--pretrained_model_name_or_path",
36 type=str,
37 default=None,
38 help="Path to pretrained model or model identifier from huggingface.co/models.",
39 )
40 parser.add_argument(
41 "--tokenizer_name",
42 type=str,
43 default=None,
44 help="Pretrained tokenizer name or path if not the same as model_name",
45 )
46 parser.add_argument(
47 "--train_data_dir",
48 type=str,
49 default=None,
50 help="A folder containing the training data."
51 )
52 parser.add_argument(
53 "--placeholder_token",
54 type=str,
55 default=None,
56 help="A token to use as a placeholder for the concept.",
57 )
58 parser.add_argument(
59 "--initializer_token",
60 type=str,
61 default=None,
62 help="A token to use as initializer word."
63 )
64 parser.add_argument(
65 "--vectors_per_token",
66 type=int,
67 default=1,
68 help="Vectors per token."
69 )
70 parser.add_argument(
71 "--repeats",
72 type=int,
73 default=100,
74 help="How many times to repeat the training data.")
75 parser.add_argument(
76 "--output_dir",
77 type=str,
78 default="text-inversion-model",
79 help="The output directory where the model predictions and checkpoints will be written.",
80 )
81 parser.add_argument(
82 "--seed",
83 type=int,
84 default=None,
85 help="A seed for reproducible training.")
86 parser.add_argument(
87 "--resolution",
88 type=int,
89 default=512,
90 help=(
91 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
92 " resolution"
93 ),
94 )
95 parser.add_argument(
96 "--center_crop",
97 action="store_true",
98 help="Whether to center crop images before resizing to resolution"
99 )
100 parser.add_argument(
101 "--num_train_epochs",
102 type=int,
103 default=100)
104 parser.add_argument(
105 "--max_train_steps",
106 type=int,
107 default=5000,
108 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
109 )
110 parser.add_argument(
111 "--gradient_accumulation_steps",
112 type=int,
113 default=1,
114 help="Number of updates steps to accumulate before performing a backward/update pass.",
115 )
116 parser.add_argument(
117 "--gradient_checkpointing",
118 action="store_true",
119 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
120 )
121 parser.add_argument(
122 "--learning_rate",
123 type=float,
124 default=1e-4,
125 help="Initial learning rate (after the potential warmup period) to use.",
126 )
127 parser.add_argument(
128 "--scale_lr",
129 action="store_true",
130 default=True,
131 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
132 )
133 parser.add_argument(
134 "--lr_scheduler",
135 type=str,
136 default="constant",
137 help=(
138 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
139 ' "constant", "constant_with_warmup"]'
140 ),
141 )
142 parser.add_argument(
143 "--lr_warmup_steps",
144 type=int,
145 default=500,
146 help="Number of steps for the warmup in the lr scheduler."
147 )
148 parser.add_argument(
149 "--adam_beta1",
150 type=float,
151 default=0.9,
152 help="The beta1 parameter for the Adam optimizer."
153 )
154 parser.add_argument(
155 "--adam_beta2",
156 type=float,
157 default=0.999,
158 help="The beta2 parameter for the Adam optimizer."
159 )
160 parser.add_argument(
161 "--adam_weight_decay",
162 type=float,
163 default=1e-2,
164 help="Weight decay to use."
165 )
166 parser.add_argument(
167 "--adam_epsilon",
168 type=float,
169 default=1e-08,
170 help="Epsilon value for the Adam optimizer"
171 )
172 parser.add_argument(
173 "--mixed_precision",
174 type=str,
175 default="no",
176 choices=["no", "fp16", "bf16"],
177 help=(
178 "Whether to use mixed precision. Choose"
179 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
180 "and an Nvidia Ampere GPU."
181 ),
182 )
183 parser.add_argument(
184 "--local_rank",
185 type=int,
186 default=-1,
187 help="For distributed training: local_rank"
188 )
189 parser.add_argument(
190 "--checkpoint_frequency",
191 type=int,
192 default=500,
193 help="How often to save a checkpoint and sample image",
194 )
195 parser.add_argument(
196 "--sample_image_size",
197 type=int,
198 default=512,
199 help="Size of sample images",
200 )
201 parser.add_argument(
202 "--stable_sample_batches",
203 type=int,
204 default=1,
205 help="Number of fixed seed sample batches to generate per checkpoint",
206 )
207 parser.add_argument(
208 "--random_sample_batches",
209 type=int,
210 default=1,
211 help="Number of random seed sample batches to generate per checkpoint",
212 )
213 parser.add_argument(
214 "--sample_batch_size",
215 type=int,
216 default=1,
217 help="Number of samples to generate per batch",
218 )
219 parser.add_argument(
220 "--train_batch_size",
221 type=int,
222 default=1,
223 help="Batch size (per device) for the training dataloader."
224 )
225 parser.add_argument(
226 "--sample_steps",
227 type=int,
228 default=50,
229 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
230 )
231 parser.add_argument(
232 "--resume_from",
233 type=str,
234 default=None,
235 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
236 )
237 parser.add_argument(
238 "--resume_checkpoint",
239 type=str,
240 default=None,
241 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
242 )
243 parser.add_argument(
244 "--config",
245 type=str,
246 default=None,
247 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
248 )
249
250 args = parser.parse_args()
251 if args.resume_from is not None:
252 with open(f"{args.resume_from}/resume.json", 'rt') as f:
253 args = parser.parse_args(
254 namespace=argparse.Namespace(**json.load(f)["args"]))
255 elif args.config is not None:
256 with open(args.config, 'rt') as f:
257 args = parser.parse_args(
258 namespace=argparse.Namespace(**json.load(f)["args"]))
259
260 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
261 if env_local_rank != -1 and env_local_rank != args.local_rank:
262 args.local_rank = env_local_rank
263
264 if args.train_data_dir is None:
265 raise ValueError("You must specify --train_data_dir")
266
267 if args.pretrained_model_name_or_path is None:
268 raise ValueError("You must specify --pretrained_model_name_or_path")
269
270 if args.placeholder_token is None:
271 raise ValueError("You must specify --placeholder_token")
272
273 if args.initializer_token is None:
274 raise ValueError("You must specify --initializer_token")
275
276 if args.output_dir is None:
277 raise ValueError("You must specify --output_dir")
278
279 return args
280
281
282def freeze_params(params):
283 for param in params:
284 param.requires_grad = False
285
286
287def save_resume_file(basepath, args, extra={}):
288 info = {"args": vars(args)}
289 info["args"].update(extra)
290 with open(f"{basepath}/resume.json", "w") as f:
291 json.dump(info, f, indent=4)
292
293
294def make_grid(images, rows, cols):
295 w, h = images[0].size
296 grid = Image.new('RGB', size=(cols*w, rows*h))
297 for i, image in enumerate(images):
298 grid.paste(image, box=(i % cols*w, i//cols*h))
299 return grid
300
301
302class Checkpointer:
303 def __init__(
304 self,
305 datamodule,
306 accelerator,
307 vae,
308 unet,
309 tokenizer,
310 placeholder_token,
311 placeholder_token_id,
312 output_dir,
313 sample_image_size,
314 random_sample_batches,
315 sample_batch_size,
316 stable_sample_batches,
317 seed
318 ):
319 self.datamodule = datamodule
320 self.accelerator = accelerator
321 self.vae = vae
322 self.unet = unet
323 self.tokenizer = tokenizer
324 self.placeholder_token = placeholder_token
325 self.placeholder_token_id = placeholder_token_id
326 self.output_dir = output_dir
327 self.sample_image_size = sample_image_size
328 self.seed = seed
329 self.random_sample_batches = random_sample_batches
330 self.sample_batch_size = sample_batch_size
331 self.stable_sample_batches = stable_sample_batches
332
333 @torch.no_grad()
334 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
335 print("Saving checkpoint for step %d..." % step)
336 with self.accelerator.autocast():
337 if path is None:
338 checkpoints_path = f"{self.output_dir}/checkpoints"
339 os.makedirs(checkpoints_path, exist_ok=True)
340
341 unwrapped = self.accelerator.unwrap_model(text_encoder)
342
343 # Save a checkpoint
344 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
345 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
346
347 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
348 if path is not None:
349 torch.save(learned_embeds_dict, path)
350 else:
351 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
352 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
353 del unwrapped
354 del learned_embeds
355
356 @torch.no_grad()
357 def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
358 samples_path = f"{self.output_dir}/samples/{mode}"
359 os.makedirs(samples_path, exist_ok=True)
360 checker = NoCheck()
361
362 unwrapped = self.accelerator.unwrap_model(text_encoder)
363 # Save a sample image
364 pipeline = StableDiffusionPipeline(
365 text_encoder=unwrapped,
366 vae=self.vae,
367 unet=self.unet,
368 tokenizer=self.tokenizer,
369 scheduler=LMSDiscreteScheduler(
370 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
371 ),
372 safety_checker=NoCheck(),
373 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
374 ).to(self.accelerator.device)
375 pipeline.enable_attention_slicing()
376
377 data = {
378 "training": self.datamodule.train_dataloader(),
379 "validation": self.datamodule.val_dataloader(),
380 }[mode]
381
382 if mode == "validation" and self.stable_sample_batches > 0 and step > 0:
383 stable_latents = torch.randn(
384 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
385 device=pipeline.device,
386 generator=torch.Generator(device=pipeline.device).manual_seed(self.seed),
387 )
388
389 all_samples = []
390 filename = f"stable_step_%d.png" % (step)
391
392 data_enum = enumerate(data)
393
394 # Generate and save stable samples
395 for i in range(0, self.stable_sample_batches):
396 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
397 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
398
399 with self.accelerator.autocast():
400 samples = pipeline(
401 prompt=prompt,
402 height=self.sample_image_size,
403 latents=stable_latents[:len(prompt)],
404 width=self.sample_image_size,
405 guidance_scale=guidance_scale,
406 eta=eta,
407 num_inference_steps=num_inference_steps,
408 output_type='pil'
409 )["sample"]
410
411 all_samples += samples
412 del samples
413
414 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
415 image_grid.save(f"{samples_path}/{filename}")
416
417 del all_samples
418 del image_grid
419 del stable_latents
420
421 all_samples = []
422 filename = f"step_%d.png" % (step)
423
424 data_enum = enumerate(data)
425
426 # Generate and save random samples
427 for i in range(0, self.random_sample_batches):
428 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
429 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
430
431 with self.accelerator.autocast():
432 samples = pipeline(
433 prompt=prompt,
434 height=self.sample_image_size,
435 width=self.sample_image_size,
436 guidance_scale=guidance_scale,
437 eta=eta,
438 num_inference_steps=num_inference_steps,
439 output_type='pil'
440 )["sample"]
441
442 all_samples += samples
443 del samples
444
445 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
446 image_grid.save(f"{samples_path}/{filename}")
447
448 del all_samples
449 del image_grid
450
451 del checker
452 del unwrapped
453 del pipeline
454 torch.cuda.empty_cache()
455
456
457class ImageToLatents():
458 def __init__(self, vae):
459 self.vae = vae
460 self.encoded_pixel_values_cache = {}
461
462 @torch.no_grad()
463 def __call__(self, batch):
464 key = "|".join(batch["key"])
465 if self.encoded_pixel_values_cache.get(key, None) is None:
466 self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist
467 latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215
468 return latents
469
470
471def main():
472 args = parse_args()
473
474 global_step_offset = 0
475 if args.resume_from is not None:
476 basepath = f"{args.resume_from}"
477 print("Resuming state from %s" % args.resume_from)
478 with open(f"{basepath}/resume.json", 'r') as f:
479 state = json.load(f)
480 global_step_offset = state["args"].get("global_step", 0)
481
482 print("We've trained %d steps so far" % global_step_offset)
483 else:
484 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
485 basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}"
486 os.makedirs(basepath, exist_ok=True)
487
488 accelerator = Accelerator(
489 log_with=LoggerType.TENSORBOARD,
490 logging_dir=f"{basepath}",
491 gradient_accumulation_steps=args.gradient_accumulation_steps,
492 mixed_precision=args.mixed_precision
493 )
494
495 # If passed along, set the training seed now.
496 if args.seed is not None:
497 set_seed(args.seed)
498
499 # Load the tokenizer and add the placeholder token as a additional special token
500 if args.tokenizer_name:
501 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
502 elif args.pretrained_model_name_or_path:
503 tokenizer = CLIPTokenizer.from_pretrained(
504 args.pretrained_model_name_or_path + '/tokenizer'
505 )
506
507 # Add the placeholder token in tokenizer
508 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
509 if num_added_tokens == 0:
510 raise ValueError(
511 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
512 " `placeholder_token` that is not already in the tokenizer."
513 )
514
515 # Convert the initializer_token, placeholder_token to ids
516 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
517 # Check if initializer_token is a single token or a sequence of tokens
518 if args.vectors_per_token % len(initializer_token_ids) != 0:
519 raise ValueError(
520 f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).")
521
522 initializer_token_ids = torch.tensor(initializer_token_ids)
523 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
524
525 # Load models and create wrapper for stable diffusion
526 text_encoder = CLIPTextModel.from_pretrained(
527 args.pretrained_model_name_or_path + '/text_encoder',
528 )
529 vae = AutoencoderKL.from_pretrained(
530 args.pretrained_model_name_or_path + '/vae',
531 )
532 unet = UNet2DConditionModel.from_pretrained(
533 args.pretrained_model_name_or_path + '/unet',
534 )
535
536 if args.gradient_checkpointing:
537 unet.enable_gradient_checkpointing()
538
539 slice_size = unet.config.attention_head_dim // 2
540 unet.set_attention_slice(slice_size)
541
542 # Resize the token embeddings as we are adding new special tokens to the tokenizer
543 text_encoder.resize_token_embeddings(len(tokenizer))
544
545 # Initialise the newly added placeholder token with the embeddings of the initializer token
546 token_embeds = text_encoder.get_input_embeddings().weight.data
547
548 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
549
550 if args.resume_checkpoint is not None:
551 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
552 args.placeholder_token]
553 else:
554 token_embeds[placeholder_token_id] = initializer_token_embeddings
555
556 # Freeze vae and unet
557 freeze_params(vae.parameters())
558 freeze_params(unet.parameters())
559 # Freeze all parameters except for the token embeddings in text encoder
560 params_to_freeze = itertools.chain(
561 text_encoder.text_model.encoder.parameters(),
562 text_encoder.text_model.final_layer_norm.parameters(),
563 text_encoder.text_model.embeddings.position_embedding.parameters(),
564 )
565 freeze_params(params_to_freeze)
566
567 if args.scale_lr:
568 args.learning_rate = (
569 args.learning_rate * args.gradient_accumulation_steps *
570 args.train_batch_size * accelerator.num_processes
571 )
572
573 # Initialize the optimizer
574 optimizer = torch.optim.AdamW(
575 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
576 lr=args.learning_rate,
577 betas=(args.adam_beta1, args.adam_beta2),
578 weight_decay=args.adam_weight_decay,
579 eps=args.adam_epsilon,
580 )
581
582 # TODO (patil-suraj): laod scheduler using args
583 noise_scheduler = DDPMScheduler(
584 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
585 )
586
587 datamodule = CSVDataModule(
588 data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer,
589 size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats,
590 center_crop=args.center_crop)
591
592 datamodule.prepare_data()
593 datamodule.setup()
594
595 train_dataloader = datamodule.train_dataloader()
596 val_dataloader = datamodule.val_dataloader()
597
598 checkpointer = Checkpointer(
599 datamodule=datamodule,
600 accelerator=accelerator,
601 vae=vae,
602 unet=unet,
603 tokenizer=tokenizer,
604 placeholder_token=args.placeholder_token,
605 placeholder_token_id=placeholder_token_id,
606 output_dir=basepath,
607 sample_image_size=args.sample_image_size,
608 sample_batch_size=args.sample_batch_size,
609 random_sample_batches=args.random_sample_batches,
610 stable_sample_batches=args.stable_sample_batches,
611 seed=args.seed
612 )
613
614 # Scheduler and math around the number of training steps.
615 overrode_max_train_steps = False
616 num_update_steps_per_epoch = math.ceil(
617 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
618 if args.max_train_steps is None:
619 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
620 overrode_max_train_steps = True
621
622 lr_scheduler = get_scheduler(
623 args.lr_scheduler,
624 optimizer=optimizer,
625 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
626 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
627 )
628
629 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
630 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
631 )
632
633 # Move vae and unet to device
634 vae.to(accelerator.device)
635 unet.to(accelerator.device)
636
637 # Keep vae and unet in eval mode as we don't train these
638 vae.eval()
639 unet.eval()
640
641 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
642 num_update_steps_per_epoch = math.ceil(
643 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
644 if overrode_max_train_steps:
645 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
646 # Afterwards we recalculate our number of training epochs
647 args.num_train_epochs = math.ceil(
648 args.max_train_steps / num_update_steps_per_epoch)
649
650 # We need to initialize the trackers we use, and also store our configuration.
651 # The trackers initializes automatically on the main process.
652 if accelerator.is_main_process:
653 accelerator.init_trackers("textual_inversion", config=vars(args))
654
655 # Train!
656 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
657
658 logger.info("***** Running training *****")
659 logger.info(f" Num Epochs = {args.num_train_epochs}")
660 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
661 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
662 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
663 logger.info(f" Total optimization steps = {args.max_train_steps}")
664 # Only show the progress bar once on each machine.
665
666 global_step = 0
667 min_val_loss = np.inf
668
669 imageToLatents = ImageToLatents(vae)
670
671 checkpointer.save_samples(
672 "validation",
673 0,
674 text_encoder,
675 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
676
677 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
678 progress_bar.set_description("Global steps")
679
680 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
681 local_progress_bar.set_description("Steps")
682
683 try:
684 for epoch in range(args.num_train_epochs):
685 local_progress_bar.reset()
686
687 text_encoder.train()
688 train_loss = 0.0
689
690 for step, batch in enumerate(train_dataloader):
691 with accelerator.accumulate(text_encoder):
692 with accelerator.autocast():
693 # Convert images to latent space
694 latents = imageToLatents(batch)
695
696 # Sample noise that we'll add to the latents
697 noise = torch.randn(latents.shape).to(latents.device)
698 bsz = latents.shape[0]
699 # Sample a random timestep for each image
700 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
701 (bsz,), device=latents.device).long()
702
703 # Add noise to the latents according to the noise magnitude at each timestep
704 # (this is the forward diffusion process)
705 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
706
707 # Get the text embedding for conditioning
708 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
709
710 # Predict the noise residual
711 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
712
713 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
714
715 accelerator.backward(loss)
716
717 # Zero out the gradients for all token embeddings except the newly added
718 # embeddings for the concept, as we only want to optimize the concept embeddings
719 if accelerator.num_processes > 1:
720 grads = text_encoder.module.get_input_embeddings().weight.grad
721 else:
722 grads = text_encoder.get_input_embeddings().weight.grad
723 # Get the index for tokens that we want to zero the grads for
724 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
725 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
726
727 optimizer.step()
728 if not accelerator.optimizer_step_was_skipped:
729 lr_scheduler.step()
730 optimizer.zero_grad()
731
732 loss = loss.detach().item()
733 train_loss += loss
734
735 # Checks if the accelerator has performed an optimization step behind the scenes
736 if accelerator.sync_gradients:
737 progress_bar.update(1)
738 local_progress_bar.update(1)
739
740 global_step += 1
741
742 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
743 progress_bar.clear()
744 local_progress_bar.clear()
745
746 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
747 save_resume_file(basepath, args, {
748 "global_step": global_step + global_step_offset,
749 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
750 })
751 checkpointer.save_samples(
752 "training",
753 global_step + global_step_offset,
754 text_encoder,
755 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
756
757 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
758 local_progress_bar.set_postfix(**logs)
759
760 if global_step >= args.max_train_steps:
761 break
762
763 train_loss /= len(train_dataloader)
764
765 text_encoder.eval()
766 val_loss = 0.0
767
768 for step, batch in enumerate(val_dataloader):
769 with torch.no_grad(), accelerator.autocast():
770 latents = imageToLatents(batch)
771
772 noise = torch.randn(latents.shape).to(latents.device)
773 bsz = latents.shape[0]
774 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
775 (bsz,), device=latents.device).long()
776
777 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
778
779 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
780
781 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
782
783 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
784
785 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
786
787 loss = loss.detach().item()
788 val_loss += loss
789
790 if accelerator.sync_gradients:
791 progress_bar.update(1)
792 local_progress_bar.update(1)
793
794 logs = {"mode": "validation", "loss": loss}
795 local_progress_bar.set_postfix(**logs)
796
797 val_loss /= len(val_dataloader)
798
799 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
800
801 progress_bar.clear()
802 local_progress_bar.clear()
803
804 if min_val_loss > val_loss:
805 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
806 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
807 min_val_loss = val_loss
808
809 checkpointer.save_samples(
810 "validation",
811 global_step + global_step_offset,
812 text_encoder,
813 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
814
815 accelerator.wait_for_everyone()
816
817 # Create the pipeline using using the trained modules and save it.
818 if accelerator.is_main_process:
819 print("Finished! Saving final checkpoint and resume state.")
820 checkpointer.checkpoint(
821 global_step + global_step_offset,
822 "end",
823 text_encoder,
824 path=f"{basepath}/learned_embeds.bin"
825 )
826
827 save_resume_file(basepath, args, {
828 "global_step": global_step + global_step_offset,
829 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
830 })
831
832 accelerator.end_training()
833
834 except KeyboardInterrupt:
835 if accelerator.is_main_process:
836 print("Interrupted, saving checkpoint and resume state...")
837 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder)
838 save_resume_file(basepath, args, {
839 "global_step": global_step + global_step_offset,
840 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
841 })
842 accelerator.end_training()
843 quit()
844
845
846if __name__ == "__main__":
847 main()