summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
-rw-r--r--.vscode/settings.json6
-rw-r--r--data/csv.py183
-rw-r--r--data/keywords.py8
-rw-r--r--environment.yaml11
-rw-r--r--infer.py124
-rw-r--r--models/attention/control.py106
-rw-r--r--models/attention/hook.py5
-rw-r--r--models/attention/structured.py65
-rw-r--r--models/clip/embeddings.py29
-rw-r--r--models/clip/tokenizer.py23
-rw-r--r--models/clip/util.py17
-rw-r--r--models/convnext/discriminator.py11
-rw-r--r--models/sparse.py12
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py262
-rw-r--r--train_dreambooth.py770
-rw-r--r--train_lora.py489
-rw-r--r--train_ti.py379
-rw-r--r--training/functional.py221
-rw-r--r--training/lr.py4
-rw-r--r--training/optimization.py38
-rw-r--r--training/sampler.py2
-rw-r--r--training/strategy/dreambooth.py29
-rw-r--r--training/strategy/lora.py41
-rw-r--r--training/strategy/ti.py27
24 files changed, 1873 insertions, 989 deletions
diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 0000000..cee7b74
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,6 @@
1{
2 "[python]": {
3 "editor.defaultFormatter": "ms-python.black-formatter"
4 },
5 "python.formatting.provider": "none"
6}
diff --git a/data/csv.py b/data/csv.py
index 14380e8..d726033 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -49,7 +49,7 @@ def generate_buckets(
49 max_pixels: Optional[int] = None, 49 max_pixels: Optional[int] = None,
50 num_buckets: int = 4, 50 num_buckets: int = 4,
51 progressive_buckets: bool = False, 51 progressive_buckets: bool = False,
52 return_tensor: bool = True 52 return_tensor: bool = True,
53): 53):
54 if max_pixels is None: 54 if max_pixels is None:
55 max_pixels = (base_size + step_size) ** 2 55 max_pixels = (base_size + step_size) ** 2
@@ -62,7 +62,11 @@ def generate_buckets(
62 62
63 for i in range(1, num_buckets + 1): 63 for i in range(1, num_buckets + 1):
64 long_side = base_size + i * step_size 64 long_side = base_size + i * step_size
65 short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) 65 short_side = min(
66 base_size
67 - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size,
68 base_size,
69 )
66 buckets.append(long_side / short_side) 70 buckets.append(long_side / short_side)
67 buckets.append(short_side / long_side) 71 buckets.append(short_side / long_side)
68 72
@@ -106,7 +110,7 @@ def collate_fn(
106 max_token_id_length: Optional[int], 110 max_token_id_length: Optional[int],
107 with_guidance: bool, 111 with_guidance: bool,
108 with_prior_preservation: bool, 112 with_prior_preservation: bool,
109 examples 113 examples,
110): 114):
111 prompt_ids = [example["prompt_ids"] for example in examples] 115 prompt_ids = [example["prompt_ids"] for example in examples]
112 nprompt_ids = [example["nprompt_ids"] for example in examples] 116 nprompt_ids = [example["nprompt_ids"] for example in examples]
@@ -125,7 +129,9 @@ def collate_fn(
125 prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) 129 prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length)
126 nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) 130 nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length)
127 inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) 131 inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length)
128 negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) 132 negative_inputs = unify_input_ids(
133 tokenizer, negative_input_ids, max_token_id_length
134 )
129 135
130 batch = { 136 batch = {
131 "prompt_ids": prompts.input_ids, 137 "prompt_ids": prompts.input_ids,
@@ -149,35 +155,39 @@ class VlpnDataItem(NamedTuple):
149 nprompt: str 155 nprompt: str
150 collection: list[str] 156 collection: list[str]
151 157
152 def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): 158 def full_prompt(
153 return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) 159 self,
160 dropout: float = 0,
161 shuffle: bool = False,
162 npgenerator: Optional[np.random.Generator] = None,
163 ):
164 return keywords_to_str(
165 self.keywords, [self.prompt], dropout, shuffle, npgenerator
166 )
154 167
155 168
156def keyword_filter( 169def keyword_filter(
157 placeholder_tokens: Optional[list[str]], 170 placeholder_tokens: Optional[list[str]],
158 collections: Optional[list[str]], 171 collections: Optional[list[str]],
159 exclude_collections: Optional[list[str]], 172 exclude_collections: Optional[list[str]],
160 item: VlpnDataItem 173 item: VlpnDataItem,
161): 174):
162 full_prompt = item.full_prompt() 175 full_prompt = item.full_prompt()
163 176
164 cond1 = placeholder_tokens is None or any( 177 cond1 = placeholder_tokens is None or any(
165 token in full_prompt 178 token in full_prompt for token in placeholder_tokens
166 for token in placeholder_tokens
167 ) 179 )
168 cond2 = collections is None or any( 180 cond2 = collections is None or any(
169 collection in item.collection 181 collection in item.collection for collection in collections
170 for collection in collections
171 ) 182 )
172 cond3 = exclude_collections is None or not any( 183 cond3 = exclude_collections is None or not any(
173 collection in item.collection 184 collection in item.collection for collection in exclude_collections
174 for collection in exclude_collections
175 ) 185 )
176 186
177 return cond1 and cond2 and cond3 187 return cond1 and cond2 and cond3
178 188
179 189
180class VlpnDataModule(): 190class VlpnDataModule:
181 def __init__( 191 def __init__(
182 self, 192 self,
183 batch_size: int, 193 batch_size: int,
@@ -222,7 +232,7 @@ class VlpnDataModule():
222 232
223 self.constant_prompt_length = constant_prompt_length 233 self.constant_prompt_length = constant_prompt_length
224 self.max_token_id_length = None 234 self.max_token_id_length = None
225 235
226 self.tokenizer = tokenizer 236 self.tokenizer = tokenizer
227 self.size = size 237 self.size = size
228 self.num_buckets = num_buckets 238 self.num_buckets = num_buckets
@@ -259,23 +269,29 @@ class VlpnDataModule():
259 nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") 269 nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "")
260 collection = item["collection"].split(", ") if "collection" in item else [] 270 collection = item["collection"].split(", ") if "collection" in item else []
261 271
262 saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) 272 saturated_keywords = str_to_keywords(
273 tpl_keywords.format(**keywords), expansions
274 )
263 275
264 inverted_tokens = keywords_to_str([ 276 inverted_tokens = keywords_to_str(
265 f"inv_{token}" 277 [
266 for token in self.placeholder_tokens 278 f"inv_{token}"
267 if token in saturated_keywords 279 for token in self.placeholder_tokens
268 ]) 280 if token in saturated_keywords
281 ]
282 )
269 283
270 items.append(VlpnDataItem( 284 items.append(
271 self.data_root / image, 285 VlpnDataItem(
272 None, 286 self.data_root / image,
273 saturated_keywords, 287 None,
274 tpl_prompt.format(**prompt), 288 saturated_keywords,
275 tpl_cprompt.format(**prompt), 289 tpl_prompt.format(**prompt),
276 tpl_nprompt.format(_inv=inverted_tokens, **nprompt), 290 tpl_cprompt.format(**prompt),
277 collection 291 tpl_nprompt.format(_inv=inverted_tokens, **nprompt),
278 )) 292 collection,
293 )
294 )
279 295
280 return items 296 return items
281 297
@@ -285,13 +301,16 @@ class VlpnDataModule():
285 301
286 return [item for item in items if self.filter(item)] 302 return [item for item in items if self.filter(item)]
287 303
288 def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: 304 def pad_items(
305 self, items: list[VlpnDataItem], num_class_images: int = 1
306 ) -> list[VlpnDataItem]:
289 image_multiplier = max(num_class_images, 1) 307 image_multiplier = max(num_class_images, 1)
290 308
291 return [ 309 return [
292 VlpnDataItem( 310 VlpnDataItem(
293 item.instance_image_path, 311 item.instance_image_path,
294 self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", 312 self.class_root
313 / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}",
295 item.keywords, 314 item.keywords,
296 item.prompt, 315 item.prompt,
297 item.cprompt, 316 item.cprompt,
@@ -303,7 +322,7 @@ class VlpnDataModule():
303 ] 322 ]
304 323
305 def setup(self): 324 def setup(self):
306 with open(self.data_file, 'rt') as f: 325 with open(self.data_file, "rt") as f:
307 metadata = json.load(f) 326 metadata = json.load(f)
308 template = metadata[self.template_key] if self.template_key in metadata else {} 327 template = metadata[self.template_key] if self.template_key in metadata else {}
309 expansions = metadata["expansions"] if "expansions" in metadata else {} 328 expansions = metadata["expansions"] if "expansions" in metadata else {}
@@ -312,25 +331,41 @@ class VlpnDataModule():
312 items = self.prepare_items(template, expansions, items) 331 items = self.prepare_items(template, expansions, items)
313 items = self.filter_items(items) 332 items = self.filter_items(items)
314 self.npgenerator.shuffle(items) 333 self.npgenerator.shuffle(items)
315 334
316 if self.constant_prompt_length: 335 if self.constant_prompt_length:
317 all_input_ids = unify_input_ids( 336 all_input_ids = unify_input_ids(
318 self.tokenizer, 337 self.tokenizer,
319 [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] 338 [
339 self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids
340 for item in items
341 ],
320 ).input_ids 342 ).input_ids
321 self.max_token_id_length = all_input_ids.shape[1] 343 self.max_token_id_length = all_input_ids.shape[1]
322 344
323 num_images = len(items) 345 num_images = len(items)
324 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 346 valid_set_size = (
347 min(self.valid_set_size, num_images)
348 if self.valid_set_size is not None
349 else num_images // 10
350 )
325 train_set_size = max(num_images - valid_set_size, 1) 351 train_set_size = max(num_images - valid_set_size, 1)
326 valid_set_size = num_images - train_set_size 352 valid_set_size = num_images - train_set_size
327 353
328 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) 354 collate_fn_ = partial(
355 collate_fn,
356 self.dtype,
357 self.tokenizer,
358 self.max_token_id_length,
359 self.with_guidance,
360 self.num_class_images != 0,
361 )
329 362
330 if valid_set_size == 0: 363 if valid_set_size == 0:
331 data_train, data_val = items, items 364 data_train, data_val = items, items
332 else: 365 else:
333 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) 366 data_train, data_val = random_split(
367 items, [train_set_size, valid_set_size], generator=self.generator
368 )
334 369
335 data_train = self.pad_items(data_train, self.num_class_images) 370 data_train = self.pad_items(data_train, self.num_class_images)
336 371
@@ -338,17 +373,25 @@ class VlpnDataModule():
338 data_train *= math.ceil(self.train_set_pad / len(data_train)) 373 data_train *= math.ceil(self.train_set_pad / len(data_train))
339 374
340 self.train_dataset = VlpnDataset( 375 self.train_dataset = VlpnDataset(
341 data_train, self.tokenizer, 376 data_train,
342 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 377 self.tokenizer,
343 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 378 num_buckets=self.num_buckets,
344 batch_size=self.batch_size, fill_batch=True, generator=self.generator, 379 progressive_buckets=self.progressive_buckets,
345 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, 380 bucket_step_size=self.bucket_step_size,
346 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, 381 bucket_max_pixels=self.bucket_max_pixels,
382 batch_size=self.batch_size,
383 fill_batch=True,
384 generator=self.generator,
385 size=self.size,
386 interpolation=self.interpolation,
387 color_jitter=self.color_jitter,
388 num_class_images=self.num_class_images,
389 dropout=self.dropout,
390 shuffle=self.shuffle,
347 ) 391 )
348 392
349 self.train_dataloader = DataLoader( 393 self.train_dataloader = DataLoader(
350 self.train_dataset, 394 self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_
351 batch_size=None, pin_memory=True, collate_fn=collate_fn_
352 ) 395 )
353 396
354 if len(data_val) != 0: 397 if len(data_val) != 0:
@@ -358,16 +401,24 @@ class VlpnDataModule():
358 data_val *= math.ceil(self.valid_set_pad / len(data_val)) 401 data_val *= math.ceil(self.valid_set_pad / len(data_val))
359 402
360 self.val_dataset = VlpnDataset( 403 self.val_dataset = VlpnDataset(
361 data_val, self.tokenizer, 404 data_val,
362 num_buckets=self.num_buckets, progressive_buckets=True, 405 self.tokenizer,
363 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 406 num_buckets=self.num_buckets,
364 batch_size=self.batch_size, generator=self.generator, 407 progressive_buckets=True,
365 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, 408 bucket_step_size=self.bucket_step_size,
409 bucket_max_pixels=self.bucket_max_pixels,
410 batch_size=self.batch_size,
411 generator=self.generator,
412 size=self.size,
413 interpolation=self.interpolation,
414 color_jitter=self.color_jitter,
366 ) 415 )
367 416
368 self.val_dataloader = DataLoader( 417 self.val_dataloader = DataLoader(
369 self.val_dataset, 418 self.val_dataset,
370 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 419 batch_size=None,
420 pin_memory=True,
421 collate_fn=collate_fn_,
371 ) 422 )
372 else: 423 else:
373 self.val_dataloader = None 424 self.val_dataloader = None
@@ -418,7 +469,13 @@ class VlpnDataset(IterableDataset):
418 469
419 self.bucket_item_range = torch.arange(len(self.bucket_items)) 470 self.bucket_item_range = torch.arange(len(self.bucket_items))
420 471
421 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() 472 self.length_ = (
473 (self.bucket_assignments.bincount() / self.batch_size)
474 .ceil()
475 .long()
476 .sum()
477 .item()
478 )
422 479
423 def get_input_ids(self, text: str): 480 def get_input_ids(self, text: str):
424 return self.tokenizer(text, padding="do_not_pad").input_ids 481 return self.tokenizer(text, padding="do_not_pad").input_ids
@@ -430,7 +487,9 @@ class VlpnDataset(IterableDataset):
430 worker_info = torch.utils.data.get_worker_info() 487 worker_info = torch.utils.data.get_worker_info()
431 488
432 if self.shuffle: 489 if self.shuffle:
433 perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) 490 perm = torch.randperm(
491 len(self.bucket_assignments), generator=self.generator
492 )
434 self.bucket_items = self.bucket_items[perm] 493 self.bucket_items = self.bucket_items[perm]
435 self.bucket_assignments = self.bucket_assignments[perm] 494 self.bucket_assignments = self.bucket_assignments[perm]
436 495
@@ -488,7 +547,9 @@ class VlpnDataset(IterableDataset):
488 547
489 if len(bucket_items) == 0: 548 if len(bucket_items) == 0:
490 bucket_items = self.bucket_items[self.bucket_assignments == bucket] 549 bucket_items = self.bucket_items[self.bucket_assignments == bucket]
491 item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] 550 item_index = bucket_items[
551 torch.randint(len(bucket_items), (1,), generator=self.generator)
552 ]
492 else: 553 else:
493 item_index = bucket_items[0] 554 item_index = bucket_items[0]
494 mask[self.bucket_item_range[bucket_mask][0]] = False 555 mask[self.bucket_item_range[bucket_mask][0]] = False
@@ -500,12 +561,18 @@ class VlpnDataset(IterableDataset):
500 example["prompt_ids"] = self.get_input_ids(item.full_prompt()) 561 example["prompt_ids"] = self.get_input_ids(item.full_prompt())
501 example["nprompt_ids"] = self.get_input_ids(item.nprompt) 562 example["nprompt_ids"] = self.get_input_ids(item.nprompt)
502 563
503 example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) 564 example["instance_prompt_ids"] = self.get_input_ids(
565 item.full_prompt(self.dropout, True, self.npgenerator)
566 )
504 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) 567 example["negative_prompt_ids"] = self.get_input_ids(item.nprompt)
505 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 568 example["instance_images"] = image_transforms(
569 get_image(item.instance_image_path)
570 )
506 571
507 if self.num_class_images != 0: 572 if self.num_class_images != 0:
508 example["class_prompt_ids"] = self.get_input_ids(item.cprompt) 573 example["class_prompt_ids"] = self.get_input_ids(item.cprompt)
509 example["class_images"] = image_transforms(get_image(item.class_image_path)) 574 example["class_images"] = image_transforms(
575 get_image(item.class_image_path)
576 )
510 577
511 batch.append(example) 578 batch.append(example)
diff --git a/data/keywords.py b/data/keywords.py
index 8632d67..83fe9ff 100644
--- a/data/keywords.py
+++ b/data/keywords.py
@@ -8,7 +8,7 @@ def keywords_to_str(
8 undroppable_keywords: list[str] = [], 8 undroppable_keywords: list[str] = [],
9 dropout: float = 0, 9 dropout: float = 0,
10 shuffle: bool = False, 10 shuffle: bool = False,
11 npgenerator: Optional[np.random.Generator] = None 11 npgenerator: Optional[np.random.Generator] = None,
12) -> str: 12) -> str:
13 if dropout != 0: 13 if dropout != 0:
14 keywords = [keyword for keyword in keywords if np.random.random() > dropout] 14 keywords = [keyword for keyword in keywords if np.random.random() > dropout]
@@ -23,7 +23,11 @@ def keywords_to_str(
23 23
24def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: 24def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]:
25 def expand_keyword(keyword: str) -> list[str]: 25 def expand_keyword(keyword: str) -> list[str]:
26 return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] 26 return (
27 [keyword] + expansions[keyword].split(", ")
28 if keyword in expansions
29 else [keyword]
30 )
27 31
28 return [ 32 return [
29 kw 33 kw
diff --git a/environment.yaml b/environment.yaml
index 1a55967..2c81a90 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -14,16 +14,17 @@ dependencies:
14 - numpy=1.24.3 14 - numpy=1.24.3
15 - pip=22.3.1 15 - pip=22.3.1
16 - python=3.10.8 16 - python=3.10.8
17 - pytorch=2.0.0=*cuda11.8* 17 - pytorch=2.0.1=*cuda11.8*
18 - torchvision=0.15.0 18 - scipy=1.10.1
19 - xformers=0.0.20.dev528 19 - torchvision=0.15.2
20 - xformers=0.0.21.dev542+git.a205b24
20 - pip: 21 - pip:
21 - -e . 22 - -e .
22 - -e git+https://github.com/huggingface/accelerate#egg=accelerate 23 - -e git+https://github.com/huggingface/accelerate#egg=accelerate
23 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 24 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
24 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation 25 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation
25 - --pre --extra-index-url https://download.hidet.org/whl hidet 26 - --pre --extra-index-url https://download.hidet.org/whl hidet
26 - bitsandbytes==0.38.1 27 - bitsandbytes==0.39.1
27 - lion-pytorch==0.0.7 28 - lion-pytorch==0.0.7
28 - peft==0.3.0 29 - peft==0.3.0
29 - python-slugify>=6.1.2 30 - python-slugify>=6.1.2
@@ -31,4 +32,4 @@ dependencies:
31 - setuptools==65.6.3 32 - setuptools==65.6.3
32 - test-tube>=0.7.5 33 - test-tube>=0.7.5
33 - timm==0.9.2 34 - timm==0.9.2
34 - transformers==4.29.1 35 - transformers==4.30.1
diff --git a/infer.py b/infer.py
index 7346de9..3b3b595 100644
--- a/infer.py
+++ b/infer.py
@@ -24,7 +24,7 @@ from diffusers import (
24 KDPM2DiscreteScheduler, 24 KDPM2DiscreteScheduler,
25 KDPM2AncestralDiscreteScheduler, 25 KDPM2AncestralDiscreteScheduler,
26 DEISMultistepScheduler, 26 DEISMultistepScheduler,
27 UniPCMultistepScheduler 27 UniPCMultistepScheduler,
28) 28)
29from peft import LoraConfig, LoraModel, set_peft_model_state_dict 29from peft import LoraConfig, LoraModel, set_peft_model_state_dict
30from safetensors.torch import load_file 30from safetensors.torch import load_file
@@ -61,7 +61,7 @@ default_cmds = {
61 "negative_prompt": None, 61 "negative_prompt": None,
62 "shuffle": False, 62 "shuffle": False,
63 "image": None, 63 "image": None,
64 "image_noise": .7, 64 "image_noise": 0.7,
65 "width": 768, 65 "width": 768,
66 "height": 768, 66 "height": 768,
67 "batch_size": 1, 67 "batch_size": 1,
@@ -69,7 +69,6 @@ default_cmds = {
69 "steps": 30, 69 "steps": 30,
70 "guidance_scale": 7.0, 70 "guidance_scale": 7.0,
71 "sag_scale": 0, 71 "sag_scale": 0,
72 "brightness_offset": 0,
73 "seed": None, 72 "seed": None,
74 "config": None, 73 "config": None,
75} 74}
@@ -85,9 +84,7 @@ def merge_dicts(d1, *args):
85 84
86 85
87def create_args_parser(): 86def create_args_parser():
88 parser = argparse.ArgumentParser( 87 parser = argparse.ArgumentParser(description="Simple example of a training script.")
89 description="Simple example of a training script."
90 )
91 parser.add_argument( 88 parser.add_argument(
92 "--model", 89 "--model",
93 type=str, 90 type=str,
@@ -118,9 +115,7 @@ def create_args_parser():
118 115
119 116
120def create_cmd_parser(): 117def create_cmd_parser():
121 parser = argparse.ArgumentParser( 118 parser = argparse.ArgumentParser(description="Simple example of a training script.")
122 description="Simple example of a training script."
123 )
124 parser.add_argument( 119 parser.add_argument(
125 "--project", 120 "--project",
126 type=str, 121 type=str,
@@ -130,13 +125,34 @@ def create_cmd_parser():
130 parser.add_argument( 125 parser.add_argument(
131 "--scheduler", 126 "--scheduler",
132 type=str, 127 type=str,
133 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc"], 128 choices=[
129 "plms",
130 "ddim",
131 "klms",
132 "dpmsm",
133 "dpmss",
134 "euler_a",
135 "kdpm2",
136 "kdpm2_a",
137 "deis",
138 "unipc",
139 ],
134 ) 140 )
135 parser.add_argument( 141 parser.add_argument(
136 "--subscheduler", 142 "--subscheduler",
137 type=str, 143 type=str,
138 default=None, 144 default=None,
139 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis"], 145 choices=[
146 "plms",
147 "ddim",
148 "klms",
149 "dpmsm",
150 "dpmss",
151 "euler_a",
152 "kdpm2",
153 "kdpm2_a",
154 "deis",
155 ],
140 ) 156 )
141 parser.add_argument( 157 parser.add_argument(
142 "--template", 158 "--template",
@@ -193,10 +209,6 @@ def create_cmd_parser():
193 type=float, 209 type=float,
194 ) 210 )
195 parser.add_argument( 211 parser.add_argument(
196 "--brightness_offset",
197 type=float,
198 )
199 parser.add_argument(
200 "--seed", 212 "--seed",
201 type=int, 213 type=int,
202 ) 214 )
@@ -214,7 +226,9 @@ def run_parser(parser, defaults, input=None):
214 226
215 if args.config is not None: 227 if args.config is not None:
216 conf_args = load_config(args.config) 228 conf_args = load_config(args.config)
217 conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] 229 conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[
230 0
231 ]
218 232
219 res = defaults.copy() 233 res = defaults.copy()
220 for dict in [vars(conf_args), vars(args)]: 234 for dict in [vars(conf_args), vars(args)]:
@@ -234,10 +248,12 @@ def load_embeddings_dir(pipeline, embeddings_dir):
234 added_tokens, added_ids = load_embeddings_from_dir( 248 added_tokens, added_ids = load_embeddings_from_dir(
235 pipeline.tokenizer, 249 pipeline.tokenizer,
236 pipeline.text_encoder.text_model.embeddings, 250 pipeline.text_encoder.text_model.embeddings,
237 Path(embeddings_dir) 251 Path(embeddings_dir),
238 ) 252 )
239 pipeline.text_encoder.text_model.embeddings.persist() 253 pipeline.text_encoder.text_model.embeddings.persist()
240 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 254 print(
255 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
256 )
241 257
242 258
243def load_lora(pipeline, path): 259def load_lora(pipeline, path):
@@ -255,9 +271,13 @@ def load_lora(pipeline, path):
255 return 271 return
256 272
257 lora_checkpoint_sd = load_file(path / tensor_files[0]) 273 lora_checkpoint_sd = load_file(path / tensor_files[0])
258 unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} 274 unet_lora_ds = {
275 k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k
276 }
259 text_encoder_lora_ds = { 277 text_encoder_lora_ds = {
260 k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k 278 k.replace("text_encoder_", ""): v
279 for k, v in lora_checkpoint_sd.items()
280 if "text_encoder_" in k
261 } 281 }
262 ti_lora_ds = { 282 ti_lora_ds = {
263 k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k 283 k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k
@@ -282,7 +302,9 @@ def load_lora(pipeline, path):
282 token_embeddings=token_embeddings, 302 token_embeddings=token_embeddings,
283 ) 303 )
284 pipeline.text_encoder.text_model.embeddings.persist() 304 pipeline.text_encoder.text_model.embeddings.persist()
285 print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") 305 print(
306 f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}"
307 )
286 308
287 return 309 return
288 310
@@ -315,17 +337,25 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None)
315 solver_p=create_scheduler(config, subscheduler), 337 solver_p=create_scheduler(config, subscheduler),
316 ) 338 )
317 else: 339 else:
318 raise ValueError(f"Unknown scheduler \"{scheduler}\"") 340 raise ValueError(f'Unknown scheduler "{scheduler}"')
319 341
320 342
321def create_pipeline(model, dtype): 343def create_pipeline(model, dtype):
322 print("Loading Stable Diffusion pipeline...") 344 print("Loading Stable Diffusion pipeline...")
323 345
324 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 346 tokenizer = MultiCLIPTokenizer.from_pretrained(
325 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) 347 model, subfolder="tokenizer", torch_dtype=dtype
326 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 348 )
327 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 349 text_encoder = CLIPTextModel.from_pretrained(
328 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 350 model, subfolder="text_encoder", torch_dtype=dtype
351 )
352 vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype)
353 unet = UNet2DConditionModel.from_pretrained(
354 model, subfolder="unet", torch_dtype=dtype
355 )
356 scheduler = DDIMScheduler.from_pretrained(
357 model, subfolder="scheduler", torch_dtype=dtype
358 )
329 359
330 patch_managed_embeddings(text_encoder) 360 patch_managed_embeddings(text_encoder)
331 361
@@ -347,7 +377,9 @@ def create_pipeline(model, dtype):
347 377
348 378
349def shuffle_prompts(prompts: list[str]) -> list[str]: 379def shuffle_prompts(prompts: list[str]) -> list[str]:
350 return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts] 380 return [
381 keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts
382 ]
351 383
352 384
353@torch.inference_mode() 385@torch.inference_mode()
@@ -386,12 +418,13 @@ def generate(output_dir: Path, pipeline, args):
386 else: 418 else:
387 init_image = None 419 init_image = None
388 420
389 pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) 421 pipeline.scheduler = create_scheduler(
422 pipeline.scheduler.config, args.scheduler, args.subscheduler
423 )
390 424
391 for i in range(args.batch_num): 425 for i in range(args.batch_num):
392 pipeline.set_progress_bar_config( 426 pipeline.set_progress_bar_config(
393 desc=f"Batch {i + 1} of {args.batch_num}", 427 desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True
394 dynamic_ncols=True
395 ) 428 )
396 429
397 seed = args.seed + i 430 seed = args.seed + i
@@ -409,7 +442,6 @@ def generate(output_dir: Path, pipeline, args):
409 generator=generator, 442 generator=generator,
410 image=init_image, 443 image=init_image,
411 strength=args.image_noise, 444 strength=args.image_noise,
412 brightness_offset=args.brightness_offset,
413 ).images 445 ).images
414 446
415 for j, image in enumerate(images): 447 for j, image in enumerate(images):
@@ -418,7 +450,7 @@ def generate(output_dir: Path, pipeline, args):
418 450
419 image.save(dir / f"{basename}.png") 451 image.save(dir / f"{basename}.png")
420 image.save(dir / f"{basename}.jpg", quality=85) 452 image.save(dir / f"{basename}.jpg", quality=85)
421 with open(dir / f"{basename}.txt", 'w') as f: 453 with open(dir / f"{basename}.txt", "w") as f:
422 f.write(prompt[j % len(args.prompt)]) 454 f.write(prompt[j % len(args.prompt)])
423 455
424 if torch.cuda.is_available(): 456 if torch.cuda.is_available():
@@ -426,10 +458,12 @@ def generate(output_dir: Path, pipeline, args):
426 458
427 459
428class CmdParse(cmd.Cmd): 460class CmdParse(cmd.Cmd):
429 prompt = 'dream> ' 461 prompt = "dream> "
430 commands = [] 462 commands = []
431 463
432 def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): 464 def __init__(
465 self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser
466 ):
433 super().__init__() 467 super().__init__()
434 468
435 self.output_dir = output_dir 469 self.output_dir = output_dir
@@ -447,10 +481,10 @@ class CmdParse(cmd.Cmd):
447 print(str(e)) 481 print(str(e))
448 return 482 return
449 483
450 if elements[0] == 'q': 484 if elements[0] == "q":
451 return True 485 return True
452 486
453 if elements[0] == 'reload_embeddings': 487 if elements[0] == "reload_embeddings":
454 load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) 488 load_embeddings_dir(self.pipeline, self.ti_embeddings_dir)
455 return 489 return
456 490
@@ -458,7 +492,7 @@ class CmdParse(cmd.Cmd):
458 args = run_parser(self.parser, default_cmds, elements) 492 args = run_parser(self.parser, default_cmds, elements)
459 493
460 if len(args.prompt) == 0: 494 if len(args.prompt) == 0:
461 print('Try again with a prompt!') 495 print("Try again with a prompt!")
462 return 496 return
463 except SystemExit: 497 except SystemExit:
464 traceback.print_exc() 498 traceback.print_exc()
@@ -471,7 +505,7 @@ class CmdParse(cmd.Cmd):
471 try: 505 try:
472 generate(self.output_dir, self.pipeline, args) 506 generate(self.output_dir, self.pipeline, args)
473 except KeyboardInterrupt: 507 except KeyboardInterrupt:
474 print('Generation cancelled.') 508 print("Generation cancelled.")
475 except Exception as e: 509 except Exception as e:
476 traceback.print_exc() 510 traceback.print_exc()
477 return 511 return
@@ -487,7 +521,9 @@ def main():
487 args = run_parser(args_parser, default_args) 521 args = run_parser(args_parser, default_args)
488 522
489 output_dir = Path(args.output_dir) 523 output_dir = Path(args.output_dir)
490 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 524 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[
525 args.precision
526 ]
491 527
492 pipeline = create_pipeline(args.model, dtype) 528 pipeline = create_pipeline(args.model, dtype)
493 529
@@ -496,7 +532,13 @@ def main():
496 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) 532 # pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
497 533
498 cmd_parser = create_cmd_parser() 534 cmd_parser = create_cmd_parser()
499 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) 535 cmd_prompt = CmdParse(
536 output_dir,
537 args.ti_embeddings_dir,
538 args.lora_embeddings_dir,
539 pipeline,
540 cmd_parser,
541 )
500 cmd_prompt.cmdloop() 542 cmd_prompt.cmdloop()
501 543
502 544
diff --git a/models/attention/control.py b/models/attention/control.py
index 248bd9f..ec378c4 100644
--- a/models/attention/control.py
+++ b/models/attention/control.py
@@ -23,7 +23,7 @@ class AttentionControl(abc.ABC):
23 attn = self.forward(attn, is_cross, place_in_unet) 23 attn = self.forward(attn, is_cross, place_in_unet)
24 else: 24 else:
25 h = attn.shape[0] 25 h = attn.shape[0]
26 attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 26 attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet)
27 self.cur_att_layer += 1 27 self.cur_att_layer += 1
28 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 28 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
29 self.cur_att_layer = 0 29 self.cur_att_layer = 0
@@ -49,12 +49,18 @@ class EmptyControl(AttentionControl):
49class AttentionStore(AttentionControl): 49class AttentionStore(AttentionControl):
50 @staticmethod 50 @staticmethod
51 def get_empty_store(): 51 def get_empty_store():
52 return {"down_cross": [], "mid_cross": [], "up_cross": [], 52 return {
53 "down_self": [], "mid_self": [], "up_self": []} 53 "down_cross": [],
54 "mid_cross": [],
55 "up_cross": [],
56 "down_self": [],
57 "mid_self": [],
58 "up_self": [],
59 }
54 60
55 def forward(self, attn, is_cross: bool, place_in_unet: str): 61 def forward(self, attn, is_cross: bool, place_in_unet: str):
56 key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 62 key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
57 if attn.shape[1] <= 32 ** 2: # avoid memory overhead 63 if attn.shape[1] <= 32**2: # avoid memory overhead
58 self.step_store[key].append(attn) 64 self.step_store[key].append(attn)
59 return attn 65 return attn
60 66
@@ -68,8 +74,10 @@ class AttentionStore(AttentionControl):
68 self.step_store = self.get_empty_store() 74 self.step_store = self.get_empty_store()
69 75
70 def get_average_attention(self): 76 def get_average_attention(self):
71 average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] 77 average_attention = {
72 for key in self.attention_store} 78 key: [item / self.cur_step for item in self.attention_store[key]]
79 for key in self.attention_store
80 }
73 return average_attention 81 return average_attention
74 82
75 def reset(self): 83 def reset(self):
@@ -90,7 +98,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC):
90 return x_t 98 return x_t
91 99
92 def replace_self_attention(self, attn_base, att_replace): 100 def replace_self_attention(self, attn_base, att_replace):
93 if att_replace.shape[2] <= 16 ** 2: 101 if att_replace.shape[2] <= 16**2:
94 return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 102 return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
95 else: 103 else:
96 return att_replace 104 return att_replace
@@ -101,41 +109,62 @@ class AttentionControlEdit(AttentionStore, abc.ABC):
101 109
102 def forward(self, attn, is_cross: bool, place_in_unet: str): 110 def forward(self, attn, is_cross: bool, place_in_unet: str):
103 super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 111 super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
104 if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): 112 if is_cross or (
113 self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]
114 ):
105 h = attn.shape[0] // (self.batch_size) 115 h = attn.shape[0] // (self.batch_size)
106 attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 116 attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
107 attn_base, attn_repalce = attn[0], attn[1:] 117 attn_base, attn_repalce = attn[0], attn[1:]
108 if is_cross: 118 if is_cross:
109 alpha_words = self.cross_replace_alpha[self.cur_step] 119 alpha_words = self.cross_replace_alpha[self.cur_step]
110 attn_repalce_new = self.replace_cross_attention( 120 attn_repalce_new = (
111 attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce 121 self.replace_cross_attention(attn_base, attn_repalce) * alpha_words
122 + (1 - alpha_words) * attn_repalce
123 )
112 attn[1:] = attn_repalce_new 124 attn[1:] = attn_repalce_new
113 else: 125 else:
114 attn[1:] = self.replace_self_attention(attn_base, attn_repalce) 126 attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
115 attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 127 attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
116 return attn 128 return attn
117 129
118 def __init__(self, prompts, num_steps: int, 130 def __init__(
119 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 131 self,
120 self_replace_steps: Union[float, Tuple[float, float]], 132 prompts,
121 local_blend: Optional[LocalBlend]): 133 num_steps: int,
134 cross_replace_steps: Union[
135 float, Tuple[float, float], Dict[str, Tuple[float, float]]
136 ],
137 self_replace_steps: Union[float, Tuple[float, float]],
138 local_blend: Optional[LocalBlend],
139 ):
122 super(AttentionControlEdit, self).__init__() 140 super(AttentionControlEdit, self).__init__()
123 self.batch_size = len(prompts) 141 self.batch_size = len(prompts)
124 self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( 142 self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(
125 prompts, num_steps, cross_replace_steps, tokenizer).to(device) 143 prompts, num_steps, cross_replace_steps, tokenizer
144 ).to(device)
126 if type(self_replace_steps) is float: 145 if type(self_replace_steps) is float:
127 self_replace_steps = 0, self_replace_steps 146 self_replace_steps = 0, self_replace_steps
128 self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) 147 self.num_self_replace = int(num_steps * self_replace_steps[0]), int(
148 num_steps * self_replace_steps[1]
149 )
129 self.local_blend = local_blend 150 self.local_blend = local_blend
130 151
131 152
132class AttentionReplace(AttentionControlEdit): 153class AttentionReplace(AttentionControlEdit):
133 def replace_cross_attention(self, attn_base, att_replace): 154 def replace_cross_attention(self, attn_base, att_replace):
134 return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) 155 return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper)
135 156
136 def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 157 def __init__(
137 local_blend: Optional[LocalBlend] = None): 158 self,
138 super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 159 prompts,
160 num_steps: int,
161 cross_replace_steps: float,
162 self_replace_steps: float,
163 local_blend: Optional[LocalBlend] = None,
164 ):
165 super(AttentionReplace, self).__init__(
166 prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend
167 )
139 self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) 168 self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
140 169
141 170
@@ -145,9 +174,17 @@ class AttentionRefine(AttentionControlEdit):
145 attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) 174 attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
146 return attn_replace 175 return attn_replace
147 176
148 def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 177 def __init__(
149 local_blend: Optional[LocalBlend] = None): 178 self,
150 super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 179 prompts,
180 num_steps: int,
181 cross_replace_steps: float,
182 self_replace_steps: float,
183 local_blend: Optional[LocalBlend] = None,
184 ):
185 super(AttentionRefine, self).__init__(
186 prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend
187 )
151 self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) 188 self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
152 self.mapper, alphas = self.mapper.to(device), alphas.to(device) 189 self.mapper, alphas = self.mapper.to(device), alphas.to(device)
153 self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) 190 self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
@@ -156,13 +193,24 @@ class AttentionRefine(AttentionControlEdit):
156class AttentionReweight(AttentionControlEdit): 193class AttentionReweight(AttentionControlEdit):
157 def replace_cross_attention(self, attn_base, att_replace): 194 def replace_cross_attention(self, attn_base, att_replace):
158 if self.prev_controller is not None: 195 if self.prev_controller is not None:
159 attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) 196 attn_base = self.prev_controller.replace_cross_attention(
197 attn_base, att_replace
198 )
160 attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] 199 attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
161 return attn_replace 200 return attn_replace
162 201
163 def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, 202 def __init__(
164 local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): 203 self,
165 super(AttentionReweight, self).__init__(prompts, num_steps, 204 prompts,
166 cross_replace_steps, self_replace_steps, local_blend) 205 num_steps: int,
206 cross_replace_steps: float,
207 self_replace_steps: float,
208 equalizer,
209 local_blend: Optional[LocalBlend] = None,
210 controller: Optional[AttentionControlEdit] = None,
211 ):
212 super(AttentionReweight, self).__init__(
213 prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend
214 )
167 self.equalizer = equalizer.to(device) 215 self.equalizer = equalizer.to(device)
168 self.prev_controller = controller 216 self.prev_controller = controller
diff --git a/models/attention/hook.py b/models/attention/hook.py
index 903de02..6b5fb68 100644
--- a/models/attention/hook.py
+++ b/models/attention/hook.py
@@ -3,6 +3,7 @@ import torch
3 3
4try: 4try:
5 import xformers.ops 5 import xformers.ops
6
6 xformers._is_functorch_available = True 7 xformers._is_functorch_available = True
7 MEM_EFFICIENT_ATTN = True 8 MEM_EFFICIENT_ATTN = True
8except ImportError: 9except ImportError:
@@ -42,10 +43,10 @@ def register_attention_control(model, controller):
42 return forward 43 return forward
43 44
44 def register_recr(net_, count, place_in_unet): 45 def register_recr(net_, count, place_in_unet):
45 if net_.__class__.__name__ == 'CrossAttention': 46 if net_.__class__.__name__ == "CrossAttention":
46 net_.forward = ca_forward(net_, place_in_unet) 47 net_.forward = ca_forward(net_, place_in_unet)
47 return count + 1 48 return count + 1
48 elif hasattr(net_, 'children'): 49 elif hasattr(net_, "children"):
49 for net__ in net_.children(): 50 for net__ in net_.children():
50 count = register_recr(net__, count, place_in_unet) 51 count = register_recr(net__, count, place_in_unet)
51 return count 52 return count
diff --git a/models/attention/structured.py b/models/attention/structured.py
index 24d889f..5bbbc06 100644
--- a/models/attention/structured.py
+++ b/models/attention/structured.py
@@ -16,7 +16,9 @@ class StructuredAttentionControl(AttentionControl):
16 if self.struct_attn: 16 if self.struct_attn:
17 out = self.struct_qkv(q, context, mask) 17 out = self.struct_qkv(q, context, mask)
18 else: 18 else:
19 context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context 19 context = torch.cat(
20 [context[0], context[1]["k"][0]], dim=0
21 ) # use key tensor for context
20 out = self.normal_qkv(q, context, mask) 22 out = self.normal_qkv(q, context, mask)
21 else: 23 else:
22 context = default(context, x) 24 context = default(context, x)
@@ -29,11 +31,13 @@ class StructuredAttentionControl(AttentionControl):
29 context: list of [uc, list of conditional context] 31 context: list of [uc, list of conditional context]
30 """ 32 """
31 uc_context = context[0] 33 uc_context = context[0]
32 context_k, context_v = context[1]['k'], context[1]['v'] 34 context_k, context_v = context[1]["k"], context[1]["v"]
33 35
34 if isinstance(context_k, list) and isinstance(context_v, list): 36 if isinstance(context_k, list) and isinstance(context_v, list):
35 out = self.multi_qkv(q, uc_context, context_k, context_v, mask) 37 out = self.multi_qkv(q, uc_context, context_k, context_v, mask)
36 elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor): 38 elif isinstance(context_k, torch.Tensor) and isinstance(
39 context_v, torch.Tensor
40 ):
37 out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) 41 out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask)
38 else: 42 else:
39 raise NotImplementedError 43 raise NotImplementedError
@@ -50,36 +54,45 @@ class StructuredAttentionControl(AttentionControl):
50 k_c = [self.to_k(c_k) for c_k in context_k] 54 k_c = [self.to_k(c_k) for c_k in context_k]
51 v_c = [self.to_v(c_v) for c_v in context_v] 55 v_c = [self.to_v(c_v) for c_v in context_v]
52 56
53 q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) 57 q = rearrange(q, "b n (h d) -> (b h) n d", h=h)
54 58
55 k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h) 59 k_uc = rearrange(k_uc, "b n (h d) -> (b h) n d", h=h)
56 v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h) 60 v_uc = rearrange(v_uc, "b n (h d) -> (b h) n d", h=h)
57 61
58 k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point 62 k_c = [
59 v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c] 63 rearrange(k, "b n (h d) -> (b h) n d", h=h) for k in k_c
64 ] # NOTE: modification point
65 v_c = [rearrange(v, "b n (h d) -> (b h) n d", h=h) for v in v_c]
60 66
61 # get composition 67 # get composition
62 sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale 68 sim_uc = einsum("b i d, b j d -> b i j", q[:true_bs], k_uc) * self.scale
63 sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c] 69 sim_c = [
70 einsum("b i d, b j d -> b i j", q[true_bs:], k) * self.scale for k in k_c
71 ]
64 72
65 attn_uc = sim_uc.softmax(dim=-1) 73 attn_uc = sim_uc.softmax(dim=-1)
66 attn_c = [sim.softmax(dim=-1) for sim in sim_c] 74 attn_c = [sim.softmax(dim=-1) for sim in sim_c]
67 75
68 # get uc output 76 # get uc output
69 out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc) 77 out_uc = einsum("b i j, b j d -> b i d", attn_uc, v_uc)
70 78
71 # get c output 79 # get c output
72 if len(v_c) == 1: 80 if len(v_c) == 1:
73 out_c_collect = [] 81 out_c_collect = []
74 for attn in attn_c: 82 for attn in attn_c:
75 for v in v_c: 83 for v in v_c:
76 out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v)) 84 out_c_collect.append(einsum("b i j, b j d -> b i d", attn, v))
77 out_c = sum(out_c_collect) / len(out_c_collect) 85 out_c = sum(out_c_collect) / len(out_c_collect)
78 else: 86 else:
79 out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c) 87 out_c = sum(
88 [
89 einsum("b i j, b j d -> b i d", attn, v)
90 for attn, v in zip(attn_c, v_c)
91 ]
92 ) / len(v_c)
80 93
81 out = torch.cat([out_uc, out_c], dim=0) 94 out = torch.cat([out_uc, out_c], dim=0)
82 out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 95 out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
83 96
84 return out 97 return out
85 98
@@ -88,21 +101,21 @@ class StructuredAttentionControl(AttentionControl):
88 k = self.to_k(context) 101 k = self.to_k(context)
89 v = self.to_v(context) 102 v = self.to_v(context)
90 103
91 q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 104 q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
92 105
93 sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 106 sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
94 107
95 if exists(mask): 108 if exists(mask):
96 mask = rearrange(mask, 'b ... -> b (...)') 109 mask = rearrange(mask, "b ... -> b (...)")
97 max_neg_value = -torch.finfo(sim.dtype).max 110 max_neg_value = -torch.finfo(sim.dtype).max
98 mask = repeat(mask, 'b j -> (b h) () j', h=h) 111 mask = repeat(mask, "b j -> (b h) () j", h=h)
99 sim.masked_fill_(~mask, max_neg_value) 112 sim.masked_fill_(~mask, max_neg_value)
100 113
101 # attention, what we cannot get enough of 114 # attention, what we cannot get enough of
102 attn = sim.softmax(dim=-1) 115 attn = sim.softmax(dim=-1)
103 116
104 out = einsum('b i j, b j d -> b i d', attn, v) 117 out = einsum("b i j, b j d -> b i d", attn, v)
105 out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 118 out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
106 119
107 return out 120 return out
108 121
@@ -111,21 +124,21 @@ class StructuredAttentionControl(AttentionControl):
111 k = self.to_k(torch.cat([uc_context, context_k], dim=0)) 124 k = self.to_k(torch.cat([uc_context, context_k], dim=0))
112 v = self.to_v(torch.cat([uc_context, context_v], dim=0)) 125 v = self.to_v(torch.cat([uc_context, context_v], dim=0))
113 126
114 q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 127 q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
115 128
116 sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 129 sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
117 130
118 if exists(mask): 131 if exists(mask):
119 mask = rearrange(mask, 'b ... -> b (...)') 132 mask = rearrange(mask, "b ... -> b (...)")
120 max_neg_value = -torch.finfo(sim.dtype).max 133 max_neg_value = -torch.finfo(sim.dtype).max
121 mask = repeat(mask, 'b j -> (b h) () j', h=h) 134 mask = repeat(mask, "b j -> (b h) () j", h=h)
122 sim.masked_fill_(~mask, max_neg_value) 135 sim.masked_fill_(~mask, max_neg_value)
123 136
124 # attention, what we cannot get enough of 137 # attention, what we cannot get enough of
125 attn = sim.softmax(dim=-1) 138 attn = sim.softmax(dim=-1)
126 139
127 out = einsum('b i j, b j d -> b i d', attn, v) 140 out = einsum("b i j, b j d -> b i d", attn, v)
128 out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 141 out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
129 return out 142 return out
130 143
131 def get_kv(self, context): 144 def get_kv(self, context):
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 7c7f2ac..8c3c6d4 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -14,7 +14,13 @@ from models.sparse import SparseEmbedding
14 14
15 15
16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): 16class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
17 def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): 17 def __init__(
18 self,
19 config: CLIPTextConfig,
20 embeddings: CLIPTextEmbeddings,
21 alpha: int = 8,
22 dropout: float = 0.0,
23 ):
18 super().__init__(config) 24 super().__init__(config)
19 25
20 self.position_embedding = embeddings.position_embedding 26 self.position_embedding = embeddings.position_embedding
@@ -28,7 +34,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
28 self.token_embedding.weight = embeddings.token_embedding.weight 34 self.token_embedding.weight = embeddings.token_embedding.weight
29 35
30 def resize(self, size: int): 36 def resize(self, size: int):
31 self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) 37 self.token_embedding = self.token_embedding.new_resized(
38 size, self.initializer_factor
39 )
32 40
33 def add_embed( 41 def add_embed(
34 self, 42 self,
@@ -46,7 +54,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
46 initializer = [initializer] 54 initializer = [initializer]
47 55
48 if isinstance(initializer, list): 56 if isinstance(initializer, list):
49 initializer = (initializer * len(token_ids))[:len(token_ids)] 57 initializer = (initializer * len(token_ids))[: len(token_ids)]
50 58
51 with torch.no_grad(): 59 with torch.no_grad():
52 initializer = self.get_embed(initializer) 60 initializer = self.get_embed(initializer)
@@ -76,24 +84,21 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
76 84
77 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): 85 def get_embed(self, input_ids: Union[list[int], torch.LongTensor]):
78 if isinstance(input_ids, list): 86 if isinstance(input_ids, list):
79 input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) 87 input_ids = torch.tensor(
88 input_ids, device=self.token_embedding.weight.device, dtype=torch.long
89 )
80 90
81 return self.token_embedding(input_ids) 91 return self.token_embedding(input_ids)
82 92
83 93
84def patch_managed_embeddings( 94def patch_managed_embeddings(
85 text_encoder: CLIPTextModel, 95 text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0
86 alpha: int = 8,
87 dropout: float = 0.0
88) -> ManagedCLIPTextEmbeddings: 96) -> ManagedCLIPTextEmbeddings:
89 if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): 97 if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings):
90 return text_encoder.text_model.embeddings 98 return text_encoder.text_model.embeddings
91 99
92 text_embeddings = ManagedCLIPTextEmbeddings( 100 text_embeddings = ManagedCLIPTextEmbeddings(
93 text_encoder.config, 101 text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout
94 text_encoder.text_model.embeddings,
95 alpha,
96 dropout
97 ) 102 )
98 text_encoder.text_model.embeddings = text_embeddings 103 text_encoder.text_model.embeddings = text_embeddings
99 return text_embeddings 104 return text_embeddings
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 789b525..a866641 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -91,18 +91,21 @@ class MultiCLIPTokenizer(CLIPTokenizer):
91 self.vector_shuffle = shuffle_none 91 self.vector_shuffle = shuffle_none
92 92
93 def add_multi_tokens( 93 def add_multi_tokens(
94 self, 94 self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1
95 new_tokens: Union[str, list[str]],
96 num_vectors: Union[int, list[int]] = 1
97 ) -> Union[list[int], list[list[int]]]: 95 ) -> Union[list[int], list[list[int]]]:
98 if isinstance(new_tokens, list): 96 if isinstance(new_tokens, list):
99 if isinstance(num_vectors, int): 97 if isinstance(num_vectors, int):
100 num_vectors = [num_vectors] * len(new_tokens) 98 num_vectors = [num_vectors] * len(new_tokens)
101 99
102 if len(num_vectors) != len(new_tokens): 100 if len(num_vectors) != len(new_tokens):
103 raise ValueError("Expected new_tokens and num_vectors to have the same len") 101 raise ValueError(
102 "Expected new_tokens and num_vectors to have the same len"
103 )
104 104
105 return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] 105 return [
106 self.add_multi_tokens(new_token, vecs)
107 for new_token, vecs in zip(new_tokens, num_vectors)
108 ]
106 109
107 if isinstance(num_vectors, list): 110 if isinstance(num_vectors, list):
108 raise ValueError("Expected num_vectors to be int for single token") 111 raise ValueError("Expected num_vectors to be int for single token")
@@ -129,13 +132,11 @@ class MultiCLIPTokenizer(CLIPTokenizer):
129 return [id] 132 return [id]
130 133
131 def expand_ids(self, ids: list[int]): 134 def expand_ids(self, ids: list[int]):
132 return [ 135 return [new_id for id in ids for new_id in self.expand_id(id)]
133 new_id
134 for id in ids
135 for new_id in self.expand_id(id)
136 ]
137 136
138 def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): 137 def expand_batched_ids(
138 self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]
139 ):
139 if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): 140 if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list):
140 return [self.expand_ids(batch) for batch in input_ids] 141 return [self.expand_ids(batch) for batch in input_ids]
141 else: 142 else:
diff --git a/models/clip/util.py b/models/clip/util.py
index f94fbc7..7196bb6 100644
--- a/models/clip/util.py
+++ b/models/clip/util.py
@@ -5,27 +5,32 @@ import torch
5from transformers import CLIPTokenizer, CLIPTextModel 5from transformers import CLIPTokenizer, CLIPTextModel
6 6
7 7
8def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): 8def unify_input_ids(
9 tokenizer: CLIPTokenizer,
10 input_ids: list[list[int]],
11 max_length: Optional[int] = None,
12):
9 if max_length is None: 13 if max_length is None:
10 return tokenizer.pad( 14 return tokenizer.pad(
11 {"input_ids": input_ids}, 15 {"input_ids": input_ids},
12 padding=True, 16 padding=True,
13 pad_to_multiple_of=tokenizer.model_max_length, 17 pad_to_multiple_of=tokenizer.model_max_length,
14 return_tensors="pt" 18 return_tensors="pt",
15 ) 19 )
16 else: 20 else:
17 return tokenizer.pad( 21 return tokenizer.pad(
18 {"input_ids": input_ids}, 22 {"input_ids": input_ids},
19 padding="max_length", 23 padding="max_length",
20 max_length=max_length, 24 max_length=max_length,
21 return_tensors="pt" 25 return_tensors="pt",
22 ) 26 )
23 27
28
24def get_extended_embeddings( 29def get_extended_embeddings(
25 text_encoder: CLIPTextModel, 30 text_encoder: CLIPTextModel,
26 input_ids: torch.LongTensor, 31 input_ids: torch.LongTensor,
27 position_ids: Optional[torch.LongTensor] = None, 32 position_ids: Optional[torch.LongTensor] = None,
28 attention_mask=None 33 attention_mask=None,
29): 34):
30 model_max_length = text_encoder.config.max_position_embeddings 35 model_max_length = text_encoder.config.max_position_embeddings
31 prompts = input_ids.shape[0] 36 prompts = input_ids.shape[0]
@@ -36,6 +41,8 @@ def get_extended_embeddings(
36 if attention_mask is not None: 41 if attention_mask is not None:
37 attention_mask = attention_mask.view((-1, model_max_length)) 42 attention_mask = attention_mask.view((-1, model_max_length))
38 43
39 text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] 44 text_embeddings = text_encoder(
45 input_ids, position_ids=position_ids, attention_mask=attention_mask
46 )[0]
40 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) 47 text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2]))
41 return text_embeddings 48 return text_embeddings
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py
index 571b915..5798bcf 100644
--- a/models/convnext/discriminator.py
+++ b/models/convnext/discriminator.py
@@ -5,7 +5,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
5from torch.nn import functional as F 5from torch.nn import functional as F
6 6
7 7
8class ConvNeXtDiscriminator(): 8class ConvNeXtDiscriminator:
9 def __init__(self, model: ConvNeXt, input_size: int) -> None: 9 def __init__(self, model: ConvNeXt, input_size: int) -> None:
10 self.net = model 10 self.net = model
11 11
@@ -22,8 +22,13 @@ class ConvNeXtDiscriminator():
22 img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) 22 img_mean = self.img_mean.to(device=img.device, dtype=img.dtype)
23 img_std = self.img_std.to(device=img.device, dtype=img.dtype) 23 img_std = self.img_std.to(device=img.device, dtype=img.dtype)
24 24
25 img = ((img + 1.) / 2.).sub(img_mean).div(img_std) 25 img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std)
26 26
27 img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) 27 img = F.interpolate(
28 img,
29 size=(self.input_size, self.input_size),
30 mode="bicubic",
31 align_corners=True,
32 )
28 pred = self.net(img) 33 pred = self.net(img)
29 return pred 34 return pred
diff --git a/models/sparse.py b/models/sparse.py
index bd45696..e5897c9 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -15,21 +15,25 @@ class SparseEmbedding(nn.Embedding):
15 ): 15 ):
16 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) 16 nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
17 17
18 self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) 18 self.register_buffer(
19 "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1
20 )
19 21
20 self.trainable = nn.ParameterList() 22 self.trainable = nn.ParameterList()
21 self.scaling = alpha 23 self.scaling = alpha
22 self.dropout_p = dropout 24 self.dropout_p = dropout
23 self.weight.requires_grad = False 25 self.weight.requires_grad = False
24 26
25 if dropout > 0.: 27 if dropout > 0.0:
26 self.dropout = nn.Dropout(p=dropout) 28 self.dropout = nn.Dropout(p=dropout)
27 else: 29 else:
28 self.dropout = nn.Identity() 30 self.dropout = nn.Identity()
29 31
30 self.reset_parameters() 32 self.reset_parameters()
31 33
32 def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): 34 def new_resized(
35 self, new_num_embeddings: int, initializer_factor: Optional[float] = None
36 ):
33 n = min(self.num_embeddings, new_num_embeddings) 37 n = min(self.num_embeddings, new_num_embeddings)
34 38
35 new_emb = SparseEmbedding( 39 new_emb = SparseEmbedding(
@@ -38,7 +42,7 @@ class SparseEmbedding(nn.Embedding):
38 self.scaling, 42 self.scaling,
39 self.dropout_p, 43 self.dropout_p,
40 device=self.weight.device, 44 device=self.weight.device,
41 dtype=self.weight.dtype 45 dtype=self.weight.dtype,
42 ) 46 )
43 if initializer_factor is not None: 47 if initializer_factor is not None:
44 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) 48 new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02)
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index aa446ec..16b8456 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -21,7 +21,9 @@ from diffusers import (
21 LMSDiscreteScheduler, 21 LMSDiscreteScheduler,
22 PNDMScheduler, 22 PNDMScheduler,
23) 23)
24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 24from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
25 StableDiffusionPipelineOutput,
26)
25from diffusers.utils import logging, randn_tensor 27from diffusers.utils import logging, randn_tensor
26from transformers import CLIPTextModel, CLIPTokenizer 28from transformers import CLIPTextModel, CLIPTokenizer
27 29
@@ -62,13 +64,35 @@ def gaussian_blur_2d(img, kernel_size, sigma):
62 return img 64 return img
63 65
64 66
67def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
68 """
69 Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
70 Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
71 """
72 std_text = noise_pred_text.std(
73 dim=list(range(1, noise_pred_text.ndim)), keepdim=True
74 )
75 std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
76 # rescale the results from guidance (fixes overexposure)
77 noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
78 # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
79 noise_cfg = (
80 guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
81 )
82 return noise_cfg
83
84
65class CrossAttnStoreProcessor: 85class CrossAttnStoreProcessor:
66 def __init__(self): 86 def __init__(self):
67 self.attention_probs = None 87 self.attention_probs = None
68 88
69 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): 89 def __call__(
90 self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None
91 ):
70 batch_size, sequence_length, _ = hidden_states.shape 92 batch_size, sequence_length, _ = hidden_states.shape
71 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 93 attention_mask = attn.prepare_attention_mask(
94 attention_mask, sequence_length, batch_size
95 )
72 query = attn.to_q(hidden_states) 96 query = attn.to_q(hidden_states)
73 97
74 if encoder_hidden_states is None: 98 if encoder_hidden_states is None:
@@ -113,7 +137,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
113 ): 137 ):
114 super().__init__() 138 super().__init__()
115 139
116 if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 140 if (
141 hasattr(scheduler.config, "steps_offset")
142 and scheduler.config.steps_offset != 1
143 ):
117 warnings.warn( 144 warnings.warn(
118 f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 145 f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
119 f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 146 f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -179,7 +206,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
179 206
180 device = torch.device("cuda") 207 device = torch.device("cuda")
181 208
182 for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: 209 for cpu_offloaded_model in [
210 self.unet,
211 self.text_encoder,
212 self.vae,
213 self.safety_checker,
214 ]:
183 if cpu_offloaded_model is not None: 215 if cpu_offloaded_model is not None:
184 cpu_offload(cpu_offloaded_model, device) 216 cpu_offload(cpu_offloaded_model, device)
185 217
@@ -223,35 +255,47 @@ class VlpnStableDiffusion(DiffusionPipeline):
223 width: int, 255 width: int,
224 height: int, 256 height: int,
225 strength: float, 257 strength: float,
226 callback_steps: Optional[int] 258 callback_steps: Optional[int],
227 ): 259 ):
228 if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): 260 if isinstance(prompt, str) or (
261 isinstance(prompt, list) and isinstance(prompt[0], int)
262 ):
229 prompt = [prompt] 263 prompt = [prompt]
230 264
231 if negative_prompt is None: 265 if negative_prompt is None:
232 negative_prompt = "" 266 negative_prompt = ""
233 267
234 if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): 268 if isinstance(negative_prompt, str) or (
269 isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)
270 ):
235 negative_prompt = [negative_prompt] * len(prompt) 271 negative_prompt = [negative_prompt] * len(prompt)
236 272
237 if not isinstance(prompt, list): 273 if not isinstance(prompt, list):
238 raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 274 raise ValueError(
275 f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
276 )
239 277
240 if not isinstance(negative_prompt, list): 278 if not isinstance(negative_prompt, list):
241 raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") 279 raise ValueError(
280 f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}"
281 )
242 282
243 if len(negative_prompt) != len(prompt): 283 if len(negative_prompt) != len(prompt):
244 raise ValueError( 284 raise ValueError(
245 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") 285 f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}"
286 )
246 287
247 if strength < 0 or strength > 1: 288 if strength < 0 or strength > 1:
248 raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") 289 raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}")
249 290
250 if height % 8 != 0 or width % 8 != 0: 291 if height % 8 != 0 or width % 8 != 0:
251 raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 292 raise ValueError(
293 f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
294 )
252 295
253 if (callback_steps is None) or ( 296 if (callback_steps is None) or (
254 callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 297 callback_steps is not None
298 and (not isinstance(callback_steps, int) or callback_steps <= 0)
255 ): 299 ):
256 raise ValueError( 300 raise ValueError(
257 f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 301 f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
@@ -266,7 +310,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
266 negative_prompt: Union[List[str], List[List[int]]], 310 negative_prompt: Union[List[str], List[List[int]]],
267 num_images_per_prompt: int, 311 num_images_per_prompt: int,
268 do_classifier_free_guidance: bool, 312 do_classifier_free_guidance: bool,
269 device 313 device,
270 ): 314 ):
271 if isinstance(prompt[0], str): 315 if isinstance(prompt[0], str):
272 text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids 316 text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids
@@ -277,7 +321,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
277 321
278 if do_classifier_free_guidance: 322 if do_classifier_free_guidance:
279 if isinstance(prompt[0], str): 323 if isinstance(prompt[0], str):
280 unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids 324 unconditional_input_ids = self.tokenizer(
325 negative_prompt, padding="do_not_pad"
326 ).input_ids
281 else: 327 else:
282 unconditional_input_ids = negative_prompt 328 unconditional_input_ids = negative_prompt
283 unconditional_input_ids *= num_images_per_prompt 329 unconditional_input_ids *= num_images_per_prompt
@@ -286,12 +332,17 @@ class VlpnStableDiffusion(DiffusionPipeline):
286 text_inputs = unify_input_ids(self.tokenizer, text_input_ids) 332 text_inputs = unify_input_ids(self.tokenizer, text_input_ids)
287 text_input_ids = text_inputs.input_ids 333 text_input_ids = text_inputs.input_ids
288 334
289 if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 335 if (
336 hasattr(self.text_encoder.config, "use_attention_mask")
337 and self.text_encoder.config.use_attention_mask
338 ):
290 attention_mask = text_inputs.attention_mask.to(device) 339 attention_mask = text_inputs.attention_mask.to(device)
291 else: 340 else:
292 attention_mask = None 341 attention_mask = None
293 342
294 prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) 343 prompt_embeds = get_extended_embeddings(
344 self.text_encoder, text_input_ids.to(device), attention_mask
345 )
295 prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 346 prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
296 347
297 return prompt_embeds 348 return prompt_embeds
@@ -301,25 +352,21 @@ class VlpnStableDiffusion(DiffusionPipeline):
301 init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 352 init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
302 353
303 t_start = max(num_inference_steps - init_timestep, 0) 354 t_start = max(num_inference_steps - init_timestep, 0)
304 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] 355 timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
305 356
306 timesteps = timesteps.to(device) 357 timesteps = timesteps.to(device)
307 358
308 return timesteps, num_inference_steps - t_start 359 return timesteps, num_inference_steps - t_start
309 360
310 def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): 361 def prepare_latents_from_image(
311 offset_image = perlin_noise( 362 self,
312 (batch_size, 1, width, height), 363 init_image,
313 res=1, 364 timestep,
314 generator=generator, 365 batch_size,
315 dtype=dtype, 366 dtype,
316 device=device 367 device,
317 ) 368 generator=None,
318 offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) 369 ):
319 offset_latents = self.vae.config.scaling_factor * offset_latents
320 return offset_latents
321
322 def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None):
323 init_image = init_image.to(device=device, dtype=dtype) 370 init_image = init_image.to(device=device, dtype=dtype)
324 latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) 371 latents = self.vae.encode(init_image).latent_dist.sample(generator=generator)
325 latents = self.vae.config.scaling_factor * latents 372 latents = self.vae.config.scaling_factor * latents
@@ -333,20 +380,32 @@ class VlpnStableDiffusion(DiffusionPipeline):
333 latents = torch.cat([latents] * batch_multiplier, dim=0) 380 latents = torch.cat([latents] * batch_multiplier, dim=0)
334 381
335 # add noise to latents using the timesteps 382 # add noise to latents using the timesteps
336 noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) 383 noise = torch.randn(
337 384 latents.shape, generator=generator, device=device, dtype=dtype
338 if brightness_offset != 0: 385 )
339 noise += brightness_offset * self.prepare_brightness_offset(
340 batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator
341 )
342 386
343 # get latents 387 # get latents
344 latents = self.scheduler.add_noise(latents, noise, timestep) 388 latents = self.scheduler.add_noise(latents, noise, timestep)
345 389
346 return latents 390 return latents
347 391
348 def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): 392 def prepare_latents(
349 shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 393 self,
394 batch_size,
395 num_channels_latents,
396 height,
397 width,
398 dtype,
399 device,
400 generator,
401 latents=None,
402 ):
403 shape = (
404 batch_size,
405 num_channels_latents,
406 height // self.vae_scale_factor,
407 width // self.vae_scale_factor,
408 )
350 if isinstance(generator, list) and len(generator) != batch_size: 409 if isinstance(generator, list) and len(generator) != batch_size:
351 raise ValueError( 410 raise ValueError(
352 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 411 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -354,15 +413,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
354 ) 413 )
355 414
356 if latents is None: 415 if latents is None:
357 latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 416 latents = randn_tensor(
417 shape, generator=generator, device=device, dtype=dtype
418 )
358 else: 419 else:
359 latents = latents.to(device) 420 latents = latents.to(device)
360 421
361 if brightness_offset != 0:
362 latents += brightness_offset * self.prepare_brightness_offset(
363 batch_size, height, width, dtype, device, generator
364 )
365
366 # scale the initial noise by the standard deviation required by the scheduler 422 # scale the initial noise by the standard deviation required by the scheduler
367 latents = latents * self.scheduler.init_noise_sigma 423 latents = latents * self.scheduler.init_noise_sigma
368 return latents 424 return latents
@@ -373,13 +429,17 @@ class VlpnStableDiffusion(DiffusionPipeline):
373 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 429 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
374 # and should be between [0, 1] 430 # and should be between [0, 1]
375 431
376 accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 432 accepts_eta = "eta" in set(
433 inspect.signature(self.scheduler.step).parameters.keys()
434 )
377 extra_step_kwargs = {} 435 extra_step_kwargs = {}
378 if accepts_eta: 436 if accepts_eta:
379 extra_step_kwargs["eta"] = eta 437 extra_step_kwargs["eta"] = eta
380 438
381 # check if the scheduler accepts generator 439 # check if the scheduler accepts generator
382 accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 440 accepts_generator = "generator" in set(
441 inspect.signature(self.scheduler.step).parameters.keys()
442 )
383 if accepts_generator: 443 if accepts_generator:
384 extra_step_kwargs["generator"] = generator 444 extra_step_kwargs["generator"] = generator
385 return extra_step_kwargs 445 return extra_step_kwargs
@@ -396,7 +456,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
396 def __call__( 456 def __call__(
397 self, 457 self,
398 prompt: Union[str, List[str], List[int], List[List[int]]], 458 prompt: Union[str, List[str], List[int], List[List[int]]],
399 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, 459 negative_prompt: Optional[
460 Union[str, List[str], List[int], List[List[int]]]
461 ] = None,
400 num_images_per_prompt: int = 1, 462 num_images_per_prompt: int = 1,
401 strength: float = 1.0, 463 strength: float = 1.0,
402 height: Optional[int] = None, 464 height: Optional[int] = None,
@@ -407,12 +469,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
407 eta: float = 0.0, 469 eta: float = 0.0,
408 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 470 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
409 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 471 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
410 brightness_offset: Union[float, torch.FloatTensor] = 0,
411 output_type: str = "pil", 472 output_type: str = "pil",
412 return_dict: bool = True, 473 return_dict: bool = True,
413 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 474 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
414 callback_steps: int = 1, 475 callback_steps: int = 1,
415 cross_attention_kwargs: Optional[Dict[str, Any]] = None, 476 cross_attention_kwargs: Optional[Dict[str, Any]] = None,
477 guidance_rescale: float = 0.0,
416 ): 478 ):
417 r""" 479 r"""
418 Function invoked when calling the pipeline for generation. 480 Function invoked when calling the pipeline for generation.
@@ -472,7 +534,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
472 width = width or self.unet.config.sample_size * self.vae_scale_factor 534 width = width or self.unet.config.sample_size * self.vae_scale_factor
473 535
474 # 1. Check inputs. Raise error if not correct 536 # 1. Check inputs. Raise error if not correct
475 prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) 537 prompt, negative_prompt = self.check_inputs(
538 prompt, negative_prompt, width, height, strength, callback_steps
539 )
476 540
477 # 2. Define call parameters 541 # 2. Define call parameters
478 batch_size = len(prompt) 542 batch_size = len(prompt)
@@ -488,7 +552,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
488 negative_prompt, 552 negative_prompt,
489 num_images_per_prompt, 553 num_images_per_prompt,
490 do_classifier_free_guidance, 554 do_classifier_free_guidance,
491 device 555 device,
492 ) 556 )
493 557
494 # 4. Prepare latent variables 558 # 4. Prepare latent variables
@@ -497,7 +561,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
497 561
498 # 5. Prepare timesteps 562 # 5. Prepare timesteps
499 self.scheduler.set_timesteps(num_inference_steps, device=device) 563 self.scheduler.set_timesteps(num_inference_steps, device=device)
500 timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 564 timesteps, num_inference_steps = self.get_timesteps(
565 num_inference_steps, strength, device
566 )
501 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 567 latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
502 568
503 # 6. Prepare latent variables 569 # 6. Prepare latent variables
@@ -506,7 +572,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
506 image, 572 image,
507 latent_timestep, 573 latent_timestep,
508 batch_size * num_images_per_prompt, 574 batch_size * num_images_per_prompt,
509 brightness_offset,
510 prompt_embeds.dtype, 575 prompt_embeds.dtype,
511 device, 576 device,
512 generator, 577 generator,
@@ -517,7 +582,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
517 num_channels_latents, 582 num_channels_latents,
518 height, 583 height,
519 width, 584 width,
520 brightness_offset,
521 prompt_embeds.dtype, 585 prompt_embeds.dtype,
522 device, 586 device,
523 generator, 587 generator,
@@ -530,14 +594,20 @@ class VlpnStableDiffusion(DiffusionPipeline):
530 # 8. Denoising loo 594 # 8. Denoising loo
531 if do_self_attention_guidance: 595 if do_self_attention_guidance:
532 store_processor = CrossAttnStoreProcessor() 596 store_processor = CrossAttnStoreProcessor()
533 self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor 597 self.unet.mid_block.attentions[0].transformer_blocks[
598 0
599 ].attn1.processor = store_processor
534 600
535 num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 601 num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
536 with self.progress_bar(total=num_inference_steps) as progress_bar: 602 with self.progress_bar(total=num_inference_steps) as progress_bar:
537 for i, t in enumerate(timesteps): 603 for i, t in enumerate(timesteps):
538 # expand the latents if we are doing classifier free guidance 604 # expand the latents if we are doing classifier free guidance
539 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 605 latent_model_input = (
540 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 606 torch.cat([latents] * 2) if do_classifier_free_guidance else latents
607 )
608 latent_model_input = self.scheduler.scale_model_input(
609 latent_model_input, t
610 )
541 611
542 # predict the noise residual 612 # predict the noise residual
543 noise_pred = self.unet( 613 noise_pred = self.unet(
@@ -551,7 +621,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
551 # perform guidance 621 # perform guidance
552 if do_classifier_free_guidance: 622 if do_classifier_free_guidance:
553 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 623 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
554 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 624 noise_pred = noise_pred_uncond + guidance_scale * (
625 noise_pred_text - noise_pred_uncond
626 )
627 noise_pred = rescale_noise_cfg(
628 noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
629 )
555 630
556 if do_self_attention_guidance: 631 if do_self_attention_guidance:
557 # classifier-free guidance produces two chunks of attention map 632 # classifier-free guidance produces two chunks of attention map
@@ -561,15 +636,24 @@ class VlpnStableDiffusion(DiffusionPipeline):
561 # DDIM-like prediction of x0 636 # DDIM-like prediction of x0
562 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) 637 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
563 # get the stored attention maps 638 # get the stored attention maps
564 uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) 639 uncond_attn, cond_attn = store_processor.attention_probs.chunk(
640 2
641 )
565 # self-attention-based degrading of latents 642 # self-attention-based degrading of latents
566 degraded_latents = self.sag_masking( 643 degraded_latents = self.sag_masking(
567 pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) 644 pred_x0,
645 uncond_attn,
646 t,
647 self.pred_epsilon(latents, noise_pred_uncond, t),
568 ) 648 )
569 uncond_emb, _ = prompt_embeds.chunk(2) 649 uncond_emb, _ = prompt_embeds.chunk(2)
570 # forward and give guidance 650 # forward and give guidance
571 degraded_pred = self.unet( 651 degraded_pred = self.unet(
572 degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] 652 degraded_latents,
653 t,
654 encoder_hidden_states=uncond_emb,
655 return_dict=False,
656 )[0]
573 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) 657 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
574 else: 658 else:
575 # DDIM-like prediction of x0 659 # DDIM-like prediction of x0
@@ -578,18 +662,29 @@ class VlpnStableDiffusion(DiffusionPipeline):
578 cond_attn = store_processor.attention_probs 662 cond_attn = store_processor.attention_probs
579 # self-attention-based degrading of latents 663 # self-attention-based degrading of latents
580 degraded_latents = self.sag_masking( 664 degraded_latents = self.sag_masking(
581 pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) 665 pred_x0,
666 cond_attn,
667 t,
668 self.pred_epsilon(latents, noise_pred, t),
582 ) 669 )
583 # forward and give guidance 670 # forward and give guidance
584 degraded_pred = self.unet( 671 degraded_pred = self.unet(
585 degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] 672 degraded_latents,
673 t,
674 encoder_hidden_states=prompt_embeds,
675 return_dict=False,
676 )[0]
586 noise_pred += sag_scale * (noise_pred - degraded_pred) 677 noise_pred += sag_scale * (noise_pred - degraded_pred)
587 678
588 # compute the previous noisy sample x_t -> x_t-1 679 # compute the previous noisy sample x_t -> x_t-1
589 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 680 latents = self.scheduler.step(
681 noise_pred, t, latents, **extra_step_kwargs, return_dict=False
682 )[0]
590 683
591 # call the callback, if provided 684 # call the callback, if provided
592 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 685 if i == len(timesteps) - 1 or (
686 (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
687 ):
593 progress_bar.update() 688 progress_bar.update()
594 if callback is not None and i % callback_steps == 0: 689 if callback is not None and i % callback_steps == 0:
595 callback(i, t, latents) 690 callback(i, t, latents)
@@ -615,7 +710,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
615 if not return_dict: 710 if not return_dict:
616 return (image, has_nsfw_concept) 711 return (image, has_nsfw_concept)
617 712
618 return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 713 return StableDiffusionPipelineOutput(
714 images=image, nsfw_content_detected=has_nsfw_concept
715 )
619 716
620 # Self-Attention-Guided (SAG) Stable Diffusion 717 # Self-Attention-Guided (SAG) Stable Diffusion
621 718
@@ -632,16 +729,23 @@ class VlpnStableDiffusion(DiffusionPipeline):
632 attn_map = attn_map.reshape(b, h, hw1, hw2) 729 attn_map = attn_map.reshape(b, h, hw1, hw2)
633 attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 730 attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
634 attn_mask = ( 731 attn_mask = (
635 attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) 732 attn_mask.reshape(b, map_size, map_size)
733 .unsqueeze(1)
734 .repeat(1, latent_channel, 1, 1)
735 .type(attn_map.dtype)
636 ) 736 )
637 attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) 737 attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w))
638 738
639 # Blur according to the self-attention mask 739 # Blur according to the self-attention mask
640 degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) 740 degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0)
641 degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) 741 degraded_latents = degraded_latents * attn_mask + original_latents * (
742 1 - attn_mask
743 )
642 744
643 # Noise it again to match the noise level 745 # Noise it again to match the noise level
644 degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) 746 degraded_latents = self.scheduler.add_noise(
747 degraded_latents, noise=eps, timesteps=t
748 )
645 749
646 return degraded_latents 750 return degraded_latents
647 751
@@ -652,13 +756,19 @@ class VlpnStableDiffusion(DiffusionPipeline):
652 756
653 beta_prod_t = 1 - alpha_prod_t 757 beta_prod_t = 1 - alpha_prod_t
654 if self.scheduler.config.prediction_type == "epsilon": 758 if self.scheduler.config.prediction_type == "epsilon":
655 pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 759 pred_original_sample = (
760 sample - beta_prod_t ** (0.5) * model_output
761 ) / alpha_prod_t ** (0.5)
656 elif self.scheduler.config.prediction_type == "sample": 762 elif self.scheduler.config.prediction_type == "sample":
657 pred_original_sample = model_output 763 pred_original_sample = model_output
658 elif self.scheduler.config.prediction_type == "v_prediction": 764 elif self.scheduler.config.prediction_type == "v_prediction":
659 pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 765 pred_original_sample = (alpha_prod_t**0.5) * sample - (
766 beta_prod_t**0.5
767 ) * model_output
660 # predict V 768 # predict V
661 model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 769 model_output = (alpha_prod_t**0.5) * model_output + (
770 beta_prod_t**0.5
771 ) * sample
662 else: 772 else:
663 raise ValueError( 773 raise ValueError(
664 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," 774 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
@@ -674,9 +784,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
674 if self.scheduler.config.prediction_type == "epsilon": 784 if self.scheduler.config.prediction_type == "epsilon":
675 pred_eps = model_output 785 pred_eps = model_output
676 elif self.scheduler.config.prediction_type == "sample": 786 elif self.scheduler.config.prediction_type == "sample":
677 pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) 787 pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (
788 beta_prod_t**0.5
789 )
678 elif self.scheduler.config.prediction_type == "v_prediction": 790 elif self.scheduler.config.prediction_type == "v_prediction":
679 pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output 791 pred_eps = (beta_prod_t**0.5) * sample + (
792 alpha_prod_t**0.5
793 ) * model_output
680 else: 794 else:
681 raise ValueError( 795 raise ValueError(
682 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," 796 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2aca1e7..659b84c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -5,34 +5,70 @@ import itertools
5from pathlib import Path 5from pathlib import Path
6from functools import partial 6from functools import partial
7import math 7import math
8import warnings
8 9
9import torch 10import torch
11import torch._dynamo
10import torch.utils.checkpoint 12import torch.utils.checkpoint
13import hidet
11 14
12from accelerate import Accelerator 15from accelerate import Accelerator
13from accelerate.logging import get_logger 16from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
15from slugify import slugify 18
19# from diffusers.models.attention_processor import AttnProcessor
20from diffusers.utils.import_utils import is_xformers_available
16import transformers 21import transformers
17 22
18from util.files import load_config, load_embeddings_from_dir 23import numpy as np
24from slugify import slugify
25
19from data.csv import VlpnDataModule, keyword_filter 26from data.csv import VlpnDataModule, keyword_filter
20from training.functional import train, get_models 27from models.clip.embeddings import patch_managed_embeddings
28from training.functional import train, add_placeholder_tokens, get_models
21from training.strategy.dreambooth import dreambooth_strategy 29from training.strategy.dreambooth import dreambooth_strategy
22from training.optimization import get_scheduler 30from training.optimization import get_scheduler
23from training.util import save_args 31from training.sampler import create_named_schedule_sampler
32from training.util import AverageMeter, save_args
33from util.files import load_config, load_embeddings_from_dir
34
24 35
25logger = get_logger(__name__) 36logger = get_logger(__name__)
26 37
38warnings.filterwarnings("ignore")
39
27 40
28torch.backends.cuda.matmul.allow_tf32 = True 41torch.backends.cuda.matmul.allow_tf32 = True
29torch.backends.cudnn.benchmark = True 42torch.backends.cudnn.benchmark = True
30 43
44# torch._dynamo.config.log_level = logging.WARNING
45torch._dynamo.config.suppress_errors = True
46
47hidet.torch.dynamo_config.use_tensor_core(True)
48hidet.torch.dynamo_config.search_space(0)
49
50
51def patch_xformers(dtype):
52 if is_xformers_available():
53 import xformers
54 import xformers.ops
55
56 orig_xformers_memory_efficient_attention = (
57 xformers.ops.memory_efficient_attention
58 )
59
60 def xformers_memory_efficient_attention(
61 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs
62 ):
63 return orig_xformers_memory_efficient_attention(
64 query.to(dtype), key.to(dtype), value.to(dtype), **kwargs
65 )
66
67 xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention
68
31 69
32def parse_args(): 70def parse_args():
33 parser = argparse.ArgumentParser( 71 parser = argparse.ArgumentParser(description="Simple example of a training script.")
34 description="Simple example of a training script."
35 )
36 parser.add_argument( 72 parser.add_argument(
37 "--pretrained_model_name_or_path", 73 "--pretrained_model_name_or_path",
38 type=str, 74 type=str,
@@ -49,7 +85,7 @@ def parse_args():
49 "--train_data_file", 85 "--train_data_file",
50 type=str, 86 type=str,
51 default=None, 87 default=None,
52 help="A folder containing the training data." 88 help="A folder containing the training data.",
53 ) 89 )
54 parser.add_argument( 90 parser.add_argument(
55 "--train_data_template", 91 "--train_data_template",
@@ -60,13 +96,13 @@ def parse_args():
60 "--train_set_pad", 96 "--train_set_pad",
61 type=int, 97 type=int,
62 default=None, 98 default=None,
63 help="The number to fill train dataset items up to." 99 help="The number to fill train dataset items up to.",
64 ) 100 )
65 parser.add_argument( 101 parser.add_argument(
66 "--valid_set_pad", 102 "--valid_set_pad",
67 type=int, 103 type=int,
68 default=None, 104 default=None,
69 help="The number to fill validation dataset items up to." 105 help="The number to fill validation dataset items up to.",
70 ) 106 )
71 parser.add_argument( 107 parser.add_argument(
72 "--project", 108 "--project",
@@ -75,20 +111,58 @@ def parse_args():
75 help="The name of the current project.", 111 help="The name of the current project.",
76 ) 112 )
77 parser.add_argument( 113 parser.add_argument(
78 "--exclude_collections", 114 "--auto_cycles", type=str, default="o", help="Cycles to run automatically."
115 )
116 parser.add_argument(
117 "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle."
118 )
119 parser.add_argument(
120 "--placeholder_tokens",
79 type=str, 121 type=str,
80 nargs='*', 122 nargs="*",
81 help="Exclude all items with a listed collection.", 123 help="A token to use as a placeholder for the concept.",
82 ) 124 )
83 parser.add_argument( 125 parser.add_argument(
84 "--train_text_encoder_epochs", 126 "--initializer_tokens",
85 default=999999, 127 type=str,
86 help="Number of epochs the text encoder will be trained." 128 nargs="*",
129 help="A token to use as initializer word.",
130 )
131 parser.add_argument(
132 "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by."
133 )
134 parser.add_argument(
135 "--initializer_noise",
136 type=float,
137 default=0,
138 help="Noise to apply to the initializer word",
139 )
140 parser.add_argument(
141 "--alias_tokens",
142 type=str,
143 nargs="*",
144 default=[],
145 help="Tokens to create an alias for.",
146 )
147 parser.add_argument(
148 "--inverted_initializer_tokens",
149 type=str,
150 nargs="*",
151 help="A token to use as initializer word.",
152 )
153 parser.add_argument(
154 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
155 )
156 parser.add_argument(
157 "--exclude_collections",
158 type=str,
159 nargs="*",
160 help="Exclude all items with a listed collection.",
87 ) 161 )
88 parser.add_argument( 162 parser.add_argument(
89 "--num_buckets", 163 "--num_buckets",
90 type=int, 164 type=int,
91 default=0, 165 default=2,
92 help="Number of aspect ratio buckets in either direction.", 166 help="Number of aspect ratio buckets in either direction.",
93 ) 167 )
94 parser.add_argument( 168 parser.add_argument(
@@ -120,19 +194,6 @@ def parse_args():
120 help="Shuffle tags.", 194 help="Shuffle tags.",
121 ) 195 )
122 parser.add_argument( 196 parser.add_argument(
123 "--vector_dropout",
124 type=int,
125 default=0,
126 help="Vector dropout probability.",
127 )
128 parser.add_argument(
129 "--vector_shuffle",
130 type=str,
131 default="auto",
132 choices=["all", "trailing", "leading", "between", "auto", "off"],
133 help='Vector shuffling algorithm.',
134 )
135 parser.add_argument(
136 "--guidance_scale", 197 "--guidance_scale",
137 type=float, 198 type=float,
138 default=0, 199 default=0,
@@ -141,7 +202,7 @@ def parse_args():
141 "--num_class_images", 202 "--num_class_images",
142 type=int, 203 type=int,
143 default=0, 204 default=0,
144 help="How many class images to generate." 205 help="How many class images to generate.",
145 ) 206 )
146 parser.add_argument( 207 parser.add_argument(
147 "--class_image_dir", 208 "--class_image_dir",
@@ -162,16 +223,18 @@ def parse_args():
162 help="The embeddings directory where Textual Inversion embeddings are stored.", 223 help="The embeddings directory where Textual Inversion embeddings are stored.",
163 ) 224 )
164 parser.add_argument( 225 parser.add_argument(
226 "--train_dir_embeddings",
227 action="store_true",
228 help="Train embeddings loaded from embeddings directory.",
229 )
230 parser.add_argument(
165 "--collection", 231 "--collection",
166 type=str, 232 type=str,
167 nargs='*', 233 nargs="*",
168 help="A collection to filter the dataset.", 234 help="A collection to filter the dataset.",
169 ) 235 )
170 parser.add_argument( 236 parser.add_argument(
171 "--seed", 237 "--seed", type=int, default=None, help="A seed for reproducible training."
172 type=int,
173 default=None,
174 help="A seed for reproducible training."
175 ) 238 )
176 parser.add_argument( 239 parser.add_argument(
177 "--resolution", 240 "--resolution",
@@ -189,15 +252,13 @@ def parse_args():
189 help="Perlin offset noise strength.", 252 help="Perlin offset noise strength.",
190 ) 253 )
191 parser.add_argument( 254 parser.add_argument(
192 "--num_train_epochs", 255 "--input_pertubation",
193 type=int, 256 type=float,
194 default=None 257 default=0,
195 ) 258 help="The scale of input pretubation. Recommended 0.1.",
196 parser.add_argument(
197 "--num_train_steps",
198 type=int,
199 default=2000
200 ) 259 )
260 parser.add_argument("--num_train_epochs", type=int, default=None)
261 parser.add_argument("--num_train_steps", type=int, default=2000)
201 parser.add_argument( 262 parser.add_argument(
202 "--gradient_accumulation_steps", 263 "--gradient_accumulation_steps",
203 type=int, 264 type=int,
@@ -205,9 +266,9 @@ def parse_args():
205 help="Number of updates steps to accumulate before performing a backward/update pass.", 266 help="Number of updates steps to accumulate before performing a backward/update pass.",
206 ) 267 )
207 parser.add_argument( 268 parser.add_argument(
208 "--gradient_checkpointing", 269 "--train_text_encoder_cycles",
209 action="store_true", 270 default=999999,
210 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 271 help="Number of epochs the text encoder will be trained.",
211 ) 272 )
212 parser.add_argument( 273 parser.add_argument(
213 "--find_lr", 274 "--find_lr",
@@ -215,9 +276,15 @@ def parse_args():
215 help="Automatically find a learning rate (no training).", 276 help="Automatically find a learning rate (no training).",
216 ) 277 )
217 parser.add_argument( 278 parser.add_argument(
218 "--learning_rate", 279 "--learning_rate_unet",
280 type=float,
281 default=1e-4,
282 help="Initial learning rate (after the potential warmup period) to use.",
283 )
284 parser.add_argument(
285 "--learning_rate_text",
219 type=float, 286 type=float,
220 default=2e-6, 287 default=5e-5,
221 help="Initial learning rate (after the potential warmup period) to use.", 288 help="Initial learning rate (after the potential warmup period) to use.",
222 ) 289 )
223 parser.add_argument( 290 parser.add_argument(
@@ -229,27 +296,31 @@ def parse_args():
229 "--lr_scheduler", 296 "--lr_scheduler",
230 type=str, 297 type=str,
231 default="one_cycle", 298 default="one_cycle",
232 choices=["linear", "cosine", "cosine_with_restarts", "polynomial", 299 choices=[
233 "constant", "constant_with_warmup", "one_cycle"], 300 "linear",
234 help='The scheduler type to use.', 301 "cosine",
302 "cosine_with_restarts",
303 "polynomial",
304 "constant",
305 "constant_with_warmup",
306 "one_cycle",
307 ],
308 help="The scheduler type to use.",
235 ) 309 )
236 parser.add_argument( 310 parser.add_argument(
237 "--lr_warmup_epochs", 311 "--lr_warmup_epochs",
238 type=int, 312 type=int,
239 default=10, 313 default=10,
240 help="Number of steps for the warmup in the lr scheduler." 314 help="Number of steps for the warmup in the lr scheduler.",
241 ) 315 )
242 parser.add_argument( 316 parser.add_argument(
243 "--lr_mid_point", 317 "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point."
244 type=float,
245 default=0.3,
246 help="OneCycle schedule mid point."
247 ) 318 )
248 parser.add_argument( 319 parser.add_argument(
249 "--lr_cycles", 320 "--lr_cycles",
250 type=int, 321 type=int,
251 default=None, 322 default=None,
252 help="Number of restart cycles in the lr scheduler (if supported)." 323 help="Number of restart cycles in the lr scheduler (if supported).",
253 ) 324 )
254 parser.add_argument( 325 parser.add_argument(
255 "--lr_warmup_func", 326 "--lr_warmup_func",
@@ -261,7 +332,7 @@ def parse_args():
261 "--lr_warmup_exp", 332 "--lr_warmup_exp",
262 type=int, 333 type=int,
263 default=1, 334 default=1,
264 help='If lr_warmup_func is "cos", exponent to modify the function' 335 help='If lr_warmup_func is "cos", exponent to modify the function',
265 ) 336 )
266 parser.add_argument( 337 parser.add_argument(
267 "--lr_annealing_func", 338 "--lr_annealing_func",
@@ -273,76 +344,76 @@ def parse_args():
273 "--lr_annealing_exp", 344 "--lr_annealing_exp",
274 type=int, 345 type=int,
275 default=3, 346 default=3,
276 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 347 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function',
277 ) 348 )
278 parser.add_argument( 349 parser.add_argument(
279 "--lr_min_lr", 350 "--lr_min_lr",
280 type=float, 351 type=float,
281 default=0.04, 352 default=0.04,
282 help="Minimum learning rate in the lr scheduler." 353 help="Minimum learning rate in the lr scheduler.",
283 )
284 parser.add_argument(
285 "--use_ema",
286 action="store_true",
287 help="Whether to use EMA model."
288 )
289 parser.add_argument(
290 "--ema_inv_gamma",
291 type=float,
292 default=1.0
293 ) 354 )
355 parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.")
294 parser.add_argument( 356 parser.add_argument(
295 "--ema_power", 357 "--schedule_sampler",
296 type=float, 358 type=str,
297 default=6/7 359 default="uniform",
298 ) 360 choices=["uniform", "loss-second-moment"],
299 parser.add_argument( 361 help="Noise schedule sampler.",
300 "--ema_max_decay",
301 type=float,
302 default=0.9999
303 ) 362 )
304 parser.add_argument( 363 parser.add_argument(
305 "--optimizer", 364 "--optimizer",
306 type=str, 365 type=str,
307 default="dadan", 366 default="adan",
308 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 367 choices=[
309 help='Optimizer to use' 368 "adam",
369 "adam8bit",
370 "adan",
371 "lion",
372 "dadam",
373 "dadan",
374 "dlion",
375 "adafactor",
376 ],
377 help="Optimizer to use",
310 ) 378 )
311 parser.add_argument( 379 parser.add_argument(
312 "--dadaptation_d0", 380 "--dadaptation_d0",
313 type=float, 381 type=float,
314 default=1e-6, 382 default=1e-6,
315 help="The d0 parameter for Dadaptation optimizers." 383 help="The d0 parameter for Dadaptation optimizers.",
384 )
385 parser.add_argument(
386 "--dadaptation_growth_rate",
387 type=float,
388 default=math.inf,
389 help="The growth_rate parameter for Dadaptation optimizers.",
316 ) 390 )
317 parser.add_argument( 391 parser.add_argument(
318 "--adam_beta1", 392 "--adam_beta1",
319 type=float, 393 type=float,
320 default=None, 394 default=None,
321 help="The beta1 parameter for the Adam optimizer." 395 help="The beta1 parameter for the Adam optimizer.",
322 ) 396 )
323 parser.add_argument( 397 parser.add_argument(
324 "--adam_beta2", 398 "--adam_beta2",
325 type=float, 399 type=float,
326 default=None, 400 default=None,
327 help="The beta2 parameter for the Adam optimizer." 401 help="The beta2 parameter for the Adam optimizer.",
328 ) 402 )
329 parser.add_argument( 403 parser.add_argument(
330 "--adam_weight_decay", 404 "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use."
331 type=float,
332 default=1e-2,
333 help="Weight decay to use."
334 ) 405 )
335 parser.add_argument( 406 parser.add_argument(
336 "--adam_epsilon", 407 "--adam_epsilon",
337 type=float, 408 type=float,
338 default=1e-08, 409 default=1e-08,
339 help="Epsilon value for the Adam optimizer" 410 help="Epsilon value for the Adam optimizer",
340 ) 411 )
341 parser.add_argument( 412 parser.add_argument(
342 "--adam_amsgrad", 413 "--adam_amsgrad",
343 type=bool, 414 type=bool,
344 default=False, 415 default=False,
345 help="Amsgrad value for the Adam optimizer" 416 help="Amsgrad value for the Adam optimizer",
346 ) 417 )
347 parser.add_argument( 418 parser.add_argument(
348 "--mixed_precision", 419 "--mixed_precision",
@@ -356,12 +427,28 @@ def parse_args():
356 ), 427 ),
357 ) 428 )
358 parser.add_argument( 429 parser.add_argument(
430 "--compile_unet",
431 action="store_true",
432 help="Compile UNet with Torch Dynamo.",
433 )
434 parser.add_argument(
435 "--use_xformers",
436 action="store_true",
437 help="Use xformers.",
438 )
439 parser.add_argument(
359 "--sample_frequency", 440 "--sample_frequency",
360 type=int, 441 type=int,
361 default=1, 442 default=1,
362 help="How often to save a checkpoint and sample image", 443 help="How often to save a checkpoint and sample image",
363 ) 444 )
364 parser.add_argument( 445 parser.add_argument(
446 "--sample_num",
447 type=int,
448 default=None,
449 help="How often to save a checkpoint and sample image (in number of samples)",
450 )
451 parser.add_argument(
365 "--sample_image_size", 452 "--sample_image_size",
366 type=int, 453 type=int,
367 default=768, 454 default=768,
@@ -383,19 +470,19 @@ def parse_args():
383 "--valid_set_size", 470 "--valid_set_size",
384 type=int, 471 type=int,
385 default=None, 472 default=None,
386 help="Number of images in the validation dataset." 473 help="Number of images in the validation dataset.",
387 ) 474 )
388 parser.add_argument( 475 parser.add_argument(
389 "--valid_set_repeat", 476 "--valid_set_repeat",
390 type=int, 477 type=int,
391 default=1, 478 default=1,
392 help="Times the images in the validation dataset are repeated." 479 help="Times the images in the validation dataset are repeated.",
393 ) 480 )
394 parser.add_argument( 481 parser.add_argument(
395 "--train_batch_size", 482 "--train_batch_size",
396 type=int, 483 type=int,
397 default=1, 484 default=1,
398 help="Batch size (per device) for the training dataloader." 485 help="Batch size (per device) for the training dataloader.",
399 ) 486 )
400 parser.add_argument( 487 parser.add_argument(
401 "--sample_steps", 488 "--sample_steps",
@@ -407,13 +494,18 @@ def parse_args():
407 "--prior_loss_weight", 494 "--prior_loss_weight",
408 type=float, 495 type=float,
409 default=1.0, 496 default=1.0,
410 help="The weight of prior preservation loss." 497 help="The weight of prior preservation loss.",
411 ) 498 )
499 parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.")
500 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha")
412 parser.add_argument( 501 parser.add_argument(
413 "--max_grad_norm", 502 "--emb_dropout",
414 default=1.0,
415 type=float, 503 type=float,
416 help="Max gradient norm." 504 default=0,
505 help="Embedding dropout probability.",
506 )
507 parser.add_argument(
508 "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
417 ) 509 )
418 parser.add_argument( 510 parser.add_argument(
419 "--noise_timesteps", 511 "--noise_timesteps",
@@ -424,7 +516,7 @@ def parse_args():
424 "--config", 516 "--config",
425 type=str, 517 type=str,
426 default=None, 518 default=None,
427 help="Path to a JSON configuration file containing arguments for invoking this script." 519 help="Path to a JSON configuration file containing arguments for invoking this script.",
428 ) 520 )
429 521
430 args = parser.parse_args() 522 args = parser.parse_args()
@@ -441,6 +533,67 @@ def parse_args():
441 if args.project is None: 533 if args.project is None:
442 raise ValueError("You must specify --project") 534 raise ValueError("You must specify --project")
443 535
536 if args.initializer_tokens is None:
537 args.initializer_tokens = []
538
539 if args.placeholder_tokens is None:
540 args.placeholder_tokens = []
541
542 if isinstance(args.placeholder_tokens, str):
543 args.placeholder_tokens = [args.placeholder_tokens]
544
545 if isinstance(args.initializer_tokens, str):
546 args.initializer_tokens = [args.initializer_tokens] * len(
547 args.placeholder_tokens
548 )
549
550 if len(args.placeholder_tokens) == 0:
551 args.placeholder_tokens = [
552 f"<*{i}>" for i in range(len(args.initializer_tokens))
553 ]
554
555 if len(args.initializer_tokens) == 0:
556 args.initializer_tokens = args.placeholder_tokens.copy()
557
558 if len(args.placeholder_tokens) != len(args.initializer_tokens):
559 raise ValueError(
560 "--placeholder_tokens and --initializer_tokens must have the same number of items"
561 )
562
563 if isinstance(args.inverted_initializer_tokens, str):
564 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
565 args.placeholder_tokens
566 )
567
568 if (
569 isinstance(args.inverted_initializer_tokens, list)
570 and len(args.inverted_initializer_tokens) != 0
571 ):
572 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
573 args.initializer_tokens += args.inverted_initializer_tokens
574
575 if isinstance(args.num_vectors, int):
576 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
577
578 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(
579 args.num_vectors
580 ):
581 raise ValueError(
582 "--placeholder_tokens and --num_vectors must have the same number of items"
583 )
584
585 if args.alias_tokens is None:
586 args.alias_tokens = []
587
588 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
589 raise ValueError("--alias_tokens must be a list with an even number of items")
590
591 if args.filter_tokens is None:
592 args.filter_tokens = args.placeholder_tokens.copy()
593
594 if isinstance(args.filter_tokens, str):
595 args.filter_tokens = [args.filter_tokens]
596
444 if isinstance(args.collection, str): 597 if isinstance(args.collection, str):
445 args.collection = [args.collection] 598 args.collection = [args.collection]
446 599
@@ -451,15 +604,15 @@ def parse_args():
451 raise ValueError("You must specify --output_dir") 604 raise ValueError("You must specify --output_dir")
452 605
453 if args.adam_beta1 is None: 606 if args.adam_beta1 is None:
454 if args.optimizer in ('adam', 'adam8bit'): 607 if args.optimizer in ("adam", "adam8bit", "dadam"):
455 args.adam_beta1 = 0.9 608 args.adam_beta1 = 0.9
456 elif args.optimizer == 'lion': 609 elif args.optimizer in ("lion", "dlion"):
457 args.adam_beta1 = 0.95 610 args.adam_beta1 = 0.95
458 611
459 if args.adam_beta2 is None: 612 if args.adam_beta2 is None:
460 if args.optimizer in ('adam', 'adam8bit'): 613 if args.optimizer in ("adam", "adam8bit", "dadam"):
461 args.adam_beta2 = 0.999 614 args.adam_beta2 = 0.999
462 elif args.optimizer == 'lion': 615 elif args.optimizer in ("lion", "dlion"):
463 args.adam_beta2 = 0.98 616 args.adam_beta2 = 0.98
464 617
465 return args 618 return args
@@ -475,7 +628,7 @@ def main():
475 accelerator = Accelerator( 628 accelerator = Accelerator(
476 log_with=LoggerType.TENSORBOARD, 629 log_with=LoggerType.TENSORBOARD,
477 project_dir=f"{output_dir}", 630 project_dir=f"{output_dir}",
478 mixed_precision=args.mixed_precision 631 mixed_precision=args.mixed_precision,
479 ) 632 )
480 633
481 weight_dtype = torch.float32 634 weight_dtype = torch.float32
@@ -484,6 +637,8 @@ def main():
484 elif args.mixed_precision == "bf16": 637 elif args.mixed_precision == "bf16":
485 weight_dtype = torch.bfloat16 638 weight_dtype = torch.bfloat16
486 639
640 patch_xformers(weight_dtype)
641
487 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) 642 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
488 643
489 if args.seed is None: 644 if args.seed is None:
@@ -493,44 +648,125 @@ def main():
493 648
494 save_args(output_dir, args) 649 save_args(output_dir, args)
495 650
496 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 651 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(
497 args.pretrained_model_name_or_path) 652 args.pretrained_model_name_or_path
498 653 )
499 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 654 embeddings = patch_managed_embeddings(
500 tokenizer.set_dropout(args.vector_dropout) 655 text_encoder, args.emb_alpha, args.emb_dropout
656 )
657 schedule_sampler = create_named_schedule_sampler(
658 args.schedule_sampler, noise_scheduler.config.num_train_timesteps
659 )
501 660
502 vae.enable_slicing() 661 vae.enable_slicing()
503 vae.set_use_memory_efficient_attention_xformers(True) 662
504 unet.enable_xformers_memory_efficient_attention() 663 if args.use_xformers:
664 vae.set_use_memory_efficient_attention_xformers(True)
665 unet.enable_xformers_memory_efficient_attention()
666 # elif args.compile_unet:
667 # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False
668 #
669 # proc = AttnProcessor()
670 #
671 # def fn_recursive_set_proc(module: torch.nn.Module):
672 # if hasattr(module, "processor"):
673 # module.processor = proc
674 #
675 # for child in module.children():
676 # fn_recursive_set_proc(child)
677 #
678 # fn_recursive_set_proc(unet)
505 679
506 if args.gradient_checkpointing: 680 if args.gradient_checkpointing:
507 unet.enable_gradient_checkpointing() 681 unet.enable_gradient_checkpointing()
508 text_encoder.gradient_checkpointing_enable() 682
683 if len(args.alias_tokens) != 0:
684 alias_placeholder_tokens = args.alias_tokens[::2]
685 alias_initializer_tokens = args.alias_tokens[1::2]
686
687 added_tokens, added_ids = add_placeholder_tokens(
688 tokenizer=tokenizer,
689 embeddings=embeddings,
690 placeholder_tokens=alias_placeholder_tokens,
691 initializer_tokens=alias_initializer_tokens,
692 )
693 embeddings.persist()
694 print(
695 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
696 )
697
698 placeholder_tokens = []
699 placeholder_token_ids = []
509 700
510 if args.embeddings_dir is not None: 701 if args.embeddings_dir is not None:
511 embeddings_dir = Path(args.embeddings_dir) 702 embeddings_dir = Path(args.embeddings_dir)
512 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 703 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
513 raise ValueError("--embeddings_dir must point to an existing directory") 704 raise ValueError("--embeddings_dir must point to an existing directory")
514 705
515 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 706 added_tokens, added_ids = load_embeddings_from_dir(
516 embeddings.persist() 707 tokenizer, embeddings, embeddings_dir
517 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 708 )
709
710 placeholder_tokens = added_tokens
711 placeholder_token_ids = added_ids
712
713 print(
714 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
715 )
716
717 if args.train_dir_embeddings:
718 print("Training embeddings from embeddings dir")
719 else:
720 embeddings.persist()
721
722 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings:
723 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
724 tokenizer=tokenizer,
725 embeddings=embeddings,
726 placeholder_tokens=args.placeholder_tokens,
727 initializer_tokens=args.initializer_tokens,
728 num_vectors=args.num_vectors,
729 initializer_noise=args.initializer_noise,
730 )
731
732 placeholder_tokens = args.placeholder_tokens
733
734 stats = list(
735 zip(
736 placeholder_tokens,
737 placeholder_token_ids,
738 args.initializer_tokens,
739 initializer_token_ids,
740 )
741 )
742 print(f"Training embeddings: {stats}")
518 743
519 if args.scale_lr: 744 if args.scale_lr:
520 args.learning_rate = ( 745 args.learning_rate_unet = (
521 args.learning_rate * args.gradient_accumulation_steps * 746 args.learning_rate_unet
522 args.train_batch_size * accelerator.num_processes 747 * args.gradient_accumulation_steps
748 * args.train_batch_size
749 * accelerator.num_processes
750 )
751 args.learning_rate_text = (
752 args.learning_rate_text
753 * args.gradient_accumulation_steps
754 * args.train_batch_size
755 * accelerator.num_processes
523 ) 756 )
524 757
525 if args.find_lr: 758 if args.find_lr:
526 args.learning_rate = 1e-6 759 args.learning_rate_unet = 1e-6
760 args.learning_rate_text = 1e-6
527 args.lr_scheduler = "exponential_growth" 761 args.lr_scheduler = "exponential_growth"
528 762
529 if args.optimizer == 'adam8bit': 763 if args.optimizer == "adam8bit":
530 try: 764 try:
531 import bitsandbytes as bnb 765 import bitsandbytes as bnb
532 except ImportError: 766 except ImportError:
533 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 767 raise ImportError(
768 "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
769 )
534 770
535 create_optimizer = partial( 771 create_optimizer = partial(
536 bnb.optim.AdamW8bit, 772 bnb.optim.AdamW8bit,
@@ -539,7 +775,7 @@ def main():
539 eps=args.adam_epsilon, 775 eps=args.adam_epsilon,
540 amsgrad=args.adam_amsgrad, 776 amsgrad=args.adam_amsgrad,
541 ) 777 )
542 elif args.optimizer == 'adam': 778 elif args.optimizer == "adam":
543 create_optimizer = partial( 779 create_optimizer = partial(
544 torch.optim.AdamW, 780 torch.optim.AdamW,
545 betas=(args.adam_beta1, args.adam_beta2), 781 betas=(args.adam_beta1, args.adam_beta2),
@@ -547,22 +783,27 @@ def main():
547 eps=args.adam_epsilon, 783 eps=args.adam_epsilon,
548 amsgrad=args.adam_amsgrad, 784 amsgrad=args.adam_amsgrad,
549 ) 785 )
550 elif args.optimizer == 'adan': 786 elif args.optimizer == "adan":
551 try: 787 try:
552 import timm.optim 788 import timm.optim
553 except ImportError: 789 except ImportError:
554 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") 790 raise ImportError(
791 "To use Adan, please install the PyTorch Image Models library: `pip install timm`."
792 )
555 793
556 create_optimizer = partial( 794 create_optimizer = partial(
557 timm.optim.Adan, 795 timm.optim.Adan,
558 weight_decay=args.adam_weight_decay, 796 weight_decay=args.adam_weight_decay,
559 eps=args.adam_epsilon, 797 eps=args.adam_epsilon,
798 no_prox=True,
560 ) 799 )
561 elif args.optimizer == 'lion': 800 elif args.optimizer == "lion":
562 try: 801 try:
563 import lion_pytorch 802 import lion_pytorch
564 except ImportError: 803 except ImportError:
565 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") 804 raise ImportError(
805 "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`."
806 )
566 807
567 create_optimizer = partial( 808 create_optimizer = partial(
568 lion_pytorch.Lion, 809 lion_pytorch.Lion,
@@ -570,7 +811,7 @@ def main():
570 weight_decay=args.adam_weight_decay, 811 weight_decay=args.adam_weight_decay,
571 use_triton=True, 812 use_triton=True,
572 ) 813 )
573 elif args.optimizer == 'adafactor': 814 elif args.optimizer == "adafactor":
574 create_optimizer = partial( 815 create_optimizer = partial(
575 transformers.optimization.Adafactor, 816 transformers.optimization.Adafactor,
576 weight_decay=args.adam_weight_decay, 817 weight_decay=args.adam_weight_decay,
@@ -580,13 +821,16 @@ def main():
580 ) 821 )
581 822
582 args.lr_scheduler = "adafactor" 823 args.lr_scheduler = "adafactor"
583 args.lr_min_lr = args.learning_rate 824 args.lr_min_lr = args.learning_rate_unet
584 args.learning_rate = None 825 args.learning_rate_unet = None
585 elif args.optimizer == 'dadam': 826 args.learning_rate_text = None
827 elif args.optimizer == "dadam":
586 try: 828 try:
587 import dadaptation 829 import dadaptation
588 except ImportError: 830 except ImportError:
589 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") 831 raise ImportError(
832 "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`."
833 )
590 834
591 create_optimizer = partial( 835 create_optimizer = partial(
592 dadaptation.DAdaptAdam, 836 dadaptation.DAdaptAdam,
@@ -595,46 +839,65 @@ def main():
595 eps=args.adam_epsilon, 839 eps=args.adam_epsilon,
596 decouple=True, 840 decouple=True,
597 d0=args.dadaptation_d0, 841 d0=args.dadaptation_d0,
842 growth_rate=args.dadaptation_growth_rate,
598 ) 843 )
599 844
600 args.learning_rate = 1.0 845 args.learning_rate_unet = 1.0
601 elif args.optimizer == 'dadan': 846 args.learning_rate_text = 1.0
847 elif args.optimizer == "dadan":
602 try: 848 try:
603 import dadaptation 849 import dadaptation
604 except ImportError: 850 except ImportError:
605 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") 851 raise ImportError(
852 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
853 )
606 854
607 create_optimizer = partial( 855 create_optimizer = partial(
608 dadaptation.DAdaptAdan, 856 dadaptation.DAdaptAdan,
609 weight_decay=args.adam_weight_decay, 857 weight_decay=args.adam_weight_decay,
610 eps=args.adam_epsilon, 858 eps=args.adam_epsilon,
611 d0=args.dadaptation_d0, 859 d0=args.dadaptation_d0,
860 growth_rate=args.dadaptation_growth_rate,
612 ) 861 )
613 862
614 args.learning_rate = 1.0 863 args.learning_rate_unet = 1.0
864 args.learning_rate_text = 1.0
865 elif args.optimizer == "dlion":
866 raise ImportError("DLion has not been merged into dadaptation yet")
615 else: 867 else:
616 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 868 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
617 869
618 trainer = partial( 870 trainer = partial(
619 train, 871 train,
620 accelerator=accelerator, 872 accelerator=accelerator,
621 unet=unet, 873 unet=unet,
622 text_encoder=text_encoder, 874 text_encoder=text_encoder,
875 tokenizer=tokenizer,
623 vae=vae, 876 vae=vae,
624 noise_scheduler=noise_scheduler, 877 noise_scheduler=noise_scheduler,
878 schedule_sampler=schedule_sampler,
879 min_snr_gamma=args.min_snr_gamma,
625 dtype=weight_dtype, 880 dtype=weight_dtype,
881 seed=args.seed,
882 compile_unet=args.compile_unet,
626 guidance_scale=args.guidance_scale, 883 guidance_scale=args.guidance_scale,
627 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 884 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
628 no_val=args.valid_set_size == 0, 885 sample_scheduler=sample_scheduler,
886 sample_batch_size=args.sample_batch_size,
887 sample_num_batches=args.sample_batches,
888 sample_num_steps=args.sample_steps,
889 sample_image_size=args.sample_image_size,
890 max_grad_norm=args.max_grad_norm,
629 ) 891 )
630 892
631 checkpoint_output_dir = output_dir / "model" 893 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
632 sample_output_dir = output_dir / "samples" 894 data_npgenerator = np.random.default_rng(args.seed)
633 895
634 datamodule = VlpnDataModule( 896 create_datamodule = partial(
897 VlpnDataModule,
635 data_file=args.train_data_file, 898 data_file=args.train_data_file,
636 batch_size=args.train_batch_size,
637 tokenizer=tokenizer, 899 tokenizer=tokenizer,
900 constant_prompt_length=args.compile_unet,
638 class_subdir=args.class_image_dir, 901 class_subdir=args.class_image_dir,
639 with_guidance=args.guidance_scale != 0, 902 with_guidance=args.guidance_scale != 0,
640 num_class_images=args.num_class_images, 903 num_class_images=args.num_class_images,
@@ -643,83 +906,186 @@ def main():
643 progressive_buckets=args.progressive_buckets, 906 progressive_buckets=args.progressive_buckets,
644 bucket_step_size=args.bucket_step_size, 907 bucket_step_size=args.bucket_step_size,
645 bucket_max_pixels=args.bucket_max_pixels, 908 bucket_max_pixels=args.bucket_max_pixels,
646 dropout=args.tag_dropout,
647 shuffle=not args.no_tag_shuffle, 909 shuffle=not args.no_tag_shuffle,
648 template_key=args.train_data_template, 910 template_key=args.train_data_template,
649 valid_set_size=args.valid_set_size,
650 train_set_pad=args.train_set_pad, 911 train_set_pad=args.train_set_pad,
651 valid_set_pad=args.valid_set_pad, 912 valid_set_pad=args.valid_set_pad,
652 seed=args.seed, 913 dtype=weight_dtype,
653 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 914 generator=data_generator,
654 dtype=weight_dtype 915 npgenerator=data_npgenerator,
655 )
656 datamodule.setup()
657
658 num_train_epochs = args.num_train_epochs
659 sample_frequency = args.sample_frequency
660 if num_train_epochs is None:
661 num_train_epochs = math.ceil(
662 args.num_train_steps / len(datamodule.train_dataset)
663 ) * args.gradient_accumulation_steps
664 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
665
666 params_to_optimize = (unet.parameters(), )
667 if args.train_text_encoder_epochs != 0:
668 params_to_optimize += (
669 text_encoder.text_model.encoder.parameters(),
670 text_encoder.text_model.final_layer_norm.parameters(),
671 )
672
673 optimizer = create_optimizer(
674 itertools.chain(*params_to_optimize),
675 lr=args.learning_rate,
676 ) 916 )
677 917
678 lr_scheduler = get_scheduler( 918 create_lr_scheduler = partial(
679 args.lr_scheduler, 919 get_scheduler,
680 optimizer=optimizer,
681 num_training_steps_per_epoch=len(datamodule.train_dataloader),
682 gradient_accumulation_steps=args.gradient_accumulation_steps,
683 min_lr=args.lr_min_lr, 920 min_lr=args.lr_min_lr,
684 warmup_func=args.lr_warmup_func, 921 warmup_func=args.lr_warmup_func,
685 annealing_func=args.lr_annealing_func, 922 annealing_func=args.lr_annealing_func,
686 warmup_exp=args.lr_warmup_exp, 923 warmup_exp=args.lr_warmup_exp,
687 annealing_exp=args.lr_annealing_exp, 924 annealing_exp=args.lr_annealing_exp,
688 cycles=args.lr_cycles,
689 end_lr=1e2, 925 end_lr=1e2,
690 train_epochs=num_train_epochs,
691 warmup_epochs=args.lr_warmup_epochs,
692 mid_point=args.lr_mid_point, 926 mid_point=args.lr_mid_point,
693 ) 927 )
694 928
695 trainer( 929 # Dreambooth
696 strategy=dreambooth_strategy, 930 # --------------------------------------------------------------------------------
697 project="dreambooth", 931
698 train_dataloader=datamodule.train_dataloader, 932 dreambooth_datamodule = create_datamodule(
699 val_dataloader=datamodule.val_dataloader, 933 valid_set_size=args.valid_set_size,
700 seed=args.seed, 934 batch_size=args.train_batch_size,
701 optimizer=optimizer, 935 dropout=args.tag_dropout,
702 lr_scheduler=lr_scheduler, 936 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
703 num_train_epochs=num_train_epochs, 937 )
704 gradient_accumulation_steps=args.gradient_accumulation_steps, 938 dreambooth_datamodule.setup()
705 sample_frequency=sample_frequency, 939
706 offset_noise_strength=args.offset_noise_strength, 940 num_train_epochs = args.num_train_epochs
707 # -- 941 dreambooth_sample_frequency = args.sample_frequency
708 tokenizer=tokenizer, 942 if num_train_epochs is None:
709 sample_scheduler=sample_scheduler, 943 num_train_epochs = (
710 sample_output_dir=sample_output_dir, 944 math.ceil(args.num_train_steps / len(dreambooth_datamodule.train_dataset))
711 checkpoint_output_dir=checkpoint_output_dir, 945 * args.gradient_accumulation_steps
712 train_text_encoder_epochs=args.train_text_encoder_epochs, 946 )
713 max_grad_norm=args.max_grad_norm, 947 dreambooth_sample_frequency = math.ceil(
714 use_ema=args.use_ema, 948 num_train_epochs * (dreambooth_sample_frequency / args.num_train_steps)
715 ema_inv_gamma=args.ema_inv_gamma, 949 )
716 ema_power=args.ema_power, 950 num_training_steps_per_epoch = math.ceil(
717 ema_max_decay=args.ema_max_decay, 951 len(dreambooth_datamodule.train_dataset) / args.gradient_accumulation_steps
718 sample_batch_size=args.sample_batch_size,
719 sample_num_batches=args.sample_batches,
720 sample_num_steps=args.sample_steps,
721 sample_image_size=args.sample_image_size,
722 ) 952 )
953 num_train_steps = num_training_steps_per_epoch * num_train_epochs
954 if args.sample_num is not None:
955 dreambooth_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
956
957 dreambooth_project = "dreambooth"
958
959 if accelerator.is_main_process:
960 accelerator.init_trackers(dreambooth_project)
961
962 dreambooth_sample_output_dir = output_dir / dreambooth_project / "samples"
963
964 training_iter = 0
965 auto_cycles = list(args.auto_cycles)
966 learning_rate_unet = args.learning_rate_unet
967 learning_rate_text = args.learning_rate_text
968 lr_scheduler = args.lr_scheduler
969 lr_warmup_epochs = args.lr_warmup_epochs
970 lr_cycles = args.lr_cycles
971
972 avg_loss = AverageMeter()
973 avg_acc = AverageMeter()
974 avg_loss_val = AverageMeter()
975 avg_acc_val = AverageMeter()
976
977 params_to_optimize = [
978 {
979 "params": (param for param in unet.parameters() if param.requires_grad),
980 "lr": learning_rate_unet,
981 },
982 {
983 "params": (
984 param for param in text_encoder.parameters() if param.requires_grad
985 ),
986 "lr": learning_rate_text,
987 },
988 ]
989 group_labels = ["unet", "text"]
990
991 dreambooth_optimizer = create_optimizer(params_to_optimize)
992
993 while True:
994 if len(auto_cycles) != 0:
995 response = auto_cycles.pop(0)
996 else:
997 response = input(
998 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> "
999 )
1000
1001 if response.lower().strip() == "o":
1002 if args.learning_rate_unet is not None:
1003 learning_rate_unet = (
1004 args.learning_rate_unet * 2 * (args.cycle_decay**training_iter)
1005 )
1006 if args.learning_rate_text is not None:
1007 learning_rate_text = (
1008 args.learning_rate_text * 2 * (args.cycle_decay**training_iter)
1009 )
1010 else:
1011 learning_rate_unet = args.learning_rate_unet * (
1012 args.cycle_decay**training_iter
1013 )
1014 learning_rate_text = args.learning_rate_text * (
1015 args.cycle_decay**training_iter
1016 )
1017
1018 if response.lower().strip() == "o":
1019 lr_scheduler = "one_cycle"
1020 lr_warmup_epochs = args.lr_warmup_epochs
1021 lr_cycles = args.lr_cycles
1022 elif response.lower().strip() == "w":
1023 lr_scheduler = "constant_with_warmup"
1024 lr_warmup_epochs = num_train_epochs
1025 elif response.lower().strip() == "c":
1026 lr_scheduler = "constant"
1027 elif response.lower().strip() == "d":
1028 lr_scheduler = "cosine"
1029 lr_warmup_epochs = 0
1030 lr_cycles = 1
1031 elif response.lower().strip() == "s":
1032 break
1033 else:
1034 continue
1035
1036 print("")
1037 print(
1038 f"============ Dreambooth cycle {training_iter + 1}: {response} ============"
1039 )
1040 print("")
1041
1042 for group, lr in zip(
1043 dreambooth_optimizer.param_groups, [learning_rate_unet, learning_rate_text]
1044 ):
1045 group["lr"] = lr
1046
1047 dreambooth_lr_scheduler = create_lr_scheduler(
1048 lr_scheduler,
1049 gradient_accumulation_steps=args.gradient_accumulation_steps,
1050 optimizer=dreambooth_optimizer,
1051 num_training_steps_per_epoch=len(dreambooth_datamodule.train_dataloader),
1052 train_epochs=num_train_epochs,
1053 cycles=lr_cycles,
1054 warmup_epochs=lr_warmup_epochs,
1055 )
1056
1057 dreambooth_checkpoint_output_dir = (
1058 output_dir / dreambooth_project / f"model_{training_iter}"
1059 )
1060
1061 trainer(
1062 strategy=dreambooth_strategy,
1063 train_dataloader=dreambooth_datamodule.train_dataloader,
1064 val_dataloader=dreambooth_datamodule.val_dataloader,
1065 optimizer=dreambooth_optimizer,
1066 lr_scheduler=dreambooth_lr_scheduler,
1067 num_train_epochs=num_train_epochs,
1068 gradient_accumulation_steps=args.gradient_accumulation_steps,
1069 global_step_offset=training_iter * num_train_steps,
1070 cycle=training_iter,
1071 train_text_encoder_cycles=args.train_text_encoder_cycles,
1072 # --
1073 group_labels=group_labels,
1074 sample_output_dir=dreambooth_sample_output_dir,
1075 checkpoint_output_dir=dreambooth_checkpoint_output_dir,
1076 sample_frequency=dreambooth_sample_frequency,
1077 offset_noise_strength=args.offset_noise_strength,
1078 input_pertubation=args.input_pertubation,
1079 no_val=args.valid_set_size == 0,
1080 avg_loss=avg_loss,
1081 avg_acc=avg_acc,
1082 avg_loss_val=avg_loss_val,
1083 avg_acc_val=avg_acc_val,
1084 )
1085
1086 training_iter += 1
1087
1088 accelerator.end_training()
723 1089
724 1090
725if __name__ == "__main__": 1091if __name__ == "__main__":
diff --git a/train_lora.py b/train_lora.py
index c74dd8f..fccf48d 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -16,6 +16,7 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from peft import LoraConfig, get_peft_model 18from peft import LoraConfig, get_peft_model
19
19# from diffusers.models.attention_processor import AttnProcessor 20# from diffusers.models.attention_processor import AttnProcessor
20from diffusers.utils.import_utils import is_xformers_available 21from diffusers.utils.import_utils import is_xformers_available
21import transformers 22import transformers
@@ -34,15 +35,20 @@ from util.files import load_config, load_embeddings_from_dir
34 35
35# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 36# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
36UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] 37UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"]
37UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"] 38UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0", "to_k", "key"] # []
38TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] 39TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"]
39TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] 40TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + [
40TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] 41 "out_proj",
42 "k_proj",
43] # []
44TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + [
45 "token_embedding"
46]
41 47
42 48
43logger = get_logger(__name__) 49logger = get_logger(__name__)
44 50
45warnings.filterwarnings('ignore') 51warnings.filterwarnings("ignore")
46 52
47 53
48torch.backends.cuda.matmul.allow_tf32 = True 54torch.backends.cuda.matmul.allow_tf32 = True
@@ -55,20 +61,27 @@ hidet.torch.dynamo_config.use_tensor_core(True)
55hidet.torch.dynamo_config.search_space(0) 61hidet.torch.dynamo_config.search_space(0)
56 62
57 63
58if is_xformers_available(): 64def patch_xformers(dtype):
59 import xformers 65 if is_xformers_available():
60 import xformers.ops 66 import xformers
61 67 import xformers.ops
62 orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention 68
63 def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs): 69 orig_xformers_memory_efficient_attention = (
64 return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs) 70 xformers.ops.memory_efficient_attention
65 xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention 71 )
72
73 def xformers_memory_efficient_attention(
74 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs
75 ):
76 return orig_xformers_memory_efficient_attention(
77 query.to(dtype), key.to(dtype), value.to(dtype), **kwargs
78 )
79
80 xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention
66 81
67 82
68def parse_args(): 83def parse_args():
69 parser = argparse.ArgumentParser( 84 parser = argparse.ArgumentParser(description="Simple example of a training script.")
70 description="Simple example of a training script."
71 )
72 parser.add_argument( 85 parser.add_argument(
73 "--pretrained_model_name_or_path", 86 "--pretrained_model_name_or_path",
74 type=str, 87 type=str,
@@ -85,7 +98,7 @@ def parse_args():
85 "--train_data_file", 98 "--train_data_file",
86 type=str, 99 type=str,
87 default=None, 100 default=None,
88 help="A folder containing the training data." 101 help="A folder containing the training data.",
89 ) 102 )
90 parser.add_argument( 103 parser.add_argument(
91 "--train_data_template", 104 "--train_data_template",
@@ -96,13 +109,13 @@ def parse_args():
96 "--train_set_pad", 109 "--train_set_pad",
97 type=int, 110 type=int,
98 default=None, 111 default=None,
99 help="The number to fill train dataset items up to." 112 help="The number to fill train dataset items up to.",
100 ) 113 )
101 parser.add_argument( 114 parser.add_argument(
102 "--valid_set_pad", 115 "--valid_set_pad",
103 type=int, 116 type=int,
104 default=None, 117 default=None,
105 help="The number to fill validation dataset items up to." 118 help="The number to fill validation dataset items up to.",
106 ) 119 )
107 parser.add_argument( 120 parser.add_argument(
108 "--project", 121 "--project",
@@ -111,64 +124,52 @@ def parse_args():
111 help="The name of the current project.", 124 help="The name of the current project.",
112 ) 125 )
113 parser.add_argument( 126 parser.add_argument(
114 "--auto_cycles", 127 "--auto_cycles", type=str, default="o", help="Cycles to run automatically."
115 type=str,
116 default="o",
117 help="Cycles to run automatically."
118 ) 128 )
119 parser.add_argument( 129 parser.add_argument(
120 "--cycle_decay", 130 "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle."
121 type=float,
122 default=1.0,
123 help="Learning rate decay per cycle."
124 ) 131 )
125 parser.add_argument( 132 parser.add_argument(
126 "--placeholder_tokens", 133 "--placeholder_tokens",
127 type=str, 134 type=str,
128 nargs='*', 135 nargs="*",
129 help="A token to use as a placeholder for the concept.", 136 help="A token to use as a placeholder for the concept.",
130 ) 137 )
131 parser.add_argument( 138 parser.add_argument(
132 "--initializer_tokens", 139 "--initializer_tokens",
133 type=str, 140 type=str,
134 nargs='*', 141 nargs="*",
135 help="A token to use as initializer word." 142 help="A token to use as initializer word.",
136 ) 143 )
137 parser.add_argument( 144 parser.add_argument(
138 "--filter_tokens", 145 "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by."
139 type=str,
140 nargs='*',
141 help="Tokens to filter the dataset by."
142 ) 146 )
143 parser.add_argument( 147 parser.add_argument(
144 "--initializer_noise", 148 "--initializer_noise",
145 type=float, 149 type=float,
146 default=0, 150 default=0,
147 help="Noise to apply to the initializer word" 151 help="Noise to apply to the initializer word",
148 ) 152 )
149 parser.add_argument( 153 parser.add_argument(
150 "--alias_tokens", 154 "--alias_tokens",
151 type=str, 155 type=str,
152 nargs='*', 156 nargs="*",
153 default=[], 157 default=[],
154 help="Tokens to create an alias for." 158 help="Tokens to create an alias for.",
155 ) 159 )
156 parser.add_argument( 160 parser.add_argument(
157 "--inverted_initializer_tokens", 161 "--inverted_initializer_tokens",
158 type=str, 162 type=str,
159 nargs='*', 163 nargs="*",
160 help="A token to use as initializer word." 164 help="A token to use as initializer word.",
161 ) 165 )
162 parser.add_argument( 166 parser.add_argument(
163 "--num_vectors", 167 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
164 type=int,
165 nargs='*',
166 help="Number of vectors per embedding."
167 ) 168 )
168 parser.add_argument( 169 parser.add_argument(
169 "--exclude_collections", 170 "--exclude_collections",
170 type=str, 171 type=str,
171 nargs='*', 172 nargs="*",
172 help="Exclude all items with a listed collection.", 173 help="Exclude all items with a listed collection.",
173 ) 174 )
174 parser.add_argument( 175 parser.add_argument(
@@ -214,7 +215,7 @@ def parse_args():
214 "--num_class_images", 215 "--num_class_images",
215 type=int, 216 type=int,
216 default=0, 217 default=0,
217 help="How many class images to generate." 218 help="How many class images to generate.",
218 ) 219 )
219 parser.add_argument( 220 parser.add_argument(
220 "--class_image_dir", 221 "--class_image_dir",
@@ -242,14 +243,11 @@ def parse_args():
242 parser.add_argument( 243 parser.add_argument(
243 "--collection", 244 "--collection",
244 type=str, 245 type=str,
245 nargs='*', 246 nargs="*",
246 help="A collection to filter the dataset.", 247 help="A collection to filter the dataset.",
247 ) 248 )
248 parser.add_argument( 249 parser.add_argument(
249 "--seed", 250 "--seed", type=int, default=None, help="A seed for reproducible training."
250 type=int,
251 default=None,
252 help="A seed for reproducible training."
253 ) 251 )
254 parser.add_argument( 252 parser.add_argument(
255 "--resolution", 253 "--resolution",
@@ -270,18 +268,10 @@ def parse_args():
270 "--input_pertubation", 268 "--input_pertubation",
271 type=float, 269 type=float,
272 default=0, 270 default=0,
273 help="The scale of input pretubation. Recommended 0.1." 271 help="The scale of input pretubation. Recommended 0.1.",
274 )
275 parser.add_argument(
276 "--num_train_epochs",
277 type=int,
278 default=None
279 )
280 parser.add_argument(
281 "--num_train_steps",
282 type=int,
283 default=2000
284 ) 272 )
273 parser.add_argument("--num_train_epochs", type=int, default=None)
274 parser.add_argument("--num_train_steps", type=int, default=2000)
285 parser.add_argument( 275 parser.add_argument(
286 "--gradient_accumulation_steps", 276 "--gradient_accumulation_steps",
287 type=int, 277 type=int,
@@ -289,22 +279,19 @@ def parse_args():
289 help="Number of updates steps to accumulate before performing a backward/update pass.", 279 help="Number of updates steps to accumulate before performing a backward/update pass.",
290 ) 280 )
291 parser.add_argument( 281 parser.add_argument(
292 "--lora_r", 282 "--lora_r", type=int, default=8, help="Lora rank, only used if use_lora is True"
293 type=int,
294 default=8,
295 help="Lora rank, only used if use_lora is True"
296 ) 283 )
297 parser.add_argument( 284 parser.add_argument(
298 "--lora_alpha", 285 "--lora_alpha",
299 type=int, 286 type=int,
300 default=32, 287 default=32,
301 help="Lora alpha, only used if use_lora is True" 288 help="Lora alpha, only used if use_lora is True",
302 ) 289 )
303 parser.add_argument( 290 parser.add_argument(
304 "--lora_dropout", 291 "--lora_dropout",
305 type=float, 292 type=float,
306 default=0.0, 293 default=0.0,
307 help="Lora dropout, only used if use_lora is True" 294 help="Lora dropout, only used if use_lora is True",
308 ) 295 )
309 parser.add_argument( 296 parser.add_argument(
310 "--lora_bias", 297 "--lora_bias",
@@ -344,7 +331,7 @@ def parse_args():
344 parser.add_argument( 331 parser.add_argument(
345 "--train_text_encoder_cycles", 332 "--train_text_encoder_cycles",
346 default=999999, 333 default=999999,
347 help="Number of epochs the text encoder will be trained." 334 help="Number of epochs the text encoder will be trained.",
348 ) 335 )
349 parser.add_argument( 336 parser.add_argument(
350 "--find_lr", 337 "--find_lr",
@@ -378,27 +365,31 @@ def parse_args():
378 "--lr_scheduler", 365 "--lr_scheduler",
379 type=str, 366 type=str,
380 default="one_cycle", 367 default="one_cycle",
381 choices=["linear", "cosine", "cosine_with_restarts", "polynomial", 368 choices=[
382 "constant", "constant_with_warmup", "one_cycle"], 369 "linear",
383 help='The scheduler type to use.', 370 "cosine",
371 "cosine_with_restarts",
372 "polynomial",
373 "constant",
374 "constant_with_warmup",
375 "one_cycle",
376 ],
377 help="The scheduler type to use.",
384 ) 378 )
385 parser.add_argument( 379 parser.add_argument(
386 "--lr_warmup_epochs", 380 "--lr_warmup_epochs",
387 type=int, 381 type=int,
388 default=10, 382 default=10,
389 help="Number of steps for the warmup in the lr scheduler." 383 help="Number of steps for the warmup in the lr scheduler.",
390 ) 384 )
391 parser.add_argument( 385 parser.add_argument(
392 "--lr_mid_point", 386 "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point."
393 type=float,
394 default=0.3,
395 help="OneCycle schedule mid point."
396 ) 387 )
397 parser.add_argument( 388 parser.add_argument(
398 "--lr_cycles", 389 "--lr_cycles",
399 type=int, 390 type=int,
400 default=None, 391 default=None,
401 help="Number of restart cycles in the lr scheduler (if supported)." 392 help="Number of restart cycles in the lr scheduler (if supported).",
402 ) 393 )
403 parser.add_argument( 394 parser.add_argument(
404 "--lr_warmup_func", 395 "--lr_warmup_func",
@@ -410,7 +401,7 @@ def parse_args():
410 "--lr_warmup_exp", 401 "--lr_warmup_exp",
411 type=int, 402 type=int,
412 default=1, 403 default=1,
413 help='If lr_warmup_func is "cos", exponent to modify the function' 404 help='If lr_warmup_func is "cos", exponent to modify the function',
414 ) 405 )
415 parser.add_argument( 406 parser.add_argument(
416 "--lr_annealing_func", 407 "--lr_annealing_func",
@@ -422,69 +413,76 @@ def parse_args():
422 "--lr_annealing_exp", 413 "--lr_annealing_exp",
423 type=int, 414 type=int,
424 default=3, 415 default=3,
425 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 416 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function',
426 ) 417 )
427 parser.add_argument( 418 parser.add_argument(
428 "--lr_min_lr", 419 "--lr_min_lr",
429 type=float, 420 type=float,
430 default=0.04, 421 default=0.04,
431 help="Minimum learning rate in the lr scheduler." 422 help="Minimum learning rate in the lr scheduler.",
432 )
433 parser.add_argument(
434 "--min_snr_gamma",
435 type=int,
436 default=5,
437 help="MinSNR gamma."
438 ) 423 )
424 parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.")
439 parser.add_argument( 425 parser.add_argument(
440 "--schedule_sampler", 426 "--schedule_sampler",
441 type=str, 427 type=str,
442 default="uniform", 428 default="uniform",
443 choices=["uniform", "loss-second-moment"], 429 choices=["uniform", "loss-second-moment"],
444 help="Noise schedule sampler." 430 help="Noise schedule sampler.",
445 ) 431 )
446 parser.add_argument( 432 parser.add_argument(
447 "--optimizer", 433 "--optimizer",
448 type=str, 434 type=str,
449 default="adan", 435 default="adan",
450 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 436 choices=[
451 help='Optimizer to use' 437 "adam",
438 "adam8bit",
439 "adan",
440 "lion",
441 "dadam",
442 "dadan",
443 "dlion",
444 "adafactor",
445 ],
446 help="Optimizer to use",
452 ) 447 )
453 parser.add_argument( 448 parser.add_argument(
454 "--dadaptation_d0", 449 "--dadaptation_d0",
455 type=float, 450 type=float,
456 default=1e-6, 451 default=1e-6,
457 help="The d0 parameter for Dadaptation optimizers." 452 help="The d0 parameter for Dadaptation optimizers.",
453 )
454 parser.add_argument(
455 "--dadaptation_growth_rate",
456 type=float,
457 default=math.inf,
458 help="The growth_rate parameter for Dadaptation optimizers.",
458 ) 459 )
459 parser.add_argument( 460 parser.add_argument(
460 "--adam_beta1", 461 "--adam_beta1",
461 type=float, 462 type=float,
462 default=None, 463 default=None,
463 help="The beta1 parameter for the Adam optimizer." 464 help="The beta1 parameter for the Adam optimizer.",
464 ) 465 )
465 parser.add_argument( 466 parser.add_argument(
466 "--adam_beta2", 467 "--adam_beta2",
467 type=float, 468 type=float,
468 default=None, 469 default=None,
469 help="The beta2 parameter for the Adam optimizer." 470 help="The beta2 parameter for the Adam optimizer.",
470 ) 471 )
471 parser.add_argument( 472 parser.add_argument(
472 "--adam_weight_decay", 473 "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use."
473 type=float,
474 default=2e-2,
475 help="Weight decay to use."
476 ) 474 )
477 parser.add_argument( 475 parser.add_argument(
478 "--adam_epsilon", 476 "--adam_epsilon",
479 type=float, 477 type=float,
480 default=1e-08, 478 default=1e-08,
481 help="Epsilon value for the Adam optimizer" 479 help="Epsilon value for the Adam optimizer",
482 ) 480 )
483 parser.add_argument( 481 parser.add_argument(
484 "--adam_amsgrad", 482 "--adam_amsgrad",
485 type=bool, 483 type=bool,
486 default=False, 484 default=False,
487 help="Amsgrad value for the Adam optimizer" 485 help="Amsgrad value for the Adam optimizer",
488 ) 486 )
489 parser.add_argument( 487 parser.add_argument(
490 "--mixed_precision", 488 "--mixed_precision",
@@ -547,19 +545,19 @@ def parse_args():
547 "--valid_set_size", 545 "--valid_set_size",
548 type=int, 546 type=int,
549 default=None, 547 default=None,
550 help="Number of images in the validation dataset." 548 help="Number of images in the validation dataset.",
551 ) 549 )
552 parser.add_argument( 550 parser.add_argument(
553 "--valid_set_repeat", 551 "--valid_set_repeat",
554 type=int, 552 type=int,
555 default=1, 553 default=1,
556 help="Times the images in the validation dataset are repeated." 554 help="Times the images in the validation dataset are repeated.",
557 ) 555 )
558 parser.add_argument( 556 parser.add_argument(
559 "--train_batch_size", 557 "--train_batch_size",
560 type=int, 558 type=int,
561 default=1, 559 default=1,
562 help="Batch size (per device) for the training dataloader." 560 help="Batch size (per device) for the training dataloader.",
563 ) 561 )
564 parser.add_argument( 562 parser.add_argument(
565 "--sample_steps", 563 "--sample_steps",
@@ -571,19 +569,10 @@ def parse_args():
571 "--prior_loss_weight", 569 "--prior_loss_weight",
572 type=float, 570 type=float,
573 default=1.0, 571 default=1.0,
574 help="The weight of prior preservation loss." 572 help="The weight of prior preservation loss.",
575 )
576 parser.add_argument(
577 "--run_pti",
578 action="store_true",
579 help="Whether to run PTI."
580 )
581 parser.add_argument(
582 "--emb_alpha",
583 type=float,
584 default=1.0,
585 help="Embedding alpha"
586 ) 573 )
574 parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.")
575 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha")
587 parser.add_argument( 576 parser.add_argument(
588 "--emb_dropout", 577 "--emb_dropout",
589 type=float, 578 type=float,
@@ -591,27 +580,16 @@ def parse_args():
591 help="Embedding dropout probability.", 580 help="Embedding dropout probability.",
592 ) 581 )
593 parser.add_argument( 582 parser.add_argument(
594 "--use_emb_decay", 583 "--use_emb_decay", action="store_true", help="Whether to use embedding decay."
595 action="store_true",
596 help="Whether to use embedding decay."
597 ) 584 )
598 parser.add_argument( 585 parser.add_argument(
599 "--emb_decay_target", 586 "--emb_decay_target", default=0.4, type=float, help="Embedding decay target."
600 default=0.4,
601 type=float,
602 help="Embedding decay target."
603 ) 587 )
604 parser.add_argument( 588 parser.add_argument(
605 "--emb_decay", 589 "--emb_decay", default=1e2, type=float, help="Embedding decay factor."
606 default=1e+2,
607 type=float,
608 help="Embedding decay factor."
609 ) 590 )
610 parser.add_argument( 591 parser.add_argument(
611 "--max_grad_norm", 592 "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
612 default=1.0,
613 type=float,
614 help="Max gradient norm."
615 ) 593 )
616 parser.add_argument( 594 parser.add_argument(
617 "--noise_timesteps", 595 "--noise_timesteps",
@@ -622,7 +600,7 @@ def parse_args():
622 "--config", 600 "--config",
623 type=str, 601 type=str,
624 default=None, 602 default=None,
625 help="Path to a JSON configuration file containing arguments for invoking this script." 603 help="Path to a JSON configuration file containing arguments for invoking this script.",
626 ) 604 )
627 605
628 args = parser.parse_args() 606 args = parser.parse_args()
@@ -649,29 +627,44 @@ def parse_args():
649 args.placeholder_tokens = [args.placeholder_tokens] 627 args.placeholder_tokens = [args.placeholder_tokens]
650 628
651 if isinstance(args.initializer_tokens, str): 629 if isinstance(args.initializer_tokens, str):
652 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 630 args.initializer_tokens = [args.initializer_tokens] * len(
631 args.placeholder_tokens
632 )
653 633
654 if len(args.placeholder_tokens) == 0: 634 if len(args.placeholder_tokens) == 0:
655 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 635 args.placeholder_tokens = [
636 f"<*{i}>" for i in range(len(args.initializer_tokens))
637 ]
656 638
657 if len(args.initializer_tokens) == 0: 639 if len(args.initializer_tokens) == 0:
658 args.initializer_tokens = args.placeholder_tokens.copy() 640 args.initializer_tokens = args.placeholder_tokens.copy()
659 641
660 if len(args.placeholder_tokens) != len(args.initializer_tokens): 642 if len(args.placeholder_tokens) != len(args.initializer_tokens):
661 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 643 raise ValueError(
644 "--placeholder_tokens and --initializer_tokens must have the same number of items"
645 )
662 646
663 if isinstance(args.inverted_initializer_tokens, str): 647 if isinstance(args.inverted_initializer_tokens, str):
664 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) 648 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
649 args.placeholder_tokens
650 )
665 651
666 if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: 652 if (
653 isinstance(args.inverted_initializer_tokens, list)
654 and len(args.inverted_initializer_tokens) != 0
655 ):
667 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] 656 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
668 args.initializer_tokens += args.inverted_initializer_tokens 657 args.initializer_tokens += args.inverted_initializer_tokens
669 658
670 if isinstance(args.num_vectors, int): 659 if isinstance(args.num_vectors, int):
671 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 660 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
672 661
673 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): 662 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(
674 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 663 args.num_vectors
664 ):
665 raise ValueError(
666 "--placeholder_tokens and --num_vectors must have the same number of items"
667 )
675 668
676 if args.alias_tokens is None: 669 if args.alias_tokens is None:
677 args.alias_tokens = [] 670 args.alias_tokens = []
@@ -695,15 +688,15 @@ def parse_args():
695 raise ValueError("You must specify --output_dir") 688 raise ValueError("You must specify --output_dir")
696 689
697 if args.adam_beta1 is None: 690 if args.adam_beta1 is None:
698 if args.optimizer in ('adam', 'adam8bit'): 691 if args.optimizer in ("adam", "adam8bit", "dadam"):
699 args.adam_beta1 = 0.9 692 args.adam_beta1 = 0.9
700 elif args.optimizer == 'lion': 693 elif args.optimizer in ("lion", "dlion"):
701 args.adam_beta1 = 0.95 694 args.adam_beta1 = 0.95
702 695
703 if args.adam_beta2 is None: 696 if args.adam_beta2 is None:
704 if args.optimizer in ('adam', 'adam8bit'): 697 if args.optimizer in ("adam", "adam8bit", "dadam"):
705 args.adam_beta2 = 0.999 698 args.adam_beta2 = 0.999
706 elif args.optimizer == 'lion': 699 elif args.optimizer in ("lion", "dlion"):
707 args.adam_beta2 = 0.98 700 args.adam_beta2 = 0.98
708 701
709 return args 702 return args
@@ -719,7 +712,7 @@ def main():
719 accelerator = Accelerator( 712 accelerator = Accelerator(
720 log_with=LoggerType.TENSORBOARD, 713 log_with=LoggerType.TENSORBOARD,
721 project_dir=f"{output_dir}", 714 project_dir=f"{output_dir}",
722 mixed_precision=args.mixed_precision 715 mixed_precision=args.mixed_precision,
723 ) 716 )
724 717
725 weight_dtype = torch.float32 718 weight_dtype = torch.float32
@@ -728,6 +721,8 @@ def main():
728 elif args.mixed_precision == "bf16": 721 elif args.mixed_precision == "bf16":
729 weight_dtype = torch.bfloat16 722 weight_dtype = torch.bfloat16
730 723
724 patch_xformers(weight_dtype)
725
731 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) 726 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
732 727
733 if args.seed is None: 728 if args.seed is None:
@@ -737,12 +732,18 @@ def main():
737 732
738 save_args(output_dir, args) 733 save_args(output_dir, args)
739 734
740 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) 735 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(
741 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) 736 args.pretrained_model_name_or_path
742 737 )
738 schedule_sampler = create_named_schedule_sampler(
739 args.schedule_sampler, noise_scheduler.config.num_train_timesteps
740 )
741
743 def ensure_embeddings(): 742 def ensure_embeddings():
744 if args.lora_text_encoder_emb: 743 if args.lora_text_encoder_emb:
745 raise ValueError("Can't use TI options when training token embeddings with LoRA") 744 raise ValueError(
745 "Can't use TI options when training token embeddings with LoRA"
746 )
746 return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) 747 return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout)
747 748
748 unet_config = LoraConfig( 749 unet_config = LoraConfig(
@@ -757,7 +758,9 @@ def main():
757 text_encoder_config = LoraConfig( 758 text_encoder_config = LoraConfig(
758 r=args.lora_text_encoder_r, 759 r=args.lora_text_encoder_r,
759 lora_alpha=args.lora_text_encoder_alpha, 760 lora_alpha=args.lora_text_encoder_alpha,
760 target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, 761 target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING
762 if args.lora_text_encoder_emb
763 else TEXT_ENCODER_TARGET_MODULES,
761 lora_dropout=args.lora_text_encoder_dropout, 764 lora_dropout=args.lora_text_encoder_dropout,
762 bias=args.lora_text_encoder_bias, 765 bias=args.lora_text_encoder_bias,
763 ) 766 )
@@ -787,7 +790,7 @@ def main():
787 790
788 if len(args.alias_tokens) != 0: 791 if len(args.alias_tokens) != 0:
789 embeddings = ensure_embeddings() 792 embeddings = ensure_embeddings()
790 793
791 alias_placeholder_tokens = args.alias_tokens[::2] 794 alias_placeholder_tokens = args.alias_tokens[::2]
792 alias_initializer_tokens = args.alias_tokens[1::2] 795 alias_initializer_tokens = args.alias_tokens[1::2]
793 796
@@ -795,27 +798,33 @@ def main():
795 tokenizer=tokenizer, 798 tokenizer=tokenizer,
796 embeddings=embeddings, 799 embeddings=embeddings,
797 placeholder_tokens=alias_placeholder_tokens, 800 placeholder_tokens=alias_placeholder_tokens,
798 initializer_tokens=alias_initializer_tokens 801 initializer_tokens=alias_initializer_tokens,
799 ) 802 )
800 embeddings.persist() 803 embeddings.persist()
801 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") 804 print(
805 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
806 )
802 807
803 placeholder_tokens = [] 808 placeholder_tokens = []
804 placeholder_token_ids = [] 809 placeholder_token_ids = []
805 810
806 if args.embeddings_dir is not None: 811 if args.embeddings_dir is not None:
807 embeddings = ensure_embeddings() 812 embeddings = ensure_embeddings()
808 813
809 embeddings_dir = Path(args.embeddings_dir) 814 embeddings_dir = Path(args.embeddings_dir)
810 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 815 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
811 raise ValueError("--embeddings_dir must point to an existing directory") 816 raise ValueError("--embeddings_dir must point to an existing directory")
812 817
813 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 818 added_tokens, added_ids = load_embeddings_from_dir(
819 tokenizer, embeddings, embeddings_dir
820 )
814 821
815 placeholder_tokens = added_tokens 822 placeholder_tokens = added_tokens
816 placeholder_token_ids = added_ids 823 placeholder_token_ids = added_ids
817 824
818 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 825 print(
826 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
827 )
819 828
820 if args.train_dir_embeddings: 829 if args.train_dir_embeddings:
821 print("Training embeddings from embeddings dir") 830 print("Training embeddings from embeddings dir")
@@ -824,7 +833,7 @@ def main():
824 833
825 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: 834 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings:
826 embeddings = ensure_embeddings() 835 embeddings = ensure_embeddings()
827 836
828 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 837 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
829 tokenizer=tokenizer, 838 tokenizer=tokenizer,
830 embeddings=embeddings, 839 embeddings=embeddings,
@@ -836,23 +845,34 @@ def main():
836 845
837 placeholder_tokens = args.placeholder_tokens 846 placeholder_tokens = args.placeholder_tokens
838 847
839 stats = list(zip( 848 stats = list(
840 placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids 849 zip(
841 )) 850 placeholder_tokens,
851 placeholder_token_ids,
852 args.initializer_tokens,
853 initializer_token_ids,
854 )
855 )
842 print(f"Training embeddings: {stats}") 856 print(f"Training embeddings: {stats}")
843 857
844 if args.scale_lr: 858 if args.scale_lr:
845 args.learning_rate_unet = ( 859 args.learning_rate_unet = (
846 args.learning_rate_unet * args.gradient_accumulation_steps * 860 args.learning_rate_unet
847 args.train_batch_size * accelerator.num_processes 861 * args.gradient_accumulation_steps
862 * args.train_batch_size
863 * accelerator.num_processes
848 ) 864 )
849 args.learning_rate_text = ( 865 args.learning_rate_text = (
850 args.learning_rate_text * args.gradient_accumulation_steps * 866 args.learning_rate_text
851 args.train_batch_size * accelerator.num_processes 867 * args.gradient_accumulation_steps
868 * args.train_batch_size
869 * accelerator.num_processes
852 ) 870 )
853 args.learning_rate_emb = ( 871 args.learning_rate_emb = (
854 args.learning_rate_emb * args.gradient_accumulation_steps * 872 args.learning_rate_emb
855 args.train_batch_size * accelerator.num_processes 873 * args.gradient_accumulation_steps
874 * args.train_batch_size
875 * accelerator.num_processes
856 ) 876 )
857 877
858 if args.find_lr: 878 if args.find_lr:
@@ -861,11 +881,13 @@ def main():
861 args.learning_rate_emb = 1e-6 881 args.learning_rate_emb = 1e-6
862 args.lr_scheduler = "exponential_growth" 882 args.lr_scheduler = "exponential_growth"
863 883
864 if args.optimizer == 'adam8bit': 884 if args.optimizer == "adam8bit":
865 try: 885 try:
866 import bitsandbytes as bnb 886 import bitsandbytes as bnb
867 except ImportError: 887 except ImportError:
868 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 888 raise ImportError(
889 "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
890 )
869 891
870 create_optimizer = partial( 892 create_optimizer = partial(
871 bnb.optim.AdamW8bit, 893 bnb.optim.AdamW8bit,
@@ -874,7 +896,7 @@ def main():
874 eps=args.adam_epsilon, 896 eps=args.adam_epsilon,
875 amsgrad=args.adam_amsgrad, 897 amsgrad=args.adam_amsgrad,
876 ) 898 )
877 elif args.optimizer == 'adam': 899 elif args.optimizer == "adam":
878 create_optimizer = partial( 900 create_optimizer = partial(
879 torch.optim.AdamW, 901 torch.optim.AdamW,
880 betas=(args.adam_beta1, args.adam_beta2), 902 betas=(args.adam_beta1, args.adam_beta2),
@@ -882,11 +904,13 @@ def main():
882 eps=args.adam_epsilon, 904 eps=args.adam_epsilon,
883 amsgrad=args.adam_amsgrad, 905 amsgrad=args.adam_amsgrad,
884 ) 906 )
885 elif args.optimizer == 'adan': 907 elif args.optimizer == "adan":
886 try: 908 try:
887 import timm.optim 909 import timm.optim
888 except ImportError: 910 except ImportError:
889 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") 911 raise ImportError(
912 "To use Adan, please install the PyTorch Image Models library: `pip install timm`."
913 )
890 914
891 create_optimizer = partial( 915 create_optimizer = partial(
892 timm.optim.Adan, 916 timm.optim.Adan,
@@ -894,11 +918,13 @@ def main():
894 eps=args.adam_epsilon, 918 eps=args.adam_epsilon,
895 no_prox=True, 919 no_prox=True,
896 ) 920 )
897 elif args.optimizer == 'lion': 921 elif args.optimizer == "lion":
898 try: 922 try:
899 import lion_pytorch 923 import lion_pytorch
900 except ImportError: 924 except ImportError:
901 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") 925 raise ImportError(
926 "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`."
927 )
902 928
903 create_optimizer = partial( 929 create_optimizer = partial(
904 lion_pytorch.Lion, 930 lion_pytorch.Lion,
@@ -906,7 +932,7 @@ def main():
906 weight_decay=args.adam_weight_decay, 932 weight_decay=args.adam_weight_decay,
907 use_triton=True, 933 use_triton=True,
908 ) 934 )
909 elif args.optimizer == 'adafactor': 935 elif args.optimizer == "adafactor":
910 create_optimizer = partial( 936 create_optimizer = partial(
911 transformers.optimization.Adafactor, 937 transformers.optimization.Adafactor,
912 weight_decay=args.adam_weight_decay, 938 weight_decay=args.adam_weight_decay,
@@ -920,11 +946,13 @@ def main():
920 args.learning_rate_unet = None 946 args.learning_rate_unet = None
921 args.learning_rate_text = None 947 args.learning_rate_text = None
922 args.learning_rate_emb = None 948 args.learning_rate_emb = None
923 elif args.optimizer == 'dadam': 949 elif args.optimizer == "dadam":
924 try: 950 try:
925 import dadaptation 951 import dadaptation
926 except ImportError: 952 except ImportError:
927 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") 953 raise ImportError(
954 "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`."
955 )
928 956
929 create_optimizer = partial( 957 create_optimizer = partial(
930 dadaptation.DAdaptAdam, 958 dadaptation.DAdaptAdam,
@@ -933,29 +961,35 @@ def main():
933 eps=args.adam_epsilon, 961 eps=args.adam_epsilon,
934 decouple=True, 962 decouple=True,
935 d0=args.dadaptation_d0, 963 d0=args.dadaptation_d0,
964 growth_rate=args.dadaptation_growth_rate,
936 ) 965 )
937 966
938 args.learning_rate_unet = 1.0 967 args.learning_rate_unet = 1.0
939 args.learning_rate_text = 1.0 968 args.learning_rate_text = 1.0
940 args.learning_rate_emb = 1.0 969 args.learning_rate_emb = 1.0
941 elif args.optimizer == 'dadan': 970 elif args.optimizer == "dadan":
942 try: 971 try:
943 import dadaptation 972 import dadaptation
944 except ImportError: 973 except ImportError:
945 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") 974 raise ImportError(
975 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
976 )
946 977
947 create_optimizer = partial( 978 create_optimizer = partial(
948 dadaptation.DAdaptAdan, 979 dadaptation.DAdaptAdan,
949 weight_decay=args.adam_weight_decay, 980 weight_decay=args.adam_weight_decay,
950 eps=args.adam_epsilon, 981 eps=args.adam_epsilon,
951 d0=args.dadaptation_d0, 982 d0=args.dadaptation_d0,
983 growth_rate=args.dadaptation_growth_rate,
952 ) 984 )
953 985
954 args.learning_rate_unet = 1.0 986 args.learning_rate_unet = 1.0
955 args.learning_rate_text = 1.0 987 args.learning_rate_text = 1.0
956 args.learning_rate_emb = 1.0 988 args.learning_rate_emb = 1.0
989 elif args.optimizer == "dlion":
990 raise ImportError("DLion has not been merged into dadaptation yet")
957 else: 991 else:
958 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 992 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
959 993
960 trainer = partial( 994 trainer = partial(
961 train, 995 train,
@@ -1026,25 +1060,33 @@ def main():
1026 1060
1027 if args.run_pti and len(placeholder_tokens) != 0: 1061 if args.run_pti and len(placeholder_tokens) != 0:
1028 embeddings = ensure_embeddings() 1062 embeddings = ensure_embeddings()
1029 1063
1030 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] 1064 filter_tokens = [
1065 token for token in args.filter_tokens if token in placeholder_tokens
1066 ]
1031 1067
1032 pti_datamodule = create_datamodule( 1068 pti_datamodule = create_datamodule(
1033 valid_set_size=0, 1069 valid_set_size=0,
1034 batch_size=args.train_batch_size, 1070 batch_size=args.train_batch_size,
1035 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), 1071 filter=partial(
1072 keyword_filter, filter_tokens, args.collection, args.exclude_collections
1073 ),
1036 ) 1074 )
1037 pti_datamodule.setup() 1075 pti_datamodule.setup()
1038 1076
1039 num_train_epochs = args.num_train_epochs 1077 num_train_epochs = args.num_train_epochs
1040 pti_sample_frequency = args.sample_frequency 1078 pti_sample_frequency = args.sample_frequency
1041 if num_train_epochs is None: 1079 if num_train_epochs is None:
1042 num_train_epochs = math.ceil( 1080 num_train_epochs = (
1043 args.num_train_steps / len(pti_datamodule.train_dataset) 1081 math.ceil(args.num_train_steps / len(pti_datamodule.train_dataset))
1044 ) * args.gradient_accumulation_steps 1082 * args.gradient_accumulation_steps
1045 pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) 1083 )
1084 pti_sample_frequency = math.ceil(
1085 num_train_epochs * (pti_sample_frequency / args.num_train_steps)
1086 )
1046 num_training_steps_per_epoch = math.ceil( 1087 num_training_steps_per_epoch = math.ceil(
1047 len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) 1088 len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps
1089 )
1048 num_train_steps = num_training_steps_per_epoch * num_train_epochs 1090 num_train_steps = num_training_steps_per_epoch * num_train_epochs
1049 if args.sample_num is not None: 1091 if args.sample_num is not None:
1050 pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) 1092 pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
@@ -1060,11 +1102,15 @@ def main():
1060 print(f"============ PTI ============") 1102 print(f"============ PTI ============")
1061 print("") 1103 print("")
1062 1104
1063 pti_optimizer = create_optimizer([{ 1105 pti_optimizer = create_optimizer(
1064 "params": text_encoder.text_model.embeddings.token_embedding.parameters(), 1106 [
1065 "lr": args.learning_rate_emb, 1107 {
1066 "weight_decay": 0, 1108 "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
1067 }]) 1109 "lr": args.learning_rate_emb,
1110 "weight_decay": 0,
1111 }
1112 ]
1113 )
1068 1114
1069 pti_lr_scheduler = create_lr_scheduler( 1115 pti_lr_scheduler = create_lr_scheduler(
1070 "constant_with_warmup", 1116 "constant_with_warmup",
@@ -1113,11 +1159,16 @@ def main():
1113 num_train_epochs = args.num_train_epochs 1159 num_train_epochs = args.num_train_epochs
1114 lora_sample_frequency = args.sample_frequency 1160 lora_sample_frequency = args.sample_frequency
1115 if num_train_epochs is None: 1161 if num_train_epochs is None:
1116 num_train_epochs = math.ceil( 1162 num_train_epochs = (
1117 args.num_train_steps / len(lora_datamodule.train_dataset) 1163 math.ceil(args.num_train_steps / len(lora_datamodule.train_dataset))
1118 ) * args.gradient_accumulation_steps 1164 * args.gradient_accumulation_steps
1119 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) 1165 )
1120 num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) 1166 lora_sample_frequency = math.ceil(
1167 num_train_epochs * (lora_sample_frequency / args.num_train_steps)
1168 )
1169 num_training_steps_per_epoch = math.ceil(
1170 len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps
1171 )
1121 num_train_steps = num_training_steps_per_epoch * num_train_epochs 1172 num_train_steps = num_training_steps_per_epoch * num_train_epochs
1122 if args.sample_num is not None: 1173 if args.sample_num is not None:
1123 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) 1174 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
@@ -1131,7 +1182,6 @@ def main():
1131 1182
1132 training_iter = 0 1183 training_iter = 0
1133 auto_cycles = list(args.auto_cycles) 1184 auto_cycles = list(args.auto_cycles)
1134 learning_rate_emb = args.learning_rate_emb
1135 learning_rate_unet = args.learning_rate_unet 1185 learning_rate_unet = args.learning_rate_unet
1136 learning_rate_text = args.learning_rate_text 1186 learning_rate_text = args.learning_rate_text
1137 lr_scheduler = args.lr_scheduler 1187 lr_scheduler = args.lr_scheduler
@@ -1145,21 +1195,15 @@ def main():
1145 1195
1146 params_to_optimize = [ 1196 params_to_optimize = [
1147 { 1197 {
1148 "params": ( 1198 "params": (param for param in unet.parameters() if param.requires_grad),
1149 param
1150 for param in unet.parameters()
1151 if param.requires_grad
1152 ),
1153 "lr": learning_rate_unet, 1199 "lr": learning_rate_unet,
1154 }, 1200 },
1155 { 1201 {
1156 "params": ( 1202 "params": (
1157 param 1203 param for param in text_encoder.parameters() if param.requires_grad
1158 for param in text_encoder.parameters()
1159 if param.requires_grad
1160 ), 1204 ),
1161 "lr": learning_rate_text, 1205 "lr": learning_rate_text,
1162 } 1206 },
1163 ] 1207 ]
1164 group_labels = ["unet", "text"] 1208 group_labels = ["unet", "text"]
1165 1209
@@ -1169,19 +1213,26 @@ def main():
1169 if len(auto_cycles) != 0: 1213 if len(auto_cycles) != 0:
1170 response = auto_cycles.pop(0) 1214 response = auto_cycles.pop(0)
1171 else: 1215 else:
1172 response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") 1216 response = input(
1217 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> "
1218 )
1173 1219
1174 if response.lower().strip() == "o": 1220 if response.lower().strip() == "o":
1175 if args.learning_rate_emb is not None:
1176 learning_rate_emb = args.learning_rate_emb * 2
1177 if args.learning_rate_unet is not None: 1221 if args.learning_rate_unet is not None:
1178 learning_rate_unet = args.learning_rate_unet * 2 1222 learning_rate_unet = (
1223 args.learning_rate_unet * 2 * (args.cycle_decay**training_iter)
1224 )
1179 if args.learning_rate_text is not None: 1225 if args.learning_rate_text is not None:
1180 learning_rate_text = args.learning_rate_text * 2 1226 learning_rate_text = (
1227 args.learning_rate_text * 2 * (args.cycle_decay**training_iter)
1228 )
1181 else: 1229 else:
1182 learning_rate_emb = args.learning_rate_emb 1230 learning_rate_unet = args.learning_rate_unet * (
1183 learning_rate_unet = args.learning_rate_unet 1231 args.cycle_decay**training_iter
1184 learning_rate_text = args.learning_rate_text 1232 )
1233 learning_rate_text = args.learning_rate_text * (
1234 args.cycle_decay**training_iter
1235 )
1185 1236
1186 if response.lower().strip() == "o": 1237 if response.lower().strip() == "o":
1187 lr_scheduler = "one_cycle" 1238 lr_scheduler = "one_cycle"
@@ -1204,9 +1255,11 @@ def main():
1204 print("") 1255 print("")
1205 print(f"============ LoRA cycle {training_iter + 1}: {response} ============") 1256 print(f"============ LoRA cycle {training_iter + 1}: {response} ============")
1206 print("") 1257 print("")
1207 1258
1208 for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): 1259 for group, lr in zip(
1209 group['lr'] = lr 1260 lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]
1261 ):
1262 group["lr"] = lr
1210 1263
1211 lora_lr_scheduler = create_lr_scheduler( 1264 lora_lr_scheduler = create_lr_scheduler(
1212 lr_scheduler, 1265 lr_scheduler,
@@ -1218,7 +1271,9 @@ def main():
1218 warmup_epochs=lr_warmup_epochs, 1271 warmup_epochs=lr_warmup_epochs,
1219 ) 1272 )
1220 1273
1221 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" 1274 lora_checkpoint_output_dir = (
1275 output_dir / lora_project / f"model_{training_iter}"
1276 )
1222 1277
1223 trainer( 1278 trainer(
1224 strategy=lora_strategy, 1279 strategy=lora_strategy,
@@ -1246,12 +1301,6 @@ def main():
1246 ) 1301 )
1247 1302
1248 training_iter += 1 1303 training_iter += 1
1249 if learning_rate_emb is not None:
1250 learning_rate_emb *= args.cycle_decay
1251 if learning_rate_unet is not None:
1252 learning_rate_unet *= args.cycle_decay
1253 if learning_rate_text is not None:
1254 learning_rate_text *= args.cycle_decay
1255 1304
1256 accelerator.end_training() 1305 accelerator.end_training()
1257 1306
diff --git a/train_ti.py b/train_ti.py
index f60e3e5..c6f0b3a 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -32,7 +32,7 @@ from util.files import load_config, load_embeddings_from_dir
32 32
33logger = get_logger(__name__) 33logger = get_logger(__name__)
34 34
35warnings.filterwarnings('ignore') 35warnings.filterwarnings("ignore")
36 36
37 37
38torch.backends.cuda.matmul.allow_tf32 = True 38torch.backends.cuda.matmul.allow_tf32 = True
@@ -46,9 +46,7 @@ hidet.torch.dynamo_config.search_space(0)
46 46
47 47
48def parse_args(): 48def parse_args():
49 parser = argparse.ArgumentParser( 49 parser = argparse.ArgumentParser(description="Simple example of a training script.")
50 description="Simple example of a training script."
51 )
52 parser.add_argument( 50 parser.add_argument(
53 "--pretrained_model_name_or_path", 51 "--pretrained_model_name_or_path",
54 type=str, 52 type=str,
@@ -65,12 +63,12 @@ def parse_args():
65 "--train_data_file", 63 "--train_data_file",
66 type=str, 64 type=str,
67 default=None, 65 default=None,
68 help="A CSV file containing the training data." 66 help="A CSV file containing the training data.",
69 ) 67 )
70 parser.add_argument( 68 parser.add_argument(
71 "--train_data_template", 69 "--train_data_template",
72 type=str, 70 type=str,
73 nargs='*', 71 nargs="*",
74 default="template", 72 default="template",
75 ) 73 )
76 parser.add_argument( 74 parser.add_argument(
@@ -80,59 +78,47 @@ def parse_args():
80 help="The name of the current project.", 78 help="The name of the current project.",
81 ) 79 )
82 parser.add_argument( 80 parser.add_argument(
83 "--auto_cycles", 81 "--auto_cycles", type=str, default="o", help="Cycles to run automatically."
84 type=str,
85 default="o",
86 help="Cycles to run automatically."
87 ) 82 )
88 parser.add_argument( 83 parser.add_argument(
89 "--cycle_decay", 84 "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle."
90 type=float,
91 default=1.0,
92 help="Learning rate decay per cycle."
93 ) 85 )
94 parser.add_argument( 86 parser.add_argument(
95 "--placeholder_tokens", 87 "--placeholder_tokens",
96 type=str, 88 type=str,
97 nargs='*', 89 nargs="*",
98 help="A token to use as a placeholder for the concept.", 90 help="A token to use as a placeholder for the concept.",
99 ) 91 )
100 parser.add_argument( 92 parser.add_argument(
101 "--initializer_tokens", 93 "--initializer_tokens",
102 type=str, 94 type=str,
103 nargs='*', 95 nargs="*",
104 help="A token to use as initializer word." 96 help="A token to use as initializer word.",
105 ) 97 )
106 parser.add_argument( 98 parser.add_argument(
107 "--filter_tokens", 99 "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by."
108 type=str,
109 nargs='*',
110 help="Tokens to filter the dataset by."
111 ) 100 )
112 parser.add_argument( 101 parser.add_argument(
113 "--initializer_noise", 102 "--initializer_noise",
114 type=float, 103 type=float,
115 default=0, 104 default=0,
116 help="Noise to apply to the initializer word" 105 help="Noise to apply to the initializer word",
117 ) 106 )
118 parser.add_argument( 107 parser.add_argument(
119 "--alias_tokens", 108 "--alias_tokens",
120 type=str, 109 type=str,
121 nargs='*', 110 nargs="*",
122 default=[], 111 default=[],
123 help="Tokens to create an alias for." 112 help="Tokens to create an alias for.",
124 ) 113 )
125 parser.add_argument( 114 parser.add_argument(
126 "--inverted_initializer_tokens", 115 "--inverted_initializer_tokens",
127 type=str, 116 type=str,
128 nargs='*', 117 nargs="*",
129 help="A token to use as initializer word." 118 help="A token to use as initializer word.",
130 ) 119 )
131 parser.add_argument( 120 parser.add_argument(
132 "--num_vectors", 121 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
133 type=int,
134 nargs='*',
135 help="Number of vectors per embedding."
136 ) 122 )
137 parser.add_argument( 123 parser.add_argument(
138 "--sequential", 124 "--sequential",
@@ -147,7 +133,7 @@ def parse_args():
147 "--num_class_images", 133 "--num_class_images",
148 type=int, 134 type=int,
149 default=0, 135 default=0,
150 help="How many class images to generate." 136 help="How many class images to generate.",
151 ) 137 )
152 parser.add_argument( 138 parser.add_argument(
153 "--class_image_dir", 139 "--class_image_dir",
@@ -158,7 +144,7 @@ def parse_args():
158 parser.add_argument( 144 parser.add_argument(
159 "--exclude_collections", 145 "--exclude_collections",
160 type=str, 146 type=str,
161 nargs='*', 147 nargs="*",
162 help="Exclude all items with a listed collection.", 148 help="Exclude all items with a listed collection.",
163 ) 149 )
164 parser.add_argument( 150 parser.add_argument(
@@ -181,14 +167,11 @@ def parse_args():
181 parser.add_argument( 167 parser.add_argument(
182 "--collection", 168 "--collection",
183 type=str, 169 type=str,
184 nargs='*', 170 nargs="*",
185 help="A collection to filter the dataset.", 171 help="A collection to filter the dataset.",
186 ) 172 )
187 parser.add_argument( 173 parser.add_argument(
188 "--seed", 174 "--seed", type=int, default=None, help="A seed for reproducible training."
189 type=int,
190 default=None,
191 help="A seed for reproducible training."
192 ) 175 )
193 parser.add_argument( 176 parser.add_argument(
194 "--resolution", 177 "--resolution",
@@ -244,7 +227,7 @@ def parse_args():
244 type=str, 227 type=str,
245 default="auto", 228 default="auto",
246 choices=["all", "trailing", "leading", "between", "auto", "off"], 229 choices=["all", "trailing", "leading", "between", "auto", "off"],
247 help='Vector shuffling algorithm.', 230 help="Vector shuffling algorithm.",
248 ) 231 )
249 parser.add_argument( 232 parser.add_argument(
250 "--offset_noise_strength", 233 "--offset_noise_strength",
@@ -256,18 +239,10 @@ def parse_args():
256 "--input_pertubation", 239 "--input_pertubation",
257 type=float, 240 type=float,
258 default=0, 241 default=0,
259 help="The scale of input pretubation. Recommended 0.1." 242 help="The scale of input pretubation. Recommended 0.1.",
260 )
261 parser.add_argument(
262 "--num_train_epochs",
263 type=int,
264 default=None
265 )
266 parser.add_argument(
267 "--num_train_steps",
268 type=int,
269 default=2000
270 ) 243 )
244 parser.add_argument("--num_train_epochs", type=int, default=None)
245 parser.add_argument("--num_train_steps", type=int, default=2000)
271 parser.add_argument( 246 parser.add_argument(
272 "--gradient_accumulation_steps", 247 "--gradient_accumulation_steps",
273 type=int, 248 type=int,
@@ -299,27 +274,31 @@ def parse_args():
299 "--lr_scheduler", 274 "--lr_scheduler",
300 type=str, 275 type=str,
301 default="one_cycle", 276 default="one_cycle",
302 choices=["linear", "cosine", "cosine_with_restarts", "polynomial", 277 choices=[
303 "constant", "constant_with_warmup", "one_cycle"], 278 "linear",
304 help='The scheduler type to use.', 279 "cosine",
280 "cosine_with_restarts",
281 "polynomial",
282 "constant",
283 "constant_with_warmup",
284 "one_cycle",
285 ],
286 help="The scheduler type to use.",
305 ) 287 )
306 parser.add_argument( 288 parser.add_argument(
307 "--lr_warmup_epochs", 289 "--lr_warmup_epochs",
308 type=int, 290 type=int,
309 default=10, 291 default=10,
310 help="Number of steps for the warmup in the lr scheduler." 292 help="Number of steps for the warmup in the lr scheduler.",
311 ) 293 )
312 parser.add_argument( 294 parser.add_argument(
313 "--lr_mid_point", 295 "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point."
314 type=float,
315 default=0.3,
316 help="OneCycle schedule mid point."
317 ) 296 )
318 parser.add_argument( 297 parser.add_argument(
319 "--lr_cycles", 298 "--lr_cycles",
320 type=int, 299 type=int,
321 default=None, 300 default=None,
322 help="Number of restart cycles in the lr scheduler." 301 help="Number of restart cycles in the lr scheduler.",
323 ) 302 )
324 parser.add_argument( 303 parser.add_argument(
325 "--lr_warmup_func", 304 "--lr_warmup_func",
@@ -331,7 +310,7 @@ def parse_args():
331 "--lr_warmup_exp", 310 "--lr_warmup_exp",
332 type=int, 311 type=int,
333 default=1, 312 default=1,
334 help='If lr_warmup_func is "cos", exponent to modify the function' 313 help='If lr_warmup_func is "cos", exponent to modify the function',
335 ) 314 )
336 parser.add_argument( 315 parser.add_argument(
337 "--lr_annealing_func", 316 "--lr_annealing_func",
@@ -343,89 +322,67 @@ def parse_args():
343 "--lr_annealing_exp", 322 "--lr_annealing_exp",
344 type=int, 323 type=int,
345 default=1, 324 default=1,
346 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 325 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function',
347 ) 326 )
348 parser.add_argument( 327 parser.add_argument(
349 "--lr_min_lr", 328 "--lr_min_lr",
350 type=float, 329 type=float,
351 default=0.04, 330 default=0.04,
352 help="Minimum learning rate in the lr scheduler." 331 help="Minimum learning rate in the lr scheduler.",
353 ) 332 )
354 parser.add_argument( 333 parser.add_argument(
355 "--use_ema", 334 "--use_ema", action="store_true", help="Whether to use EMA model."
356 action="store_true",
357 help="Whether to use EMA model."
358 )
359 parser.add_argument(
360 "--ema_inv_gamma",
361 type=float,
362 default=1.0
363 )
364 parser.add_argument(
365 "--ema_power",
366 type=float,
367 default=4/5
368 )
369 parser.add_argument(
370 "--ema_max_decay",
371 type=float,
372 default=0.9999
373 )
374 parser.add_argument(
375 "--min_snr_gamma",
376 type=int,
377 default=5,
378 help="MinSNR gamma."
379 ) 335 )
336 parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
337 parser.add_argument("--ema_power", type=float, default=4 / 5)
338 parser.add_argument("--ema_max_decay", type=float, default=0.9999)
339 parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.")
380 parser.add_argument( 340 parser.add_argument(
381 "--schedule_sampler", 341 "--schedule_sampler",
382 type=str, 342 type=str,
383 default="uniform", 343 default="uniform",
384 choices=["uniform", "loss-second-moment"], 344 choices=["uniform", "loss-second-moment"],
385 help="Noise schedule sampler." 345 help="Noise schedule sampler.",
386 ) 346 )
387 parser.add_argument( 347 parser.add_argument(
388 "--optimizer", 348 "--optimizer",
389 type=str, 349 type=str,
390 default="adan", 350 default="adan",
391 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 351 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
392 help='Optimizer to use' 352 help="Optimizer to use",
393 ) 353 )
394 parser.add_argument( 354 parser.add_argument(
395 "--dadaptation_d0", 355 "--dadaptation_d0",
396 type=float, 356 type=float,
397 default=1e-6, 357 default=1e-6,
398 help="The d0 parameter for Dadaptation optimizers." 358 help="The d0 parameter for Dadaptation optimizers.",
399 ) 359 )
400 parser.add_argument( 360 parser.add_argument(
401 "--adam_beta1", 361 "--adam_beta1",
402 type=float, 362 type=float,
403 default=None, 363 default=None,
404 help="The beta1 parameter for the Adam optimizer." 364 help="The beta1 parameter for the Adam optimizer.",
405 ) 365 )
406 parser.add_argument( 366 parser.add_argument(
407 "--adam_beta2", 367 "--adam_beta2",
408 type=float, 368 type=float,
409 default=None, 369 default=None,
410 help="The beta2 parameter for the Adam optimizer." 370 help="The beta2 parameter for the Adam optimizer.",
411 ) 371 )
412 parser.add_argument( 372 parser.add_argument(
413 "--adam_weight_decay", 373 "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use."
414 type=float,
415 default=2e-2,
416 help="Weight decay to use."
417 ) 374 )
418 parser.add_argument( 375 parser.add_argument(
419 "--adam_epsilon", 376 "--adam_epsilon",
420 type=float, 377 type=float,
421 default=1e-08, 378 default=1e-08,
422 help="Epsilon value for the Adam optimizer" 379 help="Epsilon value for the Adam optimizer",
423 ) 380 )
424 parser.add_argument( 381 parser.add_argument(
425 "--adam_amsgrad", 382 "--adam_amsgrad",
426 type=bool, 383 type=bool,
427 default=False, 384 default=False,
428 help="Amsgrad value for the Adam optimizer" 385 help="Amsgrad value for the Adam optimizer",
429 ) 386 )
430 parser.add_argument( 387 parser.add_argument(
431 "--mixed_precision", 388 "--mixed_precision",
@@ -456,7 +413,7 @@ def parse_args():
456 ) 413 )
457 parser.add_argument( 414 parser.add_argument(
458 "--no_milestone_checkpoints", 415 "--no_milestone_checkpoints",
459 action='store_true', 416 action="store_true",
460 help="If checkpoints are saved on maximum accuracy", 417 help="If checkpoints are saved on maximum accuracy",
461 ) 418 )
462 parser.add_argument( 419 parser.add_argument(
@@ -493,25 +450,25 @@ def parse_args():
493 "--valid_set_size", 450 "--valid_set_size",
494 type=int, 451 type=int,
495 default=None, 452 default=None,
496 help="Number of images in the validation dataset." 453 help="Number of images in the validation dataset.",
497 ) 454 )
498 parser.add_argument( 455 parser.add_argument(
499 "--train_set_pad", 456 "--train_set_pad",
500 type=int, 457 type=int,
501 default=None, 458 default=None,
502 help="The number to fill train dataset items up to." 459 help="The number to fill train dataset items up to.",
503 ) 460 )
504 parser.add_argument( 461 parser.add_argument(
505 "--valid_set_pad", 462 "--valid_set_pad",
506 type=int, 463 type=int,
507 default=None, 464 default=None,
508 help="The number to fill validation dataset items up to." 465 help="The number to fill validation dataset items up to.",
509 ) 466 )
510 parser.add_argument( 467 parser.add_argument(
511 "--train_batch_size", 468 "--train_batch_size",
512 type=int, 469 type=int,
513 default=1, 470 default=1,
514 help="Batch size (per device) for the training dataloader." 471 help="Batch size (per device) for the training dataloader.",
515 ) 472 )
516 parser.add_argument( 473 parser.add_argument(
517 "--sample_steps", 474 "--sample_steps",
@@ -523,14 +480,9 @@ def parse_args():
523 "--prior_loss_weight", 480 "--prior_loss_weight",
524 type=float, 481 type=float,
525 default=1.0, 482 default=1.0,
526 help="The weight of prior preservation loss." 483 help="The weight of prior preservation loss.",
527 )
528 parser.add_argument(
529 "--emb_alpha",
530 type=float,
531 default=1.0,
532 help="Embedding alpha"
533 ) 484 )
485 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha")
534 parser.add_argument( 486 parser.add_argument(
535 "--emb_dropout", 487 "--emb_dropout",
536 type=float, 488 type=float,
@@ -538,21 +490,13 @@ def parse_args():
538 help="Embedding dropout probability.", 490 help="Embedding dropout probability.",
539 ) 491 )
540 parser.add_argument( 492 parser.add_argument(
541 "--use_emb_decay", 493 "--use_emb_decay", action="store_true", help="Whether to use embedding decay."
542 action="store_true",
543 help="Whether to use embedding decay."
544 ) 494 )
545 parser.add_argument( 495 parser.add_argument(
546 "--emb_decay_target", 496 "--emb_decay_target", default=0.4, type=float, help="Embedding decay target."
547 default=0.4,
548 type=float,
549 help="Embedding decay target."
550 ) 497 )
551 parser.add_argument( 498 parser.add_argument(
552 "--emb_decay", 499 "--emb_decay", default=1e2, type=float, help="Embedding decay factor."
553 default=1e+2,
554 type=float,
555 help="Embedding decay factor."
556 ) 500 )
557 parser.add_argument( 501 parser.add_argument(
558 "--noise_timesteps", 502 "--noise_timesteps",
@@ -563,7 +507,7 @@ def parse_args():
563 "--resume_from", 507 "--resume_from",
564 type=str, 508 type=str,
565 default=None, 509 default=None,
566 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" 510 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)",
567 ) 511 )
568 parser.add_argument( 512 parser.add_argument(
569 "--global_step", 513 "--global_step",
@@ -574,7 +518,7 @@ def parse_args():
574 "--config", 518 "--config",
575 type=str, 519 type=str,
576 default=None, 520 default=None,
577 help="Path to a JSON configuration file containing arguments for invoking this script." 521 help="Path to a JSON configuration file containing arguments for invoking this script.",
578 ) 522 )
579 523
580 args = parser.parse_args() 524 args = parser.parse_args()
@@ -595,29 +539,44 @@ def parse_args():
595 args.placeholder_tokens = [args.placeholder_tokens] 539 args.placeholder_tokens = [args.placeholder_tokens]
596 540
597 if isinstance(args.initializer_tokens, str): 541 if isinstance(args.initializer_tokens, str):
598 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 542 args.initializer_tokens = [args.initializer_tokens] * len(
543 args.placeholder_tokens
544 )
599 545
600 if len(args.placeholder_tokens) == 0: 546 if len(args.placeholder_tokens) == 0:
601 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 547 args.placeholder_tokens = [
548 f"<*{i}>" for i in range(len(args.initializer_tokens))
549 ]
602 550
603 if len(args.initializer_tokens) == 0: 551 if len(args.initializer_tokens) == 0:
604 args.initializer_tokens = args.placeholder_tokens.copy() 552 args.initializer_tokens = args.placeholder_tokens.copy()
605 553
606 if len(args.placeholder_tokens) != len(args.initializer_tokens): 554 if len(args.placeholder_tokens) != len(args.initializer_tokens):
607 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 555 raise ValueError(
556 "--placeholder_tokens and --initializer_tokens must have the same number of items"
557 )
608 558
609 if isinstance(args.inverted_initializer_tokens, str): 559 if isinstance(args.inverted_initializer_tokens, str):
610 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) 560 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
561 args.placeholder_tokens
562 )
611 563
612 if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: 564 if (
565 isinstance(args.inverted_initializer_tokens, list)
566 and len(args.inverted_initializer_tokens) != 0
567 ):
613 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] 568 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
614 args.initializer_tokens += args.inverted_initializer_tokens 569 args.initializer_tokens += args.inverted_initializer_tokens
615 570
616 if isinstance(args.num_vectors, int): 571 if isinstance(args.num_vectors, int):
617 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 572 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
618 573
619 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): 574 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(
620 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 575 args.num_vectors
576 ):
577 raise ValueError(
578 "--placeholder_tokens and --num_vectors must have the same number of items"
579 )
621 580
622 if args.alias_tokens is None: 581 if args.alias_tokens is None:
623 args.alias_tokens = [] 582 args.alias_tokens = []
@@ -639,16 +598,22 @@ def parse_args():
639 ] 598 ]
640 599
641 if isinstance(args.train_data_template, str): 600 if isinstance(args.train_data_template, str):
642 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 601 args.train_data_template = [args.train_data_template] * len(
602 args.placeholder_tokens
603 )
643 604
644 if len(args.placeholder_tokens) != len(args.train_data_template): 605 if len(args.placeholder_tokens) != len(args.train_data_template):
645 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") 606 raise ValueError(
607 "--placeholder_tokens and --train_data_template must have the same number of items"
608 )
646 609
647 if args.num_vectors is None: 610 if args.num_vectors is None:
648 args.num_vectors = [None] * len(args.placeholder_tokens) 611 args.num_vectors = [None] * len(args.placeholder_tokens)
649 else: 612 else:
650 if isinstance(args.train_data_template, list): 613 if isinstance(args.train_data_template, list):
651 raise ValueError("--train_data_template can't be a list in simultaneous mode") 614 raise ValueError(
615 "--train_data_template can't be a list in simultaneous mode"
616 )
652 617
653 if isinstance(args.collection, str): 618 if isinstance(args.collection, str):
654 args.collection = [args.collection] 619 args.collection = [args.collection]
@@ -660,13 +625,13 @@ def parse_args():
660 raise ValueError("You must specify --output_dir") 625 raise ValueError("You must specify --output_dir")
661 626
662 if args.adam_beta1 is None: 627 if args.adam_beta1 is None:
663 if args.optimizer == 'lion': 628 if args.optimizer == "lion":
664 args.adam_beta1 = 0.95 629 args.adam_beta1 = 0.95
665 else: 630 else:
666 args.adam_beta1 = 0.9 631 args.adam_beta1 = 0.9
667 632
668 if args.adam_beta2 is None: 633 if args.adam_beta2 is None:
669 if args.optimizer == 'lion': 634 if args.optimizer == "lion":
670 args.adam_beta2 = 0.98 635 args.adam_beta2 = 0.98
671 else: 636 else:
672 args.adam_beta2 = 0.999 637 args.adam_beta2 = 0.999
@@ -679,13 +644,13 @@ def main():
679 644
680 global_step_offset = args.global_step 645 global_step_offset = args.global_step
681 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 646 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
682 output_dir = Path(args.output_dir)/slugify(args.project)/now 647 output_dir = Path(args.output_dir) / slugify(args.project) / now
683 output_dir.mkdir(parents=True, exist_ok=True) 648 output_dir.mkdir(parents=True, exist_ok=True)
684 649
685 accelerator = Accelerator( 650 accelerator = Accelerator(
686 log_with=LoggerType.TENSORBOARD, 651 log_with=LoggerType.TENSORBOARD,
687 project_dir=f"{output_dir}", 652 project_dir=f"{output_dir}",
688 mixed_precision=args.mixed_precision 653 mixed_precision=args.mixed_precision,
689 ) 654 )
690 655
691 weight_dtype = torch.float32 656 weight_dtype = torch.float32
@@ -703,9 +668,15 @@ def main():
703 668
704 save_args(output_dir, args) 669 save_args(output_dir, args)
705 670
706 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) 671 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(
707 embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) 672 args.pretrained_model_name_or_path
708 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) 673 )
674 embeddings = patch_managed_embeddings(
675 text_encoder, args.emb_alpha, args.emb_dropout
676 )
677 schedule_sampler = create_named_schedule_sampler(
678 args.schedule_sampler, noise_scheduler.config.num_train_timesteps
679 )
709 680
710 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 681 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
711 tokenizer.set_dropout(args.vector_dropout) 682 tokenizer.set_dropout(args.vector_dropout)
@@ -717,16 +688,16 @@ def main():
717 unet.enable_xformers_memory_efficient_attention() 688 unet.enable_xformers_memory_efficient_attention()
718 elif args.compile_unet: 689 elif args.compile_unet:
719 unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False 690 unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False
720 691
721 proc = AttnProcessor() 692 proc = AttnProcessor()
722 693
723 def fn_recursive_set_proc(module: torch.nn.Module): 694 def fn_recursive_set_proc(module: torch.nn.Module):
724 if hasattr(module, "processor"): 695 if hasattr(module, "processor"):
725 module.processor = proc 696 module.processor = proc
726 697
727 for child in module.children(): 698 for child in module.children():
728 fn_recursive_set_proc(child) 699 fn_recursive_set_proc(child)
729 700
730 fn_recursive_set_proc(unet) 701 fn_recursive_set_proc(unet)
731 702
732 if args.gradient_checkpointing: 703 if args.gradient_checkpointing:
@@ -751,18 +722,24 @@ def main():
751 tokenizer=tokenizer, 722 tokenizer=tokenizer,
752 embeddings=embeddings, 723 embeddings=embeddings,
753 placeholder_tokens=alias_placeholder_tokens, 724 placeholder_tokens=alias_placeholder_tokens,
754 initializer_tokens=alias_initializer_tokens 725 initializer_tokens=alias_initializer_tokens,
755 ) 726 )
756 embeddings.persist() 727 embeddings.persist()
757 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") 728 print(
729 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
730 )
758 731
759 if args.embeddings_dir is not None: 732 if args.embeddings_dir is not None:
760 embeddings_dir = Path(args.embeddings_dir) 733 embeddings_dir = Path(args.embeddings_dir)
761 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 734 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
762 raise ValueError("--embeddings_dir must point to an existing directory") 735 raise ValueError("--embeddings_dir must point to an existing directory")
763 736
764 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 737 added_tokens, added_ids = load_embeddings_from_dir(
765 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 738 tokenizer, embeddings, embeddings_dir
739 )
740 print(
741 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
742 )
766 743
767 if args.train_dir_embeddings: 744 if args.train_dir_embeddings:
768 args.placeholder_tokens = added_tokens 745 args.placeholder_tokens = added_tokens
@@ -772,19 +749,23 @@ def main():
772 749
773 if args.scale_lr: 750 if args.scale_lr:
774 args.learning_rate = ( 751 args.learning_rate = (
775 args.learning_rate * args.gradient_accumulation_steps * 752 args.learning_rate
776 args.train_batch_size * accelerator.num_processes 753 * args.gradient_accumulation_steps
754 * args.train_batch_size
755 * accelerator.num_processes
777 ) 756 )
778 757
779 if args.find_lr: 758 if args.find_lr:
780 args.learning_rate = 1e-5 759 args.learning_rate = 1e-5
781 args.lr_scheduler = "exponential_growth" 760 args.lr_scheduler = "exponential_growth"
782 761
783 if args.optimizer == 'adam8bit': 762 if args.optimizer == "adam8bit":
784 try: 763 try:
785 import bitsandbytes as bnb 764 import bitsandbytes as bnb
786 except ImportError: 765 except ImportError:
787 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 766 raise ImportError(
767 "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
768 )
788 769
789 create_optimizer = partial( 770 create_optimizer = partial(
790 bnb.optim.AdamW8bit, 771 bnb.optim.AdamW8bit,
@@ -793,7 +774,7 @@ def main():
793 eps=args.adam_epsilon, 774 eps=args.adam_epsilon,
794 amsgrad=args.adam_amsgrad, 775 amsgrad=args.adam_amsgrad,
795 ) 776 )
796 elif args.optimizer == 'adam': 777 elif args.optimizer == "adam":
797 create_optimizer = partial( 778 create_optimizer = partial(
798 torch.optim.AdamW, 779 torch.optim.AdamW,
799 betas=(args.adam_beta1, args.adam_beta2), 780 betas=(args.adam_beta1, args.adam_beta2),
@@ -801,11 +782,13 @@ def main():
801 eps=args.adam_epsilon, 782 eps=args.adam_epsilon,
802 amsgrad=args.adam_amsgrad, 783 amsgrad=args.adam_amsgrad,
803 ) 784 )
804 elif args.optimizer == 'adan': 785 elif args.optimizer == "adan":
805 try: 786 try:
806 import timm.optim 787 import timm.optim
807 except ImportError: 788 except ImportError:
808 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") 789 raise ImportError(
790 "To use Adan, please install the PyTorch Image Models library: `pip install timm`."
791 )
809 792
810 create_optimizer = partial( 793 create_optimizer = partial(
811 timm.optim.Adan, 794 timm.optim.Adan,
@@ -813,11 +796,13 @@ def main():
813 eps=args.adam_epsilon, 796 eps=args.adam_epsilon,
814 no_prox=True, 797 no_prox=True,
815 ) 798 )
816 elif args.optimizer == 'lion': 799 elif args.optimizer == "lion":
817 try: 800 try:
818 import lion_pytorch 801 import lion_pytorch
819 except ImportError: 802 except ImportError:
820 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") 803 raise ImportError(
804 "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`."
805 )
821 806
822 create_optimizer = partial( 807 create_optimizer = partial(
823 lion_pytorch.Lion, 808 lion_pytorch.Lion,
@@ -825,7 +810,7 @@ def main():
825 weight_decay=args.adam_weight_decay, 810 weight_decay=args.adam_weight_decay,
826 use_triton=True, 811 use_triton=True,
827 ) 812 )
828 elif args.optimizer == 'adafactor': 813 elif args.optimizer == "adafactor":
829 create_optimizer = partial( 814 create_optimizer = partial(
830 transformers.optimization.Adafactor, 815 transformers.optimization.Adafactor,
831 weight_decay=args.adam_weight_decay, 816 weight_decay=args.adam_weight_decay,
@@ -837,11 +822,13 @@ def main():
837 args.lr_scheduler = "adafactor" 822 args.lr_scheduler = "adafactor"
838 args.lr_min_lr = args.learning_rate 823 args.lr_min_lr = args.learning_rate
839 args.learning_rate = None 824 args.learning_rate = None
840 elif args.optimizer == 'dadam': 825 elif args.optimizer == "dadam":
841 try: 826 try:
842 import dadaptation 827 import dadaptation
843 except ImportError: 828 except ImportError:
844 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") 829 raise ImportError(
830 "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`."
831 )
845 832
846 create_optimizer = partial( 833 create_optimizer = partial(
847 dadaptation.DAdaptAdam, 834 dadaptation.DAdaptAdam,
@@ -851,11 +838,13 @@ def main():
851 decouple=True, 838 decouple=True,
852 d0=args.dadaptation_d0, 839 d0=args.dadaptation_d0,
853 ) 840 )
854 elif args.optimizer == 'dadan': 841 elif args.optimizer == "dadan":
855 try: 842 try:
856 import dadaptation 843 import dadaptation
857 except ImportError: 844 except ImportError:
858 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") 845 raise ImportError(
846 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
847 )
859 848
860 create_optimizer = partial( 849 create_optimizer = partial(
861 dadaptation.DAdaptAdan, 850 dadaptation.DAdaptAdan,
@@ -864,7 +853,7 @@ def main():
864 d0=args.dadaptation_d0, 853 d0=args.dadaptation_d0,
865 ) 854 )
866 else: 855 else:
867 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 856 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
868 857
869 trainer = partial( 858 trainer = partial(
870 train, 859 train,
@@ -904,10 +893,21 @@ def main():
904 sample_image_size=args.sample_image_size, 893 sample_image_size=args.sample_image_size,
905 ) 894 )
906 895
896 optimizer = create_optimizer(
897 text_encoder.text_model.embeddings.token_embedding.parameters(),
898 lr=learning_rate,
899 )
900
907 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 901 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
908 data_npgenerator = np.random.default_rng(args.seed) 902 data_npgenerator = np.random.default_rng(args.seed)
909 903
910 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): 904 def run(
905 i: int,
906 placeholder_tokens: list[str],
907 initializer_tokens: list[str],
908 num_vectors: Union[int, list[int]],
909 data_template: str,
910 ):
911 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 911 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
912 tokenizer=tokenizer, 912 tokenizer=tokenizer,
913 embeddings=embeddings, 913 embeddings=embeddings,
@@ -917,14 +917,23 @@ def main():
917 initializer_noise=args.initializer_noise, 917 initializer_noise=args.initializer_noise,
918 ) 918 )
919 919
920 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) 920 stats = list(
921 zip(
922 placeholder_tokens,
923 placeholder_token_ids,
924 initializer_tokens,
925 initializer_token_ids,
926 )
927 )
921 928
922 print("") 929 print("")
923 print(f"============ TI batch {i + 1} ============") 930 print(f"============ TI batch {i + 1} ============")
924 print("") 931 print("")
925 print(stats) 932 print(stats)
926 933
927 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] 934 filter_tokens = [
935 token for token in args.filter_tokens if token in placeholder_tokens
936 ]
928 937
929 datamodule = VlpnDataModule( 938 datamodule = VlpnDataModule(
930 data_file=args.train_data_file, 939 data_file=args.train_data_file,
@@ -945,7 +954,9 @@ def main():
945 valid_set_size=args.valid_set_size, 954 valid_set_size=args.valid_set_size,
946 train_set_pad=args.train_set_pad, 955 train_set_pad=args.train_set_pad,
947 valid_set_pad=args.valid_set_pad, 956 valid_set_pad=args.valid_set_pad,
948 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), 957 filter=partial(
958 keyword_filter, filter_tokens, args.collection, args.exclude_collections
959 ),
949 dtype=weight_dtype, 960 dtype=weight_dtype,
950 generator=data_generator, 961 generator=data_generator,
951 npgenerator=data_npgenerator, 962 npgenerator=data_npgenerator,
@@ -955,11 +966,16 @@ def main():
955 num_train_epochs = args.num_train_epochs 966 num_train_epochs = args.num_train_epochs
956 sample_frequency = args.sample_frequency 967 sample_frequency = args.sample_frequency
957 if num_train_epochs is None: 968 if num_train_epochs is None:
958 num_train_epochs = math.ceil( 969 num_train_epochs = (
959 args.num_train_steps / len(datamodule.train_dataset) 970 math.ceil(args.num_train_steps / len(datamodule.train_dataset))
960 ) * args.gradient_accumulation_steps 971 * args.gradient_accumulation_steps
961 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 972 )
962 num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) 973 sample_frequency = math.ceil(
974 num_train_epochs * (sample_frequency / args.num_train_steps)
975 )
976 num_training_steps_per_epoch = math.ceil(
977 len(datamodule.train_dataset) / args.gradient_accumulation_steps
978 )
963 num_train_steps = num_training_steps_per_epoch * num_train_epochs 979 num_train_steps = num_training_steps_per_epoch * num_train_epochs
964 if args.sample_num is not None: 980 if args.sample_num is not None:
965 sample_frequency = math.ceil(num_train_epochs / args.sample_num) 981 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
@@ -988,7 +1004,8 @@ def main():
988 response = auto_cycles.pop(0) 1004 response = auto_cycles.pop(0)
989 else: 1005 else:
990 response = input( 1006 response = input(
991 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") 1007 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> "
1008 )
992 1009
993 if response.lower().strip() == "o": 1010 if response.lower().strip() == "o":
994 if args.learning_rate is not None: 1011 if args.learning_rate is not None:
@@ -1018,10 +1035,8 @@ def main():
1018 print(f"------------ TI cycle {training_iter + 1}: {response} ------------") 1035 print(f"------------ TI cycle {training_iter + 1}: {response} ------------")
1019 print("") 1036 print("")
1020 1037
1021 optimizer = create_optimizer( 1038 for group, lr in zip(optimizer.param_groups, [learning_rate]):
1022 text_encoder.text_model.embeddings.token_embedding.parameters(), 1039 group["lr"] = lr
1023 lr=learning_rate,
1024 )
1025 1040
1026 lr_scheduler = get_scheduler( 1041 lr_scheduler = get_scheduler(
1027 lr_scheduler, 1042 lr_scheduler,
@@ -1040,7 +1055,9 @@ def main():
1040 mid_point=args.lr_mid_point, 1055 mid_point=args.lr_mid_point,
1041 ) 1056 )
1042 1057
1043 checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" 1058 checkpoint_output_dir = (
1059 output_dir / project / f"checkpoints_{training_iter}"
1060 )
1044 1061
1045 trainer( 1062 trainer(
1046 train_dataloader=datamodule.train_dataloader, 1063 train_dataloader=datamodule.train_dataloader,
@@ -1070,14 +1087,20 @@ def main():
1070 accelerator.end_training() 1087 accelerator.end_training()
1071 1088
1072 if not args.sequential: 1089 if not args.sequential:
1073 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 1090 run(
1091 0,
1092 args.placeholder_tokens,
1093 args.initializer_tokens,
1094 args.num_vectors,
1095 args.train_data_template,
1096 )
1074 else: 1097 else:
1075 for i, placeholder_token, initializer_token, num_vectors, data_template in zip( 1098 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
1076 range(len(args.placeholder_tokens)), 1099 range(len(args.placeholder_tokens)),
1077 args.placeholder_tokens, 1100 args.placeholder_tokens,
1078 args.initializer_tokens, 1101 args.initializer_tokens,
1079 args.num_vectors, 1102 args.num_vectors,
1080 args.train_data_template 1103 args.train_data_template,
1081 ): 1104 ):
1082 run(i, [placeholder_token], [initializer_token], num_vectors, data_template) 1105 run(i, [placeholder_token], [initializer_token], num_vectors, data_template)
1083 embeddings.persist() 1106 embeddings.persist()
diff --git a/training/functional.py b/training/functional.py
index fd3f9f4..f68faf9 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -14,7 +14,13 @@ import numpy as np
14 14
15from accelerate import Accelerator 15from accelerate import Accelerator
16from transformers import CLIPTextModel 16from transformers import CLIPTextModel
17from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin 17from diffusers import (
18 AutoencoderKL,
19 UNet2DConditionModel,
20 DDPMScheduler,
21 UniPCMultistepScheduler,
22 SchedulerMixin,
23)
18 24
19from tqdm.auto import tqdm 25from tqdm.auto import tqdm
20 26
@@ -33,11 +39,12 @@ from util.noise import perlin_noise
33def const(result=None): 39def const(result=None):
34 def fn(*args, **kwargs): 40 def fn(*args, **kwargs):
35 return result 41 return result
42
36 return fn 43 return fn
37 44
38 45
39@dataclass 46@dataclass
40class TrainingCallbacks(): 47class TrainingCallbacks:
41 on_log: Callable[[], dict[str, Any]] = const({}) 48 on_log: Callable[[], dict[str, Any]] = const({})
42 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 49 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
43 on_before_optimize: Callable[[int], Any] = const() 50 on_before_optimize: Callable[[int], Any] = const()
@@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol):
58 train_dataloader: DataLoader, 65 train_dataloader: DataLoader,
59 val_dataloader: Optional[DataLoader], 66 val_dataloader: Optional[DataLoader],
60 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 67 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
61 **kwargs 68 **kwargs,
62 ) -> Tuple: ... 69 ) -> Tuple:
70 ...
63 71
64 72
65@dataclass 73@dataclass
66class TrainingStrategy(): 74class TrainingStrategy:
67 callbacks: Callable[..., TrainingCallbacks] 75 callbacks: Callable[..., TrainingCallbacks]
68 prepare: TrainingStrategyPrepareCallable 76 prepare: TrainingStrategyPrepareCallable
69 77
70 78
71def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): 79def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32):
72 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 80 tokenizer = MultiCLIPTokenizer.from_pretrained(
73 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) 81 pretrained_model_name_or_path, subfolder="tokenizer"
74 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) 82 )
75 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) 83 text_encoder = CLIPTextModel.from_pretrained(
76 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 84 pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype
77 sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 85 )
86 vae = AutoencoderKL.from_pretrained(
87 pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype
88 )
89 unet = UNet2DConditionModel.from_pretrained(
90 pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype
91 )
92 noise_scheduler = DDPMScheduler.from_pretrained(
93 pretrained_model_name_or_path, subfolder="scheduler"
94 )
95 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
96 pretrained_model_name_or_path, subfolder="scheduler"
97 )
78 98
79 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler 99 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler
80 100
@@ -113,7 +133,9 @@ def save_samples(
113 133
114 generator = torch.Generator(device=accelerator.device).manual_seed(seed) 134 generator = torch.Generator(device=accelerator.device).manual_seed(seed)
115 135
116 datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] 136 datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [
137 ("train", train_dataloader, None)
138 ]
117 139
118 if val_dataloader is not None: 140 if val_dataloader is not None:
119 datasets.append(("stable", val_dataloader, generator)) 141 datasets.append(("stable", val_dataloader, generator))
@@ -124,17 +146,11 @@ def save_samples(
124 file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" 146 file_path = output_dir / pool / f"step_{cycle}_{step}.jpg"
125 file_path.parent.mkdir(parents=True, exist_ok=True) 147 file_path.parent.mkdir(parents=True, exist_ok=True)
126 148
127 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) 149 batches = list(
128 prompt_ids = [ 150 itertools.islice(itertools.cycle(data), batch_size * num_batches)
129 prompt 151 )
130 for batch in batches 152 prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]]
131 for prompt in batch["prompt_ids"] 153 nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]]
132 ]
133 nprompt_ids = [
134 prompt
135 for batch in batches
136 for prompt in batch["nprompt_ids"]
137 ]
138 154
139 with torch.inference_mode(): 155 with torch.inference_mode():
140 for i in range(num_batches): 156 for i in range(num_batches):
@@ -165,7 +181,9 @@ def save_samples(
165 pass 181 pass
166 182
167 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) 183 image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols)
168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] 184 image_grid = pipeline.numpy_to_pil(
185 image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy()
186 )[0]
169 image_grid.save(file_path, quality=85) 187 image_grid.save(file_path, quality=85)
170 188
171 del generator, pipeline 189 del generator, pipeline
@@ -184,15 +202,17 @@ def generate_class_images(
184 train_dataset: VlpnDataset, 202 train_dataset: VlpnDataset,
185 sample_batch_size: int, 203 sample_batch_size: int,
186 sample_image_size: int, 204 sample_image_size: int,
187 sample_steps: int 205 sample_steps: int,
188): 206):
189 missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] 207 missing_data = [
208 item for item in train_dataset.items if not item.class_image_path.exists()
209 ]
190 210
191 if len(missing_data) == 0: 211 if len(missing_data) == 0:
192 return 212 return
193 213
194 batched_data = [ 214 batched_data = [
195 missing_data[i:i+sample_batch_size] 215 missing_data[i : i + sample_batch_size]
196 for i in range(0, len(missing_data), sample_batch_size) 216 for i in range(0, len(missing_data), sample_batch_size)
197 ] 217 ]
198 218
@@ -216,7 +236,7 @@ def generate_class_images(
216 negative_prompt=nprompt, 236 negative_prompt=nprompt,
217 height=sample_image_size, 237 height=sample_image_size,
218 width=sample_image_size, 238 width=sample_image_size,
219 num_inference_steps=sample_steps 239 num_inference_steps=sample_steps,
220 ).images 240 ).images
221 241
222 for i, image in enumerate(images): 242 for i, image in enumerate(images):
@@ -245,8 +265,12 @@ def add_placeholder_tokens(
245 265
246 embeddings.resize(len(tokenizer)) 266 embeddings.resize(len(tokenizer))
247 267
248 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): 268 for placeholder_token_id, initializer_token_id in zip(
249 embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) 269 placeholder_token_ids, initializer_token_ids
270 ):
271 embeddings.add_embed(
272 placeholder_token_id, initializer_token_id, initializer_noise
273 )
250 274
251 return placeholder_token_ids, initializer_token_ids 275 return placeholder_token_ids, initializer_token_ids
252 276
@@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler):
261 285
262 # Expand the tensors. 286 # Expand the tensors.
263 # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 287 # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
264 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 288 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
289 timesteps
290 ].float()
265 while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 291 while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
266 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 292 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
267 alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 293 alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
268 294
269 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 295 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
296 device=timesteps.device
297 )[timesteps].float()
270 while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): 298 while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
271 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 299 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
272 sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) 300 sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
@@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler):
277 305
278 306
279def get_original( 307def get_original(
280 noise_scheduler, 308 noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor
281 model_output,
282 sample: torch.FloatTensor,
283 timesteps: torch.IntTensor
284): 309):
285 alphas_cumprod = noise_scheduler.alphas_cumprod 310 alphas_cumprod = noise_scheduler.alphas_cumprod
286 sqrt_alphas_cumprod = alphas_cumprod**0.5 311 sqrt_alphas_cumprod = alphas_cumprod**0.5
287 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 312 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
288 313
289 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 314 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
315 timesteps
316 ].float()
290 while len(sqrt_alphas_cumprod.shape) < len(sample.shape): 317 while len(sqrt_alphas_cumprod.shape) < len(sample.shape):
291 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 318 sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
292 alpha = sqrt_alphas_cumprod.expand(sample.shape) 319 alpha = sqrt_alphas_cumprod.expand(sample.shape)
293 320
294 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() 321 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
322 device=timesteps.device
323 )[timesteps].float()
295 while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): 324 while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape):
296 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 325 sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
297 sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) 326 sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape)
@@ -329,7 +358,9 @@ def loss_step(
329 eval: bool = False, 358 eval: bool = False,
330): 359):
331 images = batch["pixel_values"] 360 images = batch["pixel_values"]
332 generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None 361 generator = (
362 torch.Generator(device=images.device).manual_seed(seed + step) if eval else None
363 )
333 bsz = images.shape[0] 364 bsz = images.shape[0]
334 365
335 # Convert images to latent space 366 # Convert images to latent space
@@ -342,7 +373,7 @@ def loss_step(
342 dtype=latents.dtype, 373 dtype=latents.dtype,
343 layout=latents.layout, 374 layout=latents.layout,
344 device=latents.device, 375 device=latents.device,
345 generator=generator 376 generator=generator,
346 ) 377 )
347 applied_noise = noise 378 applied_noise = noise
348 379
@@ -353,7 +384,7 @@ def loss_step(
353 octaves=4, 384 octaves=4,
354 dtype=latents.dtype, 385 dtype=latents.dtype,
355 device=latents.device, 386 device=latents.device,
356 generator=generator 387 generator=generator,
357 ) 388 )
358 389
359 if input_pertubation != 0: 390 if input_pertubation != 0:
@@ -362,7 +393,7 @@ def loss_step(
362 dtype=latents.dtype, 393 dtype=latents.dtype,
363 layout=latents.layout, 394 layout=latents.layout,
364 device=latents.device, 395 device=latents.device,
365 generator=generator 396 generator=generator,
366 ) 397 )
367 398
368 # Sample a random timestep for each image 399 # Sample a random timestep for each image
@@ -375,25 +406,27 @@ def loss_step(
375 406
376 # Get the text embedding for conditioning 407 # Get the text embedding for conditioning
377 encoder_hidden_states = get_extended_embeddings( 408 encoder_hidden_states = get_extended_embeddings(
378 text_encoder, 409 text_encoder, batch["input_ids"], batch["attention_mask"]
379 batch["input_ids"],
380 batch["attention_mask"]
381 ) 410 )
382 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) 411 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
383 412
384 # Predict the noise residual 413 # Predict the noise residual
385 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] 414 model_pred = unet(
415 noisy_latents, timesteps, encoder_hidden_states, return_dict=False
416 )[0]
386 417
387 if guidance_scale != 0: 418 if guidance_scale != 0:
388 uncond_encoder_hidden_states = get_extended_embeddings( 419 uncond_encoder_hidden_states = get_extended_embeddings(
389 text_encoder, 420 text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"]
390 batch["negative_input_ids"],
391 batch["negative_attention_mask"]
392 ) 421 )
393 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) 422 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype)
394 423
395 model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] 424 model_pred_uncond = unet(
396 model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) 425 noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False
426 )[0]
427 model_pred = model_pred_uncond + guidance_scale * (
428 model_pred - model_pred_uncond
429 )
397 430
398 # Get the target for loss depending on the prediction type 431 # Get the target for loss depending on the prediction type
399 if noise_scheduler.config.prediction_type == "epsilon": 432 if noise_scheduler.config.prediction_type == "epsilon":
@@ -401,7 +434,9 @@ def loss_step(
401 elif noise_scheduler.config.prediction_type == "v_prediction": 434 elif noise_scheduler.config.prediction_type == "v_prediction":
402 target = noise_scheduler.get_velocity(latents, noise, timesteps) 435 target = noise_scheduler.get_velocity(latents, noise, timesteps)
403 else: 436 else:
404 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 437 raise ValueError(
438 f"Unknown prediction type {noise_scheduler.config.prediction_type}"
439 )
405 440
406 acc = (model_pred == target).float().mean() 441 acc = (model_pred == target).float().mean()
407 442
@@ -414,7 +449,9 @@ def loss_step(
414 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 449 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
415 450
416 # Compute prior loss 451 # Compute prior loss
417 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") 452 prior_loss = F.mse_loss(
453 model_pred_prior.float(), target_prior.float(), reduction="none"
454 )
418 455
419 # Add the prior loss to the instance loss. 456 # Add the prior loss to the instance loss.
420 loss = loss + prior_loss_weight * prior_loss 457 loss = loss + prior_loss_weight * prior_loss
@@ -433,7 +470,10 @@ def loss_step(
433 if min_snr_gamma != 0: 470 if min_snr_gamma != 0:
434 snr = compute_snr(timesteps, noise_scheduler) 471 snr = compute_snr(timesteps, noise_scheduler)
435 mse_loss_weights = ( 472 mse_loss_weights = (
436 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 473 torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(
474 dim=1
475 )[0]
476 / snr
437 ) 477 )
438 loss = loss * mse_loss_weights 478 loss = loss * mse_loss_weights
439 479
@@ -447,8 +487,14 @@ def loss_step(
447 487
448 488
449class LossCallable(Protocol): 489class LossCallable(Protocol):
450 def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], 490 def __call__(
451 eval: bool = False) -> Tuple[Any, Any, int]: ... 491 self,
492 step: int,
493 batch: dict[Any, Any],
494 cache: dict[str, Any],
495 eval: bool = False,
496 ) -> Tuple[Any, Any, int]:
497 ...
452 498
453 499
454def train_loop( 500def train_loop(
@@ -472,9 +518,14 @@ def train_loop(
472 avg_acc_val: AverageMeter = AverageMeter(), 518 avg_acc_val: AverageMeter = AverageMeter(),
473 callbacks: TrainingCallbacks = TrainingCallbacks(), 519 callbacks: TrainingCallbacks = TrainingCallbacks(),
474): 520):
475 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 521 num_training_steps_per_epoch = math.ceil(
476 num_val_steps_per_epoch = math.ceil( 522 len(train_dataloader) / gradient_accumulation_steps
477 len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 523 )
524 num_val_steps_per_epoch = (
525 math.ceil(len(val_dataloader) / gradient_accumulation_steps)
526 if val_dataloader is not None
527 else 0
528 )
478 529
479 num_training_steps = num_training_steps_per_epoch * num_epochs 530 num_training_steps = num_training_steps_per_epoch * num_epochs
480 num_val_steps = num_val_steps_per_epoch * num_epochs 531 num_val_steps = num_val_steps_per_epoch * num_epochs
@@ -488,14 +539,14 @@ def train_loop(
488 local_progress_bar = tqdm( 539 local_progress_bar = tqdm(
489 range(num_training_steps_per_epoch + num_val_steps_per_epoch), 540 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
490 disable=not accelerator.is_local_main_process, 541 disable=not accelerator.is_local_main_process,
491 dynamic_ncols=True 542 dynamic_ncols=True,
492 ) 543 )
493 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") 544 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
494 545
495 global_progress_bar = tqdm( 546 global_progress_bar = tqdm(
496 range(num_training_steps + num_val_steps), 547 range(num_training_steps + num_val_steps),
497 disable=not accelerator.is_local_main_process, 548 disable=not accelerator.is_local_main_process,
498 dynamic_ncols=True 549 dynamic_ncols=True,
499 ) 550 )
500 global_progress_bar.set_description("Total progress") 551 global_progress_bar.set_description("Total progress")
501 552
@@ -513,7 +564,9 @@ def train_loop(
513 try: 564 try:
514 import dadaptation 565 import dadaptation
515 566
516 isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) 567 isDadaptation = isinstance(
568 optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)
569 )
517 except ImportError: 570 except ImportError:
518 pass 571 pass
519 572
@@ -565,7 +618,10 @@ def train_loop(
565 label = group_labels[i] if i < len(group_labels) else f"{i}" 618 label = group_labels[i] if i < len(group_labels) else f"{i}"
566 logs[f"lr/{label}"] = lr 619 logs[f"lr/{label}"] = lr
567 if isDadaptation: 620 if isDadaptation:
568 lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] 621 lr = (
622 optimizer.param_groups[i]["d"]
623 * optimizer.param_groups[i]["lr"]
624 )
569 logs[f"d*lr/{label}"] = lr 625 logs[f"d*lr/{label}"] = lr
570 lrs[label] = lr 626 lrs[label] = lr
571 627
@@ -573,8 +629,10 @@ def train_loop(
573 629
574 local_progress_bar.set_postfix(**logs) 630 local_progress_bar.set_postfix(**logs)
575 631
576 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): 632 if ((step + 1) % gradient_accumulation_steps == 0) or (
577 before_optimize_result = on_before_optimize(epoch) 633 (step + 1) == len(train_dataloader)
634 ):
635 before_optimize_result = on_before_optimize(cycle)
578 636
579 optimizer.step() 637 optimizer.step()
580 lr_scheduler.step() 638 lr_scheduler.step()
@@ -614,7 +672,9 @@ def train_loop(
614 } 672 }
615 local_progress_bar.set_postfix(**logs) 673 local_progress_bar.set_postfix(**logs)
616 674
617 if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): 675 if ((step + 1) % gradient_accumulation_steps == 0) or (
676 (step + 1) == len(val_dataloader)
677 ):
618 local_progress_bar.update(1) 678 local_progress_bar.update(1)
619 global_progress_bar.update(1) 679 global_progress_bar.update(1)
620 680
@@ -634,7 +694,8 @@ def train_loop(
634 global_progress_bar.clear() 694 global_progress_bar.clear()
635 695
636 accelerator.print( 696 accelerator.print(
637 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") 697 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}"
698 )
638 on_checkpoint(global_step, "milestone") 699 on_checkpoint(global_step, "milestone")
639 best_acc_val = avg_acc_val.max 700 best_acc_val = avg_acc_val.max
640 else: 701 else:
@@ -644,7 +705,8 @@ def train_loop(
644 global_progress_bar.clear() 705 global_progress_bar.clear()
645 706
646 accelerator.print( 707 accelerator.print(
647 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") 708 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}"
709 )
648 on_checkpoint(global_step, "milestone") 710 on_checkpoint(global_step, "milestone")
649 best_acc = avg_acc.max 711 best_acc = avg_acc.max
650 712
@@ -700,17 +762,32 @@ def train(
700 avg_acc_val: AverageMeter = AverageMeter(), 762 avg_acc_val: AverageMeter = AverageMeter(),
701 **kwargs, 763 **kwargs,
702): 764):
703 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( 765 (
704 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) 766 text_encoder,
767 unet,
768 optimizer,
769 train_dataloader,
770 val_dataloader,
771 lr_scheduler,
772 ) = strategy.prepare(
773 accelerator,
774 text_encoder,
775 unet,
776 optimizer,
777 train_dataloader,
778 val_dataloader,
779 lr_scheduler,
780 **kwargs,
781 )
705 782
706 vae.to(accelerator.device, dtype=dtype) 783 vae.to(accelerator.device, dtype=dtype)
707 vae.requires_grad_(False) 784 vae.requires_grad_(False)
708 vae.eval() 785 vae.eval()
709 786
710 vae = torch.compile(vae, backend='hidet') 787 vae = torch.compile(vae, backend="hidet")
711 788
712 if compile_unet: 789 if compile_unet:
713 unet = torch.compile(unet, backend='hidet') 790 unet = torch.compile(unet, backend="hidet")
714 # unet = torch.compile(unet, mode="reduce-overhead") 791 # unet = torch.compile(unet, mode="reduce-overhead")
715 792
716 callbacks = strategy.callbacks( 793 callbacks = strategy.callbacks(
diff --git a/training/lr.py b/training/lr.py
index f5b362f..a75078f 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -23,12 +23,12 @@ def plot_metrics(
23 fig, ax_loss = plt.subplots() 23 fig, ax_loss = plt.subplots()
24 ax_acc = ax_loss.twinx() 24 ax_acc = ax_loss.twinx()
25 25
26 ax_loss.plot(lrs, losses, color='red') 26 ax_loss.plot(lrs, losses, color="red")
27 ax_loss.set_xscale("log") 27 ax_loss.set_xscale("log")
28 ax_loss.set_xlabel(f"Learning rate") 28 ax_loss.set_xlabel(f"Learning rate")
29 ax_loss.set_ylabel("Loss") 29 ax_loss.set_ylabel("Loss")
30 30
31 ax_acc.plot(lrs, accs, color='blue') 31 ax_acc.plot(lrs, accs, color="blue")
32 ax_acc.set_xscale("log") 32 ax_acc.set_xscale("log")
33 ax_acc.set_ylabel("Accuracy") 33 ax_acc.set_ylabel("Accuracy")
34 34
diff --git a/training/optimization.py b/training/optimization.py
index d22a900..55531bf 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -5,7 +5,10 @@ from functools import partial
5import torch 5import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 8from diffusers.optimization import (
9 get_scheduler as get_scheduler_,
10 get_cosine_with_hard_restarts_schedule_with_warmup,
11)
9from transformers.optimization import get_adafactor_schedule 12from transformers.optimization import get_adafactor_schedule
10 13
11 14
@@ -52,7 +55,7 @@ def get_one_cycle_schedule(
52 annealing_exp: int = 1, 55 annealing_exp: int = 1,
53 min_lr: float = 0.04, 56 min_lr: float = 0.04,
54 mid_point: float = 0.3, 57 mid_point: float = 0.3,
55 last_epoch: int = -1 58 last_epoch: int = -1,
56): 59):
57 if warmup == "linear": 60 if warmup == "linear":
58 warmup_func = warmup_linear 61 warmup_func = warmup_linear
@@ -83,12 +86,16 @@ def get_one_cycle_schedule(
83 86
84 def lr_lambda(current_step: int): 87 def lr_lambda(current_step: int):
85 phase = [p for p in phases if current_step >= p.step_min][-1] 88 phase = [p for p in phases if current_step >= p.step_min][-1]
86 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) 89 return phase.min + phase.func(
90 (current_step - phase.step_min) / (phase.step_max - phase.step_min)
91 ) * (phase.max - phase.min)
87 92
88 return LambdaLR(optimizer, lr_lambda, last_epoch) 93 return LambdaLR(optimizer, lr_lambda, last_epoch)
89 94
90 95
91def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): 96def get_exponential_growing_schedule(
97 optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1
98):
92 def lr_lambda(base_lr: float, current_step: int): 99 def lr_lambda(base_lr: float, current_step: int):
93 return (end_lr / base_lr) ** (current_step / num_training_steps) 100 return (end_lr / base_lr) ** (current_step / num_training_steps)
94 101
@@ -132,7 +139,14 @@ def get_scheduler(
132 ) 139 )
133 elif id == "exponential_growth": 140 elif id == "exponential_growth":
134 if cycles is None: 141 if cycles is None:
135 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) 142 cycles = math.ceil(
143 math.sqrt(
144 (
145 (num_training_steps - num_warmup_steps)
146 / num_training_steps_per_epoch
147 )
148 )
149 )
136 150
137 lr_scheduler = get_exponential_growing_schedule( 151 lr_scheduler = get_exponential_growing_schedule(
138 optimizer=optimizer, 152 optimizer=optimizer,
@@ -141,7 +155,14 @@ def get_scheduler(
141 ) 155 )
142 elif id == "cosine_with_restarts": 156 elif id == "cosine_with_restarts":
143 if cycles is None: 157 if cycles is None:
144 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) 158 cycles = math.ceil(
159 math.sqrt(
160 (
161 (num_training_steps - num_warmup_steps)
162 / num_training_steps_per_epoch
163 )
164 )
165 )
145 166
146 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 167 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
147 optimizer=optimizer, 168 optimizer=optimizer,
@@ -150,10 +171,7 @@ def get_scheduler(
150 num_cycles=cycles, 171 num_cycles=cycles,
151 ) 172 )
152 elif id == "adafactor": 173 elif id == "adafactor":
153 lr_scheduler = get_adafactor_schedule( 174 lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr)
154 optimizer,
155 initial_lr=min_lr
156 )
157 else: 175 else:
158 lr_scheduler = get_scheduler_( 176 lr_scheduler = get_scheduler_(
159 id, 177 id,
diff --git a/training/sampler.py b/training/sampler.py
index bdb3e90..0487d66 100644
--- a/training/sampler.py
+++ b/training/sampler.py
@@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler):
134 def weights(self): 134 def weights(self):
135 if not self._warmed_up(): 135 if not self._warmed_up():
136 return np.ones([self.num_timesteps], dtype=np.float64) 136 return np.ones([self.num_timesteps], dtype=np.float64)
137 weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 137 weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
138 weights /= np.sum(weights) 138 weights /= np.sum(weights)
139 weights *= 1 - self.uniform_prob 139 weights *= 1 - self.uniform_prob
140 weights += self.uniform_prob / len(weights) 140 weights += self.uniform_prob / len(weights)
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index e6fcc89..88b441b 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks(
29 sample_output_dir: Path, 29 sample_output_dir: Path,
30 checkpoint_output_dir: Path, 30 checkpoint_output_dir: Path,
31 seed: int, 31 seed: int,
32 train_text_encoder_epochs: int, 32 train_text_encoder_cycles: int,
33 max_grad_norm: float = 1.0, 33 max_grad_norm: float = 1.0,
34 use_ema: bool = False, 34 use_ema: bool = False,
35 ema_inv_gamma: float = 1.0, 35 ema_inv_gamma: float = 1.0,
@@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks(
85 return nullcontext() 85 return nullcontext()
86 86
87 @contextmanager 87 @contextmanager
88 def on_train(epoch: int): 88 def on_train(cycle: int):
89 unet.train() 89 unet.train()
90 tokenizer.train() 90 tokenizer.train()
91 91
92 if epoch < train_text_encoder_epochs: 92 if cycle < train_text_encoder_cycles:
93 text_encoder.train() 93 text_encoder.train()
94 elif epoch == train_text_encoder_epochs: 94 tokenizer.train()
95 text_encoder.requires_grad_(False)
96 text_encoder.eval()
97 95
98 yield 96 yield
99 97
@@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks(
106 with ema_context(): 104 with ema_context():
107 yield 105 yield
108 106
109 def on_before_optimize(epoch: int): 107 def on_before_optimize(cycle: int):
110 params_to_clip = [unet.parameters()] 108 params_to_clip = [unet.parameters()]
111 if epoch < train_text_encoder_epochs: 109 if cycle < train_text_encoder_cycles:
112 params_to_clip.append(text_encoder.parameters()) 110 params_to_clip.append(text_encoder.parameters())
113 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) 111 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
114 112
@@ -189,8 +187,16 @@ def dreambooth_prepare(
189 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
190 **kwargs 188 **kwargs
191): 189):
192 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 190 (
193 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 191 text_encoder,
192 unet,
193 optimizer,
194 train_dataloader,
195 val_dataloader,
196 lr_scheduler,
197 ) = accelerator.prepare(
198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 )
194 200
195 text_encoder.text_model.embeddings.requires_grad_(False) 201 text_encoder.text_model.embeddings.requires_grad_(False)
196 202
@@ -198,6 +204,5 @@ def dreambooth_prepare(
198 204
199 205
200dreambooth_strategy = TrainingStrategy( 206dreambooth_strategy = TrainingStrategy(
201 callbacks=dreambooth_strategy_callbacks, 207 callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare
202 prepare=dreambooth_prepare
203) 208)
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index f942b76..14e3384 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -81,7 +81,7 @@ def lora_strategy_callbacks(
81 tokenizer.eval() 81 tokenizer.eval()
82 yield 82 yield
83 83
84 def on_before_optimize(epoch: int): 84 def on_before_optimize(cycle: int):
85 if not pti_mode: 85 if not pti_mode:
86 accelerator.clip_grad_norm_( 86 accelerator.clip_grad_norm_(
87 itertools.chain( 87 itertools.chain(
@@ -89,7 +89,7 @@ def lora_strategy_callbacks(
89 text_encoder.text_model.encoder.parameters(), 89 text_encoder.text_model.encoder.parameters(),
90 text_encoder.text_model.final_layer_norm.parameters(), 90 text_encoder.text_model.final_layer_norm.parameters(),
91 ), 91 ),
92 max_grad_norm 92 max_grad_norm,
93 ) 93 )
94 94
95 if len(placeholder_tokens) != 0 and use_emb_decay: 95 if len(placeholder_tokens) != 0 and use_emb_decay:
@@ -108,7 +108,9 @@ def lora_strategy_callbacks(
108 108
109 if lambda_ != 0: 109 if lambda_ != 0:
110 norm = w[:, :].norm(dim=-1, keepdim=True) 110 norm = w[:, :].norm(dim=-1, keepdim=True)
111 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 111 w[:].add_(
112 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
113 )
112 114
113 @torch.no_grad() 115 @torch.no_grad()
114 def on_checkpoint(step, postfix): 116 def on_checkpoint(step, postfix):
@@ -128,25 +130,32 @@ def lora_strategy_callbacks(
128 130
129 if not pti_mode: 131 if not pti_mode:
130 lora_config = {} 132 lora_config = {}
131 state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) 133 state_dict = get_peft_model_state_dict(
134 unet_, state_dict=accelerator.get_state_dict(unet_)
135 )
132 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) 136 lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True)
133 137
134 text_encoder_state_dict = get_peft_model_state_dict( 138 text_encoder_state_dict = get_peft_model_state_dict(
135 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) 139 text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_)
136 ) 140 )
137 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 141 text_encoder_state_dict = {
142 f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()
143 }
138 state_dict.update(text_encoder_state_dict) 144 state_dict.update(text_encoder_state_dict)
139 lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) 145 lora_config[
146 "text_encoder_peft_config"
147 ] = text_encoder_.get_peft_config_as_dict(inference=True)
140 148
141 if len(placeholder_tokens) != 0: 149 if len(placeholder_tokens) != 0:
142 ti_state_dict = { 150 ti_state_dict = {
143 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) 151 f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids)
144 for (token, ids) 152 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids)
145 in zip(placeholder_tokens, placeholder_token_ids)
146 } 153 }
147 state_dict.update(ti_state_dict) 154 state_dict.update(ti_state_dict)
148 155
149 save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") 156 save_file(
157 state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors"
158 )
150 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 159 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
151 json.dump(lora_config, f) 160 json.dump(lora_config, f)
152 161
@@ -185,10 +194,18 @@ def lora_prepare(
185 train_dataloader: DataLoader, 194 train_dataloader: DataLoader,
186 val_dataloader: Optional[DataLoader], 195 val_dataloader: Optional[DataLoader],
187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 196 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
188 **kwargs 197 **kwargs,
189): 198):
190 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 199 (
191 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 200 text_encoder,
201 unet,
202 optimizer,
203 train_dataloader,
204 val_dataloader,
205 lr_scheduler,
206 ) = accelerator.prepare(
207 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
208 )
192 209
193 # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) 210 # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
194 211
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6bc1d7d..7373982 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks(
104 yield 104 yield
105 105
106 @torch.no_grad() 106 @torch.no_grad()
107 def on_before_optimize(epoch: int): 107 def on_before_optimize(cycle: int):
108 if use_emb_decay: 108 if use_emb_decay:
109 params = [ 109 params = [
110 p 110 p
@@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks(
116 @torch.no_grad() 116 @torch.no_grad()
117 def on_after_optimize(w, lrs: dict[str, float]): 117 def on_after_optimize(w, lrs: dict[str, float]):
118 if ema_embeddings is not None: 118 if ema_embeddings is not None:
119 ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) 119 ema_embeddings.step(
120 text_encoder.text_model.embeddings.token_embedding.parameters()
121 )
120 122
121 if use_emb_decay and w is not None: 123 if use_emb_decay and w is not None:
122 lr = lrs["emb"] if "emb" in lrs else lrs["0"] 124 lr = lrs["emb"] if "emb" in lrs else lrs["0"]
@@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks(
124 126
125 if lambda_ != 0: 127 if lambda_ != 0:
126 norm = w[:, :].norm(dim=-1, keepdim=True) 128 norm = w[:, :].norm(dim=-1, keepdim=True)
127 w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) 129 w[:].add_(
130 (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)
131 )
128 132
129 def on_log(): 133 def on_log():
130 if ema_embeddings is not None: 134 if ema_embeddings is not None:
@@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks(
136 print(f"Saving checkpoint for step {step}...") 140 print(f"Saving checkpoint for step {step}...")
137 141
138 with ema_context(): 142 with ema_context():
139 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 143 for token, ids in zip(placeholder_tokens, placeholder_token_ids):
140 text_encoder.text_model.embeddings.save_embed( 144 text_encoder.text_model.embeddings.save_embed(
141 ids, 145 ids,
142 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" 146 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin",
143 ) 147 )
144 148
145 @torch.no_grad() 149 @torch.no_grad()
@@ -183,7 +187,7 @@ def textual_inversion_prepare(
183 val_dataloader: Optional[DataLoader], 187 val_dataloader: Optional[DataLoader],
184 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 188 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
185 gradient_checkpointing: bool = False, 189 gradient_checkpointing: bool = False,
186 **kwargs 190 **kwargs,
187): 191):
188 weight_dtype = torch.float32 192 weight_dtype = torch.float32
189 if accelerator.state.mixed_precision == "fp16": 193 if accelerator.state.mixed_precision == "fp16":
@@ -191,8 +195,15 @@ def textual_inversion_prepare(
191 elif accelerator.state.mixed_precision == "bf16": 195 elif accelerator.state.mixed_precision == "bf16":
192 weight_dtype = torch.bfloat16 196 weight_dtype = torch.bfloat16
193 197
194 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 198 (
195 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) 199 text_encoder,
200 optimizer,
201 train_dataloader,
202 val_dataloader,
203 lr_scheduler,
204 ) = accelerator.prepare(
205 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
206 )
196 207
197 unet.to(accelerator.device, dtype=weight_dtype) 208 unet.to(accelerator.device, dtype=weight_dtype)
198 unet.requires_grad_(False) 209 unet.requires_grad_(False)