summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py14
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py41
-rw-r--r--train_dreambooth.py30
-rw-r--r--train_lora.py8
-rw-r--r--train_ti.py18
-rw-r--r--training/functional.py7
-rw-r--r--training/strategy/dreambooth.py3
-rw-r--r--util/files.py16
8 files changed, 69 insertions, 68 deletions
diff --git a/infer.py b/infer.py
index 3b3b595..0a219a5 100644
--- a/infer.py
+++ b/infer.py
@@ -46,7 +46,7 @@ default_args = {
46 "model": "stabilityai/stable-diffusion-2-1", 46 "model": "stabilityai/stable-diffusion-2-1",
47 "precision": "fp32", 47 "precision": "fp32",
48 "ti_embeddings_dir": "embeddings_ti", 48 "ti_embeddings_dir": "embeddings_ti",
49 "lora_embedding": None, 49 "lora_embeddings_dir": None,
50 "output_dir": "output/inference", 50 "output_dir": "output/inference",
51 "config": None, 51 "config": None,
52} 52}
@@ -99,7 +99,7 @@ def create_args_parser():
99 type=str, 99 type=str,
100 ) 100 )
101 parser.add_argument( 101 parser.add_argument(
102 "--lora_embedding", 102 "--lora_embeddings_dir",
103 type=str, 103 type=str,
104 ) 104 )
105 parser.add_argument( 105 parser.add_argument(
@@ -341,7 +341,7 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None)
341 341
342 342
343def create_pipeline(model, dtype): 343def create_pipeline(model, dtype):
344 print("Loading Stable Diffusion pipeline...") 344 print(f"Loading Stable Diffusion pipeline: {model}...")
345 345
346 tokenizer = MultiCLIPTokenizer.from_pretrained( 346 tokenizer = MultiCLIPTokenizer.from_pretrained(
347 model, subfolder="tokenizer", torch_dtype=dtype 347 model, subfolder="tokenizer", torch_dtype=dtype
@@ -435,11 +435,11 @@ def generate(output_dir: Path, pipeline, args):
435 negative_prompt=args.negative_prompt, 435 negative_prompt=args.negative_prompt,
436 height=args.height, 436 height=args.height,
437 width=args.width, 437 width=args.width,
438 generator=generator,
439 guidance_scale=args.guidance_scale,
438 num_images_per_prompt=args.batch_size, 440 num_images_per_prompt=args.batch_size,
439 num_inference_steps=args.steps, 441 num_inference_steps=args.steps,
440 guidance_scale=args.guidance_scale,
441 sag_scale=args.sag_scale, 442 sag_scale=args.sag_scale,
442 generator=generator,
443 image=init_image, 443 image=init_image,
444 strength=args.image_noise, 444 strength=args.image_noise,
445 ).images 445 ).images
@@ -527,8 +527,8 @@ def main():
527 527
528 pipeline = create_pipeline(args.model, dtype) 528 pipeline = create_pipeline(args.model, dtype)
529 529
530 load_embeddings_dir(pipeline, args.ti_embeddings_dir) 530 # load_embeddings_dir(pipeline, args.ti_embeddings_dir)
531 load_lora(pipeline, args.lora_embedding) 531 # load_lora(pipeline, args.lora_embeddings_dir)
532 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 532 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
533 533
534 cmd_parser = create_cmd_parser() 534 cmd_parser = create_cmd_parser()
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 98703d5..204276e 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -9,6 +9,7 @@ import torch.nn.functional as F
9import PIL 9import PIL
10 10
11from diffusers.configuration_utils import FrozenDict 11from diffusers.configuration_utils import FrozenDict
12from diffusers.image_processor import VaeImageProcessor
12from diffusers.utils import is_accelerate_available 13from diffusers.utils import is_accelerate_available
13from diffusers import ( 14from diffusers import (
14 AutoencoderKL, 15 AutoencoderKL,
@@ -161,6 +162,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
161 scheduler=scheduler, 162 scheduler=scheduler,
162 ) 163 )
163 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 164 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
165 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
164 166
165 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 167 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
166 r""" 168 r"""
@@ -443,14 +445,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
443 extra_step_kwargs["generator"] = generator 445 extra_step_kwargs["generator"] = generator
444 return extra_step_kwargs 446 return extra_step_kwargs
445 447
446 def decode_latents(self, latents):
447 latents = 1 / self.vae.config.scaling_factor * latents
448 image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
449 image = (image / 2 + 0.5).clamp(0, 1)
450 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
451 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
452 return image
453
454 @torch.no_grad() 448 @torch.no_grad()
455 def __call__( 449 def __call__(
456 self, 450 self,
@@ -544,6 +538,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
544 do_classifier_free_guidance = guidance_scale > 1.0 538 do_classifier_free_guidance = guidance_scale > 1.0
545 do_self_attention_guidance = sag_scale > 0.0 539 do_self_attention_guidance = sag_scale > 0.0
546 prep_from_image = isinstance(image, PIL.Image.Image) 540 prep_from_image = isinstance(image, PIL.Image.Image)
541 if not prep_from_image:
542 strength = 1
547 543
548 # 3. Encode input prompt 544 # 3. Encode input prompt
549 prompt_embeds = self.encode_prompt( 545 prompt_embeds = self.encode_prompt(
@@ -577,7 +573,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
577 ) 573 )
578 else: 574 else:
579 latents = self.prepare_latents( 575 latents = self.prepare_latents(
580 batch_size, 576 batch_size * num_images_per_prompt,
581 num_channels_latents, 577 num_channels_latents,
582 height, 578 height,
583 width, 579 width,
@@ -623,9 +619,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
623 noise_pred = noise_pred_uncond + guidance_scale * ( 619 noise_pred = noise_pred_uncond + guidance_scale * (
624 noise_pred_text - noise_pred_uncond 620 noise_pred_text - noise_pred_uncond
625 ) 621 )
626 noise_pred = rescale_noise_cfg( 622 if guidance_rescale > 0.0:
627 noise_pred, noise_pred_text, guidance_rescale=guidance_rescale 623 noise_pred = rescale_noise_cfg(
628 ) 624 noise_pred,
625 noise_pred_text,
626 guidance_rescale=guidance_rescale,
627 )
629 628
630 if do_self_attention_guidance: 629 if do_self_attention_guidance:
631 # classifier-free guidance produces two chunks of attention map 630 # classifier-free guidance produces two chunks of attention map
@@ -690,17 +689,17 @@ class VlpnStableDiffusion(DiffusionPipeline):
690 689
691 has_nsfw_concept = None 690 has_nsfw_concept = None
692 691
693 if output_type == "latent": 692 if not output_type == "latent":
693 image = self.vae.decode(
694 latents / self.vae.config.scaling_factor, return_dict=False
695 )[0]
696 else:
694 image = latents 697 image = latents
695 elif output_type == "pil":
696 # 9. Post-processing
697 image = self.decode_latents(latents)
698 698
699 # 10. Convert to PIL 699 do_denormalize = [True] * image.shape[0]
700 image = self.numpy_to_pil(image) 700 image = self.image_processor.postprocess(
701 else: 701 image, output_type=output_type, do_denormalize=do_denormalize
702 # 9. Post-processing 702 )
703 image = self.decode_latents(latents)
704 703
705 # Offload last model to CPU 704 # Offload last model to CPU
706 if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 705 if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/train_dreambooth.py b/train_dreambooth.py
index c8f03ea..be4da1a 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -17,7 +17,6 @@ from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18 18
19# from diffusers.models.attention_processor import AttnProcessor 19# from diffusers.models.attention_processor import AttnProcessor
20from diffusers.utils.import_utils import is_xformers_available
21import transformers 20import transformers
22 21
23import numpy as np 22import numpy as np
@@ -48,25 +47,6 @@ hidet.torch.dynamo_config.use_tensor_core(True)
48hidet.torch.dynamo_config.search_space(0) 47hidet.torch.dynamo_config.search_space(0)
49 48
50 49
51def patch_xformers(dtype):
52 if is_xformers_available():
53 import xformers
54 import xformers.ops
55
56 orig_xformers_memory_efficient_attention = (
57 xformers.ops.memory_efficient_attention
58 )
59
60 def xformers_memory_efficient_attention(
61 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs
62 ):
63 return orig_xformers_memory_efficient_attention(
64 query.to(dtype), key.to(dtype), value.to(dtype), **kwargs
65 )
66
67 xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention
68
69
70def parse_args(): 50def parse_args():
71 parser = argparse.ArgumentParser(description="Simple example of a training script.") 51 parser = argparse.ArgumentParser(description="Simple example of a training script.")
72 parser.add_argument( 52 parser.add_argument(
@@ -224,6 +204,12 @@ def parse_args():
224 help="A collection to filter the dataset.", 204 help="A collection to filter the dataset.",
225 ) 205 )
226 parser.add_argument( 206 parser.add_argument(
207 "--validation_prompts",
208 type=str,
209 nargs="*",
210 help="Prompts for additional validation images",
211 )
212 parser.add_argument(
227 "--seed", type=int, default=None, help="A seed for reproducible training." 213 "--seed", type=int, default=None, help="A seed for reproducible training."
228 ) 214 )
229 parser.add_argument( 215 parser.add_argument(
@@ -476,7 +462,7 @@ def parse_args():
476 parser.add_argument( 462 parser.add_argument(
477 "--sample_steps", 463 "--sample_steps",
478 type=int, 464 type=int,
479 default=10, 465 default=15,
480 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 466 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
481 ) 467 )
482 parser.add_argument( 468 parser.add_argument(
@@ -622,8 +608,6 @@ def main():
622 elif args.mixed_precision == "bf16": 608 elif args.mixed_precision == "bf16":
623 weight_dtype = torch.bfloat16 609 weight_dtype = torch.bfloat16
624 610
625 patch_xformers(weight_dtype)
626
627 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) 611 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
628 612
629 if args.seed is None: 613 if args.seed is None:
diff --git a/train_lora.py b/train_lora.py
index fbec009..2a43252 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -236,6 +236,12 @@ def parse_args():
236 help="A collection to filter the dataset.", 236 help="A collection to filter the dataset.",
237 ) 237 )
238 parser.add_argument( 238 parser.add_argument(
239 "--validation_prompts",
240 type=str,
241 nargs="*",
242 help="Prompts for additional validation images",
243 )
244 parser.add_argument(
239 "--seed", type=int, default=None, help="A seed for reproducible training." 245 "--seed", type=int, default=None, help="A seed for reproducible training."
240 ) 246 )
241 parser.add_argument( 247 parser.add_argument(
@@ -545,7 +551,7 @@ def parse_args():
545 parser.add_argument( 551 parser.add_argument(
546 "--sample_steps", 552 "--sample_steps",
547 type=int, 553 type=int,
548 default=10, 554 default=15,
549 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 555 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
550 ) 556 )
551 parser.add_argument( 557 parser.add_argument(
diff --git a/train_ti.py b/train_ti.py
index 8c63493..89f4113 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -160,6 +160,12 @@ def parse_args():
160 help="A collection to filter the dataset.", 160 help="A collection to filter the dataset.",
161 ) 161 )
162 parser.add_argument( 162 parser.add_argument(
163 "--validation_prompts",
164 type=str,
165 nargs="*",
166 help="Prompts for additional validation images",
167 )
168 parser.add_argument(
163 "--seed", type=int, default=None, help="A seed for reproducible training." 169 "--seed", type=int, default=None, help="A seed for reproducible training."
164 ) 170 )
165 parser.add_argument( 171 parser.add_argument(
@@ -456,7 +462,7 @@ def parse_args():
456 parser.add_argument( 462 parser.add_argument(
457 "--sample_steps", 463 "--sample_steps",
458 type=int, 464 type=int,
459 default=10, 465 default=15,
460 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 466 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
461 ) 467 )
462 parser.add_argument( 468 parser.add_argument(
@@ -852,11 +858,6 @@ def main():
852 sample_image_size=args.sample_image_size, 858 sample_image_size=args.sample_image_size,
853 ) 859 )
854 860
855 optimizer = create_optimizer(
856 text_encoder.text_model.embeddings.token_embedding.parameters(),
857 lr=args.learning_rate,
858 )
859
860 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 861 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
861 data_npgenerator = np.random.default_rng(args.seed) 862 data_npgenerator = np.random.default_rng(args.seed)
862 863
@@ -957,6 +958,11 @@ def main():
957 avg_loss_val = AverageMeter() 958 avg_loss_val = AverageMeter()
958 avg_acc_val = AverageMeter() 959 avg_acc_val = AverageMeter()
959 960
961 optimizer = create_optimizer(
962 text_encoder.text_model.embeddings.token_embedding.parameters(),
963 lr=args.learning_rate,
964 )
965
960 while True: 966 while True:
961 if len(auto_cycles) != 0: 967 if len(auto_cycles) != 0:
962 response = auto_cycles.pop(0) 968 response = auto_cycles.pop(0)
diff --git a/training/functional.py b/training/functional.py
index 43b03ac..546aaff 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -2,6 +2,7 @@ from dataclasses import dataclass
2import math 2import math
3from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
4from typing import Callable, Any, Tuple, Union, Optional, Protocol 4from typing import Callable, Any, Tuple, Union, Optional, Protocol
5from types import MethodType
5from functools import partial 6from functools import partial
6from pathlib import Path 7from pathlib import Path
7import itertools 8import itertools
@@ -108,6 +109,7 @@ def save_samples(
108 output_dir: Path, 109 output_dir: Path,
109 seed: int, 110 seed: int,
110 step: int, 111 step: int,
112 validation_prompts: list[str] = [],
111 cycle: int = 1, 113 cycle: int = 1,
112 batch_size: int = 1, 114 batch_size: int = 1,
113 num_batches: int = 1, 115 num_batches: int = 1,
@@ -136,7 +138,6 @@ def save_samples(
136 138
137 if val_dataloader is not None: 139 if val_dataloader is not None:
138 datasets.append(("stable", val_dataloader, generator)) 140 datasets.append(("stable", val_dataloader, generator))
139 datasets.append(("val", val_dataloader, None))
140 141
141 for pool, data, gen in datasets: 142 for pool, data, gen in datasets:
142 all_samples = [] 143 all_samples = []
@@ -165,7 +166,6 @@ def save_samples(
165 guidance_scale=guidance_scale, 166 guidance_scale=guidance_scale,
166 sag_scale=0, 167 sag_scale=0,
167 num_inference_steps=num_steps, 168 num_inference_steps=num_steps,
168 output_type=None,
169 ).images 169 ).images
170 170
171 all_samples.append(torch.from_numpy(samples)) 171 all_samples.append(torch.from_numpy(samples))
@@ -803,4 +803,7 @@ def train(
803 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 803 accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
804 accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 804 accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
805 805
806 text_encoder.forward = MethodType(text_encoder.forward, text_encoder)
807 unet.forward = MethodType(unet.forward, unet)
808
806 accelerator.free_memory() 809 accelerator.free_memory()
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 0f64747..bd853e2 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -155,9 +155,6 @@ def dreambooth_strategy_callbacks(
155 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 155 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
156 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 156 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
157 157
158 unet_.forward = MethodType(unet_.forward, unet_)
159 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
160
161 with ema_context(): 158 with ema_context():
162 pipeline = VlpnStableDiffusion( 159 pipeline = VlpnStableDiffusion(
163 text_encoder=text_encoder_, 160 text_encoder=text_encoder_,
diff --git a/util/files.py b/util/files.py
index 2712525..73ff802 100644
--- a/util/files.py
+++ b/util/files.py
@@ -8,7 +8,7 @@ from safetensors import safe_open
8 8
9 9
10def load_config(filename): 10def load_config(filename):
11 with open(filename, 'rt') as f: 11 with open(filename, "rt") as f:
12 config = json.load(f) 12 config = json.load(f)
13 13
14 args = config["args"] 14 args = config["args"]
@@ -19,11 +19,17 @@ def load_config(filename):
19 return args 19 return args
20 20
21 21
22def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): 22def load_embeddings_from_dir(
23 tokenizer: MultiCLIPTokenizer,
24 embeddings: ManagedCLIPTextEmbeddings,
25 embeddings_dir: Path,
26):
23 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 27 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
24 return [] 28 return [], []
25 29
26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] 30 filenames = [
31 filename for filename in embeddings_dir.iterdir() if filename.is_file()
32 ]
27 tokens = [filename.stem for filename in filenames] 33 tokens = [filename.stem for filename in filenames]
28 34
29 new_ids: list[list[int]] = [] 35 new_ids: list[list[int]] = []
@@ -39,7 +45,7 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC
39 45
40 embeddings.resize(len(tokenizer)) 46 embeddings.resize(len(tokenizer))
41 47
42 for (new_id, embeds) in zip(new_ids, new_embeds): 48 for new_id, embeds in zip(new_ids, new_embeds):
43 embeddings.add_embed(new_id, embeds) 49 embeddings.add_embed(new_id, embeds)
44 50
45 return tokens, new_ids 51 return tokens, new_ids