summaryrefslogtreecommitdiffstats
path: root/main.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-27 10:03:12 +0200
committerVolpeon <git@volpeon.ink>2022-09-27 10:03:12 +0200
commit5d2abb1749b5d2f2f22ad603b5c2bf9182864520 (patch)
treed122d75322dff5cce3f2eb6cac0efe375320b9fd /main.py
parentUse diffusers fork with Flash Attention (diff)
downloadtextual-inversion-diff-5d2abb1749b5d2f2f22ad603b5c2bf9182864520.tar.gz
textual-inversion-diff-5d2abb1749b5d2f2f22ad603b5c2bf9182864520.tar.bz2
textual-inversion-diff-5d2abb1749b5d2f2f22ad603b5c2bf9182864520.zip
More cleanup
Diffstat (limited to 'main.py')
-rw-r--r--main.py850
1 files changed, 0 insertions, 850 deletions
diff --git a/main.py b/main.py
deleted file mode 100644
index 51b64c1..0000000
--- a/main.py
+++ /dev/null
@@ -1,850 +0,0 @@
1import argparse
2import itertools
3import math
4import os
5import random
6import datetime
7from pathlib import Path
8from typing import Optional
9
10import numpy as np
11import torch
12import torch.nn.functional as F
13import torch.utils.checkpoint
14
15from accelerate import Accelerator
16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
19from diffusers.optimization import get_scheduler
20from pipelines.stable_diffusion.no_check import NoCheck
21from PIL import Image
22from tqdm.auto import tqdm
23from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
24from slugify import slugify
25import json
26import os
27
28from data import CSVDataModule
29
30logger = get_logger(__name__)
31
32
33def parse_args():
34 parser = argparse.ArgumentParser(
35 description="Simple example of a training script."
36 )
37 parser.add_argument(
38 "--pretrained_model_name_or_path",
39 type=str,
40 default=None,
41 help="Path to pretrained model or model identifier from huggingface.co/models.",
42 )
43 parser.add_argument(
44 "--tokenizer_name",
45 type=str,
46 default=None,
47 help="Pretrained tokenizer name or path if not the same as model_name",
48 )
49 parser.add_argument(
50 "--train_data_dir",
51 type=str,
52 default=None,
53 help="A folder containing the training data."
54 )
55 parser.add_argument(
56 "--placeholder_token",
57 type=str,
58 default=None,
59 help="A token to use as a placeholder for the concept.",
60 )
61 parser.add_argument(
62 "--initializer_token",
63 type=str,
64 default=None,
65 help="A token to use as initializer word."
66 )
67 parser.add_argument(
68 "--vectors_per_token",
69 type=int,
70 default=1,
71 help="Vectors per token."
72 )
73 parser.add_argument(
74 "--repeats",
75 type=int,
76 default=100,
77 help="How many times to repeat the training data.")
78 parser.add_argument(
79 "--output_dir",
80 type=str,
81 default="text-inversion-model",
82 help="The output directory where the model predictions and checkpoints will be written.",
83 )
84 parser.add_argument(
85 "--seed",
86 type=int,
87 default=None,
88 help="A seed for reproducible training.")
89 parser.add_argument(
90 "--resolution",
91 type=int,
92 default=512,
93 help=(
94 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
95 " resolution"
96 ),
97 )
98 parser.add_argument(
99 "--center_crop",
100 action="store_true",
101 help="Whether to center crop images before resizing to resolution"
102 )
103 parser.add_argument(
104 "--num_train_epochs",
105 type=int,
106 default=100)
107 parser.add_argument(
108 "--max_train_steps",
109 type=int,
110 default=5000,
111 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
112 )
113 parser.add_argument(
114 "--gradient_accumulation_steps",
115 type=int,
116 default=1,
117 help="Number of updates steps to accumulate before performing a backward/update pass.",
118 )
119 parser.add_argument(
120 "--gradient_checkpointing",
121 action="store_true",
122 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
123 )
124 parser.add_argument(
125 "--learning_rate",
126 type=float,
127 default=1e-4,
128 help="Initial learning rate (after the potential warmup period) to use.",
129 )
130 parser.add_argument(
131 "--scale_lr",
132 action="store_true",
133 default=True,
134 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
135 )
136 parser.add_argument(
137 "--lr_scheduler",
138 type=str,
139 default="constant",
140 help=(
141 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
142 ' "constant", "constant_with_warmup"]'
143 ),
144 )
145 parser.add_argument(
146 "--lr_warmup_steps",
147 type=int,
148 default=500,
149 help="Number of steps for the warmup in the lr scheduler."
150 )
151 parser.add_argument(
152 "--adam_beta1",
153 type=float,
154 default=0.9,
155 help="The beta1 parameter for the Adam optimizer."
156 )
157 parser.add_argument(
158 "--adam_beta2",
159 type=float,
160 default=0.999,
161 help="The beta2 parameter for the Adam optimizer."
162 )
163 parser.add_argument(
164 "--adam_weight_decay",
165 type=float,
166 default=1e-2,
167 help="Weight decay to use."
168 )
169 parser.add_argument(
170 "--adam_epsilon",
171 type=float,
172 default=1e-08,
173 help="Epsilon value for the Adam optimizer"
174 )
175 parser.add_argument(
176 "--mixed_precision",
177 type=str,
178 default="no",
179 choices=["no", "fp16", "bf16"],
180 help=(
181 "Whether to use mixed precision. Choose"
182 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
183 "and an Nvidia Ampere GPU."
184 ),
185 )
186 parser.add_argument(
187 "--local_rank",
188 type=int,
189 default=-1,
190 help="For distributed training: local_rank"
191 )
192 parser.add_argument(
193 "--checkpoint_frequency",
194 type=int,
195 default=500,
196 help="How often to save a checkpoint and sample image",
197 )
198 parser.add_argument(
199 "--sample_image_size",
200 type=int,
201 default=512,
202 help="Size of sample images",
203 )
204 parser.add_argument(
205 "--stable_sample_batches",
206 type=int,
207 default=1,
208 help="Number of fixed seed sample batches to generate per checkpoint",
209 )
210 parser.add_argument(
211 "--random_sample_batches",
212 type=int,
213 default=1,
214 help="Number of random seed sample batches to generate per checkpoint",
215 )
216 parser.add_argument(
217 "--sample_batch_size",
218 type=int,
219 default=1,
220 help="Number of samples to generate per batch",
221 )
222 parser.add_argument(
223 "--train_batch_size",
224 type=int,
225 default=1,
226 help="Batch size (per device) for the training dataloader."
227 )
228 parser.add_argument(
229 "--sample_steps",
230 type=int,
231 default=50,
232 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
233 )
234 parser.add_argument(
235 "--resume_from",
236 type=str,
237 default=None,
238 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
239 )
240 parser.add_argument(
241 "--resume_checkpoint",
242 type=str,
243 default=None,
244 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
245 )
246 parser.add_argument(
247 "--config",
248 type=str,
249 default=None,
250 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
251 )
252
253 args = parser.parse_args()
254 if args.resume_from is not None:
255 with open(f"{args.resume_from}/resume.json", 'rt') as f:
256 args = parser.parse_args(
257 namespace=argparse.Namespace(**json.load(f)["args"]))
258 elif args.config is not None:
259 with open(args.config, 'rt') as f:
260 args = parser.parse_args(
261 namespace=argparse.Namespace(**json.load(f)["args"]))
262
263 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
264 if env_local_rank != -1 and env_local_rank != args.local_rank:
265 args.local_rank = env_local_rank
266
267 if args.train_data_dir is None:
268 raise ValueError("You must specify --train_data_dir")
269
270 if args.pretrained_model_name_or_path is None:
271 raise ValueError("You must specify --pretrained_model_name_or_path")
272
273 if args.placeholder_token is None:
274 raise ValueError("You must specify --placeholder_token")
275
276 if args.initializer_token is None:
277 raise ValueError("You must specify --initializer_token")
278
279 if args.output_dir is None:
280 raise ValueError("You must specify --output_dir")
281
282 return args
283
284
285def freeze_params(params):
286 for param in params:
287 param.requires_grad = False
288
289
290def save_resume_file(basepath, args, extra={}):
291 info = {"args": vars(args)}
292 info["args"].update(extra)
293 with open(f"{basepath}/resume.json", "w") as f:
294 json.dump(info, f, indent=4)
295
296
297def make_grid(images, rows, cols):
298 w, h = images[0].size
299 grid = Image.new('RGB', size=(cols*w, rows*h))
300 for i, image in enumerate(images):
301 grid.paste(image, box=(i % cols*w, i//cols*h))
302 return grid
303
304
305class Checkpointer:
306 def __init__(
307 self,
308 datamodule,
309 accelerator,
310 vae,
311 unet,
312 tokenizer,
313 placeholder_token,
314 placeholder_token_id,
315 output_dir,
316 sample_image_size,
317 random_sample_batches,
318 sample_batch_size,
319 stable_sample_batches,
320 seed
321 ):
322 self.datamodule = datamodule
323 self.accelerator = accelerator
324 self.vae = vae
325 self.unet = unet
326 self.tokenizer = tokenizer
327 self.placeholder_token = placeholder_token
328 self.placeholder_token_id = placeholder_token_id
329 self.output_dir = output_dir
330 self.sample_image_size = sample_image_size
331 self.seed = seed
332 self.random_sample_batches = random_sample_batches
333 self.sample_batch_size = sample_batch_size
334 self.stable_sample_batches = stable_sample_batches
335
336 @torch.no_grad()
337 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
338 print("Saving checkpoint for step %d..." % step)
339 with self.accelerator.autocast():
340 if path is None:
341 checkpoints_path = f"{self.output_dir}/checkpoints"
342 os.makedirs(checkpoints_path, exist_ok=True)
343
344 unwrapped = self.accelerator.unwrap_model(text_encoder)
345
346 # Save a checkpoint
347 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
348 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
349
350 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
351 if path is not None:
352 torch.save(learned_embeds_dict, path)
353 else:
354 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
355 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
356 del unwrapped
357 del learned_embeds
358
359 @torch.no_grad()
360 def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
361 samples_path = f"{self.output_dir}/samples/{mode}"
362 os.makedirs(samples_path, exist_ok=True)
363 checker = NoCheck()
364
365 unwrapped = self.accelerator.unwrap_model(text_encoder)
366 # Save a sample image
367 pipeline = StableDiffusionPipeline(
368 text_encoder=unwrapped,
369 vae=self.vae,
370 unet=self.unet,
371 tokenizer=self.tokenizer,
372 scheduler=LMSDiscreteScheduler(
373 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
374 ),
375 safety_checker=NoCheck(),
376 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
377 ).to(self.accelerator.device)
378 pipeline.enable_attention_slicing()
379
380 data = {
381 "training": self.datamodule.train_dataloader(),
382 "validation": self.datamodule.val_dataloader(),
383 }[mode]
384
385 if mode == "validation" and self.stable_sample_batches > 0 and step > 0:
386 stable_latents = torch.randn(
387 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
388 device=pipeline.device,
389 generator=torch.Generator(device=pipeline.device).manual_seed(self.seed),
390 )
391
392 all_samples = []
393 filename = f"stable_step_%d.png" % (step)
394
395 data_enum = enumerate(data)
396
397 # Generate and save stable samples
398 for i in range(0, self.stable_sample_batches):
399 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
400 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
401
402 with self.accelerator.autocast():
403 samples = pipeline(
404 prompt=prompt,
405 height=self.sample_image_size,
406 latents=stable_latents[:len(prompt)],
407 width=self.sample_image_size,
408 guidance_scale=guidance_scale,
409 eta=eta,
410 num_inference_steps=num_inference_steps,
411 output_type='pil'
412 )["sample"]
413
414 all_samples += samples
415 del samples
416
417 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
418 image_grid.save(f"{samples_path}/{filename}")
419
420 del all_samples
421 del image_grid
422 del stable_latents
423
424 all_samples = []
425 filename = f"step_%d.png" % (step)
426
427 data_enum = enumerate(data)
428
429 # Generate and save random samples
430 for i in range(0, self.random_sample_batches):
431 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
432 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
433
434 with self.accelerator.autocast():
435 samples = pipeline(
436 prompt=prompt,
437 height=self.sample_image_size,
438 width=self.sample_image_size,
439 guidance_scale=guidance_scale,
440 eta=eta,
441 num_inference_steps=num_inference_steps,
442 output_type='pil'
443 )["sample"]
444
445 all_samples += samples
446 del samples
447
448 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
449 image_grid.save(f"{samples_path}/{filename}")
450
451 del all_samples
452 del image_grid
453
454 del checker
455 del unwrapped
456 del pipeline
457 torch.cuda.empty_cache()
458
459
460class ImageToLatents():
461 def __init__(self, vae):
462 self.vae = vae
463 self.encoded_pixel_values_cache = {}
464
465 @torch.no_grad()
466 def __call__(self, batch):
467 key = "|".join(batch["key"])
468 if self.encoded_pixel_values_cache.get(key, None) is None:
469 self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist
470 latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215
471 return latents
472
473
474def main():
475 args = parse_args()
476
477 global_step_offset = 0
478 if args.resume_from is not None:
479 basepath = f"{args.resume_from}"
480 print("Resuming state from %s" % args.resume_from)
481 with open(f"{basepath}/resume.json", 'r') as f:
482 state = json.load(f)
483 global_step_offset = state["args"].get("global_step", 0)
484
485 print("We've trained %d steps so far" % global_step_offset)
486 else:
487 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
488 basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}"
489 os.makedirs(basepath, exist_ok=True)
490
491 accelerator = Accelerator(
492 log_with=LoggerType.TENSORBOARD,
493 logging_dir=f"{basepath}",
494 gradient_accumulation_steps=args.gradient_accumulation_steps,
495 mixed_precision=args.mixed_precision
496 )
497
498 # If passed along, set the training seed now.
499 if args.seed is not None:
500 set_seed(args.seed)
501
502 # Load the tokenizer and add the placeholder token as a additional special token
503 if args.tokenizer_name:
504 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
505 elif args.pretrained_model_name_or_path:
506 tokenizer = CLIPTokenizer.from_pretrained(
507 args.pretrained_model_name_or_path + '/tokenizer'
508 )
509
510 # Add the placeholder token in tokenizer
511 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
512 if num_added_tokens == 0:
513 raise ValueError(
514 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
515 " `placeholder_token` that is not already in the tokenizer."
516 )
517
518 # Convert the initializer_token, placeholder_token to ids
519 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
520 # Check if initializer_token is a single token or a sequence of tokens
521 if args.vectors_per_token % len(initializer_token_ids) != 0:
522 raise ValueError(
523 f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).")
524
525 initializer_token_ids = torch.tensor(initializer_token_ids)
526 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
527
528 # Load models and create wrapper for stable diffusion
529 text_encoder = CLIPTextModel.from_pretrained(
530 args.pretrained_model_name_or_path + '/text_encoder',
531 )
532 vae = AutoencoderKL.from_pretrained(
533 args.pretrained_model_name_or_path + '/vae',
534 )
535 unet = UNet2DConditionModel.from_pretrained(
536 args.pretrained_model_name_or_path + '/unet',
537 )
538
539 if args.gradient_checkpointing:
540 unet.enable_gradient_checkpointing()
541
542 slice_size = unet.config.attention_head_dim // 2
543 unet.set_attention_slice(slice_size)
544
545 # Resize the token embeddings as we are adding new special tokens to the tokenizer
546 text_encoder.resize_token_embeddings(len(tokenizer))
547
548 # Initialise the newly added placeholder token with the embeddings of the initializer token
549 token_embeds = text_encoder.get_input_embeddings().weight.data
550
551 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
552
553 if args.resume_checkpoint is not None:
554 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
555 args.placeholder_token]
556 else:
557 token_embeds[placeholder_token_id] = initializer_token_embeddings
558
559 # Freeze vae and unet
560 freeze_params(vae.parameters())
561 freeze_params(unet.parameters())
562 # Freeze all parameters except for the token embeddings in text encoder
563 params_to_freeze = itertools.chain(
564 text_encoder.text_model.encoder.parameters(),
565 text_encoder.text_model.final_layer_norm.parameters(),
566 text_encoder.text_model.embeddings.position_embedding.parameters(),
567 )
568 freeze_params(params_to_freeze)
569
570 if args.scale_lr:
571 args.learning_rate = (
572 args.learning_rate * args.gradient_accumulation_steps *
573 args.train_batch_size * accelerator.num_processes
574 )
575
576 # Initialize the optimizer
577 optimizer = torch.optim.AdamW(
578 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
579 lr=args.learning_rate,
580 betas=(args.adam_beta1, args.adam_beta2),
581 weight_decay=args.adam_weight_decay,
582 eps=args.adam_epsilon,
583 )
584
585 # TODO (patil-suraj): laod scheduler using args
586 noise_scheduler = DDPMScheduler(
587 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
588 )
589
590 datamodule = CSVDataModule(
591 data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer,
592 size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats,
593 center_crop=args.center_crop)
594
595 datamodule.prepare_data()
596 datamodule.setup()
597
598 train_dataloader = datamodule.train_dataloader()
599 val_dataloader = datamodule.val_dataloader()
600
601 checkpointer = Checkpointer(
602 datamodule=datamodule,
603 accelerator=accelerator,
604 vae=vae,
605 unet=unet,
606 tokenizer=tokenizer,
607 placeholder_token=args.placeholder_token,
608 placeholder_token_id=placeholder_token_id,
609 output_dir=basepath,
610 sample_image_size=args.sample_image_size,
611 sample_batch_size=args.sample_batch_size,
612 random_sample_batches=args.random_sample_batches,
613 stable_sample_batches=args.stable_sample_batches,
614 seed=args.seed
615 )
616
617 # Scheduler and math around the number of training steps.
618 overrode_max_train_steps = False
619 num_update_steps_per_epoch = math.ceil(
620 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
621 if args.max_train_steps is None:
622 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
623 overrode_max_train_steps = True
624
625 lr_scheduler = get_scheduler(
626 args.lr_scheduler,
627 optimizer=optimizer,
628 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
629 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
630 )
631
632 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
633 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
634 )
635
636 # Move vae and unet to device
637 vae.to(accelerator.device)
638 unet.to(accelerator.device)
639
640 # Keep vae and unet in eval mode as we don't train these
641 vae.eval()
642 unet.eval()
643
644 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
645 num_update_steps_per_epoch = math.ceil(
646 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
647 if overrode_max_train_steps:
648 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
649 # Afterwards we recalculate our number of training epochs
650 args.num_train_epochs = math.ceil(
651 args.max_train_steps / num_update_steps_per_epoch)
652
653 # We need to initialize the trackers we use, and also store our configuration.
654 # The trackers initializes automatically on the main process.
655 if accelerator.is_main_process:
656 accelerator.init_trackers("textual_inversion", config=vars(args))
657
658 # Train!
659 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
660
661 logger.info("***** Running training *****")
662 logger.info(f" Num Epochs = {args.num_train_epochs}")
663 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
664 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
665 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
666 logger.info(f" Total optimization steps = {args.max_train_steps}")
667 # Only show the progress bar once on each machine.
668
669 global_step = 0
670 min_val_loss = np.inf
671
672 imageToLatents = ImageToLatents(vae)
673
674 checkpointer.save_samples(
675 "validation",
676 0,
677 text_encoder,
678 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
679
680 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
681 progress_bar.set_description("Global steps")
682
683 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
684 local_progress_bar.set_description("Steps")
685
686 try:
687 for epoch in range(args.num_train_epochs):
688 local_progress_bar.reset()
689
690 text_encoder.train()
691 train_loss = 0.0
692
693 for step, batch in enumerate(train_dataloader):
694 with accelerator.accumulate(text_encoder):
695 with accelerator.autocast():
696 # Convert images to latent space
697 latents = imageToLatents(batch)
698
699 # Sample noise that we'll add to the latents
700 noise = torch.randn(latents.shape).to(latents.device)
701 bsz = latents.shape[0]
702 # Sample a random timestep for each image
703 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
704 (bsz,), device=latents.device).long()
705
706 # Add noise to the latents according to the noise magnitude at each timestep
707 # (this is the forward diffusion process)
708 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
709
710 # Get the text embedding for conditioning
711 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
712
713 # Predict the noise residual
714 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
715
716 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
717
718 accelerator.backward(loss)
719
720 # Zero out the gradients for all token embeddings except the newly added
721 # embeddings for the concept, as we only want to optimize the concept embeddings
722 if accelerator.num_processes > 1:
723 grads = text_encoder.module.get_input_embeddings().weight.grad
724 else:
725 grads = text_encoder.get_input_embeddings().weight.grad
726 # Get the index for tokens that we want to zero the grads for
727 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
728 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
729
730 optimizer.step()
731 if not accelerator.optimizer_step_was_skipped:
732 lr_scheduler.step()
733 optimizer.zero_grad()
734
735 loss = loss.detach().item()
736 train_loss += loss
737
738 # Checks if the accelerator has performed an optimization step behind the scenes
739 if accelerator.sync_gradients:
740 progress_bar.update(1)
741 local_progress_bar.update(1)
742
743 global_step += 1
744
745 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
746 progress_bar.clear()
747 local_progress_bar.clear()
748
749 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
750 save_resume_file(basepath, args, {
751 "global_step": global_step + global_step_offset,
752 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
753 })
754 checkpointer.save_samples(
755 "training",
756 global_step + global_step_offset,
757 text_encoder,
758 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
759
760 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
761 local_progress_bar.set_postfix(**logs)
762
763 if global_step >= args.max_train_steps:
764 break
765
766 train_loss /= len(train_dataloader)
767
768 text_encoder.eval()
769 val_loss = 0.0
770
771 for step, batch in enumerate(val_dataloader):
772 with torch.no_grad(), accelerator.autocast():
773 latents = imageToLatents(batch)
774
775 noise = torch.randn(latents.shape).to(latents.device)
776 bsz = latents.shape[0]
777 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
778 (bsz,), device=latents.device).long()
779
780 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
781
782 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
783
784 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
785
786 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
787
788 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
789
790 loss = loss.detach().item()
791 val_loss += loss
792
793 if accelerator.sync_gradients:
794 progress_bar.update(1)
795 local_progress_bar.update(1)
796
797 logs = {"mode": "validation", "loss": loss}
798 local_progress_bar.set_postfix(**logs)
799
800 val_loss /= len(val_dataloader)
801
802 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
803
804 progress_bar.clear()
805 local_progress_bar.clear()
806
807 if min_val_loss > val_loss:
808 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
809 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
810 min_val_loss = val_loss
811
812 checkpointer.save_samples(
813 "validation",
814 global_step + global_step_offset,
815 text_encoder,
816 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
817
818 accelerator.wait_for_everyone()
819
820 # Create the pipeline using using the trained modules and save it.
821 if accelerator.is_main_process:
822 print("Finished! Saving final checkpoint and resume state.")
823 checkpointer.checkpoint(
824 global_step + global_step_offset,
825 "end",
826 text_encoder,
827 path=f"{basepath}/learned_embeds.bin"
828 )
829
830 save_resume_file(basepath, args, {
831 "global_step": global_step + global_step_offset,
832 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
833 })
834
835 accelerator.end_training()
836
837 except KeyboardInterrupt:
838 if accelerator.is_main_process:
839 print("Interrupted, saving checkpoint and resume state...")
840 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder)
841 save_resume_file(basepath, args, {
842 "global_step": global_step + global_step_offset,
843 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
844 })
845 accelerator.end_training()
846 quit()
847
848
849if __name__ == "__main__":
850 main()