diff options
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 71 |
1 files changed, 41 insertions, 30 deletions
diff --git a/dreambooth.py b/dreambooth.py index 8c4bf50..7b34fce 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -15,7 +15,7 @@ import torch.utils.checkpoint | |||
15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel |
19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
20 | from diffusers.training_utils import EMAModel | 20 | from diffusers.training_utils import EMAModel |
21 | from PIL import Image | 21 | from PIL import Image |
@@ -23,7 +23,6 @@ from tqdm.auto import tqdm | |||
23 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer |
24 | from slugify import slugify | 24 | from slugify import slugify |
25 | 25 | ||
26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
28 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule |
29 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
@@ -144,7 +143,7 @@ def parse_args(): | |||
144 | parser.add_argument( | 143 | parser.add_argument( |
145 | "--max_train_steps", | 144 | "--max_train_steps", |
146 | type=int, | 145 | type=int, |
147 | default=6000, | 146 | default=None, |
148 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 147 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
149 | ) | 148 | ) |
150 | parser.add_argument( | 149 | parser.add_argument( |
@@ -211,7 +210,7 @@ def parse_args(): | |||
211 | parser.add_argument( | 210 | parser.add_argument( |
212 | "--ema_power", | 211 | "--ema_power", |
213 | type=float, | 212 | type=float, |
214 | default=7 / 8 | 213 | default=6/7 |
215 | ) | 214 | ) |
216 | parser.add_argument( | 215 | parser.add_argument( |
217 | "--ema_max_decay", | 216 | "--ema_max_decay", |
@@ -284,6 +283,12 @@ def parse_args(): | |||
284 | help="Number of samples to generate per batch", | 283 | help="Number of samples to generate per batch", |
285 | ) | 284 | ) |
286 | parser.add_argument( | 285 | parser.add_argument( |
286 | "--valid_set_size", | ||
287 | type=int, | ||
288 | default=None, | ||
289 | help="Number of images in the validation dataset." | ||
290 | ) | ||
291 | parser.add_argument( | ||
287 | "--train_batch_size", | 292 | "--train_batch_size", |
288 | type=int, | 293 | type=int, |
289 | default=1, | 294 | default=1, |
@@ -292,7 +297,7 @@ def parse_args(): | |||
292 | parser.add_argument( | 297 | parser.add_argument( |
293 | "--sample_steps", | 298 | "--sample_steps", |
294 | type=int, | 299 | type=int, |
295 | default=30, | 300 | default=25, |
296 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
297 | ) | 302 | ) |
298 | parser.add_argument( | 303 | parser.add_argument( |
@@ -461,7 +466,7 @@ class Checkpointer: | |||
461 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 466 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) |
462 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 467 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
463 | 468 | ||
464 | scheduler = EulerAncestralDiscreteScheduler( | 469 | scheduler = DPMSolverMultistepScheduler( |
465 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 470 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
466 | ) | 471 | ) |
467 | 472 | ||
@@ -487,23 +492,30 @@ class Checkpointer: | |||
487 | with torch.inference_mode(): | 492 | with torch.inference_mode(): |
488 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 493 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
489 | all_samples = [] | 494 | all_samples = [] |
490 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 495 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
491 | file_path.parent.mkdir(parents=True, exist_ok=True) | 496 | file_path.parent.mkdir(parents=True, exist_ok=True) |
492 | 497 | ||
493 | data_enum = enumerate(data) | 498 | data_enum = enumerate(data) |
494 | 499 | ||
500 | batches = [ | ||
501 | batch | ||
502 | for j, batch in data_enum | ||
503 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
504 | ] | ||
505 | prompts = [ | ||
506 | prompt.format(identifier=self.instance_identifier) | ||
507 | for batch in batches | ||
508 | for prompt in batch["prompts"] | ||
509 | ] | ||
510 | nprompts = [ | ||
511 | prompt | ||
512 | for batch in batches | ||
513 | for prompt in batch["nprompts"] | ||
514 | ] | ||
515 | |||
495 | for i in range(self.sample_batches): | 516 | for i in range(self.sample_batches): |
496 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 517 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
497 | prompt = [ | 518 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
498 | prompt.format(identifier=self.instance_identifier) | ||
499 | for batch in batches | ||
500 | for prompt in batch["prompts"] | ||
501 | ][:self.sample_batch_size] | ||
502 | nprompt = [ | ||
503 | prompt | ||
504 | for batch in batches | ||
505 | for prompt in batch["nprompts"] | ||
506 | ][:self.sample_batch_size] | ||
507 | 519 | ||
508 | samples = pipeline( | 520 | samples = pipeline( |
509 | prompt=prompt, | 521 | prompt=prompt, |
@@ -523,7 +535,7 @@ class Checkpointer: | |||
523 | del samples | 535 | del samples |
524 | 536 | ||
525 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 537 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
526 | image_grid.save(file_path) | 538 | image_grid.save(file_path, quality=85) |
527 | 539 | ||
528 | del all_samples | 540 | del all_samples |
529 | del image_grid | 541 | del image_grid |
@@ -576,6 +588,12 @@ def main(): | |||
576 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 588 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
577 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 589 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
578 | 590 | ||
591 | unet.set_use_memory_efficient_attention_xformers(True) | ||
592 | |||
593 | if args.gradient_checkpointing: | ||
594 | unet.enable_gradient_checkpointing() | ||
595 | text_encoder.gradient_checkpointing_enable() | ||
596 | |||
579 | ema_unet = None | 597 | ema_unet = None |
580 | if args.use_ema: | 598 | if args.use_ema: |
581 | ema_unet = EMAModel( | 599 | ema_unet = EMAModel( |
@@ -586,12 +604,6 @@ def main(): | |||
586 | device=accelerator.device | 604 | device=accelerator.device |
587 | ) | 605 | ) |
588 | 606 | ||
589 | unet.set_use_memory_efficient_attention_xformers(True) | ||
590 | |||
591 | if args.gradient_checkpointing: | ||
592 | unet.enable_gradient_checkpointing() | ||
593 | text_encoder.gradient_checkpointing_enable() | ||
594 | |||
595 | # Freeze text_encoder and vae | 607 | # Freeze text_encoder and vae |
596 | freeze_params(vae.parameters()) | 608 | freeze_params(vae.parameters()) |
597 | 609 | ||
@@ -726,7 +738,7 @@ def main(): | |||
726 | size=args.resolution, | 738 | size=args.resolution, |
727 | repeats=args.repeats, | 739 | repeats=args.repeats, |
728 | center_crop=args.center_crop, | 740 | center_crop=args.center_crop, |
729 | valid_set_size=args.sample_batch_size*args.sample_batches, | 741 | valid_set_size=args.valid_set_size, |
730 | num_workers=args.dataloader_num_workers, | 742 | num_workers=args.dataloader_num_workers, |
731 | collate_fn=collate_fn | 743 | collate_fn=collate_fn |
732 | ) | 744 | ) |
@@ -743,7 +755,7 @@ def main(): | |||
743 | for i in range(0, len(missing_data), args.sample_batch_size) | 755 | for i in range(0, len(missing_data), args.sample_batch_size) |
744 | ] | 756 | ] |
745 | 757 | ||
746 | scheduler = EulerAncestralDiscreteScheduler( | 758 | scheduler = DPMSolverMultistepScheduler( |
747 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 759 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
748 | ) | 760 | ) |
749 | 761 | ||
@@ -962,6 +974,8 @@ def main(): | |||
962 | optimizer.step() | 974 | optimizer.step() |
963 | if not accelerator.optimizer_step_was_skipped: | 975 | if not accelerator.optimizer_step_was_skipped: |
964 | lr_scheduler.step() | 976 | lr_scheduler.step() |
977 | if args.use_ema: | ||
978 | ema_unet.step(unet) | ||
965 | optimizer.zero_grad(set_to_none=True) | 979 | optimizer.zero_grad(set_to_none=True) |
966 | 980 | ||
967 | loss = loss.detach().item() | 981 | loss = loss.detach().item() |
@@ -969,9 +983,6 @@ def main(): | |||
969 | 983 | ||
970 | # Checks if the accelerator has performed an optimization step behind the scenes | 984 | # Checks if the accelerator has performed an optimization step behind the scenes |
971 | if accelerator.sync_gradients: | 985 | if accelerator.sync_gradients: |
972 | if args.use_ema: | ||
973 | ema_unet.step(unet) | ||
974 | |||
975 | local_progress_bar.update(1) | 986 | local_progress_bar.update(1) |
976 | global_progress_bar.update(1) | 987 | global_progress_bar.update(1) |
977 | 988 | ||