summaryrefslogtreecommitdiffstats
path: root/main.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-26 16:36:42 +0200
committerVolpeon <git@volpeon.ink>2022-09-26 16:36:42 +0200
commit5588b93859c4380082a7e46bf5bef2119ec1907a (patch)
tree05a8292201912eb6f417eb2740c86df2153d1095 /main.py
downloadtextual-inversion-diff-5588b93859c4380082a7e46bf5bef2119ec1907a.tar.gz
textual-inversion-diff-5588b93859c4380082a7e46bf5bef2119ec1907a.tar.bz2
textual-inversion-diff-5588b93859c4380082a7e46bf5bef2119ec1907a.zip
Init
Diffstat (limited to 'main.py')
-rw-r--r--main.py784
1 files changed, 784 insertions, 0 deletions
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..9bf65a5
--- /dev/null
+++ b/main.py
@@ -0,0 +1,784 @@
1import argparse
2import itertools
3import math
4import os
5import random
6import datetime
7from pathlib import Path
8from typing import Optional
9
10import numpy as np
11import torch
12import torch.nn as nn
13import torch.nn.functional as F
14import torch.utils.checkpoint
15from torch.utils.data import Dataset
16
17import PIL
18from accelerate import Accelerator
19from accelerate.logging import get_logger
20from accelerate.utils import LoggerType, set_seed
21from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
22from diffusers.optimization import get_scheduler
23from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
24from einops import rearrange
25from pipelines.stable_diffusion.no_check import NoCheck
26from huggingface_hub import HfFolder, Repository, whoami
27from PIL import Image
28from tqdm.auto import tqdm
29from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
30from slugify import slugify
31import json
32import os
33import sys
34
35from data import CSVDataModule
36
37logger = get_logger(__name__)
38
39
40def parse_args():
41 parser = argparse.ArgumentParser(
42 description="Simple example of a training script.")
43 parser.add_argument(
44 "--pretrained_model_name_or_path",
45 type=str,
46 default=None,
47 help="Path to pretrained model or model identifier from huggingface.co/models.",
48 )
49 parser.add_argument(
50 "--tokenizer_name",
51 type=str,
52 default=None,
53 help="Pretrained tokenizer name or path if not the same as model_name",
54 )
55 parser.add_argument(
56 "--train_data_dir", type=str, default=None, help="A folder containing the training data."
57 )
58 parser.add_argument(
59 "--placeholder_token",
60 type=str,
61 default=None,
62 help="A token to use as a placeholder for the concept.",
63 )
64 parser.add_argument(
65 "--initializer_token", type=str, default=None, help="A token to use as initializer word."
66 )
67 parser.add_argument(
68 "--vectors_per_token", type=int, default=1, help="Vectors per token."
69 )
70 parser.add_argument("--repeats", type=int, default=100,
71 help="How many times to repeat the training data.")
72 parser.add_argument(
73 "--output_dir",
74 type=str,
75 default="text-inversion-model",
76 help="The output directory where the model predictions and checkpoints will be written.",
77 )
78 parser.add_argument("--seed", type=int, default=None,
79 help="A seed for reproducible training.")
80 parser.add_argument(
81 "--resolution",
82 type=int,
83 default=512,
84 help=(
85 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
86 " resolution"
87 ),
88 )
89 parser.add_argument(
90 "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
91 )
92 parser.add_argument(
93 "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
94 )
95 parser.add_argument("--num_train_epochs", type=int, default=100)
96 parser.add_argument(
97 "--max_train_steps",
98 type=int,
99 default=5000,
100 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
101 )
102 parser.add_argument(
103 "--gradient_accumulation_steps",
104 type=int,
105 default=1,
106 help="Number of updates steps to accumulate before performing a backward/update pass.",
107 )
108 parser.add_argument(
109 "--learning_rate",
110 type=float,
111 default=1e-4,
112 help="Initial learning rate (after the potential warmup period) to use.",
113 )
114 parser.add_argument(
115 "--scale_lr",
116 action="store_true",
117 default=True,
118 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
119 )
120 parser.add_argument(
121 "--lr_scheduler",
122 type=str,
123 default="constant",
124 help=(
125 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
126 ' "constant", "constant_with_warmup"]'
127 ),
128 )
129 parser.add_argument(
130 "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
131 )
132 parser.add_argument("--adam_beta1", type=float, default=0.9,
133 help="The beta1 parameter for the Adam optimizer.")
134 parser.add_argument("--adam_beta2", type=float, default=0.999,
135 help="The beta2 parameter for the Adam optimizer.")
136 parser.add_argument("--adam_weight_decay", type=float,
137 default=1e-2, help="Weight decay to use.")
138 parser.add_argument("--adam_epsilon", type=float, default=1e-08,
139 help="Epsilon value for the Adam optimizer")
140 parser.add_argument(
141 "--mixed_precision",
142 type=str,
143 default="no",
144 choices=["no", "fp16", "bf16"],
145 help=(
146 "Whether to use mixed precision. Choose"
147 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
148 "and an Nvidia Ampere GPU."
149 ),
150 )
151 parser.add_argument("--local_rank", type=int, default=-1,
152 help="For distributed training: local_rank")
153 parser.add_argument(
154 "--checkpoint_frequency",
155 type=int,
156 default=500,
157 help="How often to save a checkpoint and sample image",
158 )
159 parser.add_argument(
160 "--sample_image_size",
161 type=int,
162 default=512,
163 help="Size of sample images",
164 )
165 parser.add_argument(
166 "--stable_sample_batches",
167 type=int,
168 default=1,
169 help="Number of fixed seed sample batches to generate per checkpoint",
170 )
171 parser.add_argument(
172 "--random_sample_batches",
173 type=int,
174 default=1,
175 help="Number of random seed sample batches to generate per checkpoint",
176 )
177 parser.add_argument(
178 "--sample_batch_size",
179 type=int,
180 default=1,
181 help="Number of samples to generate per batch",
182 )
183 parser.add_argument(
184 "--sample_steps",
185 type=int,
186 default=50,
187 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
188 )
189 parser.add_argument(
190 "--resume_from",
191 type=str,
192 default=None,
193 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
194 )
195 parser.add_argument(
196 "--resume_checkpoint",
197 type=str,
198 default=None,
199 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
200 )
201 parser.add_argument(
202 "--config",
203 type=str,
204 default=None,
205 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
206 )
207
208 args = parser.parse_args()
209 if args.resume_from is not None:
210 with open(f"{args.resume_from}/resume.json", 'rt') as f:
211 args = parser.parse_args(
212 namespace=argparse.Namespace(**json.load(f)["args"]))
213 elif args.config is not None:
214 with open(args.config, 'rt') as f:
215 args = parser.parse_args(
216 namespace=argparse.Namespace(**json.load(f)))
217
218 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
219 if env_local_rank != -1 and env_local_rank != args.local_rank:
220 args.local_rank = env_local_rank
221
222 if args.train_data_dir is None:
223 raise ValueError("You must specify --train_data_dir")
224
225 if args.pretrained_model_name_or_path is None:
226 raise ValueError("You must specify --pretrained_model_name_or_path")
227
228 if args.placeholder_token is None:
229 raise ValueError("You must specify --placeholder_token")
230
231 if args.initializer_token is None:
232 raise ValueError("You must specify --initializer_token")
233
234 if args.output_dir is None:
235 raise ValueError("You must specify --output_dir")
236
237 return args
238
239
240def freeze_params(params):
241 for param in params:
242 param.requires_grad = False
243
244
245def save_resume_file(basepath, args, extra={}):
246 info = {"args": vars(args)}
247 info["args"].update(extra)
248 with open(f"{basepath}/resume.json", "w") as f:
249 json.dump(info, f, indent=4)
250
251
252def make_grid(images, rows, cols):
253 w, h = images[0].size
254 grid = Image.new('RGB', size=(cols*w, rows*h))
255 for i, image in enumerate(images):
256 grid.paste(image, box=(i % cols*w, i//cols*h))
257 return grid
258
259
260class Checkpointer:
261 def __init__(
262 self,
263 datamodule,
264 accelerator,
265 vae,
266 unet,
267 tokenizer,
268 placeholder_token,
269 placeholder_token_id,
270 output_dir,
271 sample_image_size,
272 random_sample_batches,
273 sample_batch_size,
274 stable_sample_batches,
275 seed
276 ):
277 self.datamodule = datamodule
278 self.accelerator = accelerator
279 self.vae = vae
280 self.unet = unet
281 self.tokenizer = tokenizer
282 self.placeholder_token = placeholder_token
283 self.placeholder_token_id = placeholder_token_id
284 self.output_dir = output_dir
285 self.sample_image_size = sample_image_size
286 self.seed = seed
287 self.random_sample_batches = random_sample_batches
288 self.sample_batch_size = sample_batch_size
289 self.stable_sample_batches = stable_sample_batches
290
291 @torch.no_grad()
292 def checkpoint(self, step, text_encoder, save_samples=True, path=None):
293 print("Saving checkpoint for step %d..." % step)
294 with self.accelerator.autocast():
295 if path is None:
296 checkpoints_path = f"{self.output_dir}/checkpoints"
297 os.makedirs(checkpoints_path, exist_ok=True)
298
299 unwrapped = self.accelerator.unwrap_model(text_encoder)
300
301 # Save a checkpoint
302 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
303 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
304
305 filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step)
306 if path is not None:
307 torch.save(learned_embeds_dict, path)
308 else:
309 torch.save(learned_embeds_dict,
310 f"{checkpoints_path}/{filename}")
311 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
312 del unwrapped
313 del learned_embeds
314
315 @torch.no_grad()
316 def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
317 samples_path = f"{self.output_dir}/samples/{mode}"
318 os.makedirs(samples_path, exist_ok=True)
319 checker = NoCheck()
320
321 unwrapped = self.accelerator.unwrap_model(text_encoder)
322 # Save a sample image
323 pipeline = StableDiffusionPipeline(
324 text_encoder=unwrapped,
325 vae=self.vae,
326 unet=self.unet,
327 tokenizer=self.tokenizer,
328 scheduler=LMSDiscreteScheduler(
329 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
330 ),
331 safety_checker=NoCheck(),
332 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
333 ).to(self.accelerator.device)
334 pipeline.enable_attention_slicing()
335
336 data = {
337 "training": self.datamodule.train_dataloader(),
338 "validation": self.datamodule.val_dataloader(),
339 }[mode]
340
341 if mode == "validation" and self.stable_sample_batches > 0:
342 stable_latents = torch.randn(
343 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
344 device=pipeline.device,
345 generator=torch.Generator(device=pipeline.device).manual_seed(self.seed),
346 )
347
348 all_samples = []
349 filename = f"stable_step_%d.png" % (step)
350
351 # Generate and save stable samples
352 for i in range(0, self.stable_sample_batches):
353 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size]
354 samples = pipeline(
355 prompt=prompt,
356 height=self.sample_image_size,
357 latents=stable_latents,
358 width=self.sample_image_size,
359 guidance_scale=guidance_scale,
360 eta=eta,
361 num_inference_steps=num_inference_steps,
362 output_type='pil'
363 )["sample"]
364
365 all_samples += samples
366 del samples
367
368 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
369 image_grid.save(f"{samples_path}/{filename}")
370
371 del all_samples
372 del image_grid
373 del stable_latents
374
375 all_samples = []
376 filename = f"step_%d.png" % (step)
377
378 # Generate and save random samples
379 for i in range(0, self.random_sample_batches):
380 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size]
381 samples = pipeline(
382 prompt=prompt,
383 height=self.sample_image_size,
384 width=self.sample_image_size,
385 guidance_scale=guidance_scale,
386 eta=eta,
387 num_inference_steps=num_inference_steps,
388 output_type='pil'
389 )["sample"]
390
391 all_samples += samples
392 del samples
393
394 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
395 image_grid.save(f"{samples_path}/{filename}")
396
397 del all_samples
398 del image_grid
399
400 del checker
401 del unwrapped
402 del pipeline
403 torch.cuda.empty_cache()
404
405
406class ImageToLatents():
407 def __init__(self, vae):
408 self.vae = vae
409 self.encoded_pixel_values_cache = {}
410
411 @torch.no_grad()
412 def __call__(self, batch):
413 key = "|".join(batch["key"])
414 if self.encoded_pixel_values_cache.get(key, None) is None:
415 self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist
416 latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215
417 return latents
418
419
420def main():
421 args = parse_args()
422
423 global_step_offset = 0
424 if args.resume_from is not None:
425 basepath = f"{args.resume_from}"
426 print("Resuming state from %s" % args.resume_from)
427 with open(f"{basepath}/resume.json", 'r') as f:
428 state = json.load(f)
429 global_step_offset = state["args"].get("global_step", 0)
430
431 print("We've trained %d steps so far" % global_step_offset)
432 else:
433 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
434 basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}"
435 os.makedirs(basepath, exist_ok=True)
436
437 accelerator = Accelerator(
438 log_with=LoggerType.TENSORBOARD,
439 logging_dir=f"{basepath}",
440 gradient_accumulation_steps=args.gradient_accumulation_steps,
441 mixed_precision=args.mixed_precision
442 )
443
444 # If passed along, set the training seed now.
445 if args.seed is not None:
446 set_seed(args.seed)
447
448 # Load the tokenizer and add the placeholder token as a additional special token
449 if args.tokenizer_name:
450 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
451 elif args.pretrained_model_name_or_path:
452 tokenizer = CLIPTokenizer.from_pretrained(
453 args.pretrained_model_name_or_path + '/tokenizer'
454 )
455
456 # Add the placeholder token in tokenizer
457 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
458 if num_added_tokens == 0:
459 raise ValueError(
460 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
461 " `placeholder_token` that is not already in the tokenizer."
462 )
463
464 # Convert the initializer_token, placeholder_token to ids
465 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
466 # Check if initializer_token is a single token or a sequence of tokens
467 if args.vectors_per_token % len(initializer_token_ids) != 0:
468 raise ValueError(
469 f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).")
470
471 initializer_token_ids = torch.tensor(initializer_token_ids)
472 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
473
474 # Load models and create wrapper for stable diffusion
475 text_encoder = CLIPTextModel.from_pretrained(
476 args.pretrained_model_name_or_path + '/text_encoder',
477 )
478 vae = AutoencoderKL.from_pretrained(
479 args.pretrained_model_name_or_path + '/vae',
480 )
481 unet = UNet2DConditionModel.from_pretrained(
482 args.pretrained_model_name_or_path + '/unet',
483 )
484
485 slice_size = unet.config.attention_head_dim // 2
486 unet.set_attention_slice(slice_size)
487
488 # Resize the token embeddings as we are adding new special tokens to the tokenizer
489 text_encoder.resize_token_embeddings(len(tokenizer))
490
491 # Initialise the newly added placeholder token with the embeddings of the initializer token
492 token_embeds = text_encoder.get_input_embeddings().weight.data
493
494 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
495
496 if args.resume_checkpoint is not None:
497 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
498 args.placeholder_token]
499 else:
500 token_embeds[placeholder_token_id] = initializer_token_embeddings
501
502 # Freeze vae and unet
503 freeze_params(vae.parameters())
504 freeze_params(unet.parameters())
505 # Freeze all parameters except for the token embeddings in text encoder
506 params_to_freeze = itertools.chain(
507 text_encoder.text_model.encoder.parameters(),
508 text_encoder.text_model.final_layer_norm.parameters(),
509 text_encoder.text_model.embeddings.position_embedding.parameters(),
510 )
511 freeze_params(params_to_freeze)
512
513 if args.scale_lr:
514 args.learning_rate = (
515 args.learning_rate * args.gradient_accumulation_steps *
516 args.train_batch_size * accelerator.num_processes
517 )
518
519 # Initialize the optimizer
520 optimizer = torch.optim.AdamW(
521 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
522 lr=args.learning_rate,
523 betas=(args.adam_beta1, args.adam_beta2),
524 weight_decay=args.adam_weight_decay,
525 eps=args.adam_epsilon,
526 )
527
528 # TODO (patil-suraj): laod scheduler using args
529 noise_scheduler = DDPMScheduler(
530 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
531 )
532
533 datamodule = CSVDataModule(
534 data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer,
535 size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats,
536 center_crop=args.center_crop)
537
538 datamodule.prepare_data()
539 datamodule.setup()
540
541 train_dataloader = datamodule.train_dataloader()
542 val_dataloader = datamodule.val_dataloader()
543
544 checkpointer = Checkpointer(
545 datamodule=datamodule,
546 accelerator=accelerator,
547 vae=vae,
548 unet=unet,
549 tokenizer=tokenizer,
550 placeholder_token=args.placeholder_token,
551 placeholder_token_id=placeholder_token_id,
552 output_dir=basepath,
553 sample_image_size=args.sample_image_size,
554 sample_batch_size=args.sample_batch_size,
555 random_sample_batches=args.random_sample_batches,
556 stable_sample_batches=args.stable_sample_batches,
557 seed=args.seed
558 )
559
560 # Scheduler and math around the number of training steps.
561 overrode_max_train_steps = False
562 num_update_steps_per_epoch = math.ceil(
563 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
564 if args.max_train_steps is None:
565 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
566 overrode_max_train_steps = True
567
568 lr_scheduler = get_scheduler(
569 args.lr_scheduler,
570 optimizer=optimizer,
571 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
572 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
573 )
574
575 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
576 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
577 )
578
579 # Move vae and unet to device
580 vae.to(accelerator.device)
581 unet.to(accelerator.device)
582
583 # Keep vae and unet in eval mode as we don't train these
584 vae.eval()
585 unet.eval()
586
587 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
588 num_update_steps_per_epoch = math.ceil(
589 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
590 if overrode_max_train_steps:
591 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
592 # Afterwards we recalculate our number of training epochs
593 args.num_train_epochs = math.ceil(
594 args.max_train_steps / num_update_steps_per_epoch)
595
596 # We need to initialize the trackers we use, and also store our configuration.
597 # The trackers initializes automatically on the main process.
598 if accelerator.is_main_process:
599 accelerator.init_trackers("textual_inversion", config=vars(args))
600
601 # Train!
602 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
603
604 logger.info("***** Running training *****")
605 logger.info(f" Num Epochs = {args.num_train_epochs}")
606 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
607 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
608 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
609 logger.info(f" Total optimization steps = {args.max_train_steps}")
610 # Only show the progress bar once on each machine.
611
612 global_step = 0
613 min_val_loss = np.inf
614
615 imageToLatents = ImageToLatents(vae)
616
617 checkpointer.save_samples(
618 "validation",
619 0,
620 text_encoder,
621 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
622
623 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
624 progress_bar.set_description("Global steps")
625
626 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
627 local_progress_bar.set_description("Steps")
628
629 try:
630 for epoch in range(args.num_train_epochs):
631 local_progress_bar.reset()
632
633 text_encoder.train()
634 train_loss = 0.0
635
636 for step, batch in enumerate(train_dataloader):
637 with accelerator.accumulate(text_encoder):
638 with accelerator.autocast():
639 # Convert images to latent space
640 latents = imageToLatents(batch)
641
642 # Sample noise that we'll add to the latents
643 noise = torch.randn(latents.shape).to(latents.device)
644 bsz = latents.shape[0]
645 # Sample a random timestep for each image
646 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
647 (bsz,), device=latents.device).long()
648
649 # Add noise to the latents according to the noise magnitude at each timestep
650 # (this is the forward diffusion process)
651 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
652
653 # Get the text embedding for conditioning
654 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
655
656 # Predict the noise residual
657 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
658
659 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
660
661 accelerator.backward(loss)
662
663 # Zero out the gradients for all token embeddings except the newly added
664 # embeddings for the concept, as we only want to optimize the concept embeddings
665 if accelerator.num_processes > 1:
666 grads = text_encoder.module.get_input_embeddings().weight.grad
667 else:
668 grads = text_encoder.get_input_embeddings().weight.grad
669 # Get the index for tokens that we want to zero the grads for
670 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
671 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
672
673 optimizer.step()
674 if not accelerator.optimizer_step_was_skipped:
675 lr_scheduler.step()
676 optimizer.zero_grad()
677
678 loss = loss.detach().item()
679 train_loss += loss
680
681 # Checks if the accelerator has performed an optimization step behind the scenes
682 if accelerator.sync_gradients:
683 progress_bar.update(1)
684 local_progress_bar.update(1)
685 global_step += 1
686
687 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
688 checkpointer.checkpoint(global_step + global_step_offset, text_encoder)
689 save_resume_file(basepath, args, {
690 "global_step": global_step + global_step_offset,
691 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
692 })
693 checkpointer.save_samples(
694 "training",
695 global_step + global_step_offset,
696 text_encoder,
697 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
698
699 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
700 local_progress_bar.set_postfix(**logs)
701
702 if global_step >= args.max_train_steps:
703 break
704
705 train_loss /= len(train_dataloader)
706
707 text_encoder.eval()
708 val_loss = 0.0
709
710 for step, batch in enumerate(val_dataloader):
711 with torch.no_grad(), accelerator.autocast():
712 latents = imageToLatents(batch)
713
714 noise = torch.randn(latents.shape).to(latents.device)
715 bsz = latents.shape[0]
716 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
717 (bsz,), device=latents.device).long()
718
719 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
720
721 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
722
723 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
724
725 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
726
727 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
728
729 loss = loss.detach().item()
730 val_loss += loss
731
732 if accelerator.sync_gradients:
733 progress_bar.update(1)
734 local_progress_bar.update(1)
735
736 logs = {"mode": "validation", "loss": loss}
737 local_progress_bar.set_postfix(**logs)
738
739 val_loss /= len(val_dataloader)
740
741 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
742
743 if min_val_loss > val_loss:
744 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
745 min_val_loss = val_loss
746
747 checkpointer.save_samples(
748 "validation",
749 global_step + global_step_offset,
750 text_encoder,
751 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
752
753 accelerator.wait_for_everyone()
754
755 # Create the pipeline using using the trained modules and save it.
756 if accelerator.is_main_process:
757 print("Finished! Saving final checkpoint and resume state.")
758 checkpointer.checkpoint(
759 global_step + global_step_offset,
760 text_encoder,
761 path=f"{basepath}/learned_embeds.bin"
762 )
763
764 save_resume_file(basepath, args, {
765 "global_step": global_step + global_step_offset,
766 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
767 })
768
769 accelerator.end_training()
770
771 except KeyboardInterrupt:
772 if accelerator.is_main_process:
773 print("Interrupted, saving checkpoint and resume state...")
774 checkpointer.checkpoint(global_step + global_step_offset, text_encoder)
775 save_resume_file(basepath, args, {
776 "global_step": global_step + global_step_offset,
777 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
778 })
779 accelerator.end_training()
780 quit()
781
782
783if __name__ == "__main__":
784 main()