From 94b676d91382267e7429bd68362019868affd9d1 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 13 Feb 2023 17:19:18 +0100 Subject: Update --- data/csv.py | 8 +-- infer.py | 14 ++--- .../stable_diffusion/vlpn_stable_diffusion.py | 69 ++++++++++++---------- train_dreambooth.py | 10 ++-- train_lora.py | 10 ++-- train_ti.py | 19 +++--- training/functional.py | 2 +- training/strategy/lora.py | 2 +- training/strategy/ti.py | 2 +- util.py | 2 +- 10 files changed, 73 insertions(+), 65 deletions(-) diff --git a/data/csv.py b/data/csv.py index b4c81d7..c5902ed 100644 --- a/data/csv.py +++ b/data/csv.py @@ -42,7 +42,7 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): def generate_buckets( - items: list[str], + items: Union[list[str], list[Path]], base_size: int, step_size: int = 64, max_pixels: Optional[int] = None, @@ -188,7 +188,7 @@ class VlpnDataModule(): raise ValueError("data_file must be a file") self.data_root = self.data_file.parent - self.class_root = self.data_root.joinpath(class_subdir) + self.class_root = self.data_root / class_subdir self.class_root.mkdir(parents=True, exist_ok=True) self.num_class_images = num_class_images @@ -218,7 +218,7 @@ class VlpnDataModule(): return [ VlpnDataItem( - self.data_root.joinpath(image.format(item["image"])), + self.data_root / image.format(item["image"]), None, prompt_to_keywords( prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), @@ -249,7 +249,7 @@ class VlpnDataModule(): return [ VlpnDataItem( item.instance_image_path, - self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), + self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", item.prompt, item.cprompt, item.nprompt, diff --git a/infer.py b/infer.py index 42b4e2d..aa75ee5 100644 --- a/infer.py +++ b/infer.py @@ -264,16 +264,16 @@ def generate(output_dir: Path, pipeline, args): if len(args.prompt) != 1: if len(args.project) != 0: - output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") + output_dir = output_dir / f"{now}_{slugify(args.project)}" else: - output_dir = output_dir.joinpath(now) + output_dir = output_dir / now for prompt in args.prompt: - dir = output_dir.joinpath(slugify(prompt)[:100]) + dir = output_dir / slugify(prompt)[:100] dir.mkdir(parents=True, exist_ok=True) image_dir.append(dir) else: - output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") + output_dir = output_dir / f"{now}_{slugify(args.prompt[0])[:100]}" output_dir.mkdir(parents=True, exist_ok=True) image_dir.append(output_dir) @@ -332,9 +332,9 @@ def generate(output_dir: Path, pipeline, args): basename = f"{seed}_{j // len(args.prompt)}" dir = image_dir[j % len(args.prompt)] - image.save(dir.joinpath(f"{basename}.png")) - image.save(dir.joinpath(f"{basename}.jpg"), quality=85) - with open(dir.joinpath(f"{basename}.txt"), 'w') as f: + image.save(dir / f"{basename}.png") + image.save(dir / f"{basename}.jpg", quality=85) + with open(dir / f"{basename}.txt", 'w') as f: f.write(prompt[j % len(args.prompt)]) if torch.cuda.is_available(): diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 66566b0..cb09fe1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional, Union, Callable import numpy as np import torch -import torchvision.transforms as T +import torch.nn.functional as F import PIL from diffusers.configuration_utils import FrozenDict @@ -39,6 +39,27 @@ def preprocess(image): return 2.0 * image - 1.0 +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + + return img + + class CrossAttnStoreProcessor: def __init__(self): self.attention_probs = None @@ -46,13 +67,17 @@ class CrossAttnStoreProcessor: def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.cross_attention_norm: + encoder_hidden_states = attn.norm_cross(encoder_hidden_states) + key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) @@ -510,12 +535,12 @@ class VlpnStableDiffusion(DiffusionPipeline): # in https://arxiv.org/pdf/2210.00939.pdf if do_classifier_free_guidance: # DDIM-like prediction of x0 - pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t) + pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) # get the stored attention maps uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) # self-attention-based degrading of latents degraded_latents = self.sag_masking( - pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t) + pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) ) uncond_emb, _ = prompt_embeds.chunk(2) # forward and give guidance @@ -523,12 +548,12 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) else: # DDIM-like prediction of x0 - pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t) + pred_x0 = self.pred_x0(latents, noise_pred, t) # get the stored attention maps cond_attn = store_processor.attention_probs # self-attention-based degrading of latents degraded_latents = self.sag_masking( - pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t) + pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) ) # forward and give guidance degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample @@ -578,8 +603,7 @@ class VlpnStableDiffusion(DiffusionPipeline): attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) # Blur according to the self-attention mask - transform = T.GaussianBlur(kernel_size=9, sigma=1.0) - degraded_latents = transform(original_latents) + degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) # Noise it again to match the noise level @@ -588,19 +612,11 @@ class VlpnStableDiffusion(DiffusionPipeline): return degraded_latents # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step - def pred_x0_from_eps(self, sample, model_output, timestep): - # 1. get previous step value (=t-1) - # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps - - # 2. compute alphas, betas + # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) + def pred_x0(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - # alpha_prod_t_prev = ( - # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod - # ) beta_prod_t = 1 - alpha_prod_t - # 3. compute predicted original sample from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.scheduler.config.prediction_type == "sample": @@ -614,24 +630,13 @@ class VlpnStableDiffusion(DiffusionPipeline): f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," " or `v_prediction`" ) - # # 4. Clip "predicted x_0" - # if self.scheduler.config.clip_sample: - # pred_original_sample = torch.clamp(pred_original_sample, -1, 1) return pred_original_sample - def pred_eps_from_noise(self, sample, model_output, timestep): - # 1. get previous step value (=t-1) - # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps - - # 2. compute alphas, betas + def pred_epsilon(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] - # alpha_prod_t_prev = ( - # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod - # ) beta_prod_t = 1 - alpha_prod_t - # 3. compute predicted eps from model output if self.scheduler.config.prediction_type == "epsilon": pred_eps = model_output elif self.scheduler.config.prediction_type == "sample": diff --git a/train_dreambooth.py b/train_dreambooth.py index 8ac70e8..4c1ec31 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -432,7 +432,7 @@ def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir = Path(args.output_dir) / slugify(args.project) / now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -448,7 +448,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 @@ -513,8 +513,8 @@ def main(): prior_loss_weight=args.prior_loss_weight, ) - checkpoint_output_dir = output_dir.joinpath("model") - sample_output_dir = output_dir.joinpath(f"samples") + checkpoint_output_dir = output_dir / "model" + sample_output_dir = output_dir / "samples" datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -596,7 +596,7 @@ def main(): sample_image_size=args.sample_image_size, ) - plot_metrics(metrics, output_dir.joinpath("lr.png")) + plot_metrics(metrics, output_dir / "lr.png") if __name__ == "__main__": diff --git a/train_lora.py b/train_lora.py index 5fd05cc..a8c1cf6 100644 --- a/train_lora.py +++ b/train_lora.py @@ -392,7 +392,7 @@ def main(): args = parse_args() now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir = Path(args.output_dir) / slugify(args.project) / now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -408,7 +408,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 @@ -489,8 +489,8 @@ def main(): prior_loss_weight=args.prior_loss_weight, ) - checkpoint_output_dir = output_dir.joinpath("model") - sample_output_dir = output_dir.joinpath(f"samples") + checkpoint_output_dir = output_dir / "model" + sample_output_dir = output_dir/"samples" datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -562,7 +562,7 @@ def main(): sample_image_size=args.sample_image_size, ) - plot_metrics(metrics, output_dir.joinpath("lr.png")) + plot_metrics(metrics, output_dir/"lr.png") if __name__ == "__main__": diff --git a/train_ti.py b/train_ti.py index c79dfa2..171d085 100644 --- a/train_ti.py +++ b/train_ti.py @@ -143,7 +143,7 @@ def parse_args(): parser.add_argument( "--num_buckets", type=int, - default=4, + default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( @@ -485,6 +485,9 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.train_data_template): raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + else: + if isinstance(args.train_data_template, list): + raise ValueError("--train_data_template can't be a list in simultaneous mode") if isinstance(args.collection, str): args.collection = [args.collection] @@ -503,7 +506,7 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir = Path(args.output_dir)/slugify(args.project)/now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -519,7 +522,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) + logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 @@ -570,7 +573,7 @@ def main(): else: optimizer_class = torch.optim.AdamW - checkpoint_output_dir = output_dir.joinpath("checkpoints") + checkpoint_output_dir = output_dir/"checkpoints" trainer = partial( train, @@ -611,11 +614,11 @@ def main(): return if len(placeholder_tokens) == 1: - sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") - metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") + sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" + metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" else: - sample_output_dir = output_dir.joinpath("samples") - metrics_output_file = output_dir.joinpath(f"lr.png") + sample_output_dir = output_dir/"samples" + metrics_output_file = output_dir/f"lr.png" placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, diff --git a/training/functional.py b/training/functional.py index ccbb4ad..83e70e2 100644 --- a/training/functional.py +++ b/training/functional.py @@ -129,7 +129,7 @@ def save_samples( for pool, data, gen in datasets: all_samples = [] - file_path = output_dir.joinpath(pool, f"step_{step}.jpg") + file_path = output_dir / pool / f"step_{step}.jpg" file_path.parent.mkdir(parents=True, exist_ok=True) batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index bc10e58..4dd1100 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -91,7 +91,7 @@ def lora_strategy_callbacks( print(f"Saving checkpoint for step {step}...") unet_ = accelerator.unwrap_model(unet) - unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) + unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") del unet_ @torch.no_grad() diff --git a/training/strategy/ti.py b/training/strategy/ti.py index da2b81c..0de3cb0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -138,7 +138,7 @@ def textual_inversion_strategy_callbacks( for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, - checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") + checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" ) @torch.no_grad() diff --git a/util.py b/util.py index 545bcb5..2712525 100644 --- a/util.py +++ b/util.py @@ -14,7 +14,7 @@ def load_config(filename): args = config["args"] if "base" in config: - args = load_config(Path(filename).parent.joinpath(config["base"])) | args + args = load_config(Path(filename).parent / config["base"]) | args return args -- cgit v1.2.3-70-g09d2