summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-12 08:18:22 +0200
committerVolpeon <git@volpeon.ink>2022-10-12 08:18:22 +0200
commitf5b656d21c5b449eed6ce212e909043c124f79ee (patch)
tree905f20900433f1e77840cd66417395168e0eec7f
parentAdded EMA support to Textual Inversion (diff)
downloadtextual-inversion-diff-f5b656d21c5b449eed6ce212e909043c124f79ee.tar.gz
textual-inversion-diff-f5b656d21c5b449eed6ce212e909043c124f79ee.tar.bz2
textual-inversion-diff-f5b656d21c5b449eed6ce212e909043c124f79ee.zip
Various updates
-rw-r--r--data/csv.py3
-rw-r--r--dreambooth.py53
-rw-r--r--environment.yaml4
-rw-r--r--infer.py10
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py11
-rw-r--r--schedulers/scheduling_euler_a.py210
-rw-r--r--textual_inversion.py74
7 files changed, 142 insertions, 223 deletions
diff --git a/data/csv.py b/data/csv.py
index 8637ac1..253ce9e 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -68,13 +68,12 @@ class CSVDataModule(pl.LightningDataModule):
68 item.nprompt if "nprompt" in item else "" 68 item.nprompt if "nprompt" in item else ""
69 ) 69 )
70 for item in data 70 for item in data
71 if "skip" not in item or item.skip != "x"
72 for i in range(image_multiplier) 71 for i in range(image_multiplier)
73 ] 72 ]
74 73
75 def prepare_data(self): 74 def prepare_data(self):
76 metadata = pd.read_csv(self.data_file) 75 metadata = pd.read_csv(self.data_file)
77 metadata = list(metadata.itertuples()) 76 metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"]
78 num_images = len(metadata) 77 num_images = len(metadata)
79 78
80 valid_set_size = int(num_images * 0.2) 79 valid_set_size = int(num_images * 0.2)
diff --git a/dreambooth.py b/dreambooth.py
index 02f83c6..775aea2 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--max_train_steps", 113 "--max_train_steps",
114 type=int, 114 type=int,
115 default=5000, 115 default=3000,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -150,7 +150,7 @@ def parse_args():
150 parser.add_argument( 150 parser.add_argument(
151 "--lr_warmup_steps", 151 "--lr_warmup_steps",
152 type=int, 152 type=int,
153 default=600, 153 default=500,
154 help="Number of steps for the warmup in the lr scheduler." 154 help="Number of steps for the warmup in the lr scheduler."
155 ) 155 )
156 parser.add_argument( 156 parser.add_argument(
@@ -167,7 +167,7 @@ def parse_args():
167 parser.add_argument( 167 parser.add_argument(
168 "--ema_power", 168 "--ema_power",
169 type=float, 169 type=float,
170 default=1.0 170 default=7 / 8
171 ) 171 )
172 parser.add_argument( 172 parser.add_argument(
173 "--ema_max_decay", 173 "--ema_max_decay",
@@ -468,20 +468,20 @@ def main():
468 if args.tokenizer_name: 468 if args.tokenizer_name:
469 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 469 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
470 elif args.pretrained_model_name_or_path: 470 elif args.pretrained_model_name_or_path:
471 tokenizer = CLIPTokenizer.from_pretrained( 471 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
472 args.pretrained_model_name_or_path + '/tokenizer'
473 )
474 472
475 # Load models and create wrapper for stable diffusion 473 # Load models and create wrapper for stable diffusion
476 text_encoder = CLIPTextModel.from_pretrained( 474 text_encoder = CLIPTextModel.from_pretrained(
477 args.pretrained_model_name_or_path + '/text_encoder', 475 args.pretrained_model_name_or_path, subfolder='text_encoder')
478 ) 476 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
479 vae = AutoencoderKL.from_pretrained( 477 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
480 args.pretrained_model_name_or_path + '/vae', 478
481 ) 479 ema_unet = EMAModel(
482 unet = UNet2DConditionModel.from_pretrained( 480 unet,
483 args.pretrained_model_name_or_path + '/unet', 481 inv_gamma=args.ema_inv_gamma,
484 ) 482 power=args.ema_power,
483 max_value=args.ema_max_decay
484 ) if args.use_ema else None
485 485
486 if args.gradient_checkpointing: 486 if args.gradient_checkpointing:
487 unet.enable_gradient_checkpointing() 487 unet.enable_gradient_checkpointing()
@@ -538,7 +538,7 @@ def main():
538 pixel_values += [example["class_images"] for example in examples] 538 pixel_values += [example["class_images"] for example in examples]
539 539
540 pixel_values = torch.stack(pixel_values) 540 pixel_values = torch.stack(pixel_values)
541 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) 541 pixel_values = pixel_values.to(memory_format=torch.contiguous_format)
542 542
543 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids 543 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
544 544
@@ -629,16 +629,10 @@ def main():
629 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 629 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
630 ) 630 )
631 631
632 ema_unet = EMAModel(
633 unet,
634 inv_gamma=args.ema_inv_gamma,
635 power=args.ema_power,
636 max_value=args.ema_max_decay
637 ) if args.use_ema else None
638
639 # Move text_encoder and vae to device 632 # Move text_encoder and vae to device
640 text_encoder.to(accelerator.device) 633 text_encoder.to(accelerator.device)
641 vae.to(accelerator.device) 634 vae.to(accelerator.device)
635 ema_unet.averaged_model.to(accelerator.device)
642 636
643 # Keep text_encoder and vae in eval mode as we don't train these 637 # Keep text_encoder and vae in eval mode as we don't train these
644 text_encoder.eval() 638 text_encoder.eval()
@@ -698,7 +692,7 @@ def main():
698 disable=not accelerator.is_local_main_process, 692 disable=not accelerator.is_local_main_process,
699 dynamic_ncols=True 693 dynamic_ncols=True
700 ) 694 )
701 local_progress_bar.set_description("Batch X out of Y") 695 local_progress_bar.set_description("Epoch X / Y")
702 696
703 global_progress_bar = tqdm( 697 global_progress_bar = tqdm(
704 range(args.max_train_steps + val_steps), 698 range(args.max_train_steps + val_steps),
@@ -709,7 +703,7 @@ def main():
709 703
710 try: 704 try:
711 for epoch in range(num_epochs): 705 for epoch in range(num_epochs):
712 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") 706 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
713 local_progress_bar.reset() 707 local_progress_bar.reset()
714 708
715 unet.train() 709 unet.train()
@@ -720,9 +714,8 @@ def main():
720 for step, batch in enumerate(train_dataloader): 714 for step, batch in enumerate(train_dataloader):
721 with accelerator.accumulate(unet): 715 with accelerator.accumulate(unet):
722 # Convert images to latent space 716 # Convert images to latent space
723 with torch.no_grad(): 717 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
724 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 718 latents = latents * 0.18215
725 latents = latents * 0.18215
726 719
727 # Sample noise that we'll add to the latents 720 # Sample noise that we'll add to the latents
728 noise = torch.randn(latents.shape).to(latents.device) 721 noise = torch.randn(latents.shape).to(latents.device)
@@ -737,8 +730,7 @@ def main():
737 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 730 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
738 731
739 # Get the text embedding for conditioning 732 # Get the text embedding for conditioning
740 with torch.no_grad(): 733 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
741 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
742 734
743 # Predict the noise residual 735 # Predict the noise residual
744 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 736 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -840,7 +832,8 @@ def main():
840 global_progress_bar.clear() 832 global_progress_bar.clear()
841 833
842 if min_val_loss > val_loss: 834 if min_val_loss > val_loss:
843 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 835 accelerator.print(
836 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
844 min_val_loss = val_loss 837 min_val_loss = val_loss
845 838
846 if sample_checkpoint and accelerator.is_main_process: 839 if sample_checkpoint and accelerator.is_main_process:
diff --git a/environment.yaml b/environment.yaml
index 5ecc5a8..de35645 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -6,7 +6,7 @@ dependencies:
6 - cudatoolkit=11.3 6 - cudatoolkit=11.3
7 - numpy=1.22.3 7 - numpy=1.22.3
8 - pip=20.3 8 - pip=20.3
9 - python=3.8.10 9 - python=3.9.13
10 - pytorch=1.12.1 10 - pytorch=1.12.1
11 - torchvision=0.13.1 11 - torchvision=0.13.1
12 - pandas=1.4.3 12 - pandas=1.4.3
@@ -32,6 +32,6 @@ dependencies:
32 - test-tube>=0.7.5 32 - test-tube>=0.7.5
33 - torch-fidelity==0.3.0 33 - torch-fidelity==0.3.0
34 - torchmetrics==0.9.3 34 - torchmetrics==0.9.3
35 - transformers==4.22.2 35 - transformers==4.23.1
36 - triton==2.0.0.dev20220924 36 - triton==2.0.0.dev20220924
37 - xformers==0.0.13 37 - xformers==0.0.13
diff --git a/infer.py b/infer.py
index 70851fd..5bd4abc 100644
--- a/infer.py
+++ b/infer.py
@@ -22,7 +22,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
22default_args = { 22default_args = {
23 "model": None, 23 "model": None,
24 "scheduler": "euler_a", 24 "scheduler": "euler_a",
25 "precision": "fp16", 25 "precision": "fp32",
26 "embeddings_dir": "embeddings", 26 "embeddings_dir": "embeddings",
27 "output_dir": "output/inference", 27 "output_dir": "output/inference",
28 "config": None, 28 "config": None,
@@ -205,10 +205,10 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir):
205def create_pipeline(model, scheduler, embeddings_dir, dtype): 205def create_pipeline(model, scheduler, embeddings_dir, dtype):
206 print("Loading Stable Diffusion pipeline...") 206 print("Loading Stable Diffusion pipeline...")
207 207
208 tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) 208 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='/tokenizer', torch_dtype=dtype)
209 text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) 209 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='/text_encoder', torch_dtype=dtype)
210 vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) 210 vae = AutoencoderKL.from_pretrained(model, subfolder='/vae', torch_dtype=dtype)
211 unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) 211 unet = UNet2DConditionModel.from_pretrained(model, subfolder='/unet', torch_dtype=dtype)
212 212
213 load_embeddings(tokenizer, text_encoder, embeddings_dir) 213 load_embeddings(tokenizer, text_encoder, embeddings_dir)
214 214
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index bfecd1c..8927a78 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 12from diffusers.utils import logging
13from transformers import CLIPTextModel, CLIPTokenizer 13from transformers import CLIPTextModel, CLIPTokenizer
14from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward 14from schedulers.scheduling_euler_a import EulerAScheduler
15 15
16logger = logging.get_logger(__name__) # pylint: disable=invalid-name 16logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17 17
@@ -284,10 +284,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
284 284
285 noise_pred = None 285 noise_pred = None
286 if isinstance(self.scheduler, EulerAScheduler): 286 if isinstance(self.scheduler, EulerAScheduler):
287 sigma = t.reshape(1) 287 c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size)
288 sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) 288 eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample
289 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, 289 noise_pred = latent_model_input + eps * c_out
290 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas)
291 else: 290 else:
292 # predict the noise residual 291 # predict the noise residual
293 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 292 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -305,7 +304,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
305 image = self.vae.decode(latents).sample 304 image = self.vae.decode(latents).sample
306 305
307 image = (image / 2 + 0.5).clamp(0, 1) 306 image = (image / 2 + 0.5).clamp(0, 1)
308 image = image.cpu().permute(0, 2, 3, 1).numpy() 307 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
309 308
310 if output_type == "pil": 309 if output_type == "pil":
311 image = self.numpy_to_pil(image) 310 image = self.numpy_to_pil(image)
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index 13ea6b3..6abe971 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -7,113 +7,6 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
7from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput 7from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
8 8
9 9
10'''
11helper functions: append_zero(),
12 t_to_sigma(),
13 get_sigmas(),
14 append_dims(),
15 CFGDenoiserForward(),
16 get_scalings(),
17 DSsigma_to_t(),
18 DiscreteEpsDDPMDenoiserForward(),
19 to_d(),
20 get_ancestral_step()
21need cleaning
22'''
23
24
25def append_zero(x):
26 return torch.cat([x, x.new_zeros([1])])
27
28
29def t_to_sigma(t, sigmas):
30 t = t.float()
31 low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
32 return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]
33
34
35def get_sigmas(sigmas, n=None):
36 if n is None:
37 return append_zero(sigmas.flip(0))
38 t_max = len(sigmas) - 1 # = 999
39 t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype)
40 return append_zero(t_to_sigma(t, sigmas))
41
42# from k_samplers utils.py
43
44
45def append_dims(x, target_dims):
46 """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
47 dims_to_append = target_dims - x.ndim
48 if dims_to_append < 0:
49 raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
50 return x[(...,) + (None,) * dims_to_append]
51
52
53def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None):
54 # x_in = torch.cat([x] * 2)#A# concat the latent
55 # sigma_in = torch.cat([sigma] * 2) #A# concat sigma
56 # cond_in = torch.cat([uncond, cond])
57 # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
58 # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2)
59 # return uncond + (cond - uncond) * cond_scale
60 noise_pred = DiscreteEpsDDPMDenoiserForward(
61 Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in)
62 return noise_pred
63
64# from k_samplers sampling.py
65
66
67def to_d(x, sigma, denoised):
68 """Converts a denoiser output to a Karras ODE derivative."""
69 return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim)
70
71
72def get_scalings(sigma):
73 sigma_data = 1.
74 c_out = -sigma
75 c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
76 return c_out, c_in
77
78# DiscreteSchedule DS
79
80
81def DSsigma_to_t(sigma, quantize=False, DSsigmas=None):
82 dists = torch.abs(sigma - DSsigmas[:, None])
83 if quantize:
84 return torch.argmin(dists, dim=0).view(sigma.shape)
85 low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
86 low, high = DSsigmas[low_idx], DSsigmas[high_idx]
87 w = (low - sigma) / (low - high)
88 w = w.clamp(0, 1)
89 t = (1 - w) * low_idx + w * high_idx
90 return t.view(sigma.shape)
91
92
93def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs):
94 sigma = sigma.to(dtype=input.dtype, device=Unet.device)
95 DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device)
96 c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)]
97 # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}")
98 eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas),
99 encoder_hidden_states=kwargs['cond']).sample
100 return input + eps * c_out
101
102
103# from k_samplers sampling.py
104def get_ancestral_step(sigma_from, sigma_to):
105 """Calculates the noise level (sigma_down) to step down to and the amount
106 of noise to add (sigma_up) when doing an ancestral sampling step."""
107 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
108 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
109 return sigma_down, sigma_up
110
111
112'''
113Euler Ancestral Scheduler
114'''
115
116
117class EulerAScheduler(SchedulerMixin, ConfigMixin): 10class EulerAScheduler(SchedulerMixin, ConfigMixin):
118 """ 11 """
119 Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and 12 Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
@@ -154,20 +47,24 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
154 beta_end: float = 0.02, 47 beta_end: float = 0.02,
155 beta_schedule: str = "linear", 48 beta_schedule: str = "linear",
156 trained_betas: Optional[np.ndarray] = None, 49 trained_betas: Optional[np.ndarray] = None,
50 tensor_format: str = "pt",
51 num_inference_steps=None,
52 device='cuda'
157 ): 53 ):
158 if trained_betas is not None: 54 if trained_betas is not None:
159 self.betas = torch.from_numpy(trained_betas) 55 self.betas = torch.from_numpy(trained_betas).to(device)
160 if beta_schedule == "linear": 56 if beta_schedule == "linear":
161 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 57 self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device)
162 elif beta_schedule == "scaled_linear": 58 elif beta_schedule == "scaled_linear":
163 # this schedule is very specific to the latent diffusion model. 59 # this schedule is very specific to the latent diffusion model.
164 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 60 self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps,
165 elif beta_schedule == "squaredcos_cap_v2": 61 dtype=torch.float32, device=device) ** 2
166 # Glide cosine schedule
167 self.betas = betas_for_alpha_bar(num_train_timesteps)
168 else: 62 else:
169 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 63 raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
170 64
65 self.device = device
66 self.tensor_format = tensor_format
67
171 self.alphas = 1.0 - self.betas 68 self.alphas = 1.0 - self.betas
172 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 69 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
173 70
@@ -175,8 +72,12 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
175 self.init_noise_sigma = 1.0 72 self.init_noise_sigma = 1.0
176 73
177 # setable values 74 # setable values
178 self.num_inference_steps = None 75 self.num_inference_steps = num_inference_steps
179 self.timesteps = np.arange(0, num_train_timesteps)[::-1] 76 self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
77 # get sigmas
78 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
79 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
80 self.set_format(tensor_format=tensor_format)
180 81
181 # A# take number of steps as input 82 # A# take number of steps as input
182 # A# store 1) number of steps 2) timesteps 3) schedule 83 # A# store 1) number of steps 2) timesteps 3) schedule
@@ -192,7 +93,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
192 93
193 self.num_inference_steps = num_inference_steps 94 self.num_inference_steps = num_inference_steps
194 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 95 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
195 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) 96 self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps)
196 self.timesteps = self.sigmas[:-1] 97 self.timesteps = self.sigmas[:-1]
197 self.is_scale_input_called = False 98 self.is_scale_input_called = False
198 99
@@ -251,8 +152,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
251 s_prev = self.sigmas[step_prev_index] 152 s_prev = self.sigmas[step_prev_index]
252 latents = sample 153 latents = sample
253 154
254 sigma_down, sigma_up = get_ancestral_step(s, s_prev) 155 sigma_down, sigma_up = self.get_ancestral_step(s, s_prev)
255 d = to_d(latents, s, model_output) 156 d = self.to_d(latents, s, model_output)
256 dt = sigma_down - s 157 dt = sigma_down - s
257 latents = latents + d * dt 158 latents = latents + d * dt
258 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, 159 latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype,
@@ -313,3 +214,76 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
313 noisy_samples = original_samples + noise * sigma 214 noisy_samples = original_samples + noise * sigma
314 self.is_scale_input_called = True 215 self.is_scale_input_called = True
315 return noisy_samples 216 return noisy_samples
217
218 # from k_samplers sampling.py
219
220 def get_ancestral_step(self, sigma_from, sigma_to):
221 """Calculates the noise level (sigma_down) to step down to and the amount
222 of noise to add (sigma_up) when doing an ancestral sampling step."""
223 sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
224 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
225 return sigma_down, sigma_up
226
227 def t_to_sigma(self, t, sigmas):
228 t = t.float()
229 low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
230 return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx]
231
232 def append_zero(self, x):
233 return torch.cat([x, x.new_zeros([1])])
234
235 def get_sigmas(self, sigmas, n=None):
236 if n is None:
237 return self.append_zero(sigmas.flip(0))
238 t_max = len(sigmas) - 1 # = 999
239 device = self.device
240 t = torch.linspace(t_max, 0, n, device=device)
241 # t = torch.linspace(t_max, 0, n, device=sigmas.device)
242 return self.append_zero(self.t_to_sigma(t, sigmas))
243
244 # from k_samplers utils.py
245 def append_dims(self, x, target_dims):
246 """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
247 dims_to_append = target_dims - x.ndim
248 if dims_to_append < 0:
249 raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
250 return x[(...,) + (None,) * dims_to_append]
251
252 # from k_samplers sampling.py
253 def to_d(self, x, sigma, denoised):
254 """Converts a denoiser output to a Karras ODE derivative."""
255 return (x - denoised) / self.append_dims(sigma, x.ndim)
256
257 def get_scalings(self, sigma):
258 sigma_data = 1.
259 c_out = -sigma
260 c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
261 return c_out, c_in
262
263 # DiscreteSchedule DS
264 def DSsigma_to_t(self, sigma, quantize=None):
265 # quantize = self.quantize if quantize is None else quantize
266 quantize = False
267 dists = torch.abs(sigma - self.DSsigmas[:, None])
268 if quantize:
269 return torch.argmin(dists, dim=0).view(sigma.shape)
270 low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
271 low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx]
272 w = (low - sigma) / (low - high)
273 w = w.clamp(0, 1)
274 t = (1 - w) * low_idx + w * high_idx
275 return t.view(sigma.shape)
276
277 def prepare_input(self, latent_in, t, batch_size):
278 sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1
279
280 sigma_in = torch.cat([sigma] * 2 * batch_size)
281 # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas)
282 # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in)
283 c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)]
284
285 sigma_in = self.DSsigma_to_t(sigma_in)
286 # s_in = latent_in.new_ones([latent_in.shape[0]])
287 # sigma_in = sigma_in * s_in
288
289 return c_out, c_in, sigma_in
diff --git a/textual_inversion.py b/textual_inversion.py
index e6d856a..3a3741d 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -17,7 +17,6 @@ from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
20from diffusers.training_utils import EMAModel
21from PIL import Image 20from PIL import Image
22from tqdm.auto import tqdm 21from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer 22from transformers import CLIPTextModel, CLIPTokenizer
@@ -112,7 +111,7 @@ def parse_args():
112 parser.add_argument( 111 parser.add_argument(
113 "--max_train_steps", 112 "--max_train_steps",
114 type=int, 113 type=int,
115 default=5000, 114 default=3000,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 115 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 116 )
118 parser.add_argument( 117 parser.add_argument(
@@ -150,31 +149,10 @@ def parse_args():
150 parser.add_argument( 149 parser.add_argument(
151 "--lr_warmup_steps", 150 "--lr_warmup_steps",
152 type=int, 151 type=int,
153 default=600, 152 default=500,
154 help="Number of steps for the warmup in the lr scheduler." 153 help="Number of steps for the warmup in the lr scheduler."
155 ) 154 )
156 parser.add_argument( 155 parser.add_argument(
157 "--use_ema",
158 action="store_true",
159 default=True,
160 help="Whether to use EMA model."
161 )
162 parser.add_argument(
163 "--ema_inv_gamma",
164 type=float,
165 default=1.0
166 )
167 parser.add_argument(
168 "--ema_power",
169 type=float,
170 default=1.0
171 )
172 parser.add_argument(
173 "--ema_max_decay",
174 type=float,
175 default=0.9999
176 )
177 parser.add_argument(
178 "--use_8bit_adam", 156 "--use_8bit_adam",
179 action="store_true", 157 action="store_true",
180 help="Whether or not to use 8-bit Adam from bitsandbytes." 158 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -348,7 +326,6 @@ class Checkpointer:
348 unet, 326 unet,
349 tokenizer, 327 tokenizer,
350 text_encoder, 328 text_encoder,
351 ema_text_encoder,
352 placeholder_token, 329 placeholder_token,
353 placeholder_token_id, 330 placeholder_token_id,
354 output_dir: Path, 331 output_dir: Path,
@@ -363,7 +340,6 @@ class Checkpointer:
363 self.unet = unet 340 self.unet = unet
364 self.tokenizer = tokenizer 341 self.tokenizer = tokenizer
365 self.text_encoder = text_encoder 342 self.text_encoder = text_encoder
366 self.ema_text_encoder = ema_text_encoder
367 self.placeholder_token = placeholder_token 343 self.placeholder_token = placeholder_token
368 self.placeholder_token_id = placeholder_token_id 344 self.placeholder_token_id = placeholder_token_id
369 self.output_dir = output_dir 345 self.output_dir = output_dir
@@ -380,8 +356,7 @@ class Checkpointer:
380 checkpoints_path = self.output_dir.joinpath("checkpoints") 356 checkpoints_path = self.output_dir.joinpath("checkpoints")
381 checkpoints_path.mkdir(parents=True, exist_ok=True) 357 checkpoints_path.mkdir(parents=True, exist_ok=True)
382 358
383 unwrapped = self.accelerator.unwrap_model( 359 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
384 self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder)
385 360
386 # Save a checkpoint 361 # Save a checkpoint
387 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 362 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
@@ -400,8 +375,7 @@ class Checkpointer:
400 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 375 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
401 samples_path = Path(self.output_dir).joinpath("samples") 376 samples_path = Path(self.output_dir).joinpath("samples")
402 377
403 unwrapped = self.accelerator.unwrap_model( 378 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
404 self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder)
405 scheduler = EulerAScheduler( 379 scheduler = EulerAScheduler(
406 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" 380 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
407 ) 381 )
@@ -507,9 +481,7 @@ def main():
507 if args.tokenizer_name: 481 if args.tokenizer_name:
508 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 482 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
509 elif args.pretrained_model_name_or_path: 483 elif args.pretrained_model_name_or_path:
510 tokenizer = CLIPTokenizer.from_pretrained( 484 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
511 args.pretrained_model_name_or_path + '/tokenizer'
512 )
513 485
514 # Add the placeholder token in tokenizer 486 # Add the placeholder token in tokenizer
515 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 487 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
@@ -530,15 +502,10 @@ def main():
530 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 502 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
531 503
532 # Load models and create wrapper for stable diffusion 504 # Load models and create wrapper for stable diffusion
533 text_encoder = CLIPTextModel.from_pretrained( 505 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
534 args.pretrained_model_name_or_path + '/text_encoder', 506 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
535 )
536 vae = AutoencoderKL.from_pretrained(
537 args.pretrained_model_name_or_path + '/vae',
538 )
539 unet = UNet2DConditionModel.from_pretrained( 507 unet = UNet2DConditionModel.from_pretrained(
540 args.pretrained_model_name_or_path + '/unet', 508 args.pretrained_model_name_or_path, subfolder='unet')
541 )
542 509
543 if args.gradient_checkpointing: 510 if args.gradient_checkpointing:
544 unet.enable_gradient_checkpointing() 511 unet.enable_gradient_checkpointing()
@@ -707,13 +674,6 @@ def main():
707 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 674 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
708 ) 675 )
709 676
710 ema_text_encoder = EMAModel(
711 text_encoder,
712 inv_gamma=args.ema_inv_gamma,
713 power=args.ema_power,
714 max_value=args.ema_max_decay
715 ) if args.use_ema else None
716
717 # Move vae and unet to device 677 # Move vae and unet to device
718 vae.to(accelerator.device) 678 vae.to(accelerator.device)
719 unet.to(accelerator.device) 679 unet.to(accelerator.device)
@@ -757,7 +717,6 @@ def main():
757 unet=unet, 717 unet=unet,
758 tokenizer=tokenizer, 718 tokenizer=tokenizer,
759 text_encoder=text_encoder, 719 text_encoder=text_encoder,
760 ema_text_encoder=ema_text_encoder,
761 placeholder_token=args.placeholder_token, 720 placeholder_token=args.placeholder_token,
762 placeholder_token_id=placeholder_token_id, 721 placeholder_token_id=placeholder_token_id,
763 output_dir=basepath, 722 output_dir=basepath,
@@ -777,7 +736,7 @@ def main():
777 disable=not accelerator.is_local_main_process, 736 disable=not accelerator.is_local_main_process,
778 dynamic_ncols=True 737 dynamic_ncols=True
779 ) 738 )
780 local_progress_bar.set_description("Batch X out of Y") 739 local_progress_bar.set_description("Epoch X / Y")
781 740
782 global_progress_bar = tqdm( 741 global_progress_bar = tqdm(
783 range(args.max_train_steps + val_steps), 742 range(args.max_train_steps + val_steps),
@@ -788,7 +747,7 @@ def main():
788 747
789 try: 748 try:
790 for epoch in range(num_epochs): 749 for epoch in range(num_epochs):
791 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") 750 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
792 local_progress_bar.reset() 751 local_progress_bar.reset()
793 752
794 text_encoder.train() 753 text_encoder.train()
@@ -799,9 +758,8 @@ def main():
799 for step, batch in enumerate(train_dataloader): 758 for step, batch in enumerate(train_dataloader):
800 with accelerator.accumulate(text_encoder): 759 with accelerator.accumulate(text_encoder):
801 # Convert images to latent space 760 # Convert images to latent space
802 with torch.no_grad(): 761 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
803 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 762 latents = latents * 0.18215
804 latents = latents * 0.18215
805 763
806 # Sample noise that we'll add to the latents 764 # Sample noise that we'll add to the latents
807 noise = torch.randn(latents.shape).to(latents.device) 765 noise = torch.randn(latents.shape).to(latents.device)
@@ -859,9 +817,6 @@ def main():
859 817
860 # Checks if the accelerator has performed an optimization step behind the scenes 818 # Checks if the accelerator has performed an optimization step behind the scenes
861 if accelerator.sync_gradients: 819 if accelerator.sync_gradients:
862 if args.use_ema:
863 ema_text_encoder.step(unet)
864
865 local_progress_bar.update(1) 820 local_progress_bar.update(1)
866 global_progress_bar.update(1) 821 global_progress_bar.update(1)
867 822
@@ -881,8 +836,6 @@ def main():
881 }) 836 })
882 837
883 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 838 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
884 if args.use_ema:
885 logs["ema_decay"] = ema_text_encoder.decay
886 839
887 accelerator.log(logs, step=global_step) 840 accelerator.log(logs, step=global_step)
888 841
@@ -937,7 +890,8 @@ def main():
937 global_progress_bar.clear() 890 global_progress_bar.clear()
938 891
939 if min_val_loss > val_loss: 892 if min_val_loss > val_loss:
940 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 893 accelerator.print(
894 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
941 checkpointer.checkpoint(global_step + global_step_offset, "milestone") 895 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
942 min_val_loss = val_loss 896 min_val_loss = val_loss
943 897