summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py2
-rw-r--r--dreambooth.py71
-rw-r--r--infer.py21
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py33
-rw-r--r--schedulers/scheduling_euler_ancestral_discrete.py261
-rw-r--r--textual_inversion.py13
-rw-r--r--training/optimization.py2
7 files changed, 87 insertions, 316 deletions
diff --git a/data/csv.py b/data/csv.py
index 793fbf8..67ac43b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -93,7 +93,7 @@ class CSVDataModule(pl.LightningDataModule):
93 items = [item for item in items if not "skip" in item or item["skip"] != True] 93 items = [item for item in items if not "skip" in item or item["skip"] != True]
94 num_images = len(items) 94 num_images = len(items)
95 95
96 valid_set_size = int(num_images * 0.2) 96 valid_set_size = int(num_images * 0.1)
97 if self.valid_set_size: 97 if self.valid_set_size:
98 valid_set_size = min(valid_set_size, self.valid_set_size) 98 valid_set_size = min(valid_set_size, self.valid_set_size)
99 valid_set_size = max(valid_set_size, 1) 99 valid_set_size = max(valid_set_size, 1)
diff --git a/dreambooth.py b/dreambooth.py
index 8c4bf50..7b34fce 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -15,7 +15,7 @@ import torch.utils.checkpoint
15from accelerate import Accelerator 15from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from diffusers.training_utils import EMAModel 20from diffusers.training_utils import EMAModel
21from PIL import Image 21from PIL import Image
@@ -23,7 +23,6 @@ from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer 23from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule 27from data.csv import CSVDataModule
29from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
@@ -144,7 +143,7 @@ def parse_args():
144 parser.add_argument( 143 parser.add_argument(
145 "--max_train_steps", 144 "--max_train_steps",
146 type=int, 145 type=int,
147 default=6000, 146 default=None,
148 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 147 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
149 ) 148 )
150 parser.add_argument( 149 parser.add_argument(
@@ -211,7 +210,7 @@ def parse_args():
211 parser.add_argument( 210 parser.add_argument(
212 "--ema_power", 211 "--ema_power",
213 type=float, 212 type=float,
214 default=7 / 8 213 default=6/7
215 ) 214 )
216 parser.add_argument( 215 parser.add_argument(
217 "--ema_max_decay", 216 "--ema_max_decay",
@@ -284,6 +283,12 @@ def parse_args():
284 help="Number of samples to generate per batch", 283 help="Number of samples to generate per batch",
285 ) 284 )
286 parser.add_argument( 285 parser.add_argument(
286 "--valid_set_size",
287 type=int,
288 default=None,
289 help="Number of images in the validation dataset."
290 )
291 parser.add_argument(
287 "--train_batch_size", 292 "--train_batch_size",
288 type=int, 293 type=int,
289 default=1, 294 default=1,
@@ -292,7 +297,7 @@ def parse_args():
292 parser.add_argument( 297 parser.add_argument(
293 "--sample_steps", 298 "--sample_steps",
294 type=int, 299 type=int,
295 default=30, 300 default=25,
296 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 301 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
297 ) 302 )
298 parser.add_argument( 303 parser.add_argument(
@@ -461,7 +466,7 @@ class Checkpointer:
461 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) 466 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
462 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) 467 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
463 468
464 scheduler = EulerAncestralDiscreteScheduler( 469 scheduler = DPMSolverMultistepScheduler(
465 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 470 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
466 ) 471 )
467 472
@@ -487,23 +492,30 @@ class Checkpointer:
487 with torch.inference_mode(): 492 with torch.inference_mode():
488 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 493 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
489 all_samples = [] 494 all_samples = []
490 file_path = samples_path.joinpath(pool, f"step_{step}.png") 495 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
491 file_path.parent.mkdir(parents=True, exist_ok=True) 496 file_path.parent.mkdir(parents=True, exist_ok=True)
492 497
493 data_enum = enumerate(data) 498 data_enum = enumerate(data)
494 499
500 batches = [
501 batch
502 for j, batch in data_enum
503 if j * data.batch_size < self.sample_batch_size * self.sample_batches
504 ]
505 prompts = [
506 prompt.format(identifier=self.instance_identifier)
507 for batch in batches
508 for prompt in batch["prompts"]
509 ]
510 nprompts = [
511 prompt
512 for batch in batches
513 for prompt in batch["nprompts"]
514 ]
515
495 for i in range(self.sample_batches): 516 for i in range(self.sample_batches):
496 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 517 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
497 prompt = [ 518 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
498 prompt.format(identifier=self.instance_identifier)
499 for batch in batches
500 for prompt in batch["prompts"]
501 ][:self.sample_batch_size]
502 nprompt = [
503 prompt
504 for batch in batches
505 for prompt in batch["nprompts"]
506 ][:self.sample_batch_size]
507 519
508 samples = pipeline( 520 samples = pipeline(
509 prompt=prompt, 521 prompt=prompt,
@@ -523,7 +535,7 @@ class Checkpointer:
523 del samples 535 del samples
524 536
525 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) 537 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
526 image_grid.save(file_path) 538 image_grid.save(file_path, quality=85)
527 539
528 del all_samples 540 del all_samples
529 del image_grid 541 del image_grid
@@ -576,6 +588,12 @@ def main():
576 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 588 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
577 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') 589 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
578 590
591 unet.set_use_memory_efficient_attention_xformers(True)
592
593 if args.gradient_checkpointing:
594 unet.enable_gradient_checkpointing()
595 text_encoder.gradient_checkpointing_enable()
596
579 ema_unet = None 597 ema_unet = None
580 if args.use_ema: 598 if args.use_ema:
581 ema_unet = EMAModel( 599 ema_unet = EMAModel(
@@ -586,12 +604,6 @@ def main():
586 device=accelerator.device 604 device=accelerator.device
587 ) 605 )
588 606
589 unet.set_use_memory_efficient_attention_xformers(True)
590
591 if args.gradient_checkpointing:
592 unet.enable_gradient_checkpointing()
593 text_encoder.gradient_checkpointing_enable()
594
595 # Freeze text_encoder and vae 607 # Freeze text_encoder and vae
596 freeze_params(vae.parameters()) 608 freeze_params(vae.parameters())
597 609
@@ -726,7 +738,7 @@ def main():
726 size=args.resolution, 738 size=args.resolution,
727 repeats=args.repeats, 739 repeats=args.repeats,
728 center_crop=args.center_crop, 740 center_crop=args.center_crop,
729 valid_set_size=args.sample_batch_size*args.sample_batches, 741 valid_set_size=args.valid_set_size,
730 num_workers=args.dataloader_num_workers, 742 num_workers=args.dataloader_num_workers,
731 collate_fn=collate_fn 743 collate_fn=collate_fn
732 ) 744 )
@@ -743,7 +755,7 @@ def main():
743 for i in range(0, len(missing_data), args.sample_batch_size) 755 for i in range(0, len(missing_data), args.sample_batch_size)
744 ] 756 ]
745 757
746 scheduler = EulerAncestralDiscreteScheduler( 758 scheduler = DPMSolverMultistepScheduler(
747 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 759 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
748 ) 760 )
749 761
@@ -962,6 +974,8 @@ def main():
962 optimizer.step() 974 optimizer.step()
963 if not accelerator.optimizer_step_was_skipped: 975 if not accelerator.optimizer_step_was_skipped:
964 lr_scheduler.step() 976 lr_scheduler.step()
977 if args.use_ema:
978 ema_unet.step(unet)
965 optimizer.zero_grad(set_to_none=True) 979 optimizer.zero_grad(set_to_none=True)
966 980
967 loss = loss.detach().item() 981 loss = loss.detach().item()
@@ -969,9 +983,6 @@ def main():
969 983
970 # Checks if the accelerator has performed an optimization step behind the scenes 984 # Checks if the accelerator has performed an optimization step behind the scenes
971 if accelerator.sync_gradients: 985 if accelerator.sync_gradients:
972 if args.use_ema:
973 ema_unet.step(unet)
974
975 local_progress_bar.update(1) 986 local_progress_bar.update(1)
976 global_progress_bar.update(1) 987 global_progress_bar.update(1)
977 988
diff --git a/infer.py b/infer.py
index 9bc9efe..9b0ec1f 100644
--- a/infer.py
+++ b/infer.py
@@ -8,11 +8,10 @@ from pathlib import Path
8import torch 8import torch
9import json 9import json
10from PIL import Image 10from PIL import Image
11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler 11from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler
12from transformers import CLIPTextModel, CLIPTokenizer 12from transformers import CLIPTextModel, CLIPTokenizer
13from slugify import slugify 13from slugify import slugify
14 14
15from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
17 16
18 17
@@ -21,7 +20,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
21 20
22default_args = { 21default_args = {
23 "model": None, 22 "model": None,
24 "scheduler": "euler_a", 23 "scheduler": "dpmpp",
25 "precision": "fp32", 24 "precision": "fp32",
26 "ti_embeddings_dir": "embeddings_ti", 25 "ti_embeddings_dir": "embeddings_ti",
27 "output_dir": "output/inference", 26 "output_dir": "output/inference",
@@ -65,7 +64,7 @@ def create_args_parser():
65 parser.add_argument( 64 parser.add_argument(
66 "--scheduler", 65 "--scheduler",
67 type=str, 66 type=str,
68 choices=["plms", "ddim", "klms", "euler_a"], 67 choices=["plms", "ddim", "klms", "dpmpp", "euler_a"],
69 ) 68 )
70 parser.add_argument( 69 parser.add_argument(
71 "--precision", 70 "--precision",
@@ -222,6 +221,10 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype):
222 scheduler = DDIMScheduler( 221 scheduler = DDIMScheduler(
223 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False 222 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False
224 ) 223 )
224 elif scheduler == "dpmpp":
225 scheduler = DPMSolverMultistepScheduler(
226 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
227 )
225 else: 228 else:
226 scheduler = EulerAncestralDiscreteScheduler( 229 scheduler = EulerAncestralDiscreteScheduler(
227 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 230 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
@@ -282,7 +285,8 @@ def generate(output_dir, pipeline, args):
282 ).images 285 ).images
283 286
284 for j, image in enumerate(images): 287 for j, image in enumerate(images):
285 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) 288 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png"))
289 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85)
286 290
287 if torch.cuda.is_available(): 291 if torch.cuda.is_available():
288 torch.cuda.empty_cache() 292 torch.cuda.empty_cache()
@@ -312,15 +316,16 @@ class CmdParse(cmd.Cmd):
312 316
313 try: 317 try:
314 args = run_parser(self.parser, default_cmds, elements) 318 args = run_parser(self.parser, default_cmds, elements)
319
320 if len(args.prompt) == 0:
321 print('Try again with a prompt!')
322 return
315 except SystemExit: 323 except SystemExit:
316 self.parser.print_help() 324 self.parser.print_help()
317 except Exception as e: 325 except Exception as e:
318 print(e) 326 print(e)
319 return 327 return
320 328
321 if len(args.prompt) == 0:
322 print('Try again with a prompt!')
323
324 try: 329 try:
325 generate(self.output_dir, self.pipeline, args) 330 generate(self.output_dir, self.pipeline, args)
326 except KeyboardInterrupt: 331 except KeyboardInterrupt:
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 36942f0..ba057ba 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -8,11 +8,20 @@ import PIL
8 8
9from diffusers.configuration_utils import FrozenDict 9from diffusers.configuration_utils import FrozenDict
10from diffusers.utils import is_accelerate_available 10from diffusers.utils import is_accelerate_available
11from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 11from diffusers import (
12 AutoencoderKL,
13 DiffusionPipeline,
14 UNet2DConditionModel,
15 DDIMScheduler,
16 DPMSolverMultistepScheduler,
17 EulerAncestralDiscreteScheduler,
18 EulerDiscreteScheduler,
19 LMSDiscreteScheduler,
20 PNDMScheduler,
21)
12from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
13from diffusers.utils import logging 23from diffusers.utils import logging
14from transformers import CLIPTextModel, CLIPTokenizer 24from transformers import CLIPTextModel, CLIPTokenizer
15from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
16from models.clip.prompt import PromptProcessor 25from models.clip.prompt import PromptProcessor
17 26
18logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -33,7 +42,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
33 text_encoder: CLIPTextModel, 42 text_encoder: CLIPTextModel,
34 tokenizer: CLIPTokenizer, 43 tokenizer: CLIPTokenizer,
35 unet: UNet2DConditionModel, 44 unet: UNet2DConditionModel,
36 scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], 45 scheduler: Union[
46 DDIMScheduler,
47 PNDMScheduler,
48 LMSDiscreteScheduler,
49 EulerDiscreteScheduler,
50 EulerAncestralDiscreteScheduler,
51 DPMSolverMultistepScheduler,
52 ],
37 **kwargs, 53 **kwargs,
38 ): 54 ):
39 super().__init__() 55 super().__init__()
@@ -252,19 +268,14 @@ class VlpnStableDiffusion(DiffusionPipeline):
252 latents = 0.18215 * latents 268 latents = 0.18215 * latents
253 269
254 # expand init_latents for batch_size 270 # expand init_latents for batch_size
255 latents = torch.cat([latents] * batch_size) 271 latents = torch.cat([latents] * batch_size, dim=0)
256 272
257 # get the original timestep using init_timestep 273 # get the original timestep using init_timestep
258 init_timestep = int(num_inference_steps * strength) + offset 274 init_timestep = int(num_inference_steps * strength) + offset
259 init_timestep = min(init_timestep, num_inference_steps) 275 init_timestep = min(init_timestep, num_inference_steps)
260 276
261 if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): 277 timesteps = self.scheduler.timesteps[-init_timestep]
262 timesteps = torch.tensor( 278 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
263 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
264 )
265 else:
266 timesteps = self.scheduler.timesteps[-init_timestep]
267 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
268 279
269 # add noise to latents using the timesteps 280 # add noise to latents using the timesteps
270 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) 281 noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py
deleted file mode 100644
index cef50fe..0000000
--- a/schedulers/scheduling_euler_ancestral_discrete.py
+++ /dev/null
@@ -1,261 +0,0 @@
1# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from dataclasses import dataclass
16from typing import Optional, Tuple, Union
17
18import numpy as np
19import torch
20
21from diffusers.configuration_utils import ConfigMixin, register_to_config
22from diffusers.utils import BaseOutput, deprecate, logging
23from diffusers.schedulers.scheduling_utils import SchedulerMixin
24
25
26logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
28
29@dataclass
30# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
31class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
32 """
33 Output class for the scheduler's step function output.
34
35 Args:
36 prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37 Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
38 denoising loop.
39 pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40 The predicted denoised sample (x_{0}) based on the model output from the current timestep.
41 `pred_original_sample` can be used to preview progress or for guidance.
42 """
43
44 prev_sample: torch.FloatTensor
45 pred_original_sample: Optional[torch.FloatTensor] = None
46
47
48class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
49 """
50 Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
51 https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
52
53 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
54 function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
55 [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
56 [`~ConfigMixin.from_config`] functions.
57
58 Args:
59 num_train_timesteps (`int`): number of diffusion steps used to train the model.
60 beta_start (`float`): the starting `beta` value of inference.
61 beta_end (`float`): the final `beta` value.
62 beta_schedule (`str`):
63 the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
64 `linear` or `scaled_linear`.
65 trained_betas (`np.ndarray`, optional):
66 option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
67
68 """
69
70 @register_to_config
71 def __init__(
72 self,
73 num_train_timesteps: int = 1000,
74 beta_start: float = 0.0001,
75 beta_end: float = 0.02,
76 beta_schedule: str = "linear",
77 trained_betas: Optional[np.ndarray] = None,
78 ):
79 if trained_betas is not None:
80 self.betas = torch.from_numpy(trained_betas)
81 elif beta_schedule == "linear":
82 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
83 elif beta_schedule == "scaled_linear":
84 # this schedule is very specific to the latent diffusion model.
85 self.betas = (
86 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
87 )
88 else:
89 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
90
91 self.alphas = 1.0 - self.betas
92 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
93
94 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
95 sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
96 self.sigmas = torch.from_numpy(sigmas)
97
98 # standard deviation of the initial noise distribution
99 self.init_noise_sigma = self.sigmas.max()
100
101 # setable values
102 self.num_inference_steps = None
103 timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
104 self.timesteps = torch.from_numpy(timesteps)
105 self.is_scale_input_called = False
106
107 def scale_model_input(
108 self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
109 ) -> torch.FloatTensor:
110 """
111 Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
112
113 Args:
114 sample (`torch.FloatTensor`): input sample
115 timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
116
117 Returns:
118 `torch.FloatTensor`: scaled input sample
119 """
120 if isinstance(timestep, torch.Tensor):
121 timestep = timestep.to(self.timesteps.device)
122 step_index = (self.timesteps == timestep).nonzero().item()
123 sigma = self.sigmas[step_index]
124 sample = sample / ((sigma**2 + 1) ** 0.5)
125 self.is_scale_input_called = True
126 return sample
127
128 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
129 """
130 Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
131
132 Args:
133 num_inference_steps (`int`):
134 the number of diffusion steps used when generating samples with a pre-trained model.
135 device (`str` or `torch.device`, optional):
136 the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
137 """
138 self.num_inference_steps = num_inference_steps
139
140 timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
141 sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
142 sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
143 sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
144 self.sigmas = torch.from_numpy(sigmas).to(device=device)
145 self.timesteps = torch.from_numpy(timesteps).to(device=device)
146
147 def step(
148 self,
149 model_output: torch.FloatTensor,
150 timestep: Union[float, torch.FloatTensor],
151 sample: torch.FloatTensor,
152 generator: Optional[torch.Generator] = None,
153 return_dict: bool = True,
154 ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
155 """
156 Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
157 process from the learned model outputs (most often the predicted noise).
158
159 Args:
160 model_output (`torch.FloatTensor`): direct output from learned diffusion model.
161 timestep (`float`): current timestep in the diffusion chain.
162 sample (`torch.FloatTensor`):
163 current instance of sample being created by diffusion process.
164 generator (`torch.Generator`, optional): Random number generator.
165 return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
166
167 Returns:
168 [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
169 [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
170 a `tuple`. When returning a tuple, the first element is the sample tensor.
171
172 """
173
174 if (
175 isinstance(timestep, int)
176 or isinstance(timestep, torch.IntTensor)
177 or isinstance(timestep, torch.LongTensor)
178 ):
179 raise ValueError(
180 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
181 " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
182 " one of the `scheduler.timesteps` as a timestep.",
183 )
184
185 if not self.is_scale_input_called:
186 logger.warn(
187 "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
188 "See `StableDiffusionPipeline` for a usage example."
189 )
190
191 if isinstance(timestep, torch.Tensor):
192 timestep = timestep.to(self.timesteps.device)
193
194 step_index = (self.timesteps == timestep).nonzero().item()
195 sigma = self.sigmas[step_index]
196
197 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
198 pred_original_sample = sample - sigma * model_output
199 sigma_from = self.sigmas[step_index]
200 sigma_to = self.sigmas[step_index + 1]
201 sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
202 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
203
204 # 2. Convert to an ODE derivative
205 derivative = (sample - pred_original_sample) / sigma
206
207 dt = sigma_down - sigma
208
209 prev_sample = sample + derivative * dt
210
211 device = model_output.device if torch.is_tensor(model_output) else "cpu"
212 noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
213 prev_sample = prev_sample + noise * sigma_up
214
215 if not return_dict:
216 return (prev_sample,)
217
218 return EulerAncestralDiscreteSchedulerOutput(
219 prev_sample=prev_sample, pred_original_sample=pred_original_sample
220 )
221
222 def add_noise(
223 self,
224 original_samples: torch.FloatTensor,
225 noise: torch.FloatTensor,
226 timesteps: torch.FloatTensor,
227 ) -> torch.FloatTensor:
228 # Make sure sigmas and timesteps have the same device and dtype as original_samples
229 self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
230 if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
231 # mps does not support float64
232 self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
233 timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
234 else:
235 self.timesteps = self.timesteps.to(original_samples.device)
236 timesteps = timesteps.to(original_samples.device)
237
238 schedule_timesteps = self.timesteps
239
240 if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
241 deprecate(
242 "timesteps as indices",
243 "0.8.0",
244 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
245 " `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
246 " pass values from `scheduler.timesteps` as timesteps.",
247 standard_warn=False,
248 )
249 step_indices = timesteps
250 else:
251 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
252
253 sigma = self.sigmas[step_indices].flatten()
254 while len(sigma.shape) < len(original_samples.shape):
255 sigma = sigma.unsqueeze(-1)
256
257 noisy_samples = original_samples + noise * sigma
258 return noisy_samples
259
260 def __len__(self):
261 return self.config.num_train_timesteps
diff --git a/textual_inversion.py b/textual_inversion.py
index 578c054..999161b 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -15,14 +15,13 @@ import torch.utils.checkpoint
15from accelerate import Accelerator 15from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from PIL import Image 20from PIL import Image
21from tqdm.auto import tqdm 21from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 23from slugify import slugify
24 24
25from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from data.csv import CSVDataModule 26from data.csv import CSVDataModule
28from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
@@ -134,7 +133,7 @@ def parse_args():
134 parser.add_argument( 133 parser.add_argument(
135 "--max_train_steps", 134 "--max_train_steps",
136 type=int, 135 type=int,
137 default=10000, 136 default=None,
138 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 137 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
139 ) 138 )
140 parser.add_argument( 139 parser.add_argument(
@@ -252,6 +251,12 @@ def parse_args():
252 help="Number of samples to generate per batch", 251 help="Number of samples to generate per batch",
253 ) 252 )
254 parser.add_argument( 253 parser.add_argument(
254 "--valid_set_size",
255 type=int,
256 default=None,
257 help="Number of images in the validation dataset."
258 )
259 parser.add_argument(
255 "--train_batch_size", 260 "--train_batch_size",
256 type=int, 261 type=int,
257 default=1, 262 default=1,
@@ -637,7 +642,7 @@ def main():
637 size=args.resolution, 642 size=args.resolution,
638 repeats=args.repeats, 643 repeats=args.repeats,
639 center_crop=args.center_crop, 644 center_crop=args.center_crop,
640 valid_set_size=args.sample_batch_size*args.sample_batches, 645 valid_set_size=args.valid_set_size,
641 num_workers=args.dataloader_num_workers, 646 num_workers=args.dataloader_num_workers,
642 collate_fn=collate_fn 647 collate_fn=collate_fn
643 ) 648 )
diff --git a/training/optimization.py b/training/optimization.py
index 0fd7ec8..0e603fa 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -6,7 +6,7 @@ from diffusers.utils import logging
6logger = logging.get_logger(__name__) 6logger = logging.get_logger(__name__)
7 7
8 8
9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.43, last_epoch=-1): 9def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.4, last_epoch=-1):
10 """ 10 """
11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 11 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 12 a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.