diff options
| author | Volpeon <git@volpeon.ink> | 2022-11-14 17:09:58 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-11-14 17:09:58 +0100 |
| commit | 2ad46871e2ead985445da2848a4eb7072b6e48aa (patch) | |
| tree | 3137923e2c00fe1d3cd37ddcc93c8a847b0c0762 /dreambooth.py | |
| parent | Update (diff) | |
| download | textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.tar.gz textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.tar.bz2 textual-inversion-diff-2ad46871e2ead985445da2848a4eb7072b6e48aa.zip | |
Update
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 71 |
1 files changed, 41 insertions, 30 deletions
diff --git a/dreambooth.py b/dreambooth.py index 8c4bf50..7b34fce 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -15,7 +15,7 @@ import torch.utils.checkpoint | |||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel |
| 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 20 | from diffusers.training_utils import EMAModel | 20 | from diffusers.training_utils import EMAModel |
| 21 | from PIL import Image | 21 | from PIL import Image |
| @@ -23,7 +23,6 @@ from tqdm.auto import tqdm | |||
| 23 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
| 24 | from slugify import slugify | 24 | from slugify import slugify |
| 25 | 25 | ||
| 26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
| 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 28 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
| 29 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
| @@ -144,7 +143,7 @@ def parse_args(): | |||
| 144 | parser.add_argument( | 143 | parser.add_argument( |
| 145 | "--max_train_steps", | 144 | "--max_train_steps", |
| 146 | type=int, | 145 | type=int, |
| 147 | default=6000, | 146 | default=None, |
| 148 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 147 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 149 | ) | 148 | ) |
| 150 | parser.add_argument( | 149 | parser.add_argument( |
| @@ -211,7 +210,7 @@ def parse_args(): | |||
| 211 | parser.add_argument( | 210 | parser.add_argument( |
| 212 | "--ema_power", | 211 | "--ema_power", |
| 213 | type=float, | 212 | type=float, |
| 214 | default=7 / 8 | 213 | default=6/7 |
| 215 | ) | 214 | ) |
| 216 | parser.add_argument( | 215 | parser.add_argument( |
| 217 | "--ema_max_decay", | 216 | "--ema_max_decay", |
| @@ -284,6 +283,12 @@ def parse_args(): | |||
| 284 | help="Number of samples to generate per batch", | 283 | help="Number of samples to generate per batch", |
| 285 | ) | 284 | ) |
| 286 | parser.add_argument( | 285 | parser.add_argument( |
| 286 | "--valid_set_size", | ||
| 287 | type=int, | ||
| 288 | default=None, | ||
| 289 | help="Number of images in the validation dataset." | ||
| 290 | ) | ||
| 291 | parser.add_argument( | ||
| 287 | "--train_batch_size", | 292 | "--train_batch_size", |
| 288 | type=int, | 293 | type=int, |
| 289 | default=1, | 294 | default=1, |
| @@ -292,7 +297,7 @@ def parse_args(): | |||
| 292 | parser.add_argument( | 297 | parser.add_argument( |
| 293 | "--sample_steps", | 298 | "--sample_steps", |
| 294 | type=int, | 299 | type=int, |
| 295 | default=30, | 300 | default=25, |
| 296 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 297 | ) | 302 | ) |
| 298 | parser.add_argument( | 303 | parser.add_argument( |
| @@ -461,7 +466,7 @@ class Checkpointer: | |||
| 461 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 466 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) |
| 462 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 467 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 463 | 468 | ||
| 464 | scheduler = EulerAncestralDiscreteScheduler( | 469 | scheduler = DPMSolverMultistepScheduler( |
| 465 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 470 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 466 | ) | 471 | ) |
| 467 | 472 | ||
| @@ -487,23 +492,30 @@ class Checkpointer: | |||
| 487 | with torch.inference_mode(): | 492 | with torch.inference_mode(): |
| 488 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 493 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 489 | all_samples = [] | 494 | all_samples = [] |
| 490 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 495 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
| 491 | file_path.parent.mkdir(parents=True, exist_ok=True) | 496 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 492 | 497 | ||
| 493 | data_enum = enumerate(data) | 498 | data_enum = enumerate(data) |
| 494 | 499 | ||
| 500 | batches = [ | ||
| 501 | batch | ||
| 502 | for j, batch in data_enum | ||
| 503 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
| 504 | ] | ||
| 505 | prompts = [ | ||
| 506 | prompt.format(identifier=self.instance_identifier) | ||
| 507 | for batch in batches | ||
| 508 | for prompt in batch["prompts"] | ||
| 509 | ] | ||
| 510 | nprompts = [ | ||
| 511 | prompt | ||
| 512 | for batch in batches | ||
| 513 | for prompt in batch["nprompts"] | ||
| 514 | ] | ||
| 515 | |||
| 495 | for i in range(self.sample_batches): | 516 | for i in range(self.sample_batches): |
| 496 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 517 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
| 497 | prompt = [ | 518 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
| 498 | prompt.format(identifier=self.instance_identifier) | ||
| 499 | for batch in batches | ||
| 500 | for prompt in batch["prompts"] | ||
| 501 | ][:self.sample_batch_size] | ||
| 502 | nprompt = [ | ||
| 503 | prompt | ||
| 504 | for batch in batches | ||
| 505 | for prompt in batch["nprompts"] | ||
| 506 | ][:self.sample_batch_size] | ||
| 507 | 519 | ||
| 508 | samples = pipeline( | 520 | samples = pipeline( |
| 509 | prompt=prompt, | 521 | prompt=prompt, |
| @@ -523,7 +535,7 @@ class Checkpointer: | |||
| 523 | del samples | 535 | del samples |
| 524 | 536 | ||
| 525 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 537 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
| 526 | image_grid.save(file_path) | 538 | image_grid.save(file_path, quality=85) |
| 527 | 539 | ||
| 528 | del all_samples | 540 | del all_samples |
| 529 | del image_grid | 541 | del image_grid |
| @@ -576,6 +588,12 @@ def main(): | |||
| 576 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 588 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 577 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 589 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
| 578 | 590 | ||
| 591 | unet.set_use_memory_efficient_attention_xformers(True) | ||
| 592 | |||
| 593 | if args.gradient_checkpointing: | ||
| 594 | unet.enable_gradient_checkpointing() | ||
| 595 | text_encoder.gradient_checkpointing_enable() | ||
| 596 | |||
| 579 | ema_unet = None | 597 | ema_unet = None |
| 580 | if args.use_ema: | 598 | if args.use_ema: |
| 581 | ema_unet = EMAModel( | 599 | ema_unet = EMAModel( |
| @@ -586,12 +604,6 @@ def main(): | |||
| 586 | device=accelerator.device | 604 | device=accelerator.device |
| 587 | ) | 605 | ) |
| 588 | 606 | ||
| 589 | unet.set_use_memory_efficient_attention_xformers(True) | ||
| 590 | |||
| 591 | if args.gradient_checkpointing: | ||
| 592 | unet.enable_gradient_checkpointing() | ||
| 593 | text_encoder.gradient_checkpointing_enable() | ||
| 594 | |||
| 595 | # Freeze text_encoder and vae | 607 | # Freeze text_encoder and vae |
| 596 | freeze_params(vae.parameters()) | 608 | freeze_params(vae.parameters()) |
| 597 | 609 | ||
| @@ -726,7 +738,7 @@ def main(): | |||
| 726 | size=args.resolution, | 738 | size=args.resolution, |
| 727 | repeats=args.repeats, | 739 | repeats=args.repeats, |
| 728 | center_crop=args.center_crop, | 740 | center_crop=args.center_crop, |
| 729 | valid_set_size=args.sample_batch_size*args.sample_batches, | 741 | valid_set_size=args.valid_set_size, |
| 730 | num_workers=args.dataloader_num_workers, | 742 | num_workers=args.dataloader_num_workers, |
| 731 | collate_fn=collate_fn | 743 | collate_fn=collate_fn |
| 732 | ) | 744 | ) |
| @@ -743,7 +755,7 @@ def main(): | |||
| 743 | for i in range(0, len(missing_data), args.sample_batch_size) | 755 | for i in range(0, len(missing_data), args.sample_batch_size) |
| 744 | ] | 756 | ] |
| 745 | 757 | ||
| 746 | scheduler = EulerAncestralDiscreteScheduler( | 758 | scheduler = DPMSolverMultistepScheduler( |
| 747 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 759 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 748 | ) | 760 | ) |
| 749 | 761 | ||
| @@ -962,6 +974,8 @@ def main(): | |||
| 962 | optimizer.step() | 974 | optimizer.step() |
| 963 | if not accelerator.optimizer_step_was_skipped: | 975 | if not accelerator.optimizer_step_was_skipped: |
| 964 | lr_scheduler.step() | 976 | lr_scheduler.step() |
| 977 | if args.use_ema: | ||
| 978 | ema_unet.step(unet) | ||
| 965 | optimizer.zero_grad(set_to_none=True) | 979 | optimizer.zero_grad(set_to_none=True) |
| 966 | 980 | ||
| 967 | loss = loss.detach().item() | 981 | loss = loss.detach().item() |
| @@ -969,9 +983,6 @@ def main(): | |||
| 969 | 983 | ||
| 970 | # Checks if the accelerator has performed an optimization step behind the scenes | 984 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 971 | if accelerator.sync_gradients: | 985 | if accelerator.sync_gradients: |
| 972 | if args.use_ema: | ||
| 973 | ema_unet.step(unet) | ||
| 974 | |||
| 975 | local_progress_bar.update(1) | 986 | local_progress_bar.update(1) |
| 976 | global_progress_bar.update(1) | 987 | global_progress_bar.update(1) |
| 977 | 988 | ||
