summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-14 17:09:58 +0100
committerVolpeon <git@volpeon.ink>2022-11-14 17:09:58 +0100
commit2ad46871e2ead985445da2848a4eb7072b6e48aa (patch)
tree3137923e2c00fe1d3cd37ddcc93c8a847b0c0762 /dreambooth.py
parentUpdate (diff)
downloadtextual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.tar.gz
textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.tar.bz2
textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.zip
Update
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py71
1 files changed, 41 insertions, 30 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 8c4bf50..7b34fce 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -15,7 +15,7 @@ import torch.utils.checkpoint
15from accelerate import Accelerator 15from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from diffusers.training_utils import EMAModel 20from diffusers.training_utils import EMAModel
21from PIL import Image 21from PIL import Image
@@ -23,7 +23,6 @@ from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer 23from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule 27from data.csv import CSVDataModule
29from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
@@ -144,7 +143,7 @@ def parse_args():
144 parser.add_argument( 143 parser.add_argument(
145 "--max_train_steps", 144 "--max_train_steps",
146 type=int, 145 type=int,
147 default=6000, 146 default=None,
148 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 147 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
149 ) 148 )
150 parser.add_argument( 149 parser.add_argument(
@@ -211,7 +210,7 @@ def parse_args():
211 parser.add_argument( 210 parser.add_argument(
212 "--ema_power", 211 "--ema_power",
213 type=float, 212 type=float,
214 default=7 / 8 213 default=6/7
215 ) 214 )
216 parser.add_argument( 215 parser.add_argument(
217 "--ema_max_decay", 216 "--ema_max_decay",
@@ -284,6 +283,12 @@ def parse_args():
284 help="Number of samples to generate per batch", 283 help="Number of samples to generate per batch",
285 ) 284 )
286 parser.add_argument( 285 parser.add_argument(
286 "--valid_set_size",
287 type=int,
288 default=None,
289 help="Number of images in the validation dataset."
290 )
291 parser.add_argument(
287 "--train_batch_size", 292 "--train_batch_size",
288 type=int, 293 type=int,
289 default=1, 294 default=1,
@@ -292,7 +297,7 @@ def parse_args():
292 parser.add_argument( 297 parser.add_argument(
293 "--sample_steps", 298 "--sample_steps",
294 type=int, 299 type=int,
295 default=30, 300 default=25,
296 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 301 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
297 ) 302 )
298 parser.add_argument( 303 parser.add_argument(
@@ -461,7 +466,7 @@ class Checkpointer:
461 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 466 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
462 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) 467 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
463 468
464 scheduler = EulerAncestralDiscreteScheduler( 469 scheduler = DPMSolverMultistepScheduler(
465 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 470 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
466 ) 471 )
467 472
@@ -487,23 +492,30 @@ class Checkpointer:
487 with torch.inference_mode(): 492 with torch.inference_mode():
488 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 493 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
489 all_samples = [] 494 all_samples = []
490 file_path = samples_path.joinpath(pool, f"step_{step}.png") 495 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
491 file_path.parent.mkdir(parents=True, exist_ok=True) 496 file_path.parent.mkdir(parents=True, exist_ok=True)
492 497
493 data_enum = enumerate(data) 498 data_enum = enumerate(data)
494 499
500 batches = [
501 batch
502 for j, batch in data_enum
503 if j * data.batch_size < self.sample_batch_size * self.sample_batches
504 ]
505 prompts = [
506 prompt.format(identifier=self.instance_identifier)
507 for batch in batches
508 for prompt in batch["prompts"]
509 ]
510 nprompts = [
511 prompt
512 for batch in batches
513 for prompt in batch["nprompts"]
514 ]
515
495 for i in range(self.sample_batches): 516 for i in range(self.sample_batches):
496 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 517 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
497 prompt = [ 518 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
498 prompt.format(identifier=self.instance_identifier)
499 for batch in batches
500 for prompt in batch["prompts"]
501 ][:self.sample_batch_size]
502 nprompt = [
503 prompt
504 for batch in batches
505 for prompt in batch["nprompts"]
506 ][:self.sample_batch_size]
507 519
508 samples = pipeline( 520 samples = pipeline(
509 prompt=prompt, 521 prompt=prompt,
@@ -523,7 +535,7 @@ class Checkpointer:
523 del samples 535 del samples
524 536
525 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) 537 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
526 image_grid.save(file_path) 538 image_grid.save(file_path, quality=85)
527 539
528 del all_samples 540 del all_samples
529 del image_grid 541 del image_grid
@@ -576,6 +588,12 @@ def main():
576 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 588 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
577 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') 589 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
578 590
591 unet.set_use_memory_efficient_attention_xformers(True)
592
593 if args.gradient_checkpointing:
594 unet.enable_gradient_checkpointing()
595 text_encoder.gradient_checkpointing_enable()
596
579 ema_unet = None 597 ema_unet = None
580 if args.use_ema: 598 if args.use_ema:
581 ema_unet = EMAModel( 599 ema_unet = EMAModel(
@@ -586,12 +604,6 @@ def main():
586 device=accelerator.device 604 device=accelerator.device
587 ) 605 )
588 606
589 unet.set_use_memory_efficient_attention_xformers(True)
590
591 if args.gradient_checkpointing:
592 unet.enable_gradient_checkpointing()
593 text_encoder.gradient_checkpointing_enable()
594
595 # Freeze text_encoder and vae 607 # Freeze text_encoder and vae
596 freeze_params(vae.parameters()) 608 freeze_params(vae.parameters())
597 609
@@ -726,7 +738,7 @@ def main():
726 size=args.resolution, 738 size=args.resolution,
727 repeats=args.repeats, 739 repeats=args.repeats,
728 center_crop=args.center_crop, 740 center_crop=args.center_crop,
729 valid_set_size=args.sample_batch_size*args.sample_batches, 741 valid_set_size=args.valid_set_size,
730 num_workers=args.dataloader_num_workers, 742 num_workers=args.dataloader_num_workers,
731 collate_fn=collate_fn 743 collate_fn=collate_fn
732 ) 744 )
@@ -743,7 +755,7 @@ def main():
743 for i in range(0, len(missing_data), args.sample_batch_size) 755 for i in range(0, len(missing_data), args.sample_batch_size)
744 ] 756 ]
745 757
746 scheduler = EulerAncestralDiscreteScheduler( 758 scheduler = DPMSolverMultistepScheduler(
747 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 759 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
748 ) 760 )
749 761
@@ -962,6 +974,8 @@ def main():
962 optimizer.step() 974 optimizer.step()
963 if not accelerator.optimizer_step_was_skipped: 975 if not accelerator.optimizer_step_was_skipped:
964 lr_scheduler.step() 976 lr_scheduler.step()
977 if args.use_ema:
978 ema_unet.step(unet)
965 optimizer.zero_grad(set_to_none=True) 979 optimizer.zero_grad(set_to_none=True)
966 980
967 loss = loss.detach().item() 981 loss = loss.detach().item()
@@ -969,9 +983,6 @@ def main():
969 983
970 # Checks if the accelerator has performed an optimization step behind the scenes 984 # Checks if the accelerator has performed an optimization step behind the scenes
971 if accelerator.sync_gradients: 985 if accelerator.sync_gradients:
972 if args.use_ema:
973 ema_unet.step(unet)
974
975 local_progress_bar.update(1) 986 local_progress_bar.update(1)
976 global_progress_bar.update(1) 987 global_progress_bar.update(1)
977 988