diff options
| -rw-r--r-- | data/csv.py | 2 | ||||
| -rw-r--r-- | dreambooth.py | 71 | ||||
| -rw-r--r-- | infer.py | 21 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 33 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_ancestral_discrete.py | 261 | ||||
| -rw-r--r-- | textual_inversion.py | 13 | ||||
| -rw-r--r-- | training/optimization.py | 2 | 
7 files changed, 87 insertions, 316 deletions
| diff --git a/data/csv.py b/data/csv.py index 793fbf8..67ac43b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -93,7 +93,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 93 | items = [item for item in items if not "skip" in item or item["skip"] != True] | 93 | items = [item for item in items if not "skip" in item or item["skip"] != True] | 
| 94 | num_images = len(items) | 94 | num_images = len(items) | 
| 95 | 95 | ||
| 96 | valid_set_size = int(num_images * 0.2) | 96 | valid_set_size = int(num_images * 0.1) | 
| 97 | if self.valid_set_size: | 97 | if self.valid_set_size: | 
| 98 | valid_set_size = min(valid_set_size, self.valid_set_size) | 98 | valid_set_size = min(valid_set_size, self.valid_set_size) | 
| 99 | valid_set_size = max(valid_set_size, 1) | 99 | valid_set_size = max(valid_set_size, 1) | 
| diff --git a/dreambooth.py b/dreambooth.py index 8c4bf50..7b34fce 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -15,7 +15,7 @@ import torch.utils.checkpoint | |||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator | 
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger | 
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed | 
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel | 
| 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 
| 20 | from diffusers.training_utils import EMAModel | 20 | from diffusers.training_utils import EMAModel | 
| 21 | from PIL import Image | 21 | from PIL import Image | 
| @@ -23,7 +23,6 @@ from tqdm.auto import tqdm | |||
| 23 | from transformers import CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPTextModel, CLIPTokenizer | 
| 24 | from slugify import slugify | 24 | from slugify import slugify | 
| 25 | 25 | ||
| 26 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
| 27 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 
| 28 | from data.csv import CSVDataModule | 27 | from data.csv import CSVDataModule | 
| 29 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule | 
| @@ -144,7 +143,7 @@ def parse_args(): | |||
| 144 | parser.add_argument( | 143 | parser.add_argument( | 
| 145 | "--max_train_steps", | 144 | "--max_train_steps", | 
| 146 | type=int, | 145 | type=int, | 
| 147 | default=6000, | 146 | default=None, | 
| 148 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 147 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 
| 149 | ) | 148 | ) | 
| 150 | parser.add_argument( | 149 | parser.add_argument( | 
| @@ -211,7 +210,7 @@ def parse_args(): | |||
| 211 | parser.add_argument( | 210 | parser.add_argument( | 
| 212 | "--ema_power", | 211 | "--ema_power", | 
| 213 | type=float, | 212 | type=float, | 
| 214 | default=7 / 8 | 213 | default=6/7 | 
| 215 | ) | 214 | ) | 
| 216 | parser.add_argument( | 215 | parser.add_argument( | 
| 217 | "--ema_max_decay", | 216 | "--ema_max_decay", | 
| @@ -284,6 +283,12 @@ def parse_args(): | |||
| 284 | help="Number of samples to generate per batch", | 283 | help="Number of samples to generate per batch", | 
| 285 | ) | 284 | ) | 
| 286 | parser.add_argument( | 285 | parser.add_argument( | 
| 286 | "--valid_set_size", | ||
| 287 | type=int, | ||
| 288 | default=None, | ||
| 289 | help="Number of images in the validation dataset." | ||
| 290 | ) | ||
| 291 | parser.add_argument( | ||
| 287 | "--train_batch_size", | 292 | "--train_batch_size", | 
| 288 | type=int, | 293 | type=int, | 
| 289 | default=1, | 294 | default=1, | 
| @@ -292,7 +297,7 @@ def parse_args(): | |||
| 292 | parser.add_argument( | 297 | parser.add_argument( | 
| 293 | "--sample_steps", | 298 | "--sample_steps", | 
| 294 | type=int, | 299 | type=int, | 
| 295 | default=30, | 300 | default=25, | 
| 296 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 
| 297 | ) | 302 | ) | 
| 298 | parser.add_argument( | 303 | parser.add_argument( | 
| @@ -461,7 +466,7 @@ class Checkpointer: | |||
| 461 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 466 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | 
| 462 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 467 | unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 
| 463 | 468 | ||
| 464 | scheduler = EulerAncestralDiscreteScheduler( | 469 | scheduler = DPMSolverMultistepScheduler( | 
| 465 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 470 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 
| 466 | ) | 471 | ) | 
| 467 | 472 | ||
| @@ -487,23 +492,30 @@ class Checkpointer: | |||
| 487 | with torch.inference_mode(): | 492 | with torch.inference_mode(): | 
| 488 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 493 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 
| 489 | all_samples = [] | 494 | all_samples = [] | 
| 490 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 495 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 
| 491 | file_path.parent.mkdir(parents=True, exist_ok=True) | 496 | file_path.parent.mkdir(parents=True, exist_ok=True) | 
| 492 | 497 | ||
| 493 | data_enum = enumerate(data) | 498 | data_enum = enumerate(data) | 
| 494 | 499 | ||
| 500 | batches = [ | ||
| 501 | batch | ||
| 502 | for j, batch in data_enum | ||
| 503 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
| 504 | ] | ||
| 505 | prompts = [ | ||
| 506 | prompt.format(identifier=self.instance_identifier) | ||
| 507 | for batch in batches | ||
| 508 | for prompt in batch["prompts"] | ||
| 509 | ] | ||
| 510 | nprompts = [ | ||
| 511 | prompt | ||
| 512 | for batch in batches | ||
| 513 | for prompt in batch["nprompts"] | ||
| 514 | ] | ||
| 515 | |||
| 495 | for i in range(self.sample_batches): | 516 | for i in range(self.sample_batches): | 
| 496 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 517 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 
| 497 | prompt = [ | 518 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 
| 498 | prompt.format(identifier=self.instance_identifier) | ||
| 499 | for batch in batches | ||
| 500 | for prompt in batch["prompts"] | ||
| 501 | ][:self.sample_batch_size] | ||
| 502 | nprompt = [ | ||
| 503 | prompt | ||
| 504 | for batch in batches | ||
| 505 | for prompt in batch["nprompts"] | ||
| 506 | ][:self.sample_batch_size] | ||
| 507 | 519 | ||
| 508 | samples = pipeline( | 520 | samples = pipeline( | 
| 509 | prompt=prompt, | 521 | prompt=prompt, | 
| @@ -523,7 +535,7 @@ class Checkpointer: | |||
| 523 | del samples | 535 | del samples | 
| 524 | 536 | ||
| 525 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 537 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 
| 526 | image_grid.save(file_path) | 538 | image_grid.save(file_path, quality=85) | 
| 527 | 539 | ||
| 528 | del all_samples | 540 | del all_samples | 
| 529 | del image_grid | 541 | del image_grid | 
| @@ -576,6 +588,12 @@ def main(): | |||
| 576 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 588 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 
| 577 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 589 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 
| 578 | 590 | ||
| 591 | unet.set_use_memory_efficient_attention_xformers(True) | ||
| 592 | |||
| 593 | if args.gradient_checkpointing: | ||
| 594 | unet.enable_gradient_checkpointing() | ||
| 595 | text_encoder.gradient_checkpointing_enable() | ||
| 596 | |||
| 579 | ema_unet = None | 597 | ema_unet = None | 
| 580 | if args.use_ema: | 598 | if args.use_ema: | 
| 581 | ema_unet = EMAModel( | 599 | ema_unet = EMAModel( | 
| @@ -586,12 +604,6 @@ def main(): | |||
| 586 | device=accelerator.device | 604 | device=accelerator.device | 
| 587 | ) | 605 | ) | 
| 588 | 606 | ||
| 589 | unet.set_use_memory_efficient_attention_xformers(True) | ||
| 590 | |||
| 591 | if args.gradient_checkpointing: | ||
| 592 | unet.enable_gradient_checkpointing() | ||
| 593 | text_encoder.gradient_checkpointing_enable() | ||
| 594 | |||
| 595 | # Freeze text_encoder and vae | 607 | # Freeze text_encoder and vae | 
| 596 | freeze_params(vae.parameters()) | 608 | freeze_params(vae.parameters()) | 
| 597 | 609 | ||
| @@ -726,7 +738,7 @@ def main(): | |||
| 726 | size=args.resolution, | 738 | size=args.resolution, | 
| 727 | repeats=args.repeats, | 739 | repeats=args.repeats, | 
| 728 | center_crop=args.center_crop, | 740 | center_crop=args.center_crop, | 
| 729 | valid_set_size=args.sample_batch_size*args.sample_batches, | 741 | valid_set_size=args.valid_set_size, | 
| 730 | num_workers=args.dataloader_num_workers, | 742 | num_workers=args.dataloader_num_workers, | 
| 731 | collate_fn=collate_fn | 743 | collate_fn=collate_fn | 
| 732 | ) | 744 | ) | 
| @@ -743,7 +755,7 @@ def main(): | |||
| 743 | for i in range(0, len(missing_data), args.sample_batch_size) | 755 | for i in range(0, len(missing_data), args.sample_batch_size) | 
| 744 | ] | 756 | ] | 
| 745 | 757 | ||
| 746 | scheduler = EulerAncestralDiscreteScheduler( | 758 | scheduler = DPMSolverMultistepScheduler( | 
| 747 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 759 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 
| 748 | ) | 760 | ) | 
| 749 | 761 | ||
| @@ -962,6 +974,8 @@ def main(): | |||
| 962 | optimizer.step() | 974 | optimizer.step() | 
| 963 | if not accelerator.optimizer_step_was_skipped: | 975 | if not accelerator.optimizer_step_was_skipped: | 
| 964 | lr_scheduler.step() | 976 | lr_scheduler.step() | 
| 977 | if args.use_ema: | ||
| 978 | ema_unet.step(unet) | ||
| 965 | optimizer.zero_grad(set_to_none=True) | 979 | optimizer.zero_grad(set_to_none=True) | 
| 966 | 980 | ||
| 967 | loss = loss.detach().item() | 981 | loss = loss.detach().item() | 
| @@ -969,9 +983,6 @@ def main(): | |||
| 969 | 983 | ||
| 970 | # Checks if the accelerator has performed an optimization step behind the scenes | 984 | # Checks if the accelerator has performed an optimization step behind the scenes | 
| 971 | if accelerator.sync_gradients: | 985 | if accelerator.sync_gradients: | 
| 972 | if args.use_ema: | ||
| 973 | ema_unet.step(unet) | ||
| 974 | |||
| 975 | local_progress_bar.update(1) | 986 | local_progress_bar.update(1) | 
| 976 | global_progress_bar.update(1) | 987 | global_progress_bar.update(1) | 
| 977 | 988 | ||
| @@ -8,11 +8,10 @@ from pathlib import Path | |||
| 8 | import torch | 8 | import torch | 
| 9 | import json | 9 | import json | 
| 10 | from PIL import Image | 10 | from PIL import Image | 
| 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler | 
| 12 | from transformers import CLIPTextModel, CLIPTokenizer | 12 | from transformers import CLIPTextModel, CLIPTokenizer | 
| 13 | from slugify import slugify | 13 | from slugify import slugify | 
| 14 | 14 | ||
| 15 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
| 16 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 
| 17 | 16 | ||
| 18 | 17 | ||
| @@ -21,7 +20,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
| 21 | 20 | ||
| 22 | default_args = { | 21 | default_args = { | 
| 23 | "model": None, | 22 | "model": None, | 
| 24 | "scheduler": "euler_a", | 23 | "scheduler": "dpmpp", | 
| 25 | "precision": "fp32", | 24 | "precision": "fp32", | 
| 26 | "ti_embeddings_dir": "embeddings_ti", | 25 | "ti_embeddings_dir": "embeddings_ti", | 
| 27 | "output_dir": "output/inference", | 26 | "output_dir": "output/inference", | 
| @@ -65,7 +64,7 @@ def create_args_parser(): | |||
| 65 | parser.add_argument( | 64 | parser.add_argument( | 
| 66 | "--scheduler", | 65 | "--scheduler", | 
| 67 | type=str, | 66 | type=str, | 
| 68 | choices=["plms", "ddim", "klms", "euler_a"], | 67 | choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], | 
| 69 | ) | 68 | ) | 
| 70 | parser.add_argument( | 69 | parser.add_argument( | 
| 71 | "--precision", | 70 | "--precision", | 
| @@ -222,6 +221,10 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): | |||
| 222 | scheduler = DDIMScheduler( | 221 | scheduler = DDIMScheduler( | 
| 223 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 222 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False | 
| 224 | ) | 223 | ) | 
| 224 | elif scheduler == "dpmpp": | ||
| 225 | scheduler = DPMSolverMultistepScheduler( | ||
| 226 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 227 | ) | ||
| 225 | else: | 228 | else: | 
| 226 | scheduler = EulerAncestralDiscreteScheduler( | 229 | scheduler = EulerAncestralDiscreteScheduler( | 
| 227 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 230 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 
| @@ -282,7 +285,8 @@ def generate(output_dir, pipeline, args): | |||
| 282 | ).images | 285 | ).images | 
| 283 | 286 | ||
| 284 | for j, image in enumerate(images): | 287 | for j, image in enumerate(images): | 
| 285 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) | 288 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) | 
| 289 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) | ||
| 286 | 290 | ||
| 287 | if torch.cuda.is_available(): | 291 | if torch.cuda.is_available(): | 
| 288 | torch.cuda.empty_cache() | 292 | torch.cuda.empty_cache() | 
| @@ -312,15 +316,16 @@ class CmdParse(cmd.Cmd): | |||
| 312 | 316 | ||
| 313 | try: | 317 | try: | 
| 314 | args = run_parser(self.parser, default_cmds, elements) | 318 | args = run_parser(self.parser, default_cmds, elements) | 
| 319 | |||
| 320 | if len(args.prompt) == 0: | ||
| 321 | print('Try again with a prompt!') | ||
| 322 | return | ||
| 315 | except SystemExit: | 323 | except SystemExit: | 
| 316 | self.parser.print_help() | 324 | self.parser.print_help() | 
| 317 | except Exception as e: | 325 | except Exception as e: | 
| 318 | print(e) | 326 | print(e) | 
| 319 | return | 327 | return | 
| 320 | 328 | ||
| 321 | if len(args.prompt) == 0: | ||
| 322 | print('Try again with a prompt!') | ||
| 323 | |||
| 324 | try: | 329 | try: | 
| 325 | generate(self.output_dir, self.pipeline, args) | 330 | generate(self.output_dir, self.pipeline, args) | 
| 326 | except KeyboardInterrupt: | 331 | except KeyboardInterrupt: | 
| diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 36942f0..ba057ba 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -8,11 +8,20 @@ import PIL | |||
| 8 | 8 | ||
| 9 | from diffusers.configuration_utils import FrozenDict | 9 | from diffusers.configuration_utils import FrozenDict | 
| 10 | from diffusers.utils import is_accelerate_available | 10 | from diffusers.utils import is_accelerate_available | 
| 11 | from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel | 11 | from diffusers import ( | 
| 12 | AutoencoderKL, | ||
| 13 | DiffusionPipeline, | ||
| 14 | UNet2DConditionModel, | ||
| 15 | DDIMScheduler, | ||
| 16 | DPMSolverMultistepScheduler, | ||
| 17 | EulerAncestralDiscreteScheduler, | ||
| 18 | EulerDiscreteScheduler, | ||
| 19 | LMSDiscreteScheduler, | ||
| 20 | PNDMScheduler, | ||
| 21 | ) | ||
| 12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 
| 13 | from diffusers.utils import logging | 23 | from diffusers.utils import logging | 
| 14 | from transformers import CLIPTextModel, CLIPTokenizer | 24 | from transformers import CLIPTextModel, CLIPTokenizer | 
| 15 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
| 16 | from models.clip.prompt import PromptProcessor | 25 | from models.clip.prompt import PromptProcessor | 
| 17 | 26 | ||
| 18 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 
| @@ -33,7 +42,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 33 | text_encoder: CLIPTextModel, | 42 | text_encoder: CLIPTextModel, | 
| 34 | tokenizer: CLIPTokenizer, | 43 | tokenizer: CLIPTokenizer, | 
| 35 | unet: UNet2DConditionModel, | 44 | unet: UNet2DConditionModel, | 
| 36 | scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], | 45 | scheduler: Union[ | 
| 46 | DDIMScheduler, | ||
| 47 | PNDMScheduler, | ||
| 48 | LMSDiscreteScheduler, | ||
| 49 | EulerDiscreteScheduler, | ||
| 50 | EulerAncestralDiscreteScheduler, | ||
| 51 | DPMSolverMultistepScheduler, | ||
| 52 | ], | ||
| 37 | **kwargs, | 53 | **kwargs, | 
| 38 | ): | 54 | ): | 
| 39 | super().__init__() | 55 | super().__init__() | 
| @@ -252,19 +268,14 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 252 | latents = 0.18215 * latents | 268 | latents = 0.18215 * latents | 
| 253 | 269 | ||
| 254 | # expand init_latents for batch_size | 270 | # expand init_latents for batch_size | 
| 255 | latents = torch.cat([latents] * batch_size) | 271 | latents = torch.cat([latents] * batch_size, dim=0) | 
| 256 | 272 | ||
| 257 | # get the original timestep using init_timestep | 273 | # get the original timestep using init_timestep | 
| 258 | init_timestep = int(num_inference_steps * strength) + offset | 274 | init_timestep = int(num_inference_steps * strength) + offset | 
| 259 | init_timestep = min(init_timestep, num_inference_steps) | 275 | init_timestep = min(init_timestep, num_inference_steps) | 
| 260 | 276 | ||
| 261 | if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): | 277 | timesteps = self.scheduler.timesteps[-init_timestep] | 
| 262 | timesteps = torch.tensor( | 278 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | 
| 263 | [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device | ||
| 264 | ) | ||
| 265 | else: | ||
| 266 | timesteps = self.scheduler.timesteps[-init_timestep] | ||
| 267 | timesteps = torch.tensor([timesteps] * batch_size, device=self.device) | ||
| 268 | 279 | ||
| 269 | # add noise to latents using the timesteps | 280 | # add noise to latents using the timesteps | 
| 270 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) | 281 | noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) | 
| diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py deleted file mode 100644 index cef50fe..0000000 --- a/schedulers/scheduling_euler_ancestral_discrete.py +++ /dev/null | |||
| @@ -1,261 +0,0 @@ | |||
| 1 | # Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. | ||
| 2 | # | ||
| 3 | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| 4 | # you may not use this file except in compliance with the License. | ||
| 5 | # You may obtain a copy of the License at | ||
| 6 | # | ||
| 7 | # http://www.apache.org/licenses/LICENSE-2.0 | ||
| 8 | # | ||
| 9 | # Unless required by applicable law or agreed to in writing, software | ||
| 10 | # distributed under the License is distributed on an "AS IS" BASIS, | ||
| 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| 12 | # See the License for the specific language governing permissions and | ||
| 13 | # limitations under the License. | ||
| 14 | |||
| 15 | from dataclasses import dataclass | ||
| 16 | from typing import Optional, Tuple, Union | ||
| 17 | |||
| 18 | import numpy as np | ||
| 19 | import torch | ||
| 20 | |||
| 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config | ||
| 22 | from diffusers.utils import BaseOutput, deprecate, logging | ||
| 23 | from diffusers.schedulers.scheduling_utils import SchedulerMixin | ||
| 24 | |||
| 25 | |||
| 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
| 27 | |||
| 28 | |||
| 29 | @dataclass | ||
| 30 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete | ||
| 31 | class EulerAncestralDiscreteSchedulerOutput(BaseOutput): | ||
| 32 | """ | ||
| 33 | Output class for the scheduler's step function output. | ||
| 34 | |||
| 35 | Args: | ||
| 36 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): | ||
| 37 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the | ||
| 38 | denoising loop. | ||
| 39 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): | ||
| 40 | The predicted denoised sample (x_{0}) based on the model output from the current timestep. | ||
| 41 | `pred_original_sample` can be used to preview progress or for guidance. | ||
| 42 | """ | ||
| 43 | |||
| 44 | prev_sample: torch.FloatTensor | ||
| 45 | pred_original_sample: Optional[torch.FloatTensor] = None | ||
| 46 | |||
| 47 | |||
| 48 | class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): | ||
| 49 | """ | ||
| 50 | Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: | ||
| 51 | https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 | ||
| 52 | |||
| 53 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | ||
| 54 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | ||
| 55 | [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and | ||
| 56 | [`~ConfigMixin.from_config`] functions. | ||
| 57 | |||
| 58 | Args: | ||
| 59 | num_train_timesteps (`int`): number of diffusion steps used to train the model. | ||
| 60 | beta_start (`float`): the starting `beta` value of inference. | ||
| 61 | beta_end (`float`): the final `beta` value. | ||
| 62 | beta_schedule (`str`): | ||
| 63 | the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | ||
| 64 | `linear` or `scaled_linear`. | ||
| 65 | trained_betas (`np.ndarray`, optional): | ||
| 66 | option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. | ||
| 67 | |||
| 68 | """ | ||
| 69 | |||
| 70 | @register_to_config | ||
| 71 | def __init__( | ||
| 72 | self, | ||
| 73 | num_train_timesteps: int = 1000, | ||
| 74 | beta_start: float = 0.0001, | ||
| 75 | beta_end: float = 0.02, | ||
| 76 | beta_schedule: str = "linear", | ||
| 77 | trained_betas: Optional[np.ndarray] = None, | ||
| 78 | ): | ||
| 79 | if trained_betas is not None: | ||
| 80 | self.betas = torch.from_numpy(trained_betas) | ||
| 81 | elif beta_schedule == "linear": | ||
| 82 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | ||
| 83 | elif beta_schedule == "scaled_linear": | ||
| 84 | # this schedule is very specific to the latent diffusion model. | ||
| 85 | self.betas = ( | ||
| 86 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | ||
| 87 | ) | ||
| 88 | else: | ||
| 89 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | ||
| 90 | |||
| 91 | self.alphas = 1.0 - self.betas | ||
| 92 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | ||
| 93 | |||
| 94 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
| 95 | sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) | ||
| 96 | self.sigmas = torch.from_numpy(sigmas) | ||
| 97 | |||
| 98 | # standard deviation of the initial noise distribution | ||
| 99 | self.init_noise_sigma = self.sigmas.max() | ||
| 100 | |||
| 101 | # setable values | ||
| 102 | self.num_inference_steps = None | ||
| 103 | timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() | ||
| 104 | self.timesteps = torch.from_numpy(timesteps) | ||
| 105 | self.is_scale_input_called = False | ||
| 106 | |||
| 107 | def scale_model_input( | ||
| 108 | self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] | ||
| 109 | ) -> torch.FloatTensor: | ||
| 110 | """ | ||
| 111 | Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. | ||
| 112 | |||
| 113 | Args: | ||
| 114 | sample (`torch.FloatTensor`): input sample | ||
| 115 | timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain | ||
| 116 | |||
| 117 | Returns: | ||
| 118 | `torch.FloatTensor`: scaled input sample | ||
| 119 | """ | ||
| 120 | if isinstance(timestep, torch.Tensor): | ||
| 121 | timestep = timestep.to(self.timesteps.device) | ||
| 122 | step_index = (self.timesteps == timestep).nonzero().item() | ||
| 123 | sigma = self.sigmas[step_index] | ||
| 124 | sample = sample / ((sigma**2 + 1) ** 0.5) | ||
| 125 | self.is_scale_input_called = True | ||
| 126 | return sample | ||
| 127 | |||
| 128 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): | ||
| 129 | """ | ||
| 130 | Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. | ||
| 131 | |||
| 132 | Args: | ||
| 133 | num_inference_steps (`int`): | ||
| 134 | the number of diffusion steps used when generating samples with a pre-trained model. | ||
| 135 | device (`str` or `torch.device`, optional): | ||
| 136 | the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | ||
| 137 | """ | ||
| 138 | self.num_inference_steps = num_inference_steps | ||
| 139 | |||
| 140 | timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() | ||
| 141 | sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) | ||
| 142 | sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) | ||
| 143 | sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) | ||
| 144 | self.sigmas = torch.from_numpy(sigmas).to(device=device) | ||
| 145 | self.timesteps = torch.from_numpy(timesteps).to(device=device) | ||
| 146 | |||
| 147 | def step( | ||
| 148 | self, | ||
| 149 | model_output: torch.FloatTensor, | ||
| 150 | timestep: Union[float, torch.FloatTensor], | ||
| 151 | sample: torch.FloatTensor, | ||
| 152 | generator: Optional[torch.Generator] = None, | ||
| 153 | return_dict: bool = True, | ||
| 154 | ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: | ||
| 155 | """ | ||
| 156 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||
| 157 | process from the learned model outputs (most often the predicted noise). | ||
| 158 | |||
| 159 | Args: | ||
| 160 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. | ||
| 161 | timestep (`float`): current timestep in the diffusion chain. | ||
| 162 | sample (`torch.FloatTensor`): | ||
| 163 | current instance of sample being created by diffusion process. | ||
| 164 | generator (`torch.Generator`, optional): Random number generator. | ||
| 165 | return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class | ||
| 166 | |||
| 167 | Returns: | ||
| 168 | [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: | ||
| 169 | [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise | ||
| 170 | a `tuple`. When returning a tuple, the first element is the sample tensor. | ||
| 171 | |||
| 172 | """ | ||
| 173 | |||
| 174 | if ( | ||
| 175 | isinstance(timestep, int) | ||
| 176 | or isinstance(timestep, torch.IntTensor) | ||
| 177 | or isinstance(timestep, torch.LongTensor) | ||
| 178 | ): | ||
| 179 | raise ValueError( | ||
| 180 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | ||
| 181 | " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" | ||
| 182 | " one of the `scheduler.timesteps` as a timestep.", | ||
| 183 | ) | ||
| 184 | |||
| 185 | if not self.is_scale_input_called: | ||
| 186 | logger.warn( | ||
| 187 | "The `scale_model_input` function should be called before `step` to ensure correct denoising. " | ||
| 188 | "See `StableDiffusionPipeline` for a usage example." | ||
| 189 | ) | ||
| 190 | |||
| 191 | if isinstance(timestep, torch.Tensor): | ||
| 192 | timestep = timestep.to(self.timesteps.device) | ||
| 193 | |||
| 194 | step_index = (self.timesteps == timestep).nonzero().item() | ||
| 195 | sigma = self.sigmas[step_index] | ||
| 196 | |||
| 197 | # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise | ||
| 198 | pred_original_sample = sample - sigma * model_output | ||
| 199 | sigma_from = self.sigmas[step_index] | ||
| 200 | sigma_to = self.sigmas[step_index + 1] | ||
| 201 | sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 | ||
| 202 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | ||
| 203 | |||
| 204 | # 2. Convert to an ODE derivative | ||
| 205 | derivative = (sample - pred_original_sample) / sigma | ||
| 206 | |||
| 207 | dt = sigma_down - sigma | ||
| 208 | |||
| 209 | prev_sample = sample + derivative * dt | ||
| 210 | |||
| 211 | device = model_output.device if torch.is_tensor(model_output) else "cpu" | ||
| 212 | noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) | ||
| 213 | prev_sample = prev_sample + noise * sigma_up | ||
| 214 | |||
| 215 | if not return_dict: | ||
| 216 | return (prev_sample,) | ||
| 217 | |||
| 218 | return EulerAncestralDiscreteSchedulerOutput( | ||
| 219 | prev_sample=prev_sample, pred_original_sample=pred_original_sample | ||
| 220 | ) | ||
| 221 | |||
| 222 | def add_noise( | ||
| 223 | self, | ||
| 224 | original_samples: torch.FloatTensor, | ||
| 225 | noise: torch.FloatTensor, | ||
| 226 | timesteps: torch.FloatTensor, | ||
| 227 | ) -> torch.FloatTensor: | ||
| 228 | # Make sure sigmas and timesteps have the same device and dtype as original_samples | ||
| 229 | self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) | ||
| 230 | if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): | ||
| 231 | # mps does not support float64 | ||
| 232 | self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) | ||
| 233 | timesteps = timesteps.to(original_samples.device, dtype=torch.float32) | ||
| 234 | else: | ||
| 235 | self.timesteps = self.timesteps.to(original_samples.device) | ||
| 236 | timesteps = timesteps.to(original_samples.device) | ||
| 237 | |||
| 238 | schedule_timesteps = self.timesteps | ||
| 239 | |||
| 240 | if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): | ||
| 241 | deprecate( | ||
| 242 | "timesteps as indices", | ||
| 243 | "0.8.0", | ||
| 244 | "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" | ||
| 245 | " `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to" | ||
| 246 | " pass values from `scheduler.timesteps` as timesteps.", | ||
| 247 | standard_warn=False, | ||
| 248 | ) | ||
| 249 | step_indices = timesteps | ||
| 250 | else: | ||
| 251 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | ||
| 252 | |||
| 253 | sigma = self.sigmas[step_indices].flatten() | ||
| 254 | while len(sigma.shape) < len(original_samples.shape): | ||
| 255 | sigma = sigma.unsqueeze(-1) | ||
| 256 | |||
| 257 | noisy_samples = original_samples + noise * sigma | ||
| 258 | return noisy_samples | ||
| 259 | |||
| 260 | def __len__(self): | ||
| 261 | return self.config.num_train_timesteps | ||
| diff --git a/textual_inversion.py b/textual_inversion.py index 578c054..999161b 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -15,14 +15,13 @@ import torch.utils.checkpoint | |||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator | 
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger | 
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed | 
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel | 
| 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 19 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 
| 20 | from PIL import Image | 20 | from PIL import Image | 
| 21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm | 
| 22 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer | 
| 23 | from slugify import slugify | 23 | from slugify import slugify | 
| 24 | 24 | ||
| 25 | from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 
| 27 | from data.csv import CSVDataModule | 26 | from data.csv import CSVDataModule | 
| 28 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule | 
| @@ -134,7 +133,7 @@ def parse_args(): | |||
| 134 | parser.add_argument( | 133 | parser.add_argument( | 
| 135 | "--max_train_steps", | 134 | "--max_train_steps", | 
| 136 | type=int, | 135 | type=int, | 
| 137 | default=10000, | 136 | default=None, | 
| 138 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 137 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 
| 139 | ) | 138 | ) | 
| 140 | parser.add_argument( | 139 | parser.add_argument( | 
| @@ -252,6 +251,12 @@ def parse_args(): | |||
| 252 | help="Number of samples to generate per batch", | 251 | help="Number of samples to generate per batch", | 
| 253 | ) | 252 | ) | 
| 254 | parser.add_argument( | 253 | parser.add_argument( | 
| 254 | "--valid_set_size", | ||
| 255 | type=int, | ||
| 256 | default=None, | ||
| 257 | help="Number of images in the validation dataset." | ||
| 258 | ) | ||
| 259 | parser.add_argument( | ||
| 255 | "--train_batch_size", | 260 | "--train_batch_size", | 
| 256 | type=int, | 261 | type=int, | 
| 257 | default=1, | 262 | default=1, | 
| @@ -637,7 +642,7 @@ def main(): | |||
| 637 | size=args.resolution, | 642 | size=args.resolution, | 
| 638 | repeats=args.repeats, | 643 | repeats=args.repeats, | 
| 639 | center_crop=args.center_crop, | 644 | center_crop=args.center_crop, | 
| 640 | valid_set_size=args.sample_batch_size*args.sample_batches, | 645 | valid_set_size=args.valid_set_size, | 
| 641 | num_workers=args.dataloader_num_workers, | 646 | num_workers=args.dataloader_num_workers, | 
| 642 | collate_fn=collate_fn | 647 | collate_fn=collate_fn | 
| 643 | ) | 648 | ) | 
| diff --git a/training/optimization.py b/training/optimization.py index 0fd7ec8..0e603fa 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -6,7 +6,7 @@ from diffusers.utils import logging | |||
| 6 | logger = logging.get_logger(__name__) | 6 | logger = logging.get_logger(__name__) | 
| 7 | 7 | ||
| 8 | 8 | ||
| 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.43, last_epoch=-1): | 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.4, last_epoch=-1): | 
| 10 | """ | 10 | """ | 
| 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 
| 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 
