summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 11:26:31 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 11:26:31 +0200
commit0f493e1ac8406de061861ed390f283e821180e79 (patch)
tree0186a40130f095f1a3bdaa3bf4064a5bd5d35187 /textual_inversion.py
parentSmall performance improvements (diff)
downloadtextual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.tar.gz
textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.tar.bz2
textual-inversion-diff-0f493e1ac8406de061861ed390f283e821180e79.zip
Use euler_a for samples in learning scripts; backported improvement from Dreambooth to Textual Inversion
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py308
1 files changed, 164 insertions, 144 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 399d876..7a7d7fc 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -3,6 +3,8 @@ import itertools
3import math 3import math
4import os 4import os
5import datetime 5import datetime
6import logging
7from pathlib import Path
6 8
7import numpy as np 9import numpy as np
8import torch 10import torch
@@ -13,12 +15,13 @@ from accelerate import Accelerator
13from accelerate.logging import get_logger 15from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
15from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler
16from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
17from pipelines.stable_diffusion.no_check import NoCheck
18from PIL import Image 20from PIL import Image
19from tqdm.auto import tqdm 21from tqdm.auto import tqdm
20from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
21from slugify import slugify 23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
22import json 25import json
23import os 26import os
24 27
@@ -44,10 +47,10 @@ def parse_args():
44 help="Pretrained tokenizer name or path if not the same as model_name", 47 help="Pretrained tokenizer name or path if not the same as model_name",
45 ) 48 )
46 parser.add_argument( 49 parser.add_argument(
47 "--train_data_dir", 50 "--train_data_file",
48 type=str, 51 type=str,
49 default=None, 52 default=None,
50 help="A folder containing the training data." 53 help="A CSV file containing the training data."
51 ) 54 )
52 parser.add_argument( 55 parser.add_argument(
53 "--placeholder_token", 56 "--placeholder_token",
@@ -146,6 +149,11 @@ def parse_args():
146 help="Number of steps for the warmup in the lr scheduler." 149 help="Number of steps for the warmup in the lr scheduler."
147 ) 150 )
148 parser.add_argument( 151 parser.add_argument(
152 "--use_8bit_adam",
153 action="store_true",
154 help="Whether or not to use 8-bit Adam from bitsandbytes."
155 )
156 parser.add_argument(
149 "--adam_beta1", 157 "--adam_beta1",
150 type=float, 158 type=float,
151 default=0.9, 159 default=0.9,
@@ -225,7 +233,7 @@ def parse_args():
225 parser.add_argument( 233 parser.add_argument(
226 "--sample_steps", 234 "--sample_steps",
227 type=int, 235 type=int,
228 default=50, 236 default=30,
229 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 237 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
230 ) 238 )
231 parser.add_argument( 239 parser.add_argument(
@@ -261,8 +269,8 @@ def parse_args():
261 if env_local_rank != -1 and env_local_rank != args.local_rank: 269 if env_local_rank != -1 and env_local_rank != args.local_rank:
262 args.local_rank = env_local_rank 270 args.local_rank = env_local_rank
263 271
264 if args.train_data_dir is None: 272 if args.train_data_file is None:
265 raise ValueError("You must specify --train_data_dir") 273 raise ValueError("You must specify --train_data_file")
266 274
267 if args.pretrained_model_name_or_path is None: 275 if args.pretrained_model_name_or_path is None:
268 raise ValueError("You must specify --pretrained_model_name_or_path") 276 raise ValueError("You must specify --pretrained_model_name_or_path")
@@ -333,53 +341,51 @@ class Checkpointer:
333 @torch.no_grad() 341 @torch.no_grad()
334 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): 342 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
335 print("Saving checkpoint for step %d..." % step) 343 print("Saving checkpoint for step %d..." % step)
336 with self.accelerator.autocast(): 344
337 if path is None: 345 if path is None:
338 checkpoints_path = f"{self.output_dir}/checkpoints" 346 checkpoints_path = f"{self.output_dir}/checkpoints"
339 os.makedirs(checkpoints_path, exist_ok=True) 347 os.makedirs(checkpoints_path, exist_ok=True)
340 348
341 unwrapped = self.accelerator.unwrap_model(text_encoder) 349 unwrapped = self.accelerator.unwrap_model(text_encoder)
342 350
343 # Save a checkpoint 351 # Save a checkpoint
344 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 352 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
345 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 353 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
346 354
347 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) 355 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
348 if path is not None: 356 if path is not None:
349 torch.save(learned_embeds_dict, path) 357 torch.save(learned_embeds_dict, path)
350 else: 358 else:
351 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") 359 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
352 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") 360 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
353 del unwrapped 361
354 del learned_embeds 362 del unwrapped
363 del learned_embeds
355 364
356 @torch.no_grad() 365 @torch.no_grad()
357 def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): 366 def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
358 samples_path = f"{self.output_dir}/samples/{mode}" 367 samples_path = Path(self.output_dir).joinpath("samples")
359 os.makedirs(samples_path, exist_ok=True)
360 checker = NoCheck()
361 368
362 unwrapped = self.accelerator.unwrap_model(text_encoder) 369 unwrapped = self.accelerator.unwrap_model(text_encoder)
370 scheduler = EulerAScheduler(
371 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
372 )
373
363 # Save a sample image 374 # Save a sample image
364 pipeline = StableDiffusionPipeline( 375 pipeline = VlpnStableDiffusion(
365 text_encoder=unwrapped, 376 text_encoder=unwrapped,
366 vae=self.vae, 377 vae=self.vae,
367 unet=self.unet, 378 unet=self.unet,
368 tokenizer=self.tokenizer, 379 tokenizer=self.tokenizer,
369 scheduler=LMSDiscreteScheduler( 380 scheduler=scheduler,
370 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
371 ),
372 safety_checker=NoCheck(),
373 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 381 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
374 ).to(self.accelerator.device) 382 ).to(self.accelerator.device)
375 pipeline.enable_attention_slicing() 383 pipeline.enable_attention_slicing()
376 384
377 data = { 385 train_data = self.datamodule.train_dataloader()
378 "training": self.datamodule.train_dataloader(), 386 val_data = self.datamodule.val_dataloader()
379 "validation": self.datamodule.val_dataloader(),
380 }[mode]
381 387
382 if mode == "validation" and self.stable_sample_batches > 0 and step > 0: 388 if self.stable_sample_batches > 0:
383 stable_latents = torch.randn( 389 stable_latents = torch.randn(
384 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 390 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
385 device=pipeline.device, 391 device=pipeline.device,
@@ -387,14 +393,17 @@ class Checkpointer:
387 ) 393 )
388 394
389 all_samples = [] 395 all_samples = []
390 filename = f"stable_step_%d.png" % (step) 396 file_path = samples_path.joinpath("stable", f"step_{step}.png")
397 file_path.parent.mkdir(parents=True, exist_ok=True)
391 398
392 data_enum = enumerate(data) 399 data_enum = enumerate(val_data)
393 400
394 # Generate and save stable samples 401 # Generate and save stable samples
395 for i in range(0, self.stable_sample_batches): 402 for i in range(0, self.stable_sample_batches):
396 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 403 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
397 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] 404 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
405
406 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
398 407
399 with self.accelerator.autocast(): 408 with self.accelerator.autocast():
400 samples = pipeline( 409 samples = pipeline(
@@ -405,67 +414,64 @@ class Checkpointer:
405 guidance_scale=guidance_scale, 414 guidance_scale=guidance_scale,
406 eta=eta, 415 eta=eta,
407 num_inference_steps=num_inference_steps, 416 num_inference_steps=num_inference_steps,
417 generator=generator,
408 output_type='pil' 418 output_type='pil'
409 )["sample"] 419 )["sample"]
410 420
411 all_samples += samples 421 all_samples += samples
422
423 del generator
412 del samples 424 del samples
413 425
414 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) 426 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
415 image_grid.save(f"{samples_path}/{filename}") 427 image_grid.save(file_path)
416 428
417 del all_samples 429 del all_samples
418 del image_grid 430 del image_grid
419 del stable_latents 431 del stable_latents
420 432
421 all_samples = [] 433 for data, pool in [(val_data, "val"), (train_data, "train")]:
422 filename = f"step_%d.png" % (step) 434 all_samples = []
435 file_path = samples_path.joinpath(pool, f"step_{step}.png")
436 file_path.parent.mkdir(parents=True, exist_ok=True)
423 437
424 data_enum = enumerate(data) 438 data_enum = enumerate(data)
425 439
426 # Generate and save random samples 440 for i in range(0, self.random_sample_batches):
427 for i in range(0, self.random_sample_batches): 441 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
428 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 442 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
429 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
430 443
431 with self.accelerator.autocast(): 444 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
432 samples = pipeline(
433 prompt=prompt,
434 height=self.sample_image_size,
435 width=self.sample_image_size,
436 guidance_scale=guidance_scale,
437 eta=eta,
438 num_inference_steps=num_inference_steps,
439 output_type='pil'
440 )["sample"]
441 445
442 all_samples += samples 446 with self.accelerator.autocast():
443 del samples 447 samples = pipeline(
448 prompt=prompt,
449 height=self.sample_image_size,
450 width=self.sample_image_size,
451 guidance_scale=guidance_scale,
452 eta=eta,
453 num_inference_steps=num_inference_steps,
454 generator=generator,
455 output_type='pil'
456 )["sample"]
444 457
445 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 458 all_samples += samples
446 image_grid.save(f"{samples_path}/{filename}")
447 459
448 del all_samples 460 del generator
449 del image_grid 461 del samples
450 462
451 del checker 463 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
452 del unwrapped 464 image_grid.save(file_path)
453 del pipeline
454 torch.cuda.empty_cache()
455 465
466 del all_samples
467 del image_grid
456 468
457class ImageToLatents(): 469 del unwrapped
458 def __init__(self, vae): 470 del scheduler
459 self.vae = vae 471 del pipeline
460 self.encoded_pixel_values_cache = {}
461 472
462 @torch.no_grad() 473 if torch.cuda.is_available():
463 def __call__(self, batch): 474 torch.cuda.empty_cache()
464 key = "|".join(batch["key"])
465 if self.encoded_pixel_values_cache.get(key, None) is None:
466 self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist
467 latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215
468 return latents
469 475
470 476
471def main(): 477def main():
@@ -473,17 +479,17 @@ def main():
473 479
474 global_step_offset = 0 480 global_step_offset = 0
475 if args.resume_from is not None: 481 if args.resume_from is not None:
476 basepath = f"{args.resume_from}" 482 basepath = Path(args.resume_from)
477 print("Resuming state from %s" % args.resume_from) 483 print("Resuming state from %s" % args.resume_from)
478 with open(f"{basepath}/resume.json", 'r') as f: 484 with open(basepath.joinpath("resume.json"), 'r') as f:
479 state = json.load(f) 485 state = json.load(f)
480 global_step_offset = state["args"].get("global_step", 0) 486 global_step_offset = state["args"].get("global_step", 0)
481 487
482 print("We've trained %d steps so far" % global_step_offset) 488 print("We've trained %d steps so far" % global_step_offset)
483 else: 489 else:
484 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 490 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
485 basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}" 491 basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now)
486 os.makedirs(basepath, exist_ok=True) 492 basepath.mkdir(parents=True, exist_ok=True)
487 493
488 accelerator = Accelerator( 494 accelerator = Accelerator(
489 log_with=LoggerType.TENSORBOARD, 495 log_with=LoggerType.TENSORBOARD,
@@ -492,6 +498,8 @@ def main():
492 mixed_precision=args.mixed_precision 498 mixed_precision=args.mixed_precision
493 ) 499 )
494 500
501 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
502
495 # If passed along, set the training seed now. 503 # If passed along, set the training seed now.
496 if args.seed is not None: 504 if args.seed is not None:
497 set_seed(args.seed) 505 set_seed(args.seed)
@@ -570,8 +578,19 @@ def main():
570 args.train_batch_size * accelerator.num_processes 578 args.train_batch_size * accelerator.num_processes
571 ) 579 )
572 580
581 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
582 if args.use_8bit_adam:
583 try:
584 import bitsandbytes as bnb
585 except ImportError:
586 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
587
588 optimizer_class = bnb.optim.AdamW8bit
589 else:
590 optimizer_class = torch.optim.AdamW
591
573 # Initialize the optimizer 592 # Initialize the optimizer
574 optimizer = torch.optim.AdamW( 593 optimizer = optimizer_class(
575 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 594 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
576 lr=args.learning_rate, 595 lr=args.learning_rate,
577 betas=(args.adam_beta1, args.adam_beta2), 596 betas=(args.adam_beta1, args.adam_beta2),
@@ -585,7 +604,7 @@ def main():
585 ) 604 )
586 605
587 datamodule = CSVDataModule( 606 datamodule = CSVDataModule(
588 data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer, 607 data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer,
589 size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, 608 size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats,
590 center_crop=args.center_crop) 609 center_crop=args.center_crop)
591 610
@@ -608,13 +627,12 @@ def main():
608 sample_batch_size=args.sample_batch_size, 627 sample_batch_size=args.sample_batch_size,
609 random_sample_batches=args.random_sample_batches, 628 random_sample_batches=args.random_sample_batches,
610 stable_sample_batches=args.stable_sample_batches, 629 stable_sample_batches=args.stable_sample_batches,
611 seed=args.seed 630 seed=args.seed or torch.random.seed()
612 ) 631 )
613 632
614 # Scheduler and math around the number of training steps. 633 # Scheduler and math around the number of training steps.
615 overrode_max_train_steps = False 634 overrode_max_train_steps = False
616 num_update_steps_per_epoch = math.ceil( 635 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
617 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
618 if args.max_train_steps is None: 636 if args.max_train_steps is None:
619 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 637 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
620 overrode_max_train_steps = True 638 overrode_max_train_steps = True
@@ -643,9 +661,10 @@ def main():
643 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) 661 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
644 if overrode_max_train_steps: 662 if overrode_max_train_steps:
645 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 663 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
646 # Afterwards we recalculate our number of training epochs 664
647 args.num_train_epochs = math.ceil( 665 num_val_steps_per_epoch = len(val_dataloader)
648 args.max_train_steps / num_update_steps_per_epoch) 666 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
667 val_steps = num_val_steps_per_epoch * num_epochs
649 668
650 # We need to initialize the trackers we use, and also store our configuration. 669 # We need to initialize the trackers we use, and also store our configuration.
651 # The trackers initializes automatically on the main process. 670 # The trackers initializes automatically on the main process.
@@ -656,7 +675,7 @@ def main():
656 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 675 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
657 676
658 logger.info("***** Running training *****") 677 logger.info("***** Running training *****")
659 logger.info(f" Num Epochs = {args.num_train_epochs}") 678 logger.info(f" Num Epochs = {num_epochs}")
660 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 679 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
661 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 680 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
662 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 681 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
@@ -666,22 +685,22 @@ def main():
666 global_step = 0 685 global_step = 0
667 min_val_loss = np.inf 686 min_val_loss = np.inf
668 687
669 imageToLatents = ImageToLatents(vae) 688 if accelerator.is_main_process:
670 689 checkpointer.save_samples(
671 checkpointer.save_samples( 690 0,
672 "validation", 691 text_encoder,
673 0, 692 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
674 text_encoder,
675 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
676 693
677 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 694 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch),
678 progress_bar.set_description("Global steps") 695 disable=not accelerator.is_local_main_process)
696 local_progress_bar.set_description("Batch X out of Y")
679 697
680 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) 698 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process)
681 local_progress_bar.set_description("Steps") 699 global_progress_bar.set_description("Total progress")
682 700
683 try: 701 try:
684 for epoch in range(args.num_train_epochs): 702 for epoch in range(num_epochs):
703 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}")
685 local_progress_bar.reset() 704 local_progress_bar.reset()
686 705
687 text_encoder.train() 706 text_encoder.train()
@@ -689,27 +708,30 @@ def main():
689 708
690 for step, batch in enumerate(train_dataloader): 709 for step, batch in enumerate(train_dataloader):
691 with accelerator.accumulate(text_encoder): 710 with accelerator.accumulate(text_encoder):
692 with accelerator.autocast(): 711 # Convert images to latent space
693 # Convert images to latent space 712 with torch.no_grad():
694 latents = imageToLatents(batch) 713 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
714 latents = latents * 0.18215
695 715
696 # Sample noise that we'll add to the latents 716 # Sample noise that we'll add to the latents
697 noise = torch.randn(latents.shape).to(latents.device) 717 noise = torch.randn(latents.shape).to(latents.device)
698 bsz = latents.shape[0] 718 bsz = latents.shape[0]
699 # Sample a random timestep for each image 719 # Sample a random timestep for each image
700 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, 720 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
701 (bsz,), device=latents.device).long() 721 (bsz,), device=latents.device)
722 timesteps = timesteps.long()
702 723
703 # Add noise to the latents according to the noise magnitude at each timestep 724 # Add noise to the latents according to the noise magnitude at each timestep
704 # (this is the forward diffusion process) 725 # (this is the forward diffusion process)
705 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 726 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
706 727
707 # Get the text embedding for conditioning 728 # Get the text embedding for conditioning
708 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 729 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
709 730
710 # Predict the noise residual 731 # Predict the noise residual
711 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 732 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
712 733
734 with accelerator.autocast():
713 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 735 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
714 736
715 accelerator.backward(loss) 737 accelerator.backward(loss)
@@ -727,32 +749,27 @@ def main():
727 optimizer.step() 749 optimizer.step()
728 if not accelerator.optimizer_step_was_skipped: 750 if not accelerator.optimizer_step_was_skipped:
729 lr_scheduler.step() 751 lr_scheduler.step()
730 optimizer.zero_grad() 752 optimizer.zero_grad(set_to_none=True)
731 753
732 loss = loss.detach().item() 754 loss = loss.detach().item()
733 train_loss += loss 755 train_loss += loss
734 756
735 # Checks if the accelerator has performed an optimization step behind the scenes 757 # Checks if the accelerator has performed an optimization step behind the scenes
736 if accelerator.sync_gradients: 758 if accelerator.sync_gradients:
737 progress_bar.update(1)
738 local_progress_bar.update(1) 759 local_progress_bar.update(1)
760 global_progress_bar.update(1)
739 761
740 global_step += 1 762 global_step += 1
741 763
742 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: 764 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
743 progress_bar.clear()
744 local_progress_bar.clear() 765 local_progress_bar.clear()
766 global_progress_bar.clear()
745 767
746 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) 768 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
747 save_resume_file(basepath, args, { 769 save_resume_file(basepath, args, {
748 "global_step": global_step + global_step_offset, 770 "global_step": global_step + global_step_offset,
749 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 771 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
750 }) 772 })
751 checkpointer.save_samples(
752 "training",
753 global_step + global_step_offset,
754 text_encoder,
755 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
756 773
757 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 774 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
758 local_progress_bar.set_postfix(**logs) 775 local_progress_bar.set_postfix(**logs)
@@ -762,17 +779,21 @@ def main():
762 779
763 train_loss /= len(train_dataloader) 780 train_loss /= len(train_dataloader)
764 781
782 accelerator.wait_for_everyone()
783
765 text_encoder.eval() 784 text_encoder.eval()
766 val_loss = 0.0 785 val_loss = 0.0
767 786
768 for step, batch in enumerate(val_dataloader): 787 for step, batch in enumerate(val_dataloader):
769 with torch.no_grad(), accelerator.autocast(): 788 with torch.no_grad():
770 latents = imageToLatents(batch) 789 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
790 latents = latents * 0.18215
771 791
772 noise = torch.randn(latents.shape).to(latents.device) 792 noise = torch.randn(latents.shape).to(latents.device)
773 bsz = latents.shape[0] 793 bsz = latents.shape[0]
774 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, 794 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
775 (bsz,), device=latents.device).long() 795 (bsz,), device=latents.device)
796 timesteps = timesteps.long()
776 797
777 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 798 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
778 799
@@ -782,14 +803,15 @@ def main():
782 803
783 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 804 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
784 805
785 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 806 with accelerator.autocast():
807 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
786 808
787 loss = loss.detach().item() 809 loss = loss.detach().item()
788 val_loss += loss 810 val_loss += loss
789 811
790 if accelerator.sync_gradients: 812 if accelerator.sync_gradients:
791 progress_bar.update(1)
792 local_progress_bar.update(1) 813 local_progress_bar.update(1)
814 global_progress_bar.update(1)
793 815
794 logs = {"mode": "validation", "loss": loss} 816 logs = {"mode": "validation", "loss": loss}
795 local_progress_bar.set_postfix(**logs) 817 local_progress_bar.set_postfix(**logs)
@@ -798,21 +820,19 @@ def main():
798 820
799 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) 821 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
800 822
801 progress_bar.clear()
802 local_progress_bar.clear() 823 local_progress_bar.clear()
824 global_progress_bar.clear()
803 825
804 if min_val_loss > val_loss: 826 if min_val_loss > val_loss:
805 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 827 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
806 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) 828 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
807 min_val_loss = val_loss 829 min_val_loss = val_loss
808 830
809 checkpointer.save_samples( 831 if accelerator.is_main_process:
810 "validation", 832 checkpointer.save_samples(
811 global_step + global_step_offset, 833 global_step + global_step_offset,
812 text_encoder, 834 text_encoder,
813 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 835 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
814
815 accelerator.wait_for_everyone()
816 836
817 # Create the pipeline using using the trained modules and save it. 837 # Create the pipeline using using the trained modules and save it.
818 if accelerator.is_main_process: 838 if accelerator.is_main_process: