From 27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 16:26:22 +0200 Subject: Fixes --- infer.py | 14 ++++---- .../stable_diffusion/vlpn_stable_diffusion.py | 41 +++++++++++----------- train_dreambooth.py | 30 ++++------------ train_lora.py | 8 ++++- train_ti.py | 18 ++++++---- training/functional.py | 7 ++-- training/strategy/dreambooth.py | 3 -- util/files.py | 16 ++++++--- 8 files changed, 69 insertions(+), 68 deletions(-) diff --git a/infer.py b/infer.py index 3b3b595..0a219a5 100644 --- a/infer.py +++ b/infer.py @@ -46,7 +46,7 @@ default_args = { "model": "stabilityai/stable-diffusion-2-1", "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", - "lora_embedding": None, + "lora_embeddings_dir": None, "output_dir": "output/inference", "config": None, } @@ -99,7 +99,7 @@ def create_args_parser(): type=str, ) parser.add_argument( - "--lora_embedding", + "--lora_embeddings_dir", type=str, ) parser.add_argument( @@ -341,7 +341,7 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) def create_pipeline(model, dtype): - print("Loading Stable Diffusion pipeline...") + print(f"Loading Stable Diffusion pipeline: {model}...") tokenizer = MultiCLIPTokenizer.from_pretrained( model, subfolder="tokenizer", torch_dtype=dtype @@ -435,11 +435,11 @@ def generate(output_dir: Path, pipeline, args): negative_prompt=args.negative_prompt, height=args.height, width=args.width, + generator=generator, + guidance_scale=args.guidance_scale, num_images_per_prompt=args.batch_size, num_inference_steps=args.steps, - guidance_scale=args.guidance_scale, sag_scale=args.sag_scale, - generator=generator, image=init_image, strength=args.image_noise, ).images @@ -527,8 +527,8 @@ def main(): pipeline = create_pipeline(args.model, dtype) - load_embeddings_dir(pipeline, args.ti_embeddings_dir) - load_lora(pipeline, args.lora_embedding) + # load_embeddings_dir(pipeline, args.ti_embeddings_dir) + # load_lora(pipeline, args.lora_embeddings_dir) # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) cmd_parser = create_cmd_parser() diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 98703d5..204276e 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -9,6 +9,7 @@ import torch.nn.functional as F import PIL from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import is_accelerate_available from diffusers import ( AutoencoderKL, @@ -161,6 +162,7 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -443,14 +445,6 @@ class VlpnStableDiffusion(DiffusionPipeline): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - @torch.no_grad() def __call__( self, @@ -544,6 +538,8 @@ class VlpnStableDiffusion(DiffusionPipeline): do_classifier_free_guidance = guidance_scale > 1.0 do_self_attention_guidance = sag_scale > 0.0 prep_from_image = isinstance(image, PIL.Image.Image) + if not prep_from_image: + strength = 1 # 3. Encode input prompt prompt_embeds = self.encode_prompt( @@ -577,7 +573,7 @@ class VlpnStableDiffusion(DiffusionPipeline): ) else: latents = self.prepare_latents( - batch_size, + batch_size * num_images_per_prompt, num_channels_latents, height, width, @@ -623,9 +619,12 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) - noise_pred = rescale_noise_cfg( - noise_pred, noise_pred_text, guidance_rescale=guidance_rescale - ) + if guidance_rescale > 0.0: + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=guidance_rescale, + ) if do_self_attention_guidance: # classifier-free guidance produces two chunks of attention map @@ -690,17 +689,17 @@ class VlpnStableDiffusion(DiffusionPipeline): has_nsfw_concept = None - if output_type == "latent": + if not output_type == "latent": + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + else: image = latents - elif output_type == "pil": - # 9. Post-processing - image = self.decode_latents(latents) - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 9. Post-processing - image = self.decode_latents(latents) + do_denormalize = [True] * image.shape[0] + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/train_dreambooth.py b/train_dreambooth.py index c8f03ea..be4da1a 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -17,7 +17,6 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed # from diffusers.models.attention_processor import AttnProcessor -from diffusers.utils.import_utils import is_xformers_available import transformers import numpy as np @@ -48,25 +47,6 @@ hidet.torch.dynamo_config.use_tensor_core(True) hidet.torch.dynamo_config.search_space(0) -def patch_xformers(dtype): - if is_xformers_available(): - import xformers - import xformers.ops - - orig_xformers_memory_efficient_attention = ( - xformers.ops.memory_efficient_attention - ) - - def xformers_memory_efficient_attention( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs - ): - return orig_xformers_memory_efficient_attention( - query.to(dtype), key.to(dtype), value.to(dtype), **kwargs - ) - - xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention - - def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -223,6 +203,12 @@ def parse_args(): nargs="*", help="A collection to filter the dataset.", ) + parser.add_argument( + "--validation_prompts", + type=str, + nargs="*", + help="Prompts for additional validation images", + ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) @@ -476,7 +462,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=10, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -622,8 +608,6 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - patch_xformers(weight_dtype) - logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: diff --git a/train_lora.py b/train_lora.py index fbec009..2a43252 100644 --- a/train_lora.py +++ b/train_lora.py @@ -235,6 +235,12 @@ def parse_args(): nargs="*", help="A collection to filter the dataset.", ) + parser.add_argument( + "--validation_prompts", + type=str, + nargs="*", + help="Prompts for additional validation images", + ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) @@ -545,7 +551,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=10, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( diff --git a/train_ti.py b/train_ti.py index 8c63493..89f4113 100644 --- a/train_ti.py +++ b/train_ti.py @@ -159,6 +159,12 @@ def parse_args(): nargs="*", help="A collection to filter the dataset.", ) + parser.add_argument( + "--validation_prompts", + type=str, + nargs="*", + help="Prompts for additional validation images", + ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) @@ -456,7 +462,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=10, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -852,11 +858,6 @@ def main(): sample_image_size=args.sample_image_size, ) - optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_embedding.parameters(), - lr=args.learning_rate, - ) - data_generator = torch.Generator(device="cpu").manual_seed(args.seed) data_npgenerator = np.random.default_rng(args.seed) @@ -957,6 +958,11 @@ def main(): avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() + optimizer = create_optimizer( + text_encoder.text_model.embeddings.token_embedding.parameters(), + lr=args.learning_rate, + ) + while True: if len(auto_cycles) != 0: response = auto_cycles.pop(0) diff --git a/training/functional.py b/training/functional.py index 43b03ac..546aaff 100644 --- a/training/functional.py +++ b/training/functional.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union, Optional, Protocol +from types import MethodType from functools import partial from pathlib import Path import itertools @@ -108,6 +109,7 @@ def save_samples( output_dir: Path, seed: int, step: int, + validation_prompts: list[str] = [], cycle: int = 1, batch_size: int = 1, num_batches: int = 1, @@ -136,7 +138,6 @@ def save_samples( if val_dataloader is not None: datasets.append(("stable", val_dataloader, generator)) - datasets.append(("val", val_dataloader, None)) for pool, data, gen in datasets: all_samples = [] @@ -165,7 +166,6 @@ def save_samples( guidance_scale=guidance_scale, sag_scale=0, num_inference_steps=num_steps, - output_type=None, ).images all_samples.append(torch.from_numpy(samples)) @@ -803,4 +803,7 @@ def train( accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) accelerator.unwrap_model(unet, keep_fp32_wrapper=False) + text_encoder.forward = MethodType(text_encoder.forward, text_encoder) + unet.forward = MethodType(unet.forward, unet) + accelerator.free_memory() diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0f64747..bd853e2 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -155,9 +155,6 @@ def dreambooth_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) - unet_.forward = MethodType(unet_.forward, unet_) - text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) - with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, diff --git a/util/files.py b/util/files.py index 2712525..73ff802 100644 --- a/util/files.py +++ b/util/files.py @@ -8,7 +8,7 @@ from safetensors import safe_open def load_config(filename): - with open(filename, 'rt') as f: + with open(filename, "rt") as f: config = json.load(f) args = config["args"] @@ -19,11 +19,17 @@ def load_config(filename): return args -def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): +def load_embeddings_from_dir( + tokenizer: MultiCLIPTokenizer, + embeddings: ManagedCLIPTextEmbeddings, + embeddings_dir: Path, +): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): - return [] + return [], [] - filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] + filenames = [ + filename for filename in embeddings_dir.iterdir() if filename.is_file() + ] tokens = [filename.stem for filename in filenames] new_ids: list[list[int]] = [] @@ -39,7 +45,7 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC embeddings.resize(len(tokenizer)) - for (new_id, embeds) in zip(new_ids, new_embeds): + for new_id, embeds in zip(new_ids, new_embeds): embeddings.add_embed(new_id, embeds) return tokens, new_ids -- cgit v1.2.3-54-g00ecf