summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py31
-rw-r--r--models/clip/util.py23
-rw-r--r--train_lora.py3
-rw-r--r--training/attention_processor.py47
-rw-r--r--training/functional.py4
5 files changed, 43 insertions, 65 deletions
diff --git a/data/csv.py b/data/csv.py
index 81e8b6b..14380e8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -100,7 +100,14 @@ def generate_buckets(
100 return buckets, bucket_items, bucket_assignments 100 return buckets, bucket_items, bucket_assignments
101 101
102 102
103def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool, with_prior_preservation: bool, examples): 103def collate_fn(
104 dtype: torch.dtype,
105 tokenizer: CLIPTokenizer,
106 max_token_id_length: Optional[int],
107 with_guidance: bool,
108 with_prior_preservation: bool,
109 examples
110):
104 prompt_ids = [example["prompt_ids"] for example in examples] 111 prompt_ids = [example["prompt_ids"] for example in examples]
105 nprompt_ids = [example["nprompt_ids"] for example in examples] 112 nprompt_ids = [example["nprompt_ids"] for example in examples]
106 113
@@ -115,10 +122,10 @@ def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_guidance: bool
115 pixel_values = torch.stack(pixel_values) 122 pixel_values = torch.stack(pixel_values)
116 pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) 123 pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format)
117 124
118 prompts = unify_input_ids(tokenizer, prompt_ids) 125 prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length)
119 nprompts = unify_input_ids(tokenizer, nprompt_ids) 126 nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length)
120 inputs = unify_input_ids(tokenizer, input_ids) 127 inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length)
121 negative_inputs = unify_input_ids(tokenizer, negative_input_ids) 128 negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length)
122 129
123 batch = { 130 batch = {
124 "prompt_ids": prompts.input_ids, 131 "prompt_ids": prompts.input_ids,
@@ -176,6 +183,7 @@ class VlpnDataModule():
176 batch_size: int, 183 batch_size: int,
177 data_file: str, 184 data_file: str,
178 tokenizer: CLIPTokenizer, 185 tokenizer: CLIPTokenizer,
186 constant_prompt_length: bool = False,
179 class_subdir: str = "cls", 187 class_subdir: str = "cls",
180 with_guidance: bool = False, 188 with_guidance: bool = False,
181 num_class_images: int = 1, 189 num_class_images: int = 1,
@@ -212,6 +220,9 @@ class VlpnDataModule():
212 self.num_class_images = num_class_images 220 self.num_class_images = num_class_images
213 self.with_guidance = with_guidance 221 self.with_guidance = with_guidance
214 222
223 self.constant_prompt_length = constant_prompt_length
224 self.max_token_id_length = None
225
215 self.tokenizer = tokenizer 226 self.tokenizer = tokenizer
216 self.size = size 227 self.size = size
217 self.num_buckets = num_buckets 228 self.num_buckets = num_buckets
@@ -301,14 +312,20 @@ class VlpnDataModule():
301 items = self.prepare_items(template, expansions, items) 312 items = self.prepare_items(template, expansions, items)
302 items = self.filter_items(items) 313 items = self.filter_items(items)
303 self.npgenerator.shuffle(items) 314 self.npgenerator.shuffle(items)
315
316 if self.constant_prompt_length:
317 all_input_ids = unify_input_ids(
318 self.tokenizer,
319 [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items]
320 ).input_ids
321 self.max_token_id_length = all_input_ids.shape[1]
304 322
305 num_images = len(items) 323 num_images = len(items)
306
307 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 324 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10
308 train_set_size = max(num_images - valid_set_size, 1) 325 train_set_size = max(num_images - valid_set_size, 1)
309 valid_set_size = num_images - train_set_size 326 valid_set_size = num_images - train_set_size
310 327
311 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.with_guidance, self.num_class_images != 0) 328 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0)
312 329
313 if valid_set_size == 0: 330 if valid_set_size == 0:
314 data_train, data_val = items, items 331 data_train, data_val = items, items
diff --git a/models/clip/util.py b/models/clip/util.py
index 883de6a..f94fbc7 100644
--- a/models/clip/util.py
+++ b/models/clip/util.py
@@ -5,14 +5,21 @@ import torch
5from transformers import CLIPTokenizer, CLIPTextModel 5from transformers import CLIPTokenizer, CLIPTextModel
6 6
7 7
8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]]): 8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None):
9 return tokenizer.pad( 9 if max_length is None:
10 {"input_ids": input_ids}, 10 return tokenizer.pad(
11 padding=True, 11 {"input_ids": input_ids},
12 pad_to_multiple_of=tokenizer.model_max_length, 12 padding=True,
13 return_tensors="pt" 13 pad_to_multiple_of=tokenizer.model_max_length,
14 ) 14 return_tensors="pt"
15 15 )
16 else:
17 return tokenizer.pad(
18 {"input_ids": input_ids},
19 padding="max_length",
20 max_length=max_length,
21 return_tensors="pt"
22 )
16 23
17def get_extended_embeddings( 24def get_extended_embeddings(
18 text_encoder: CLIPTextModel, 25 text_encoder: CLIPTextModel,
diff --git a/train_lora.py b/train_lora.py
index a58bef7..12d7e72 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -49,7 +49,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
49torch.backends.cudnn.benchmark = True 49torch.backends.cudnn.benchmark = True
50 50
51torch._dynamo.config.log_level = logging.WARNING 51torch._dynamo.config.log_level = logging.WARNING
52torch._dynamo.config.suppress_errors = True 52# torch._dynamo.config.suppress_errors = True
53 53
54hidet.torch.dynamo_config.use_tensor_core(True) 54hidet.torch.dynamo_config.use_tensor_core(True)
55hidet.torch.dynamo_config.search_space(0) 55hidet.torch.dynamo_config.search_space(0)
@@ -992,6 +992,7 @@ def main():
992 VlpnDataModule, 992 VlpnDataModule,
993 data_file=args.train_data_file, 993 data_file=args.train_data_file,
994 tokenizer=tokenizer, 994 tokenizer=tokenizer,
995 constant_prompt_length=args.compile_unet,
995 class_subdir=args.class_image_dir, 996 class_subdir=args.class_image_dir,
996 with_guidance=args.guidance_scale != 0, 997 with_guidance=args.guidance_scale != 0,
997 num_class_images=args.num_class_images, 998 num_class_images=args.num_class_images,
diff --git a/training/attention_processor.py b/training/attention_processor.py
deleted file mode 100644
index 4309bd4..0000000
--- a/training/attention_processor.py
+++ /dev/null
@@ -1,47 +0,0 @@
1from typing import Callable, Optional, Union
2
3import xformers
4import xformers.ops
5
6from diffusers.models.attention_processor import Attention
7
8
9class XFormersAttnProcessor:
10 def __init__(self, attention_op: Optional[Callable] = None):
11 self.attention_op = attention_op
12
13 def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
14 batch_size, sequence_length, _ = (
15 hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
16 )
17
18 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
19
20 query = attn.to_q(hidden_states)
21
22 if encoder_hidden_states is None:
23 encoder_hidden_states = hidden_states
24 elif attn.norm_cross:
25 encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
26
27 key = attn.to_k(encoder_hidden_states)
28 value = attn.to_v(encoder_hidden_states)
29
30 query = attn.head_to_batch_dim(query).contiguous()
31 key = attn.head_to_batch_dim(key).contiguous()
32 value = attn.head_to_batch_dim(value).contiguous()
33
34 query = query.to(key.dtype)
35 value = value.to(key.dtype)
36
37 hidden_states = xformers.ops.memory_efficient_attention(
38 query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
39 )
40 hidden_states = hidden_states.to(query.dtype)
41 hidden_states = attn.batch_to_head_dim(hidden_states)
42
43 # linear proj
44 hidden_states = attn.to_out[0](hidden_states)
45 # dropout
46 hidden_states = attn.to_out[1](hidden_states)
47 return hidden_states
diff --git a/training/functional.py b/training/functional.py
index fd3f9f4..10560e5 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -710,8 +710,8 @@ def train(
710 vae = torch.compile(vae, backend='hidet') 710 vae = torch.compile(vae, backend='hidet')
711 711
712 if compile_unet: 712 if compile_unet:
713 unet = torch.compile(unet, backend='hidet') 713 # unet = torch.compile(unet, backend='hidet')
714 # unet = torch.compile(unet, mode="reduce-overhead") 714 unet = torch.compile(unet, mode="reduce-overhead")
715 715
716 callbacks = strategy.callbacks( 716 callbacks = strategy.callbacks(
717 accelerator=accelerator, 717 accelerator=accelerator,