diff options
-rw-r--r-- | .vscode/settings.json | 6 | ||||
-rw-r--r-- | data/csv.py | 183 | ||||
-rw-r--r-- | data/keywords.py | 8 | ||||
-rw-r--r-- | environment.yaml | 11 | ||||
-rw-r--r-- | infer.py | 124 | ||||
-rw-r--r-- | models/attention/control.py | 106 | ||||
-rw-r--r-- | models/attention/hook.py | 5 | ||||
-rw-r--r-- | models/attention/structured.py | 65 | ||||
-rw-r--r-- | models/clip/embeddings.py | 29 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 23 | ||||
-rw-r--r-- | models/clip/util.py | 17 | ||||
-rw-r--r-- | models/convnext/discriminator.py | 11 | ||||
-rw-r--r-- | models/sparse.py | 12 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 262 | ||||
-rw-r--r-- | train_dreambooth.py | 770 | ||||
-rw-r--r-- | train_lora.py | 489 | ||||
-rw-r--r-- | train_ti.py | 379 | ||||
-rw-r--r-- | training/functional.py | 221 | ||||
-rw-r--r-- | training/lr.py | 4 | ||||
-rw-r--r-- | training/optimization.py | 38 | ||||
-rw-r--r-- | training/sampler.py | 2 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 29 | ||||
-rw-r--r-- | training/strategy/lora.py | 41 | ||||
-rw-r--r-- | training/strategy/ti.py | 27 |
24 files changed, 1873 insertions, 989 deletions
diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..cee7b74 --- /dev/null +++ b/.vscode/settings.json | |||
@@ -0,0 +1,6 @@ | |||
1 | { | ||
2 | "[python]": { | ||
3 | "editor.defaultFormatter": "ms-python.black-formatter" | ||
4 | }, | ||
5 | "python.formatting.provider": "none" | ||
6 | } | ||
diff --git a/data/csv.py b/data/csv.py index 14380e8..d726033 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -49,7 +49,7 @@ def generate_buckets( | |||
49 | max_pixels: Optional[int] = None, | 49 | max_pixels: Optional[int] = None, |
50 | num_buckets: int = 4, | 50 | num_buckets: int = 4, |
51 | progressive_buckets: bool = False, | 51 | progressive_buckets: bool = False, |
52 | return_tensor: bool = True | 52 | return_tensor: bool = True, |
53 | ): | 53 | ): |
54 | if max_pixels is None: | 54 | if max_pixels is None: |
55 | max_pixels = (base_size + step_size) ** 2 | 55 | max_pixels = (base_size + step_size) ** 2 |
@@ -62,7 +62,11 @@ def generate_buckets( | |||
62 | 62 | ||
63 | for i in range(1, num_buckets + 1): | 63 | for i in range(1, num_buckets + 1): |
64 | long_side = base_size + i * step_size | 64 | long_side = base_size + i * step_size |
65 | short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) | 65 | short_side = min( |
66 | base_size | ||
67 | - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, | ||
68 | base_size, | ||
69 | ) | ||
66 | buckets.append(long_side / short_side) | 70 | buckets.append(long_side / short_side) |
67 | buckets.append(short_side / long_side) | 71 | buckets.append(short_side / long_side) |
68 | 72 | ||
@@ -106,7 +110,7 @@ def collate_fn( | |||
106 | max_token_id_length: Optional[int], | 110 | max_token_id_length: Optional[int], |
107 | with_guidance: bool, | 111 | with_guidance: bool, |
108 | with_prior_preservation: bool, | 112 | with_prior_preservation: bool, |
109 | examples | 113 | examples, |
110 | ): | 114 | ): |
111 | prompt_ids = [example["prompt_ids"] for example in examples] | 115 | prompt_ids = [example["prompt_ids"] for example in examples] |
112 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 116 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
@@ -125,7 +129,9 @@ def collate_fn( | |||
125 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) | 129 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) |
126 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) | 130 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) |
127 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) | 131 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) |
128 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) | 132 | negative_inputs = unify_input_ids( |
133 | tokenizer, negative_input_ids, max_token_id_length | ||
134 | ) | ||
129 | 135 | ||
130 | batch = { | 136 | batch = { |
131 | "prompt_ids": prompts.input_ids, | 137 | "prompt_ids": prompts.input_ids, |
@@ -149,35 +155,39 @@ class VlpnDataItem(NamedTuple): | |||
149 | nprompt: str | 155 | nprompt: str |
150 | collection: list[str] | 156 | collection: list[str] |
151 | 157 | ||
152 | def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): | 158 | def full_prompt( |
153 | return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) | 159 | self, |
160 | dropout: float = 0, | ||
161 | shuffle: bool = False, | ||
162 | npgenerator: Optional[np.random.Generator] = None, | ||
163 | ): | ||
164 | return keywords_to_str( | ||
165 | self.keywords, [self.prompt], dropout, shuffle, npgenerator | ||
166 | ) | ||
154 | 167 | ||
155 | 168 | ||
156 | def keyword_filter( | 169 | def keyword_filter( |
157 | placeholder_tokens: Optional[list[str]], | 170 | placeholder_tokens: Optional[list[str]], |
158 | collections: Optional[list[str]], | 171 | collections: Optional[list[str]], |
159 | exclude_collections: Optional[list[str]], | 172 | exclude_collections: Optional[list[str]], |
160 | item: VlpnDataItem | 173 | item: VlpnDataItem, |
161 | ): | 174 | ): |
162 | full_prompt = item.full_prompt() | 175 | full_prompt = item.full_prompt() |
163 | 176 | ||
164 | cond1 = placeholder_tokens is None or any( | 177 | cond1 = placeholder_tokens is None or any( |
165 | token in full_prompt | 178 | token in full_prompt for token in placeholder_tokens |
166 | for token in placeholder_tokens | ||
167 | ) | 179 | ) |
168 | cond2 = collections is None or any( | 180 | cond2 = collections is None or any( |
169 | collection in item.collection | 181 | collection in item.collection for collection in collections |
170 | for collection in collections | ||
171 | ) | 182 | ) |
172 | cond3 = exclude_collections is None or not any( | 183 | cond3 = exclude_collections is None or not any( |
173 | collection in item.collection | 184 | collection in item.collection for collection in exclude_collections |
174 | for collection in exclude_collections | ||
175 | ) | 185 | ) |
176 | 186 | ||
177 | return cond1 and cond2 and cond3 | 187 | return cond1 and cond2 and cond3 |
178 | 188 | ||
179 | 189 | ||
180 | class VlpnDataModule(): | 190 | class VlpnDataModule: |
181 | def __init__( | 191 | def __init__( |
182 | self, | 192 | self, |
183 | batch_size: int, | 193 | batch_size: int, |
@@ -222,7 +232,7 @@ class VlpnDataModule(): | |||
222 | 232 | ||
223 | self.constant_prompt_length = constant_prompt_length | 233 | self.constant_prompt_length = constant_prompt_length |
224 | self.max_token_id_length = None | 234 | self.max_token_id_length = None |
225 | 235 | ||
226 | self.tokenizer = tokenizer | 236 | self.tokenizer = tokenizer |
227 | self.size = size | 237 | self.size = size |
228 | self.num_buckets = num_buckets | 238 | self.num_buckets = num_buckets |
@@ -259,23 +269,29 @@ class VlpnDataModule(): | |||
259 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") | 269 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") |
260 | collection = item["collection"].split(", ") if "collection" in item else [] | 270 | collection = item["collection"].split(", ") if "collection" in item else [] |
261 | 271 | ||
262 | saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) | 272 | saturated_keywords = str_to_keywords( |
273 | tpl_keywords.format(**keywords), expansions | ||
274 | ) | ||
263 | 275 | ||
264 | inverted_tokens = keywords_to_str([ | 276 | inverted_tokens = keywords_to_str( |
265 | f"inv_{token}" | 277 | [ |
266 | for token in self.placeholder_tokens | 278 | f"inv_{token}" |
267 | if token in saturated_keywords | 279 | for token in self.placeholder_tokens |
268 | ]) | 280 | if token in saturated_keywords |
281 | ] | ||
282 | ) | ||
269 | 283 | ||
270 | items.append(VlpnDataItem( | 284 | items.append( |
271 | self.data_root / image, | 285 | VlpnDataItem( |
272 | None, | 286 | self.data_root / image, |
273 | saturated_keywords, | 287 | None, |
274 | tpl_prompt.format(**prompt), | 288 | saturated_keywords, |
275 | tpl_cprompt.format(**prompt), | 289 | tpl_prompt.format(**prompt), |
276 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), | 290 | tpl_cprompt.format(**prompt), |
277 | collection | 291 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), |
278 | )) | 292 | collection, |
293 | ) | ||
294 | ) | ||
279 | 295 | ||
280 | return items | 296 | return items |
281 | 297 | ||
@@ -285,13 +301,16 @@ class VlpnDataModule(): | |||
285 | 301 | ||
286 | return [item for item in items if self.filter(item)] | 302 | return [item for item in items if self.filter(item)] |
287 | 303 | ||
288 | def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: | 304 | def pad_items( |
305 | self, items: list[VlpnDataItem], num_class_images: int = 1 | ||
306 | ) -> list[VlpnDataItem]: | ||
289 | image_multiplier = max(num_class_images, 1) | 307 | image_multiplier = max(num_class_images, 1) |
290 | 308 | ||
291 | return [ | 309 | return [ |
292 | VlpnDataItem( | 310 | VlpnDataItem( |
293 | item.instance_image_path, | 311 | item.instance_image_path, |
294 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | 312 | self.class_root |
313 | / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | ||
295 | item.keywords, | 314 | item.keywords, |
296 | item.prompt, | 315 | item.prompt, |
297 | item.cprompt, | 316 | item.cprompt, |
@@ -303,7 +322,7 @@ class VlpnDataModule(): | |||
303 | ] | 322 | ] |
304 | 323 | ||
305 | def setup(self): | 324 | def setup(self): |
306 | with open(self.data_file, 'rt') as f: | 325 | with open(self.data_file, "rt") as f: |
307 | metadata = json.load(f) | 326 | metadata = json.load(f) |
308 | template = metadata[self.template_key] if self.template_key in metadata else {} | 327 | template = metadata[self.template_key] if self.template_key in metadata else {} |
309 | expansions = metadata["expansions"] if "expansions" in metadata else {} | 328 | expansions = metadata["expansions"] if "expansions" in metadata else {} |
@@ -312,25 +331,41 @@ class VlpnDataModule(): | |||
312 | items = self.prepare_items(template, expansions, items) | 331 | items = self.prepare_items(template, expansions, items) |
313 | items = self.filter_items(items) | 332 | items = self.filter_items(items) |
314 | self.npgenerator.shuffle(items) | 333 | self.npgenerator.shuffle(items) |
315 | 334 | ||
316 | if self.constant_prompt_length: | 335 | if self.constant_prompt_length: |
317 | all_input_ids = unify_input_ids( | 336 | all_input_ids = unify_input_ids( |
318 | self.tokenizer, | 337 | self.tokenizer, |
319 | [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] | 338 | [ |
339 | self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids | ||
340 | for item in items | ||
341 | ], | ||
320 | ).input_ids | 342 | ).input_ids |
321 | self.max_token_id_length = all_input_ids.shape[1] | 343 | self.max_token_id_length = all_input_ids.shape[1] |
322 | 344 | ||
323 | num_images = len(items) | 345 | num_images = len(items) |
324 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 346 | valid_set_size = ( |
347 | min(self.valid_set_size, num_images) | ||
348 | if self.valid_set_size is not None | ||
349 | else num_images // 10 | ||
350 | ) | ||
325 | train_set_size = max(num_images - valid_set_size, 1) | 351 | train_set_size = max(num_images - valid_set_size, 1) |
326 | valid_set_size = num_images - train_set_size | 352 | valid_set_size = num_images - train_set_size |
327 | 353 | ||
328 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) | 354 | collate_fn_ = partial( |
355 | collate_fn, | ||
356 | self.dtype, | ||
357 | self.tokenizer, | ||
358 | self.max_token_id_length, | ||
359 | self.with_guidance, | ||
360 | self.num_class_images != 0, | ||
361 | ) | ||
329 | 362 | ||
330 | if valid_set_size == 0: | 363 | if valid_set_size == 0: |
331 | data_train, data_val = items, items | 364 | data_train, data_val = items, items |
332 | else: | 365 | else: |
333 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) | 366 | data_train, data_val = random_split( |
367 | items, [train_set_size, valid_set_size], generator=self.generator | ||
368 | ) | ||
334 | 369 | ||
335 | data_train = self.pad_items(data_train, self.num_class_images) | 370 | data_train = self.pad_items(data_train, self.num_class_images) |
336 | 371 | ||
@@ -338,17 +373,25 @@ class VlpnDataModule(): | |||
338 | data_train *= math.ceil(self.train_set_pad / len(data_train)) | 373 | data_train *= math.ceil(self.train_set_pad / len(data_train)) |
339 | 374 | ||
340 | self.train_dataset = VlpnDataset( | 375 | self.train_dataset = VlpnDataset( |
341 | data_train, self.tokenizer, | 376 | data_train, |
342 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 377 | self.tokenizer, |
343 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 378 | num_buckets=self.num_buckets, |
344 | batch_size=self.batch_size, fill_batch=True, generator=self.generator, | 379 | progressive_buckets=self.progressive_buckets, |
345 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 380 | bucket_step_size=self.bucket_step_size, |
346 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 381 | bucket_max_pixels=self.bucket_max_pixels, |
382 | batch_size=self.batch_size, | ||
383 | fill_batch=True, | ||
384 | generator=self.generator, | ||
385 | size=self.size, | ||
386 | interpolation=self.interpolation, | ||
387 | color_jitter=self.color_jitter, | ||
388 | num_class_images=self.num_class_images, | ||
389 | dropout=self.dropout, | ||
390 | shuffle=self.shuffle, | ||
347 | ) | 391 | ) |
348 | 392 | ||
349 | self.train_dataloader = DataLoader( | 393 | self.train_dataloader = DataLoader( |
350 | self.train_dataset, | 394 | self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
351 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | ||
352 | ) | 395 | ) |
353 | 396 | ||
354 | if len(data_val) != 0: | 397 | if len(data_val) != 0: |
@@ -358,16 +401,24 @@ class VlpnDataModule(): | |||
358 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) | 401 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) |
359 | 402 | ||
360 | self.val_dataset = VlpnDataset( | 403 | self.val_dataset = VlpnDataset( |
361 | data_val, self.tokenizer, | 404 | data_val, |
362 | num_buckets=self.num_buckets, progressive_buckets=True, | 405 | self.tokenizer, |
363 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 406 | num_buckets=self.num_buckets, |
364 | batch_size=self.batch_size, generator=self.generator, | 407 | progressive_buckets=True, |
365 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 408 | bucket_step_size=self.bucket_step_size, |
409 | bucket_max_pixels=self.bucket_max_pixels, | ||
410 | batch_size=self.batch_size, | ||
411 | generator=self.generator, | ||
412 | size=self.size, | ||
413 | interpolation=self.interpolation, | ||
414 | color_jitter=self.color_jitter, | ||
366 | ) | 415 | ) |
367 | 416 | ||
368 | self.val_dataloader = DataLoader( | 417 | self.val_dataloader = DataLoader( |
369 | self.val_dataset, | 418 | self.val_dataset, |
370 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 419 | batch_size=None, |
420 | pin_memory=True, | ||
421 | collate_fn=collate_fn_, | ||
371 | ) | 422 | ) |
372 | else: | 423 | else: |
373 | self.val_dataloader = None | 424 | self.val_dataloader = None |
@@ -418,7 +469,13 @@ class VlpnDataset(IterableDataset): | |||
418 | 469 | ||
419 | self.bucket_item_range = torch.arange(len(self.bucket_items)) | 470 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
420 | 471 | ||
421 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 472 | self.length_ = ( |
473 | (self.bucket_assignments.bincount() / self.batch_size) | ||
474 | .ceil() | ||
475 | .long() | ||
476 | .sum() | ||
477 | .item() | ||
478 | ) | ||
422 | 479 | ||
423 | def get_input_ids(self, text: str): | 480 | def get_input_ids(self, text: str): |
424 | return self.tokenizer(text, padding="do_not_pad").input_ids | 481 | return self.tokenizer(text, padding="do_not_pad").input_ids |
@@ -430,7 +487,9 @@ class VlpnDataset(IterableDataset): | |||
430 | worker_info = torch.utils.data.get_worker_info() | 487 | worker_info = torch.utils.data.get_worker_info() |
431 | 488 | ||
432 | if self.shuffle: | 489 | if self.shuffle: |
433 | perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) | 490 | perm = torch.randperm( |
491 | len(self.bucket_assignments), generator=self.generator | ||
492 | ) | ||
434 | self.bucket_items = self.bucket_items[perm] | 493 | self.bucket_items = self.bucket_items[perm] |
435 | self.bucket_assignments = self.bucket_assignments[perm] | 494 | self.bucket_assignments = self.bucket_assignments[perm] |
436 | 495 | ||
@@ -488,7 +547,9 @@ class VlpnDataset(IterableDataset): | |||
488 | 547 | ||
489 | if len(bucket_items) == 0: | 548 | if len(bucket_items) == 0: |
490 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] | 549 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] |
491 | item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] | 550 | item_index = bucket_items[ |
551 | torch.randint(len(bucket_items), (1,), generator=self.generator) | ||
552 | ] | ||
492 | else: | 553 | else: |
493 | item_index = bucket_items[0] | 554 | item_index = bucket_items[0] |
494 | mask[self.bucket_item_range[bucket_mask][0]] = False | 555 | mask[self.bucket_item_range[bucket_mask][0]] = False |
@@ -500,12 +561,18 @@ class VlpnDataset(IterableDataset): | |||
500 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) | 561 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) |
501 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 562 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
502 | 563 | ||
503 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) | 564 | example["instance_prompt_ids"] = self.get_input_ids( |
565 | item.full_prompt(self.dropout, True, self.npgenerator) | ||
566 | ) | ||
504 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 567 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
505 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 568 | example["instance_images"] = image_transforms( |
569 | get_image(item.instance_image_path) | ||
570 | ) | ||
506 | 571 | ||
507 | if self.num_class_images != 0: | 572 | if self.num_class_images != 0: |
508 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) | 573 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) |
509 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 574 | example["class_images"] = image_transforms( |
575 | get_image(item.class_image_path) | ||
576 | ) | ||
510 | 577 | ||
511 | batch.append(example) | 578 | batch.append(example) |
diff --git a/data/keywords.py b/data/keywords.py index 8632d67..83fe9ff 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
@@ -8,7 +8,7 @@ def keywords_to_str( | |||
8 | undroppable_keywords: list[str] = [], | 8 | undroppable_keywords: list[str] = [], |
9 | dropout: float = 0, | 9 | dropout: float = 0, |
10 | shuffle: bool = False, | 10 | shuffle: bool = False, |
11 | npgenerator: Optional[np.random.Generator] = None | 11 | npgenerator: Optional[np.random.Generator] = None, |
12 | ) -> str: | 12 | ) -> str: |
13 | if dropout != 0: | 13 | if dropout != 0: |
14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
@@ -23,7 +23,11 @@ def keywords_to_str( | |||
23 | 23 | ||
24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: | 24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: |
25 | def expand_keyword(keyword: str) -> list[str]: | 25 | def expand_keyword(keyword: str) -> list[str]: |
26 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | 26 | return ( |
27 | [keyword] + expansions[keyword].split(", ") | ||
28 | if keyword in expansions | ||
29 | else [keyword] | ||
30 | ) | ||
27 | 31 | ||
28 | return [ | 32 | return [ |
29 | kw | 33 | kw |
diff --git a/environment.yaml b/environment.yaml index 1a55967..2c81a90 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -14,16 +14,17 @@ dependencies: | |||
14 | - numpy=1.24.3 | 14 | - numpy=1.24.3 |
15 | - pip=22.3.1 | 15 | - pip=22.3.1 |
16 | - python=3.10.8 | 16 | - python=3.10.8 |
17 | - pytorch=2.0.0=*cuda11.8* | 17 | - pytorch=2.0.1=*cuda11.8* |
18 | - torchvision=0.15.0 | 18 | - scipy=1.10.1 |
19 | - xformers=0.0.20.dev528 | 19 | - torchvision=0.15.2 |
20 | - xformers=0.0.21.dev542+git.a205b24 | ||
20 | - pip: | 21 | - pip: |
21 | - -e . | 22 | - -e . |
22 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate | 23 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate |
23 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 24 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
24 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | 25 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation |
25 | - --pre --extra-index-url https://download.hidet.org/whl hidet | 26 | - --pre --extra-index-url https://download.hidet.org/whl hidet |
26 | - bitsandbytes==0.38.1 | 27 | - bitsandbytes==0.39.1 |
27 | - lion-pytorch==0.0.7 | 28 | - lion-pytorch==0.0.7 |
28 | - peft==0.3.0 | 29 | - peft==0.3.0 |
29 | - python-slugify>=6.1.2 | 30 | - python-slugify>=6.1.2 |
@@ -31,4 +32,4 @@ dependencies: | |||
31 | - setuptools==65.6.3 | 32 | - setuptools==65.6.3 |
32 | - test-tube>=0.7.5 | 33 | - test-tube>=0.7.5 |
33 | - timm==0.9.2 | 34 | - timm==0.9.2 |
34 | - transformers==4.29.1 | 35 | - transformers==4.30.1 |
@@ -24,7 +24,7 @@ from diffusers import ( | |||
24 | KDPM2DiscreteScheduler, | 24 | KDPM2DiscreteScheduler, |
25 | KDPM2AncestralDiscreteScheduler, | 25 | KDPM2AncestralDiscreteScheduler, |
26 | DEISMultistepScheduler, | 26 | DEISMultistepScheduler, |
27 | UniPCMultistepScheduler | 27 | UniPCMultistepScheduler, |
28 | ) | 28 | ) |
29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict | 29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict |
30 | from safetensors.torch import load_file | 30 | from safetensors.torch import load_file |
@@ -61,7 +61,7 @@ default_cmds = { | |||
61 | "negative_prompt": None, | 61 | "negative_prompt": None, |
62 | "shuffle": False, | 62 | "shuffle": False, |
63 | "image": None, | 63 | "image": None, |
64 | "image_noise": .7, | 64 | "image_noise": 0.7, |
65 | "width": 768, | 65 | "width": 768, |
66 | "height": 768, | 66 | "height": 768, |
67 | "batch_size": 1, | 67 | "batch_size": 1, |
@@ -69,7 +69,6 @@ default_cmds = { | |||
69 | "steps": 30, | 69 | "steps": 30, |
70 | "guidance_scale": 7.0, | 70 | "guidance_scale": 7.0, |
71 | "sag_scale": 0, | 71 | "sag_scale": 0, |
72 | "brightness_offset": 0, | ||
73 | "seed": None, | 72 | "seed": None, |
74 | "config": None, | 73 | "config": None, |
75 | } | 74 | } |
@@ -85,9 +84,7 @@ def merge_dicts(d1, *args): | |||
85 | 84 | ||
86 | 85 | ||
87 | def create_args_parser(): | 86 | def create_args_parser(): |
88 | parser = argparse.ArgumentParser( | 87 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
89 | description="Simple example of a training script." | ||
90 | ) | ||
91 | parser.add_argument( | 88 | parser.add_argument( |
92 | "--model", | 89 | "--model", |
93 | type=str, | 90 | type=str, |
@@ -118,9 +115,7 @@ def create_args_parser(): | |||
118 | 115 | ||
119 | 116 | ||
120 | def create_cmd_parser(): | 117 | def create_cmd_parser(): |
121 | parser = argparse.ArgumentParser( | 118 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
122 | description="Simple example of a training script." | ||
123 | ) | ||
124 | parser.add_argument( | 119 | parser.add_argument( |
125 | "--project", | 120 | "--project", |
126 | type=str, | 121 | type=str, |
@@ -130,13 +125,34 @@ def create_cmd_parser(): | |||
130 | parser.add_argument( | 125 | parser.add_argument( |
131 | "--scheduler", | 126 | "--scheduler", |
132 | type=str, | 127 | type=str, |
133 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc"], | 128 | choices=[ |
129 | "plms", | ||
130 | "ddim", | ||
131 | "klms", | ||
132 | "dpmsm", | ||
133 | "dpmss", | ||
134 | "euler_a", | ||
135 | "kdpm2", | ||
136 | "kdpm2_a", | ||
137 | "deis", | ||
138 | "unipc", | ||
139 | ], | ||
134 | ) | 140 | ) |
135 | parser.add_argument( | 141 | parser.add_argument( |
136 | "--subscheduler", | 142 | "--subscheduler", |
137 | type=str, | 143 | type=str, |
138 | default=None, | 144 | default=None, |
139 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis"], | 145 | choices=[ |
146 | "plms", | ||
147 | "ddim", | ||
148 | "klms", | ||
149 | "dpmsm", | ||
150 | "dpmss", | ||
151 | "euler_a", | ||
152 | "kdpm2", | ||
153 | "kdpm2_a", | ||
154 | "deis", | ||
155 | ], | ||
140 | ) | 156 | ) |
141 | parser.add_argument( | 157 | parser.add_argument( |
142 | "--template", | 158 | "--template", |
@@ -193,10 +209,6 @@ def create_cmd_parser(): | |||
193 | type=float, | 209 | type=float, |
194 | ) | 210 | ) |
195 | parser.add_argument( | 211 | parser.add_argument( |
196 | "--brightness_offset", | ||
197 | type=float, | ||
198 | ) | ||
199 | parser.add_argument( | ||
200 | "--seed", | 212 | "--seed", |
201 | type=int, | 213 | type=int, |
202 | ) | 214 | ) |
@@ -214,7 +226,9 @@ def run_parser(parser, defaults, input=None): | |||
214 | 226 | ||
215 | if args.config is not None: | 227 | if args.config is not None: |
216 | conf_args = load_config(args.config) | 228 | conf_args = load_config(args.config) |
217 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] | 229 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[ |
230 | 0 | ||
231 | ] | ||
218 | 232 | ||
219 | res = defaults.copy() | 233 | res = defaults.copy() |
220 | for dict in [vars(conf_args), vars(args)]: | 234 | for dict in [vars(conf_args), vars(args)]: |
@@ -234,10 +248,12 @@ def load_embeddings_dir(pipeline, embeddings_dir): | |||
234 | added_tokens, added_ids = load_embeddings_from_dir( | 248 | added_tokens, added_ids = load_embeddings_from_dir( |
235 | pipeline.tokenizer, | 249 | pipeline.tokenizer, |
236 | pipeline.text_encoder.text_model.embeddings, | 250 | pipeline.text_encoder.text_model.embeddings, |
237 | Path(embeddings_dir) | 251 | Path(embeddings_dir), |
238 | ) | 252 | ) |
239 | pipeline.text_encoder.text_model.embeddings.persist() | 253 | pipeline.text_encoder.text_model.embeddings.persist() |
240 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 254 | print( |
255 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
256 | ) | ||
241 | 257 | ||
242 | 258 | ||
243 | def load_lora(pipeline, path): | 259 | def load_lora(pipeline, path): |
@@ -255,9 +271,13 @@ def load_lora(pipeline, path): | |||
255 | return | 271 | return |
256 | 272 | ||
257 | lora_checkpoint_sd = load_file(path / tensor_files[0]) | 273 | lora_checkpoint_sd = load_file(path / tensor_files[0]) |
258 | unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} | 274 | unet_lora_ds = { |
275 | k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k | ||
276 | } | ||
259 | text_encoder_lora_ds = { | 277 | text_encoder_lora_ds = { |
260 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | 278 | k.replace("text_encoder_", ""): v |
279 | for k, v in lora_checkpoint_sd.items() | ||
280 | if "text_encoder_" in k | ||
261 | } | 281 | } |
262 | ti_lora_ds = { | 282 | ti_lora_ds = { |
263 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k | 283 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k |
@@ -282,7 +302,9 @@ def load_lora(pipeline, path): | |||
282 | token_embeddings=token_embeddings, | 302 | token_embeddings=token_embeddings, |
283 | ) | 303 | ) |
284 | pipeline.text_encoder.text_model.embeddings.persist() | 304 | pipeline.text_encoder.text_model.embeddings.persist() |
285 | print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") | 305 | print( |
306 | f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}" | ||
307 | ) | ||
286 | 308 | ||
287 | return | 309 | return |
288 | 310 | ||
@@ -315,17 +337,25 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) | |||
315 | solver_p=create_scheduler(config, subscheduler), | 337 | solver_p=create_scheduler(config, subscheduler), |
316 | ) | 338 | ) |
317 | else: | 339 | else: |
318 | raise ValueError(f"Unknown scheduler \"{scheduler}\"") | 340 | raise ValueError(f'Unknown scheduler "{scheduler}"') |
319 | 341 | ||
320 | 342 | ||
321 | def create_pipeline(model, dtype): | 343 | def create_pipeline(model, dtype): |
322 | print("Loading Stable Diffusion pipeline...") | 344 | print("Loading Stable Diffusion pipeline...") |
323 | 345 | ||
324 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 346 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
325 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 347 | model, subfolder="tokenizer", torch_dtype=dtype |
326 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 348 | ) |
327 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 349 | text_encoder = CLIPTextModel.from_pretrained( |
328 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 350 | model, subfolder="text_encoder", torch_dtype=dtype |
351 | ) | ||
352 | vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype) | ||
353 | unet = UNet2DConditionModel.from_pretrained( | ||
354 | model, subfolder="unet", torch_dtype=dtype | ||
355 | ) | ||
356 | scheduler = DDIMScheduler.from_pretrained( | ||
357 | model, subfolder="scheduler", torch_dtype=dtype | ||
358 | ) | ||
329 | 359 | ||
330 | patch_managed_embeddings(text_encoder) | 360 | patch_managed_embeddings(text_encoder) |
331 | 361 | ||
@@ -347,7 +377,9 @@ def create_pipeline(model, dtype): | |||
347 | 377 | ||
348 | 378 | ||
349 | def shuffle_prompts(prompts: list[str]) -> list[str]: | 379 | def shuffle_prompts(prompts: list[str]) -> list[str]: |
350 | return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts] | 380 | return [ |
381 | keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts | ||
382 | ] | ||
351 | 383 | ||
352 | 384 | ||
353 | @torch.inference_mode() | 385 | @torch.inference_mode() |
@@ -386,12 +418,13 @@ def generate(output_dir: Path, pipeline, args): | |||
386 | else: | 418 | else: |
387 | init_image = None | 419 | init_image = None |
388 | 420 | ||
389 | pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) | 421 | pipeline.scheduler = create_scheduler( |
422 | pipeline.scheduler.config, args.scheduler, args.subscheduler | ||
423 | ) | ||
390 | 424 | ||
391 | for i in range(args.batch_num): | 425 | for i in range(args.batch_num): |
392 | pipeline.set_progress_bar_config( | 426 | pipeline.set_progress_bar_config( |
393 | desc=f"Batch {i + 1} of {args.batch_num}", | 427 | desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True |
394 | dynamic_ncols=True | ||
395 | ) | 428 | ) |
396 | 429 | ||
397 | seed = args.seed + i | 430 | seed = args.seed + i |
@@ -409,7 +442,6 @@ def generate(output_dir: Path, pipeline, args): | |||
409 | generator=generator, | 442 | generator=generator, |
410 | image=init_image, | 443 | image=init_image, |
411 | strength=args.image_noise, | 444 | strength=args.image_noise, |
412 | brightness_offset=args.brightness_offset, | ||
413 | ).images | 445 | ).images |
414 | 446 | ||
415 | for j, image in enumerate(images): | 447 | for j, image in enumerate(images): |
@@ -418,7 +450,7 @@ def generate(output_dir: Path, pipeline, args): | |||
418 | 450 | ||
419 | image.save(dir / f"{basename}.png") | 451 | image.save(dir / f"{basename}.png") |
420 | image.save(dir / f"{basename}.jpg", quality=85) | 452 | image.save(dir / f"{basename}.jpg", quality=85) |
421 | with open(dir / f"{basename}.txt", 'w') as f: | 453 | with open(dir / f"{basename}.txt", "w") as f: |
422 | f.write(prompt[j % len(args.prompt)]) | 454 | f.write(prompt[j % len(args.prompt)]) |
423 | 455 | ||
424 | if torch.cuda.is_available(): | 456 | if torch.cuda.is_available(): |
@@ -426,10 +458,12 @@ def generate(output_dir: Path, pipeline, args): | |||
426 | 458 | ||
427 | 459 | ||
428 | class CmdParse(cmd.Cmd): | 460 | class CmdParse(cmd.Cmd): |
429 | prompt = 'dream> ' | 461 | prompt = "dream> " |
430 | commands = [] | 462 | commands = [] |
431 | 463 | ||
432 | def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): | 464 | def __init__( |
465 | self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser | ||
466 | ): | ||
433 | super().__init__() | 467 | super().__init__() |
434 | 468 | ||
435 | self.output_dir = output_dir | 469 | self.output_dir = output_dir |
@@ -447,10 +481,10 @@ class CmdParse(cmd.Cmd): | |||
447 | print(str(e)) | 481 | print(str(e)) |
448 | return | 482 | return |
449 | 483 | ||
450 | if elements[0] == 'q': | 484 | if elements[0] == "q": |
451 | return True | 485 | return True |
452 | 486 | ||
453 | if elements[0] == 'reload_embeddings': | 487 | if elements[0] == "reload_embeddings": |
454 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) | 488 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) |
455 | return | 489 | return |
456 | 490 | ||
@@ -458,7 +492,7 @@ class CmdParse(cmd.Cmd): | |||
458 | args = run_parser(self.parser, default_cmds, elements) | 492 | args = run_parser(self.parser, default_cmds, elements) |
459 | 493 | ||
460 | if len(args.prompt) == 0: | 494 | if len(args.prompt) == 0: |
461 | print('Try again with a prompt!') | 495 | print("Try again with a prompt!") |
462 | return | 496 | return |
463 | except SystemExit: | 497 | except SystemExit: |
464 | traceback.print_exc() | 498 | traceback.print_exc() |
@@ -471,7 +505,7 @@ class CmdParse(cmd.Cmd): | |||
471 | try: | 505 | try: |
472 | generate(self.output_dir, self.pipeline, args) | 506 | generate(self.output_dir, self.pipeline, args) |
473 | except KeyboardInterrupt: | 507 | except KeyboardInterrupt: |
474 | print('Generation cancelled.') | 508 | print("Generation cancelled.") |
475 | except Exception as e: | 509 | except Exception as e: |
476 | traceback.print_exc() | 510 | traceback.print_exc() |
477 | return | 511 | return |
@@ -487,7 +521,9 @@ def main(): | |||
487 | args = run_parser(args_parser, default_args) | 521 | args = run_parser(args_parser, default_args) |
488 | 522 | ||
489 | output_dir = Path(args.output_dir) | 523 | output_dir = Path(args.output_dir) |
490 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 524 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[ |
525 | args.precision | ||
526 | ] | ||
491 | 527 | ||
492 | pipeline = create_pipeline(args.model, dtype) | 528 | pipeline = create_pipeline(args.model, dtype) |
493 | 529 | ||
@@ -496,7 +532,13 @@ def main(): | |||
496 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
497 | 533 | ||
498 | cmd_parser = create_cmd_parser() | 534 | cmd_parser = create_cmd_parser() |
499 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) | 535 | cmd_prompt = CmdParse( |
536 | output_dir, | ||
537 | args.ti_embeddings_dir, | ||
538 | args.lora_embeddings_dir, | ||
539 | pipeline, | ||
540 | cmd_parser, | ||
541 | ) | ||
500 | cmd_prompt.cmdloop() | 542 | cmd_prompt.cmdloop() |
501 | 543 | ||
502 | 544 | ||
diff --git a/models/attention/control.py b/models/attention/control.py index 248bd9f..ec378c4 100644 --- a/models/attention/control.py +++ b/models/attention/control.py | |||
@@ -23,7 +23,7 @@ class AttentionControl(abc.ABC): | |||
23 | attn = self.forward(attn, is_cross, place_in_unet) | 23 | attn = self.forward(attn, is_cross, place_in_unet) |
24 | else: | 24 | else: |
25 | h = attn.shape[0] | 25 | h = attn.shape[0] |
26 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) | 26 | attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) |
27 | self.cur_att_layer += 1 | 27 | self.cur_att_layer += 1 |
28 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | 28 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: |
29 | self.cur_att_layer = 0 | 29 | self.cur_att_layer = 0 |
@@ -49,12 +49,18 @@ class EmptyControl(AttentionControl): | |||
49 | class AttentionStore(AttentionControl): | 49 | class AttentionStore(AttentionControl): |
50 | @staticmethod | 50 | @staticmethod |
51 | def get_empty_store(): | 51 | def get_empty_store(): |
52 | return {"down_cross": [], "mid_cross": [], "up_cross": [], | 52 | return { |
53 | "down_self": [], "mid_self": [], "up_self": []} | 53 | "down_cross": [], |
54 | "mid_cross": [], | ||
55 | "up_cross": [], | ||
56 | "down_self": [], | ||
57 | "mid_self": [], | ||
58 | "up_self": [], | ||
59 | } | ||
54 | 60 | ||
55 | def forward(self, attn, is_cross: bool, place_in_unet: str): | 61 | def forward(self, attn, is_cross: bool, place_in_unet: str): |
56 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | 62 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" |
57 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead | 63 | if attn.shape[1] <= 32**2: # avoid memory overhead |
58 | self.step_store[key].append(attn) | 64 | self.step_store[key].append(attn) |
59 | return attn | 65 | return attn |
60 | 66 | ||
@@ -68,8 +74,10 @@ class AttentionStore(AttentionControl): | |||
68 | self.step_store = self.get_empty_store() | 74 | self.step_store = self.get_empty_store() |
69 | 75 | ||
70 | def get_average_attention(self): | 76 | def get_average_attention(self): |
71 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] | 77 | average_attention = { |
72 | for key in self.attention_store} | 78 | key: [item / self.cur_step for item in self.attention_store[key]] |
79 | for key in self.attention_store | ||
80 | } | ||
73 | return average_attention | 81 | return average_attention |
74 | 82 | ||
75 | def reset(self): | 83 | def reset(self): |
@@ -90,7 +98,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC): | |||
90 | return x_t | 98 | return x_t |
91 | 99 | ||
92 | def replace_self_attention(self, attn_base, att_replace): | 100 | def replace_self_attention(self, attn_base, att_replace): |
93 | if att_replace.shape[2] <= 16 ** 2: | 101 | if att_replace.shape[2] <= 16**2: |
94 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) | 102 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) |
95 | else: | 103 | else: |
96 | return att_replace | 104 | return att_replace |
@@ -101,41 +109,62 @@ class AttentionControlEdit(AttentionStore, abc.ABC): | |||
101 | 109 | ||
102 | def forward(self, attn, is_cross: bool, place_in_unet: str): | 110 | def forward(self, attn, is_cross: bool, place_in_unet: str): |
103 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) | 111 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) |
104 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): | 112 | if is_cross or ( |
113 | self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1] | ||
114 | ): | ||
105 | h = attn.shape[0] // (self.batch_size) | 115 | h = attn.shape[0] // (self.batch_size) |
106 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) | 116 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) |
107 | attn_base, attn_repalce = attn[0], attn[1:] | 117 | attn_base, attn_repalce = attn[0], attn[1:] |
108 | if is_cross: | 118 | if is_cross: |
109 | alpha_words = self.cross_replace_alpha[self.cur_step] | 119 | alpha_words = self.cross_replace_alpha[self.cur_step] |
110 | attn_repalce_new = self.replace_cross_attention( | 120 | attn_repalce_new = ( |
111 | attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce | 121 | self.replace_cross_attention(attn_base, attn_repalce) * alpha_words |
122 | + (1 - alpha_words) * attn_repalce | ||
123 | ) | ||
112 | attn[1:] = attn_repalce_new | 124 | attn[1:] = attn_repalce_new |
113 | else: | 125 | else: |
114 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) | 126 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) |
115 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) | 127 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) |
116 | return attn | 128 | return attn |
117 | 129 | ||
118 | def __init__(self, prompts, num_steps: int, | 130 | def __init__( |
119 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], | 131 | self, |
120 | self_replace_steps: Union[float, Tuple[float, float]], | 132 | prompts, |
121 | local_blend: Optional[LocalBlend]): | 133 | num_steps: int, |
134 | cross_replace_steps: Union[ | ||
135 | float, Tuple[float, float], Dict[str, Tuple[float, float]] | ||
136 | ], | ||
137 | self_replace_steps: Union[float, Tuple[float, float]], | ||
138 | local_blend: Optional[LocalBlend], | ||
139 | ): | ||
122 | super(AttentionControlEdit, self).__init__() | 140 | super(AttentionControlEdit, self).__init__() |
123 | self.batch_size = len(prompts) | 141 | self.batch_size = len(prompts) |
124 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( | 142 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( |
125 | prompts, num_steps, cross_replace_steps, tokenizer).to(device) | 143 | prompts, num_steps, cross_replace_steps, tokenizer |
144 | ).to(device) | ||
126 | if type(self_replace_steps) is float: | 145 | if type(self_replace_steps) is float: |
127 | self_replace_steps = 0, self_replace_steps | 146 | self_replace_steps = 0, self_replace_steps |
128 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) | 147 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int( |
148 | num_steps * self_replace_steps[1] | ||
149 | ) | ||
129 | self.local_blend = local_blend | 150 | self.local_blend = local_blend |
130 | 151 | ||
131 | 152 | ||
132 | class AttentionReplace(AttentionControlEdit): | 153 | class AttentionReplace(AttentionControlEdit): |
133 | def replace_cross_attention(self, attn_base, att_replace): | 154 | def replace_cross_attention(self, attn_base, att_replace): |
134 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) | 155 | return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) |
135 | 156 | ||
136 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, | 157 | def __init__( |
137 | local_blend: Optional[LocalBlend] = None): | 158 | self, |
138 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) | 159 | prompts, |
160 | num_steps: int, | ||
161 | cross_replace_steps: float, | ||
162 | self_replace_steps: float, | ||
163 | local_blend: Optional[LocalBlend] = None, | ||
164 | ): | ||
165 | super(AttentionReplace, self).__init__( | ||
166 | prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend | ||
167 | ) | ||
139 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) | 168 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) |
140 | 169 | ||
141 | 170 | ||
@@ -145,9 +174,17 @@ class AttentionRefine(AttentionControlEdit): | |||
145 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) | 174 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) |
146 | return attn_replace | 175 | return attn_replace |
147 | 176 | ||
148 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, | 177 | def __init__( |
149 | local_blend: Optional[LocalBlend] = None): | 178 | self, |
150 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) | 179 | prompts, |
180 | num_steps: int, | ||
181 | cross_replace_steps: float, | ||
182 | self_replace_steps: float, | ||
183 | local_blend: Optional[LocalBlend] = None, | ||
184 | ): | ||
185 | super(AttentionRefine, self).__init__( | ||
186 | prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend | ||
187 | ) | ||
151 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) | 188 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) |
152 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) | 189 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) |
153 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) | 190 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) |
@@ -156,13 +193,24 @@ class AttentionRefine(AttentionControlEdit): | |||
156 | class AttentionReweight(AttentionControlEdit): | 193 | class AttentionReweight(AttentionControlEdit): |
157 | def replace_cross_attention(self, attn_base, att_replace): | 194 | def replace_cross_attention(self, attn_base, att_replace): |
158 | if self.prev_controller is not None: | 195 | if self.prev_controller is not None: |
159 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) | 196 | attn_base = self.prev_controller.replace_cross_attention( |
197 | attn_base, att_replace | ||
198 | ) | ||
160 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] | 199 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] |
161 | return attn_replace | 200 | return attn_replace |
162 | 201 | ||
163 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, | 202 | def __init__( |
164 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): | 203 | self, |
165 | super(AttentionReweight, self).__init__(prompts, num_steps, | 204 | prompts, |
166 | cross_replace_steps, self_replace_steps, local_blend) | 205 | num_steps: int, |
206 | cross_replace_steps: float, | ||
207 | self_replace_steps: float, | ||
208 | equalizer, | ||
209 | local_blend: Optional[LocalBlend] = None, | ||
210 | controller: Optional[AttentionControlEdit] = None, | ||
211 | ): | ||
212 | super(AttentionReweight, self).__init__( | ||
213 | prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend | ||
214 | ) | ||
167 | self.equalizer = equalizer.to(device) | 215 | self.equalizer = equalizer.to(device) |
168 | self.prev_controller = controller | 216 | self.prev_controller = controller |
diff --git a/models/attention/hook.py b/models/attention/hook.py index 903de02..6b5fb68 100644 --- a/models/attention/hook.py +++ b/models/attention/hook.py | |||
@@ -3,6 +3,7 @@ import torch | |||
3 | 3 | ||
4 | try: | 4 | try: |
5 | import xformers.ops | 5 | import xformers.ops |
6 | |||
6 | xformers._is_functorch_available = True | 7 | xformers._is_functorch_available = True |
7 | MEM_EFFICIENT_ATTN = True | 8 | MEM_EFFICIENT_ATTN = True |
8 | except ImportError: | 9 | except ImportError: |
@@ -42,10 +43,10 @@ def register_attention_control(model, controller): | |||
42 | return forward | 43 | return forward |
43 | 44 | ||
44 | def register_recr(net_, count, place_in_unet): | 45 | def register_recr(net_, count, place_in_unet): |
45 | if net_.__class__.__name__ == 'CrossAttention': | 46 | if net_.__class__.__name__ == "CrossAttention": |
46 | net_.forward = ca_forward(net_, place_in_unet) | 47 | net_.forward = ca_forward(net_, place_in_unet) |
47 | return count + 1 | 48 | return count + 1 |
48 | elif hasattr(net_, 'children'): | 49 | elif hasattr(net_, "children"): |
49 | for net__ in net_.children(): | 50 | for net__ in net_.children(): |
50 | count = register_recr(net__, count, place_in_unet) | 51 | count = register_recr(net__, count, place_in_unet) |
51 | return count | 52 | return count |
diff --git a/models/attention/structured.py b/models/attention/structured.py index 24d889f..5bbbc06 100644 --- a/models/attention/structured.py +++ b/models/attention/structured.py | |||
@@ -16,7 +16,9 @@ class StructuredAttentionControl(AttentionControl): | |||
16 | if self.struct_attn: | 16 | if self.struct_attn: |
17 | out = self.struct_qkv(q, context, mask) | 17 | out = self.struct_qkv(q, context, mask) |
18 | else: | 18 | else: |
19 | context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context | 19 | context = torch.cat( |
20 | [context[0], context[1]["k"][0]], dim=0 | ||
21 | ) # use key tensor for context | ||
20 | out = self.normal_qkv(q, context, mask) | 22 | out = self.normal_qkv(q, context, mask) |
21 | else: | 23 | else: |
22 | context = default(context, x) | 24 | context = default(context, x) |
@@ -29,11 +31,13 @@ class StructuredAttentionControl(AttentionControl): | |||
29 | context: list of [uc, list of conditional context] | 31 | context: list of [uc, list of conditional context] |
30 | """ | 32 | """ |
31 | uc_context = context[0] | 33 | uc_context = context[0] |
32 | context_k, context_v = context[1]['k'], context[1]['v'] | 34 | context_k, context_v = context[1]["k"], context[1]["v"] |
33 | 35 | ||
34 | if isinstance(context_k, list) and isinstance(context_v, list): | 36 | if isinstance(context_k, list) and isinstance(context_v, list): |
35 | out = self.multi_qkv(q, uc_context, context_k, context_v, mask) | 37 | out = self.multi_qkv(q, uc_context, context_k, context_v, mask) |
36 | elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor): | 38 | elif isinstance(context_k, torch.Tensor) and isinstance( |
39 | context_v, torch.Tensor | ||
40 | ): | ||
37 | out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) | 41 | out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) |
38 | else: | 42 | else: |
39 | raise NotImplementedError | 43 | raise NotImplementedError |
@@ -50,36 +54,45 @@ class StructuredAttentionControl(AttentionControl): | |||
50 | k_c = [self.to_k(c_k) for c_k in context_k] | 54 | k_c = [self.to_k(c_k) for c_k in context_k] |
51 | v_c = [self.to_v(c_v) for c_v in context_v] | 55 | v_c = [self.to_v(c_v) for c_v in context_v] |
52 | 56 | ||
53 | q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) | 57 | q = rearrange(q, "b n (h d) -> (b h) n d", h=h) |
54 | 58 | ||
55 | k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h) | 59 | k_uc = rearrange(k_uc, "b n (h d) -> (b h) n d", h=h) |
56 | v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h) | 60 | v_uc = rearrange(v_uc, "b n (h d) -> (b h) n d", h=h) |
57 | 61 | ||
58 | k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point | 62 | k_c = [ |
59 | v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c] | 63 | rearrange(k, "b n (h d) -> (b h) n d", h=h) for k in k_c |
64 | ] # NOTE: modification point | ||
65 | v_c = [rearrange(v, "b n (h d) -> (b h) n d", h=h) for v in v_c] | ||
60 | 66 | ||
61 | # get composition | 67 | # get composition |
62 | sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale | 68 | sim_uc = einsum("b i d, b j d -> b i j", q[:true_bs], k_uc) * self.scale |
63 | sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c] | 69 | sim_c = [ |
70 | einsum("b i d, b j d -> b i j", q[true_bs:], k) * self.scale for k in k_c | ||
71 | ] | ||
64 | 72 | ||
65 | attn_uc = sim_uc.softmax(dim=-1) | 73 | attn_uc = sim_uc.softmax(dim=-1) |
66 | attn_c = [sim.softmax(dim=-1) for sim in sim_c] | 74 | attn_c = [sim.softmax(dim=-1) for sim in sim_c] |
67 | 75 | ||
68 | # get uc output | 76 | # get uc output |
69 | out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc) | 77 | out_uc = einsum("b i j, b j d -> b i d", attn_uc, v_uc) |
70 | 78 | ||
71 | # get c output | 79 | # get c output |
72 | if len(v_c) == 1: | 80 | if len(v_c) == 1: |
73 | out_c_collect = [] | 81 | out_c_collect = [] |
74 | for attn in attn_c: | 82 | for attn in attn_c: |
75 | for v in v_c: | 83 | for v in v_c: |
76 | out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v)) | 84 | out_c_collect.append(einsum("b i j, b j d -> b i d", attn, v)) |
77 | out_c = sum(out_c_collect) / len(out_c_collect) | 85 | out_c = sum(out_c_collect) / len(out_c_collect) |
78 | else: | 86 | else: |
79 | out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c) | 87 | out_c = sum( |
88 | [ | ||
89 | einsum("b i j, b j d -> b i d", attn, v) | ||
90 | for attn, v in zip(attn_c, v_c) | ||
91 | ] | ||
92 | ) / len(v_c) | ||
80 | 93 | ||
81 | out = torch.cat([out_uc, out_c], dim=0) | 94 | out = torch.cat([out_uc, out_c], dim=0) |
82 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | 95 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) |
83 | 96 | ||
84 | return out | 97 | return out |
85 | 98 | ||
@@ -88,21 +101,21 @@ class StructuredAttentionControl(AttentionControl): | |||
88 | k = self.to_k(context) | 101 | k = self.to_k(context) |
89 | v = self.to_v(context) | 102 | v = self.to_v(context) |
90 | 103 | ||
91 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | 104 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) |
92 | 105 | ||
93 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | 106 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale |
94 | 107 | ||
95 | if exists(mask): | 108 | if exists(mask): |
96 | mask = rearrange(mask, 'b ... -> b (...)') | 109 | mask = rearrange(mask, "b ... -> b (...)") |
97 | max_neg_value = -torch.finfo(sim.dtype).max | 110 | max_neg_value = -torch.finfo(sim.dtype).max |
98 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | 111 | mask = repeat(mask, "b j -> (b h) () j", h=h) |
99 | sim.masked_fill_(~mask, max_neg_value) | 112 | sim.masked_fill_(~mask, max_neg_value) |
100 | 113 | ||
101 | # attention, what we cannot get enough of | 114 | # attention, what we cannot get enough of |
102 | attn = sim.softmax(dim=-1) | 115 | attn = sim.softmax(dim=-1) |
103 | 116 | ||
104 | out = einsum('b i j, b j d -> b i d', attn, v) | 117 | out = einsum("b i j, b j d -> b i d", attn, v) |
105 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | 118 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) |
106 | 119 | ||
107 | return out | 120 | return out |
108 | 121 | ||
@@ -111,21 +124,21 @@ class StructuredAttentionControl(AttentionControl): | |||
111 | k = self.to_k(torch.cat([uc_context, context_k], dim=0)) | 124 | k = self.to_k(torch.cat([uc_context, context_k], dim=0)) |
112 | v = self.to_v(torch.cat([uc_context, context_v], dim=0)) | 125 | v = self.to_v(torch.cat([uc_context, context_v], dim=0)) |
113 | 126 | ||
114 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | 127 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) |
115 | 128 | ||
116 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | 129 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale |
117 | 130 | ||
118 | if exists(mask): | 131 | if exists(mask): |
119 | mask = rearrange(mask, 'b ... -> b (...)') | 132 | mask = rearrange(mask, "b ... -> b (...)") |
120 | max_neg_value = -torch.finfo(sim.dtype).max | 133 | max_neg_value = -torch.finfo(sim.dtype).max |
121 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | 134 | mask = repeat(mask, "b j -> (b h) () j", h=h) |
122 | sim.masked_fill_(~mask, max_neg_value) | 135 | sim.masked_fill_(~mask, max_neg_value) |
123 | 136 | ||
124 | # attention, what we cannot get enough of | 137 | # attention, what we cannot get enough of |
125 | attn = sim.softmax(dim=-1) | 138 | attn = sim.softmax(dim=-1) |
126 | 139 | ||
127 | out = einsum('b i j, b j d -> b i d', attn, v) | 140 | out = einsum("b i j, b j d -> b i d", attn, v) |
128 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | 141 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) |
129 | return out | 142 | return out |
130 | 143 | ||
131 | def get_kv(self, context): | 144 | def get_kv(self, context): |
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 7c7f2ac..8c3c6d4 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -14,7 +14,13 @@ from models.sparse import SparseEmbedding | |||
14 | 14 | ||
15 | 15 | ||
16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): | 17 | def __init__( |
18 | self, | ||
19 | config: CLIPTextConfig, | ||
20 | embeddings: CLIPTextEmbeddings, | ||
21 | alpha: int = 8, | ||
22 | dropout: float = 0.0, | ||
23 | ): | ||
18 | super().__init__(config) | 24 | super().__init__(config) |
19 | 25 | ||
20 | self.position_embedding = embeddings.position_embedding | 26 | self.position_embedding = embeddings.position_embedding |
@@ -28,7 +34,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
28 | self.token_embedding.weight = embeddings.token_embedding.weight | 34 | self.token_embedding.weight = embeddings.token_embedding.weight |
29 | 35 | ||
30 | def resize(self, size: int): | 36 | def resize(self, size: int): |
31 | self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) | 37 | self.token_embedding = self.token_embedding.new_resized( |
38 | size, self.initializer_factor | ||
39 | ) | ||
32 | 40 | ||
33 | def add_embed( | 41 | def add_embed( |
34 | self, | 42 | self, |
@@ -46,7 +54,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
46 | initializer = [initializer] | 54 | initializer = [initializer] |
47 | 55 | ||
48 | if isinstance(initializer, list): | 56 | if isinstance(initializer, list): |
49 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 57 | initializer = (initializer * len(token_ids))[: len(token_ids)] |
50 | 58 | ||
51 | with torch.no_grad(): | 59 | with torch.no_grad(): |
52 | initializer = self.get_embed(initializer) | 60 | initializer = self.get_embed(initializer) |
@@ -76,24 +84,21 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
76 | 84 | ||
77 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
78 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |
79 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 87 | input_ids = torch.tensor( |
88 | input_ids, device=self.token_embedding.weight.device, dtype=torch.long | ||
89 | ) | ||
80 | 90 | ||
81 | return self.token_embedding(input_ids) | 91 | return self.token_embedding(input_ids) |
82 | 92 | ||
83 | 93 | ||
84 | def patch_managed_embeddings( | 94 | def patch_managed_embeddings( |
85 | text_encoder: CLIPTextModel, | 95 | text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0 |
86 | alpha: int = 8, | ||
87 | dropout: float = 0.0 | ||
88 | ) -> ManagedCLIPTextEmbeddings: | 96 | ) -> ManagedCLIPTextEmbeddings: |
89 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): | 97 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): |
90 | return text_encoder.text_model.embeddings | 98 | return text_encoder.text_model.embeddings |
91 | 99 | ||
92 | text_embeddings = ManagedCLIPTextEmbeddings( | 100 | text_embeddings = ManagedCLIPTextEmbeddings( |
93 | text_encoder.config, | 101 | text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout |
94 | text_encoder.text_model.embeddings, | ||
95 | alpha, | ||
96 | dropout | ||
97 | ) | 102 | ) |
98 | text_encoder.text_model.embeddings = text_embeddings | 103 | text_encoder.text_model.embeddings = text_embeddings |
99 | return text_embeddings | 104 | return text_embeddings |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 789b525..a866641 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -91,18 +91,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
91 | self.vector_shuffle = shuffle_none | 91 | self.vector_shuffle = shuffle_none |
92 | 92 | ||
93 | def add_multi_tokens( | 93 | def add_multi_tokens( |
94 | self, | 94 | self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1 |
95 | new_tokens: Union[str, list[str]], | ||
96 | num_vectors: Union[int, list[int]] = 1 | ||
97 | ) -> Union[list[int], list[list[int]]]: | 95 | ) -> Union[list[int], list[list[int]]]: |
98 | if isinstance(new_tokens, list): | 96 | if isinstance(new_tokens, list): |
99 | if isinstance(num_vectors, int): | 97 | if isinstance(num_vectors, int): |
100 | num_vectors = [num_vectors] * len(new_tokens) | 98 | num_vectors = [num_vectors] * len(new_tokens) |
101 | 99 | ||
102 | if len(num_vectors) != len(new_tokens): | 100 | if len(num_vectors) != len(new_tokens): |
103 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | 101 | raise ValueError( |
102 | "Expected new_tokens and num_vectors to have the same len" | ||
103 | ) | ||
104 | 104 | ||
105 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | 105 | return [ |
106 | self.add_multi_tokens(new_token, vecs) | ||
107 | for new_token, vecs in zip(new_tokens, num_vectors) | ||
108 | ] | ||
106 | 109 | ||
107 | if isinstance(num_vectors, list): | 110 | if isinstance(num_vectors, list): |
108 | raise ValueError("Expected num_vectors to be int for single token") | 111 | raise ValueError("Expected num_vectors to be int for single token") |
@@ -129,13 +132,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
129 | return [id] | 132 | return [id] |
130 | 133 | ||
131 | def expand_ids(self, ids: list[int]): | 134 | def expand_ids(self, ids: list[int]): |
132 | return [ | 135 | return [new_id for id in ids for new_id in self.expand_id(id)] |
133 | new_id | ||
134 | for id in ids | ||
135 | for new_id in self.expand_id(id) | ||
136 | ] | ||
137 | 136 | ||
138 | def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): | 137 | def expand_batched_ids( |
138 | self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]] | ||
139 | ): | ||
139 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): | 140 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): |
140 | return [self.expand_ids(batch) for batch in input_ids] | 141 | return [self.expand_ids(batch) for batch in input_ids] |
141 | else: | 142 | else: |
diff --git a/models/clip/util.py b/models/clip/util.py index f94fbc7..7196bb6 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
@@ -5,27 +5,32 @@ import torch | |||
5 | from transformers import CLIPTokenizer, CLIPTextModel | 5 | from transformers import CLIPTokenizer, CLIPTextModel |
6 | 6 | ||
7 | 7 | ||
8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): | 8 | def unify_input_ids( |
9 | tokenizer: CLIPTokenizer, | ||
10 | input_ids: list[list[int]], | ||
11 | max_length: Optional[int] = None, | ||
12 | ): | ||
9 | if max_length is None: | 13 | if max_length is None: |
10 | return tokenizer.pad( | 14 | return tokenizer.pad( |
11 | {"input_ids": input_ids}, | 15 | {"input_ids": input_ids}, |
12 | padding=True, | 16 | padding=True, |
13 | pad_to_multiple_of=tokenizer.model_max_length, | 17 | pad_to_multiple_of=tokenizer.model_max_length, |
14 | return_tensors="pt" | 18 | return_tensors="pt", |
15 | ) | 19 | ) |
16 | else: | 20 | else: |
17 | return tokenizer.pad( | 21 | return tokenizer.pad( |
18 | {"input_ids": input_ids}, | 22 | {"input_ids": input_ids}, |
19 | padding="max_length", | 23 | padding="max_length", |
20 | max_length=max_length, | 24 | max_length=max_length, |
21 | return_tensors="pt" | 25 | return_tensors="pt", |
22 | ) | 26 | ) |
23 | 27 | ||
28 | |||
24 | def get_extended_embeddings( | 29 | def get_extended_embeddings( |
25 | text_encoder: CLIPTextModel, | 30 | text_encoder: CLIPTextModel, |
26 | input_ids: torch.LongTensor, | 31 | input_ids: torch.LongTensor, |
27 | position_ids: Optional[torch.LongTensor] = None, | 32 | position_ids: Optional[torch.LongTensor] = None, |
28 | attention_mask=None | 33 | attention_mask=None, |
29 | ): | 34 | ): |
30 | model_max_length = text_encoder.config.max_position_embeddings | 35 | model_max_length = text_encoder.config.max_position_embeddings |
31 | prompts = input_ids.shape[0] | 36 | prompts = input_ids.shape[0] |
@@ -36,6 +41,8 @@ def get_extended_embeddings( | |||
36 | if attention_mask is not None: | 41 | if attention_mask is not None: |
37 | attention_mask = attention_mask.view((-1, model_max_length)) | 42 | attention_mask = attention_mask.view((-1, model_max_length)) |
38 | 43 | ||
39 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | 44 | text_embeddings = text_encoder( |
45 | input_ids, position_ids=position_ids, attention_mask=attention_mask | ||
46 | )[0] | ||
40 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | 47 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
41 | return text_embeddings | 48 | return text_embeddings |
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py index 571b915..5798bcf 100644 --- a/models/convnext/discriminator.py +++ b/models/convnext/discriminator.py | |||
@@ -5,7 +5,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |||
5 | from torch.nn import functional as F | 5 | from torch.nn import functional as F |
6 | 6 | ||
7 | 7 | ||
8 | class ConvNeXtDiscriminator(): | 8 | class ConvNeXtDiscriminator: |
9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: | 9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: |
10 | self.net = model | 10 | self.net = model |
11 | 11 | ||
@@ -22,8 +22,13 @@ class ConvNeXtDiscriminator(): | |||
22 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) | 22 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) |
23 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) | 23 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) |
24 | 24 | ||
25 | img = ((img + 1.) / 2.).sub(img_mean).div(img_std) | 25 | img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std) |
26 | 26 | ||
27 | img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) | 27 | img = F.interpolate( |
28 | img, | ||
29 | size=(self.input_size, self.input_size), | ||
30 | mode="bicubic", | ||
31 | align_corners=True, | ||
32 | ) | ||
28 | pred = self.net(img) | 33 | pred = self.net(img) |
29 | return pred | 34 | return pred |
diff --git a/models/sparse.py b/models/sparse.py index bd45696..e5897c9 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
@@ -15,21 +15,25 @@ class SparseEmbedding(nn.Embedding): | |||
15 | ): | 15 | ): |
16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) |
17 | 17 | ||
18 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | 18 | self.register_buffer( |
19 | "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1 | ||
20 | ) | ||
19 | 21 | ||
20 | self.trainable = nn.ParameterList() | 22 | self.trainable = nn.ParameterList() |
21 | self.scaling = alpha | 23 | self.scaling = alpha |
22 | self.dropout_p = dropout | 24 | self.dropout_p = dropout |
23 | self.weight.requires_grad = False | 25 | self.weight.requires_grad = False |
24 | 26 | ||
25 | if dropout > 0.: | 27 | if dropout > 0.0: |
26 | self.dropout = nn.Dropout(p=dropout) | 28 | self.dropout = nn.Dropout(p=dropout) |
27 | else: | 29 | else: |
28 | self.dropout = nn.Identity() | 30 | self.dropout = nn.Identity() |
29 | 31 | ||
30 | self.reset_parameters() | 32 | self.reset_parameters() |
31 | 33 | ||
32 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | 34 | def new_resized( |
35 | self, new_num_embeddings: int, initializer_factor: Optional[float] = None | ||
36 | ): | ||
33 | n = min(self.num_embeddings, new_num_embeddings) | 37 | n = min(self.num_embeddings, new_num_embeddings) |
34 | 38 | ||
35 | new_emb = SparseEmbedding( | 39 | new_emb = SparseEmbedding( |
@@ -38,7 +42,7 @@ class SparseEmbedding(nn.Embedding): | |||
38 | self.scaling, | 42 | self.scaling, |
39 | self.dropout_p, | 43 | self.dropout_p, |
40 | device=self.weight.device, | 44 | device=self.weight.device, |
41 | dtype=self.weight.dtype | 45 | dtype=self.weight.dtype, |
42 | ) | 46 | ) |
43 | if initializer_factor is not None: | 47 | if initializer_factor is not None: |
44 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | 48 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index aa446ec..16b8456 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -21,7 +21,9 @@ from diffusers import ( | |||
21 | LMSDiscreteScheduler, | 21 | LMSDiscreteScheduler, |
22 | PNDMScheduler, | 22 | PNDMScheduler, |
23 | ) | 23 | ) |
24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( |
25 | StableDiffusionPipelineOutput, | ||
26 | ) | ||
25 | from diffusers.utils import logging, randn_tensor | 27 | from diffusers.utils import logging, randn_tensor |
26 | from transformers import CLIPTextModel, CLIPTokenizer | 28 | from transformers import CLIPTextModel, CLIPTokenizer |
27 | 29 | ||
@@ -62,13 +64,35 @@ def gaussian_blur_2d(img, kernel_size, sigma): | |||
62 | return img | 64 | return img |
63 | 65 | ||
64 | 66 | ||
67 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | ||
68 | """ | ||
69 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | ||
70 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | ||
71 | """ | ||
72 | std_text = noise_pred_text.std( | ||
73 | dim=list(range(1, noise_pred_text.ndim)), keepdim=True | ||
74 | ) | ||
75 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | ||
76 | # rescale the results from guidance (fixes overexposure) | ||
77 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | ||
78 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | ||
79 | noise_cfg = ( | ||
80 | guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | ||
81 | ) | ||
82 | return noise_cfg | ||
83 | |||
84 | |||
65 | class CrossAttnStoreProcessor: | 85 | class CrossAttnStoreProcessor: |
66 | def __init__(self): | 86 | def __init__(self): |
67 | self.attention_probs = None | 87 | self.attention_probs = None |
68 | 88 | ||
69 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): | 89 | def __call__( |
90 | self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None | ||
91 | ): | ||
70 | batch_size, sequence_length, _ = hidden_states.shape | 92 | batch_size, sequence_length, _ = hidden_states.shape |
71 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | 93 | attention_mask = attn.prepare_attention_mask( |
94 | attention_mask, sequence_length, batch_size | ||
95 | ) | ||
72 | query = attn.to_q(hidden_states) | 96 | query = attn.to_q(hidden_states) |
73 | 97 | ||
74 | if encoder_hidden_states is None: | 98 | if encoder_hidden_states is None: |
@@ -113,7 +137,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
113 | ): | 137 | ): |
114 | super().__init__() | 138 | super().__init__() |
115 | 139 | ||
116 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | 140 | if ( |
141 | hasattr(scheduler.config, "steps_offset") | ||
142 | and scheduler.config.steps_offset != 1 | ||
143 | ): | ||
117 | warnings.warn( | 144 | warnings.warn( |
118 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | 145 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" |
119 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | 146 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " |
@@ -179,7 +206,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
179 | 206 | ||
180 | device = torch.device("cuda") | 207 | device = torch.device("cuda") |
181 | 208 | ||
182 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: | 209 | for cpu_offloaded_model in [ |
210 | self.unet, | ||
211 | self.text_encoder, | ||
212 | self.vae, | ||
213 | self.safety_checker, | ||
214 | ]: | ||
183 | if cpu_offloaded_model is not None: | 215 | if cpu_offloaded_model is not None: |
184 | cpu_offload(cpu_offloaded_model, device) | 216 | cpu_offload(cpu_offloaded_model, device) |
185 | 217 | ||
@@ -223,35 +255,47 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
223 | width: int, | 255 | width: int, |
224 | height: int, | 256 | height: int, |
225 | strength: float, | 257 | strength: float, |
226 | callback_steps: Optional[int] | 258 | callback_steps: Optional[int], |
227 | ): | 259 | ): |
228 | if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): | 260 | if isinstance(prompt, str) or ( |
261 | isinstance(prompt, list) and isinstance(prompt[0], int) | ||
262 | ): | ||
229 | prompt = [prompt] | 263 | prompt = [prompt] |
230 | 264 | ||
231 | if negative_prompt is None: | 265 | if negative_prompt is None: |
232 | negative_prompt = "" | 266 | negative_prompt = "" |
233 | 267 | ||
234 | if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): | 268 | if isinstance(negative_prompt, str) or ( |
269 | isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int) | ||
270 | ): | ||
235 | negative_prompt = [negative_prompt] * len(prompt) | 271 | negative_prompt = [negative_prompt] * len(prompt) |
236 | 272 | ||
237 | if not isinstance(prompt, list): | 273 | if not isinstance(prompt, list): |
238 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | 274 | raise ValueError( |
275 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" | ||
276 | ) | ||
239 | 277 | ||
240 | if not isinstance(negative_prompt, list): | 278 | if not isinstance(negative_prompt, list): |
241 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | 279 | raise ValueError( |
280 | f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}" | ||
281 | ) | ||
242 | 282 | ||
243 | if len(negative_prompt) != len(prompt): | 283 | if len(negative_prompt) != len(prompt): |
244 | raise ValueError( | 284 | raise ValueError( |
245 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | 285 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}" |
286 | ) | ||
246 | 287 | ||
247 | if strength < 0 or strength > 1: | 288 | if strength < 0 or strength > 1: |
248 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | 289 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") |
249 | 290 | ||
250 | if height % 8 != 0 or width % 8 != 0: | 291 | if height % 8 != 0 or width % 8 != 0: |
251 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 292 | raise ValueError( |
293 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." | ||
294 | ) | ||
252 | 295 | ||
253 | if (callback_steps is None) or ( | 296 | if (callback_steps is None) or ( |
254 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) | 297 | callback_steps is not None |
298 | and (not isinstance(callback_steps, int) or callback_steps <= 0) | ||
255 | ): | 299 | ): |
256 | raise ValueError( | 300 | raise ValueError( |
257 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | 301 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
@@ -266,7 +310,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
266 | negative_prompt: Union[List[str], List[List[int]]], | 310 | negative_prompt: Union[List[str], List[List[int]]], |
267 | num_images_per_prompt: int, | 311 | num_images_per_prompt: int, |
268 | do_classifier_free_guidance: bool, | 312 | do_classifier_free_guidance: bool, |
269 | device | 313 | device, |
270 | ): | 314 | ): |
271 | if isinstance(prompt[0], str): | 315 | if isinstance(prompt[0], str): |
272 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids | 316 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids |
@@ -277,7 +321,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
277 | 321 | ||
278 | if do_classifier_free_guidance: | 322 | if do_classifier_free_guidance: |
279 | if isinstance(prompt[0], str): | 323 | if isinstance(prompt[0], str): |
280 | unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids | 324 | unconditional_input_ids = self.tokenizer( |
325 | negative_prompt, padding="do_not_pad" | ||
326 | ).input_ids | ||
281 | else: | 327 | else: |
282 | unconditional_input_ids = negative_prompt | 328 | unconditional_input_ids = negative_prompt |
283 | unconditional_input_ids *= num_images_per_prompt | 329 | unconditional_input_ids *= num_images_per_prompt |
@@ -286,12 +332,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
286 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) | 332 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) |
287 | text_input_ids = text_inputs.input_ids | 333 | text_input_ids = text_inputs.input_ids |
288 | 334 | ||
289 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 335 | if ( |
336 | hasattr(self.text_encoder.config, "use_attention_mask") | ||
337 | and self.text_encoder.config.use_attention_mask | ||
338 | ): | ||
290 | attention_mask = text_inputs.attention_mask.to(device) | 339 | attention_mask = text_inputs.attention_mask.to(device) |
291 | else: | 340 | else: |
292 | attention_mask = None | 341 | attention_mask = None |
293 | 342 | ||
294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) | 343 | prompt_embeds = get_extended_embeddings( |
344 | self.text_encoder, text_input_ids.to(device), attention_mask | ||
345 | ) | ||
295 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | 346 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) |
296 | 347 | ||
297 | return prompt_embeds | 348 | return prompt_embeds |
@@ -301,25 +352,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
301 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | 352 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
302 | 353 | ||
303 | t_start = max(num_inference_steps - init_timestep, 0) | 354 | t_start = max(num_inference_steps - init_timestep, 0) |
304 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] | 355 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] |
305 | 356 | ||
306 | timesteps = timesteps.to(device) | 357 | timesteps = timesteps.to(device) |
307 | 358 | ||
308 | return timesteps, num_inference_steps - t_start | 359 | return timesteps, num_inference_steps - t_start |
309 | 360 | ||
310 | def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): | 361 | def prepare_latents_from_image( |
311 | offset_image = perlin_noise( | 362 | self, |
312 | (batch_size, 1, width, height), | 363 | init_image, |
313 | res=1, | 364 | timestep, |
314 | generator=generator, | 365 | batch_size, |
315 | dtype=dtype, | 366 | dtype, |
316 | device=device | 367 | device, |
317 | ) | 368 | generator=None, |
318 | offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) | 369 | ): |
319 | offset_latents = self.vae.config.scaling_factor * offset_latents | ||
320 | return offset_latents | ||
321 | |||
322 | def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): | ||
323 | init_image = init_image.to(device=device, dtype=dtype) | 370 | init_image = init_image.to(device=device, dtype=dtype) |
324 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) | 371 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
325 | latents = self.vae.config.scaling_factor * latents | 372 | latents = self.vae.config.scaling_factor * latents |
@@ -333,20 +380,32 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
333 | latents = torch.cat([latents] * batch_multiplier, dim=0) | 380 | latents = torch.cat([latents] * batch_multiplier, dim=0) |
334 | 381 | ||
335 | # add noise to latents using the timesteps | 382 | # add noise to latents using the timesteps |
336 | noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) | 383 | noise = torch.randn( |
337 | 384 | latents.shape, generator=generator, device=device, dtype=dtype | |
338 | if brightness_offset != 0: | 385 | ) |
339 | noise += brightness_offset * self.prepare_brightness_offset( | ||
340 | batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator | ||
341 | ) | ||
342 | 386 | ||
343 | # get latents | 387 | # get latents |
344 | latents = self.scheduler.add_noise(latents, noise, timestep) | 388 | latents = self.scheduler.add_noise(latents, noise, timestep) |
345 | 389 | ||
346 | return latents | 390 | return latents |
347 | 391 | ||
348 | def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): | 392 | def prepare_latents( |
349 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | 393 | self, |
394 | batch_size, | ||
395 | num_channels_latents, | ||
396 | height, | ||
397 | width, | ||
398 | dtype, | ||
399 | device, | ||
400 | generator, | ||
401 | latents=None, | ||
402 | ): | ||
403 | shape = ( | ||
404 | batch_size, | ||
405 | num_channels_latents, | ||
406 | height // self.vae_scale_factor, | ||
407 | width // self.vae_scale_factor, | ||
408 | ) | ||
350 | if isinstance(generator, list) and len(generator) != batch_size: | 409 | if isinstance(generator, list) and len(generator) != batch_size: |
351 | raise ValueError( | 410 | raise ValueError( |
352 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | 411 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
@@ -354,15 +413,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
354 | ) | 413 | ) |
355 | 414 | ||
356 | if latents is None: | 415 | if latents is None: |
357 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | 416 | latents = randn_tensor( |
417 | shape, generator=generator, device=device, dtype=dtype | ||
418 | ) | ||
358 | else: | 419 | else: |
359 | latents = latents.to(device) | 420 | latents = latents.to(device) |
360 | 421 | ||
361 | if brightness_offset != 0: | ||
362 | latents += brightness_offset * self.prepare_brightness_offset( | ||
363 | batch_size, height, width, dtype, device, generator | ||
364 | ) | ||
365 | |||
366 | # scale the initial noise by the standard deviation required by the scheduler | 422 | # scale the initial noise by the standard deviation required by the scheduler |
367 | latents = latents * self.scheduler.init_noise_sigma | 423 | latents = latents * self.scheduler.init_noise_sigma |
368 | return latents | 424 | return latents |
@@ -373,13 +429,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
373 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | 429 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
374 | # and should be between [0, 1] | 430 | # and should be between [0, 1] |
375 | 431 | ||
376 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 432 | accepts_eta = "eta" in set( |
433 | inspect.signature(self.scheduler.step).parameters.keys() | ||
434 | ) | ||
377 | extra_step_kwargs = {} | 435 | extra_step_kwargs = {} |
378 | if accepts_eta: | 436 | if accepts_eta: |
379 | extra_step_kwargs["eta"] = eta | 437 | extra_step_kwargs["eta"] = eta |
380 | 438 | ||
381 | # check if the scheduler accepts generator | 439 | # check if the scheduler accepts generator |
382 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 440 | accepts_generator = "generator" in set( |
441 | inspect.signature(self.scheduler.step).parameters.keys() | ||
442 | ) | ||
383 | if accepts_generator: | 443 | if accepts_generator: |
384 | extra_step_kwargs["generator"] = generator | 444 | extra_step_kwargs["generator"] = generator |
385 | return extra_step_kwargs | 445 | return extra_step_kwargs |
@@ -396,7 +456,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
396 | def __call__( | 456 | def __call__( |
397 | self, | 457 | self, |
398 | prompt: Union[str, List[str], List[int], List[List[int]]], | 458 | prompt: Union[str, List[str], List[int], List[List[int]]], |
399 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, | 459 | negative_prompt: Optional[ |
460 | Union[str, List[str], List[int], List[List[int]]] | ||
461 | ] = None, | ||
400 | num_images_per_prompt: int = 1, | 462 | num_images_per_prompt: int = 1, |
401 | strength: float = 1.0, | 463 | strength: float = 1.0, |
402 | height: Optional[int] = None, | 464 | height: Optional[int] = None, |
@@ -407,12 +469,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
407 | eta: float = 0.0, | 469 | eta: float = 0.0, |
408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 470 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 471 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
410 | brightness_offset: Union[float, torch.FloatTensor] = 0, | ||
411 | output_type: str = "pil", | 472 | output_type: str = "pil", |
412 | return_dict: bool = True, | 473 | return_dict: bool = True, |
413 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 474 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
414 | callback_steps: int = 1, | 475 | callback_steps: int = 1, |
415 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, | 476 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
477 | guidance_rescale: float = 0.0, | ||
416 | ): | 478 | ): |
417 | r""" | 479 | r""" |
418 | Function invoked when calling the pipeline for generation. | 480 | Function invoked when calling the pipeline for generation. |
@@ -472,7 +534,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
472 | width = width or self.unet.config.sample_size * self.vae_scale_factor | 534 | width = width or self.unet.config.sample_size * self.vae_scale_factor |
473 | 535 | ||
474 | # 1. Check inputs. Raise error if not correct | 536 | # 1. Check inputs. Raise error if not correct |
475 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) | 537 | prompt, negative_prompt = self.check_inputs( |
538 | prompt, negative_prompt, width, height, strength, callback_steps | ||
539 | ) | ||
476 | 540 | ||
477 | # 2. Define call parameters | 541 | # 2. Define call parameters |
478 | batch_size = len(prompt) | 542 | batch_size = len(prompt) |
@@ -488,7 +552,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
488 | negative_prompt, | 552 | negative_prompt, |
489 | num_images_per_prompt, | 553 | num_images_per_prompt, |
490 | do_classifier_free_guidance, | 554 | do_classifier_free_guidance, |
491 | device | 555 | device, |
492 | ) | 556 | ) |
493 | 557 | ||
494 | # 4. Prepare latent variables | 558 | # 4. Prepare latent variables |
@@ -497,7 +561,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
497 | 561 | ||
498 | # 5. Prepare timesteps | 562 | # 5. Prepare timesteps |
499 | self.scheduler.set_timesteps(num_inference_steps, device=device) | 563 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
500 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | 564 | timesteps, num_inference_steps = self.get_timesteps( |
565 | num_inference_steps, strength, device | ||
566 | ) | ||
501 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 567 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
502 | 568 | ||
503 | # 6. Prepare latent variables | 569 | # 6. Prepare latent variables |
@@ -506,7 +572,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
506 | image, | 572 | image, |
507 | latent_timestep, | 573 | latent_timestep, |
508 | batch_size * num_images_per_prompt, | 574 | batch_size * num_images_per_prompt, |
509 | brightness_offset, | ||
510 | prompt_embeds.dtype, | 575 | prompt_embeds.dtype, |
511 | device, | 576 | device, |
512 | generator, | 577 | generator, |
@@ -517,7 +582,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
517 | num_channels_latents, | 582 | num_channels_latents, |
518 | height, | 583 | height, |
519 | width, | 584 | width, |
520 | brightness_offset, | ||
521 | prompt_embeds.dtype, | 585 | prompt_embeds.dtype, |
522 | device, | 586 | device, |
523 | generator, | 587 | generator, |
@@ -530,14 +594,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
530 | # 8. Denoising loo | 594 | # 8. Denoising loo |
531 | if do_self_attention_guidance: | 595 | if do_self_attention_guidance: |
532 | store_processor = CrossAttnStoreProcessor() | 596 | store_processor = CrossAttnStoreProcessor() |
533 | self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor | 597 | self.unet.mid_block.attentions[0].transformer_blocks[ |
598 | 0 | ||
599 | ].attn1.processor = store_processor | ||
534 | 600 | ||
535 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | 601 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
536 | with self.progress_bar(total=num_inference_steps) as progress_bar: | 602 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
537 | for i, t in enumerate(timesteps): | 603 | for i, t in enumerate(timesteps): |
538 | # expand the latents if we are doing classifier free guidance | 604 | # expand the latents if we are doing classifier free guidance |
539 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 605 | latent_model_input = ( |
540 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 606 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
607 | ) | ||
608 | latent_model_input = self.scheduler.scale_model_input( | ||
609 | latent_model_input, t | ||
610 | ) | ||
541 | 611 | ||
542 | # predict the noise residual | 612 | # predict the noise residual |
543 | noise_pred = self.unet( | 613 | noise_pred = self.unet( |
@@ -551,7 +621,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
551 | # perform guidance | 621 | # perform guidance |
552 | if do_classifier_free_guidance: | 622 | if do_classifier_free_guidance: |
553 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 623 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
554 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 624 | noise_pred = noise_pred_uncond + guidance_scale * ( |
625 | noise_pred_text - noise_pred_uncond | ||
626 | ) | ||
627 | noise_pred = rescale_noise_cfg( | ||
628 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale | ||
629 | ) | ||
555 | 630 | ||
556 | if do_self_attention_guidance: | 631 | if do_self_attention_guidance: |
557 | # classifier-free guidance produces two chunks of attention map | 632 | # classifier-free guidance produces two chunks of attention map |
@@ -561,15 +636,24 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
561 | # DDIM-like prediction of x0 | 636 | # DDIM-like prediction of x0 |
562 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) | 637 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) |
563 | # get the stored attention maps | 638 | # get the stored attention maps |
564 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) | 639 | uncond_attn, cond_attn = store_processor.attention_probs.chunk( |
640 | 2 | ||
641 | ) | ||
565 | # self-attention-based degrading of latents | 642 | # self-attention-based degrading of latents |
566 | degraded_latents = self.sag_masking( | 643 | degraded_latents = self.sag_masking( |
567 | pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) | 644 | pred_x0, |
645 | uncond_attn, | ||
646 | t, | ||
647 | self.pred_epsilon(latents, noise_pred_uncond, t), | ||
568 | ) | 648 | ) |
569 | uncond_emb, _ = prompt_embeds.chunk(2) | 649 | uncond_emb, _ = prompt_embeds.chunk(2) |
570 | # forward and give guidance | 650 | # forward and give guidance |
571 | degraded_pred = self.unet( | 651 | degraded_pred = self.unet( |
572 | degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] | 652 | degraded_latents, |
653 | t, | ||
654 | encoder_hidden_states=uncond_emb, | ||
655 | return_dict=False, | ||
656 | )[0] | ||
573 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | 657 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) |
574 | else: | 658 | else: |
575 | # DDIM-like prediction of x0 | 659 | # DDIM-like prediction of x0 |
@@ -578,18 +662,29 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
578 | cond_attn = store_processor.attention_probs | 662 | cond_attn = store_processor.attention_probs |
579 | # self-attention-based degrading of latents | 663 | # self-attention-based degrading of latents |
580 | degraded_latents = self.sag_masking( | 664 | degraded_latents = self.sag_masking( |
581 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) | 665 | pred_x0, |
666 | cond_attn, | ||
667 | t, | ||
668 | self.pred_epsilon(latents, noise_pred, t), | ||
582 | ) | 669 | ) |
583 | # forward and give guidance | 670 | # forward and give guidance |
584 | degraded_pred = self.unet( | 671 | degraded_pred = self.unet( |
585 | degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] | 672 | degraded_latents, |
673 | t, | ||
674 | encoder_hidden_states=prompt_embeds, | ||
675 | return_dict=False, | ||
676 | )[0] | ||
586 | noise_pred += sag_scale * (noise_pred - degraded_pred) | 677 | noise_pred += sag_scale * (noise_pred - degraded_pred) |
587 | 678 | ||
588 | # compute the previous noisy sample x_t -> x_t-1 | 679 | # compute the previous noisy sample x_t -> x_t-1 |
589 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | 680 | latents = self.scheduler.step( |
681 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False | ||
682 | )[0] | ||
590 | 683 | ||
591 | # call the callback, if provided | 684 | # call the callback, if provided |
592 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | 685 | if i == len(timesteps) - 1 or ( |
686 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | ||
687 | ): | ||
593 | progress_bar.update() | 688 | progress_bar.update() |
594 | if callback is not None and i % callback_steps == 0: | 689 | if callback is not None and i % callback_steps == 0: |
595 | callback(i, t, latents) | 690 | callback(i, t, latents) |
@@ -615,7 +710,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
615 | if not return_dict: | 710 | if not return_dict: |
616 | return (image, has_nsfw_concept) | 711 | return (image, has_nsfw_concept) |
617 | 712 | ||
618 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | 713 | return StableDiffusionPipelineOutput( |
714 | images=image, nsfw_content_detected=has_nsfw_concept | ||
715 | ) | ||
619 | 716 | ||
620 | # Self-Attention-Guided (SAG) Stable Diffusion | 717 | # Self-Attention-Guided (SAG) Stable Diffusion |
621 | 718 | ||
@@ -632,16 +729,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
632 | attn_map = attn_map.reshape(b, h, hw1, hw2) | 729 | attn_map = attn_map.reshape(b, h, hw1, hw2) |
633 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 | 730 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 |
634 | attn_mask = ( | 731 | attn_mask = ( |
635 | attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) | 732 | attn_mask.reshape(b, map_size, map_size) |
733 | .unsqueeze(1) | ||
734 | .repeat(1, latent_channel, 1, 1) | ||
735 | .type(attn_map.dtype) | ||
636 | ) | 736 | ) |
637 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) | 737 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) |
638 | 738 | ||
639 | # Blur according to the self-attention mask | 739 | # Blur according to the self-attention mask |
640 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) | 740 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) |
641 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) | 741 | degraded_latents = degraded_latents * attn_mask + original_latents * ( |
742 | 1 - attn_mask | ||
743 | ) | ||
642 | 744 | ||
643 | # Noise it again to match the noise level | 745 | # Noise it again to match the noise level |
644 | degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) | 746 | degraded_latents = self.scheduler.add_noise( |
747 | degraded_latents, noise=eps, timesteps=t | ||
748 | ) | ||
645 | 749 | ||
646 | return degraded_latents | 750 | return degraded_latents |
647 | 751 | ||
@@ -652,13 +756,19 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
652 | 756 | ||
653 | beta_prod_t = 1 - alpha_prod_t | 757 | beta_prod_t = 1 - alpha_prod_t |
654 | if self.scheduler.config.prediction_type == "epsilon": | 758 | if self.scheduler.config.prediction_type == "epsilon": |
655 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | 759 | pred_original_sample = ( |
760 | sample - beta_prod_t ** (0.5) * model_output | ||
761 | ) / alpha_prod_t ** (0.5) | ||
656 | elif self.scheduler.config.prediction_type == "sample": | 762 | elif self.scheduler.config.prediction_type == "sample": |
657 | pred_original_sample = model_output | 763 | pred_original_sample = model_output |
658 | elif self.scheduler.config.prediction_type == "v_prediction": | 764 | elif self.scheduler.config.prediction_type == "v_prediction": |
659 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | 765 | pred_original_sample = (alpha_prod_t**0.5) * sample - ( |
766 | beta_prod_t**0.5 | ||
767 | ) * model_output | ||
660 | # predict V | 768 | # predict V |
661 | model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | 769 | model_output = (alpha_prod_t**0.5) * model_output + ( |
770 | beta_prod_t**0.5 | ||
771 | ) * sample | ||
662 | else: | 772 | else: |
663 | raise ValueError( | 773 | raise ValueError( |
664 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 774 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |
@@ -674,9 +784,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
674 | if self.scheduler.config.prediction_type == "epsilon": | 784 | if self.scheduler.config.prediction_type == "epsilon": |
675 | pred_eps = model_output | 785 | pred_eps = model_output |
676 | elif self.scheduler.config.prediction_type == "sample": | 786 | elif self.scheduler.config.prediction_type == "sample": |
677 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) | 787 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / ( |
788 | beta_prod_t**0.5 | ||
789 | ) | ||
678 | elif self.scheduler.config.prediction_type == "v_prediction": | 790 | elif self.scheduler.config.prediction_type == "v_prediction": |
679 | pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output | 791 | pred_eps = (beta_prod_t**0.5) * sample + ( |
792 | alpha_prod_t**0.5 | ||
793 | ) * model_output | ||
680 | else: | 794 | else: |
681 | raise ValueError( | 795 | raise ValueError( |
682 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 796 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2aca1e7..659b84c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -5,34 +5,70 @@ import itertools | |||
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from functools import partial | 6 | from functools import partial |
7 | import math | 7 | import math |
8 | import warnings | ||
8 | 9 | ||
9 | import torch | 10 | import torch |
11 | import torch._dynamo | ||
10 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
13 | import hidet | ||
11 | 14 | ||
12 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
13 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
15 | from slugify import slugify | 18 | |
19 | # from diffusers.models.attention_processor import AttnProcessor | ||
20 | from diffusers.utils.import_utils import is_xformers_available | ||
16 | import transformers | 21 | import transformers |
17 | 22 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 23 | import numpy as np |
24 | from slugify import slugify | ||
25 | |||
19 | from data.csv import VlpnDataModule, keyword_filter | 26 | from data.csv import VlpnDataModule, keyword_filter |
20 | from training.functional import train, get_models | 27 | from models.clip.embeddings import patch_managed_embeddings |
28 | from training.functional import train, add_placeholder_tokens, get_models | ||
21 | from training.strategy.dreambooth import dreambooth_strategy | 29 | from training.strategy.dreambooth import dreambooth_strategy |
22 | from training.optimization import get_scheduler | 30 | from training.optimization import get_scheduler |
23 | from training.util import save_args | 31 | from training.sampler import create_named_schedule_sampler |
32 | from training.util import AverageMeter, save_args | ||
33 | from util.files import load_config, load_embeddings_from_dir | ||
34 | |||
24 | 35 | ||
25 | logger = get_logger(__name__) | 36 | logger = get_logger(__name__) |
26 | 37 | ||
38 | warnings.filterwarnings("ignore") | ||
39 | |||
27 | 40 | ||
28 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
29 | torch.backends.cudnn.benchmark = True | 42 | torch.backends.cudnn.benchmark = True |
30 | 43 | ||
44 | # torch._dynamo.config.log_level = logging.WARNING | ||
45 | torch._dynamo.config.suppress_errors = True | ||
46 | |||
47 | hidet.torch.dynamo_config.use_tensor_core(True) | ||
48 | hidet.torch.dynamo_config.search_space(0) | ||
49 | |||
50 | |||
51 | def patch_xformers(dtype): | ||
52 | if is_xformers_available(): | ||
53 | import xformers | ||
54 | import xformers.ops | ||
55 | |||
56 | orig_xformers_memory_efficient_attention = ( | ||
57 | xformers.ops.memory_efficient_attention | ||
58 | ) | ||
59 | |||
60 | def xformers_memory_efficient_attention( | ||
61 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
62 | ): | ||
63 | return orig_xformers_memory_efficient_attention( | ||
64 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
65 | ) | ||
66 | |||
67 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
68 | |||
31 | 69 | ||
32 | def parse_args(): | 70 | def parse_args(): |
33 | parser = argparse.ArgumentParser( | 71 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
34 | description="Simple example of a training script." | ||
35 | ) | ||
36 | parser.add_argument( | 72 | parser.add_argument( |
37 | "--pretrained_model_name_or_path", | 73 | "--pretrained_model_name_or_path", |
38 | type=str, | 74 | type=str, |
@@ -49,7 +85,7 @@ def parse_args(): | |||
49 | "--train_data_file", | 85 | "--train_data_file", |
50 | type=str, | 86 | type=str, |
51 | default=None, | 87 | default=None, |
52 | help="A folder containing the training data." | 88 | help="A folder containing the training data.", |
53 | ) | 89 | ) |
54 | parser.add_argument( | 90 | parser.add_argument( |
55 | "--train_data_template", | 91 | "--train_data_template", |
@@ -60,13 +96,13 @@ def parse_args(): | |||
60 | "--train_set_pad", | 96 | "--train_set_pad", |
61 | type=int, | 97 | type=int, |
62 | default=None, | 98 | default=None, |
63 | help="The number to fill train dataset items up to." | 99 | help="The number to fill train dataset items up to.", |
64 | ) | 100 | ) |
65 | parser.add_argument( | 101 | parser.add_argument( |
66 | "--valid_set_pad", | 102 | "--valid_set_pad", |
67 | type=int, | 103 | type=int, |
68 | default=None, | 104 | default=None, |
69 | help="The number to fill validation dataset items up to." | 105 | help="The number to fill validation dataset items up to.", |
70 | ) | 106 | ) |
71 | parser.add_argument( | 107 | parser.add_argument( |
72 | "--project", | 108 | "--project", |
@@ -75,20 +111,58 @@ def parse_args(): | |||
75 | help="The name of the current project.", | 111 | help="The name of the current project.", |
76 | ) | 112 | ) |
77 | parser.add_argument( | 113 | parser.add_argument( |
78 | "--exclude_collections", | 114 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
115 | ) | ||
116 | parser.add_argument( | ||
117 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." | ||
118 | ) | ||
119 | parser.add_argument( | ||
120 | "--placeholder_tokens", | ||
79 | type=str, | 121 | type=str, |
80 | nargs='*', | 122 | nargs="*", |
81 | help="Exclude all items with a listed collection.", | 123 | help="A token to use as a placeholder for the concept.", |
82 | ) | 124 | ) |
83 | parser.add_argument( | 125 | parser.add_argument( |
84 | "--train_text_encoder_epochs", | 126 | "--initializer_tokens", |
85 | default=999999, | 127 | type=str, |
86 | help="Number of epochs the text encoder will be trained." | 128 | nargs="*", |
129 | help="A token to use as initializer word.", | ||
130 | ) | ||
131 | parser.add_argument( | ||
132 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." | ||
133 | ) | ||
134 | parser.add_argument( | ||
135 | "--initializer_noise", | ||
136 | type=float, | ||
137 | default=0, | ||
138 | help="Noise to apply to the initializer word", | ||
139 | ) | ||
140 | parser.add_argument( | ||
141 | "--alias_tokens", | ||
142 | type=str, | ||
143 | nargs="*", | ||
144 | default=[], | ||
145 | help="Tokens to create an alias for.", | ||
146 | ) | ||
147 | parser.add_argument( | ||
148 | "--inverted_initializer_tokens", | ||
149 | type=str, | ||
150 | nargs="*", | ||
151 | help="A token to use as initializer word.", | ||
152 | ) | ||
153 | parser.add_argument( | ||
154 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | ||
155 | ) | ||
156 | parser.add_argument( | ||
157 | "--exclude_collections", | ||
158 | type=str, | ||
159 | nargs="*", | ||
160 | help="Exclude all items with a listed collection.", | ||
87 | ) | 161 | ) |
88 | parser.add_argument( | 162 | parser.add_argument( |
89 | "--num_buckets", | 163 | "--num_buckets", |
90 | type=int, | 164 | type=int, |
91 | default=0, | 165 | default=2, |
92 | help="Number of aspect ratio buckets in either direction.", | 166 | help="Number of aspect ratio buckets in either direction.", |
93 | ) | 167 | ) |
94 | parser.add_argument( | 168 | parser.add_argument( |
@@ -120,19 +194,6 @@ def parse_args(): | |||
120 | help="Shuffle tags.", | 194 | help="Shuffle tags.", |
121 | ) | 195 | ) |
122 | parser.add_argument( | 196 | parser.add_argument( |
123 | "--vector_dropout", | ||
124 | type=int, | ||
125 | default=0, | ||
126 | help="Vector dropout probability.", | ||
127 | ) | ||
128 | parser.add_argument( | ||
129 | "--vector_shuffle", | ||
130 | type=str, | ||
131 | default="auto", | ||
132 | choices=["all", "trailing", "leading", "between", "auto", "off"], | ||
133 | help='Vector shuffling algorithm.', | ||
134 | ) | ||
135 | parser.add_argument( | ||
136 | "--guidance_scale", | 197 | "--guidance_scale", |
137 | type=float, | 198 | type=float, |
138 | default=0, | 199 | default=0, |
@@ -141,7 +202,7 @@ def parse_args(): | |||
141 | "--num_class_images", | 202 | "--num_class_images", |
142 | type=int, | 203 | type=int, |
143 | default=0, | 204 | default=0, |
144 | help="How many class images to generate." | 205 | help="How many class images to generate.", |
145 | ) | 206 | ) |
146 | parser.add_argument( | 207 | parser.add_argument( |
147 | "--class_image_dir", | 208 | "--class_image_dir", |
@@ -162,16 +223,18 @@ def parse_args(): | |||
162 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 223 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
163 | ) | 224 | ) |
164 | parser.add_argument( | 225 | parser.add_argument( |
226 | "--train_dir_embeddings", | ||
227 | action="store_true", | ||
228 | help="Train embeddings loaded from embeddings directory.", | ||
229 | ) | ||
230 | parser.add_argument( | ||
165 | "--collection", | 231 | "--collection", |
166 | type=str, | 232 | type=str, |
167 | nargs='*', | 233 | nargs="*", |
168 | help="A collection to filter the dataset.", | 234 | help="A collection to filter the dataset.", |
169 | ) | 235 | ) |
170 | parser.add_argument( | 236 | parser.add_argument( |
171 | "--seed", | 237 | "--seed", type=int, default=None, help="A seed for reproducible training." |
172 | type=int, | ||
173 | default=None, | ||
174 | help="A seed for reproducible training." | ||
175 | ) | 238 | ) |
176 | parser.add_argument( | 239 | parser.add_argument( |
177 | "--resolution", | 240 | "--resolution", |
@@ -189,15 +252,13 @@ def parse_args(): | |||
189 | help="Perlin offset noise strength.", | 252 | help="Perlin offset noise strength.", |
190 | ) | 253 | ) |
191 | parser.add_argument( | 254 | parser.add_argument( |
192 | "--num_train_epochs", | 255 | "--input_pertubation", |
193 | type=int, | 256 | type=float, |
194 | default=None | 257 | default=0, |
195 | ) | 258 | help="The scale of input pretubation. Recommended 0.1.", |
196 | parser.add_argument( | ||
197 | "--num_train_steps", | ||
198 | type=int, | ||
199 | default=2000 | ||
200 | ) | 259 | ) |
260 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
261 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
201 | parser.add_argument( | 262 | parser.add_argument( |
202 | "--gradient_accumulation_steps", | 263 | "--gradient_accumulation_steps", |
203 | type=int, | 264 | type=int, |
@@ -205,9 +266,9 @@ def parse_args(): | |||
205 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 266 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
206 | ) | 267 | ) |
207 | parser.add_argument( | 268 | parser.add_argument( |
208 | "--gradient_checkpointing", | 269 | "--train_text_encoder_cycles", |
209 | action="store_true", | 270 | default=999999, |
210 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 271 | help="Number of epochs the text encoder will be trained.", |
211 | ) | 272 | ) |
212 | parser.add_argument( | 273 | parser.add_argument( |
213 | "--find_lr", | 274 | "--find_lr", |
@@ -215,9 +276,15 @@ def parse_args(): | |||
215 | help="Automatically find a learning rate (no training).", | 276 | help="Automatically find a learning rate (no training).", |
216 | ) | 277 | ) |
217 | parser.add_argument( | 278 | parser.add_argument( |
218 | "--learning_rate", | 279 | "--learning_rate_unet", |
280 | type=float, | ||
281 | default=1e-4, | ||
282 | help="Initial learning rate (after the potential warmup period) to use.", | ||
283 | ) | ||
284 | parser.add_argument( | ||
285 | "--learning_rate_text", | ||
219 | type=float, | 286 | type=float, |
220 | default=2e-6, | 287 | default=5e-5, |
221 | help="Initial learning rate (after the potential warmup period) to use.", | 288 | help="Initial learning rate (after the potential warmup period) to use.", |
222 | ) | 289 | ) |
223 | parser.add_argument( | 290 | parser.add_argument( |
@@ -229,27 +296,31 @@ def parse_args(): | |||
229 | "--lr_scheduler", | 296 | "--lr_scheduler", |
230 | type=str, | 297 | type=str, |
231 | default="one_cycle", | 298 | default="one_cycle", |
232 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 299 | choices=[ |
233 | "constant", "constant_with_warmup", "one_cycle"], | 300 | "linear", |
234 | help='The scheduler type to use.', | 301 | "cosine", |
302 | "cosine_with_restarts", | ||
303 | "polynomial", | ||
304 | "constant", | ||
305 | "constant_with_warmup", | ||
306 | "one_cycle", | ||
307 | ], | ||
308 | help="The scheduler type to use.", | ||
235 | ) | 309 | ) |
236 | parser.add_argument( | 310 | parser.add_argument( |
237 | "--lr_warmup_epochs", | 311 | "--lr_warmup_epochs", |
238 | type=int, | 312 | type=int, |
239 | default=10, | 313 | default=10, |
240 | help="Number of steps for the warmup in the lr scheduler." | 314 | help="Number of steps for the warmup in the lr scheduler.", |
241 | ) | 315 | ) |
242 | parser.add_argument( | 316 | parser.add_argument( |
243 | "--lr_mid_point", | 317 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
244 | type=float, | ||
245 | default=0.3, | ||
246 | help="OneCycle schedule mid point." | ||
247 | ) | 318 | ) |
248 | parser.add_argument( | 319 | parser.add_argument( |
249 | "--lr_cycles", | 320 | "--lr_cycles", |
250 | type=int, | 321 | type=int, |
251 | default=None, | 322 | default=None, |
252 | help="Number of restart cycles in the lr scheduler (if supported)." | 323 | help="Number of restart cycles in the lr scheduler (if supported).", |
253 | ) | 324 | ) |
254 | parser.add_argument( | 325 | parser.add_argument( |
255 | "--lr_warmup_func", | 326 | "--lr_warmup_func", |
@@ -261,7 +332,7 @@ def parse_args(): | |||
261 | "--lr_warmup_exp", | 332 | "--lr_warmup_exp", |
262 | type=int, | 333 | type=int, |
263 | default=1, | 334 | default=1, |
264 | help='If lr_warmup_func is "cos", exponent to modify the function' | 335 | help='If lr_warmup_func is "cos", exponent to modify the function', |
265 | ) | 336 | ) |
266 | parser.add_argument( | 337 | parser.add_argument( |
267 | "--lr_annealing_func", | 338 | "--lr_annealing_func", |
@@ -273,76 +344,76 @@ def parse_args(): | |||
273 | "--lr_annealing_exp", | 344 | "--lr_annealing_exp", |
274 | type=int, | 345 | type=int, |
275 | default=3, | 346 | default=3, |
276 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 347 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
277 | ) | 348 | ) |
278 | parser.add_argument( | 349 | parser.add_argument( |
279 | "--lr_min_lr", | 350 | "--lr_min_lr", |
280 | type=float, | 351 | type=float, |
281 | default=0.04, | 352 | default=0.04, |
282 | help="Minimum learning rate in the lr scheduler." | 353 | help="Minimum learning rate in the lr scheduler.", |
283 | ) | ||
284 | parser.add_argument( | ||
285 | "--use_ema", | ||
286 | action="store_true", | ||
287 | help="Whether to use EMA model." | ||
288 | ) | ||
289 | parser.add_argument( | ||
290 | "--ema_inv_gamma", | ||
291 | type=float, | ||
292 | default=1.0 | ||
293 | ) | 354 | ) |
355 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
294 | parser.add_argument( | 356 | parser.add_argument( |
295 | "--ema_power", | 357 | "--schedule_sampler", |
296 | type=float, | 358 | type=str, |
297 | default=6/7 | 359 | default="uniform", |
298 | ) | 360 | choices=["uniform", "loss-second-moment"], |
299 | parser.add_argument( | 361 | help="Noise schedule sampler.", |
300 | "--ema_max_decay", | ||
301 | type=float, | ||
302 | default=0.9999 | ||
303 | ) | 362 | ) |
304 | parser.add_argument( | 363 | parser.add_argument( |
305 | "--optimizer", | 364 | "--optimizer", |
306 | type=str, | 365 | type=str, |
307 | default="dadan", | 366 | default="adan", |
308 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 367 | choices=[ |
309 | help='Optimizer to use' | 368 | "adam", |
369 | "adam8bit", | ||
370 | "adan", | ||
371 | "lion", | ||
372 | "dadam", | ||
373 | "dadan", | ||
374 | "dlion", | ||
375 | "adafactor", | ||
376 | ], | ||
377 | help="Optimizer to use", | ||
310 | ) | 378 | ) |
311 | parser.add_argument( | 379 | parser.add_argument( |
312 | "--dadaptation_d0", | 380 | "--dadaptation_d0", |
313 | type=float, | 381 | type=float, |
314 | default=1e-6, | 382 | default=1e-6, |
315 | help="The d0 parameter for Dadaptation optimizers." | 383 | help="The d0 parameter for Dadaptation optimizers.", |
384 | ) | ||
385 | parser.add_argument( | ||
386 | "--dadaptation_growth_rate", | ||
387 | type=float, | ||
388 | default=math.inf, | ||
389 | help="The growth_rate parameter for Dadaptation optimizers.", | ||
316 | ) | 390 | ) |
317 | parser.add_argument( | 391 | parser.add_argument( |
318 | "--adam_beta1", | 392 | "--adam_beta1", |
319 | type=float, | 393 | type=float, |
320 | default=None, | 394 | default=None, |
321 | help="The beta1 parameter for the Adam optimizer." | 395 | help="The beta1 parameter for the Adam optimizer.", |
322 | ) | 396 | ) |
323 | parser.add_argument( | 397 | parser.add_argument( |
324 | "--adam_beta2", | 398 | "--adam_beta2", |
325 | type=float, | 399 | type=float, |
326 | default=None, | 400 | default=None, |
327 | help="The beta2 parameter for the Adam optimizer." | 401 | help="The beta2 parameter for the Adam optimizer.", |
328 | ) | 402 | ) |
329 | parser.add_argument( | 403 | parser.add_argument( |
330 | "--adam_weight_decay", | 404 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
331 | type=float, | ||
332 | default=1e-2, | ||
333 | help="Weight decay to use." | ||
334 | ) | 405 | ) |
335 | parser.add_argument( | 406 | parser.add_argument( |
336 | "--adam_epsilon", | 407 | "--adam_epsilon", |
337 | type=float, | 408 | type=float, |
338 | default=1e-08, | 409 | default=1e-08, |
339 | help="Epsilon value for the Adam optimizer" | 410 | help="Epsilon value for the Adam optimizer", |
340 | ) | 411 | ) |
341 | parser.add_argument( | 412 | parser.add_argument( |
342 | "--adam_amsgrad", | 413 | "--adam_amsgrad", |
343 | type=bool, | 414 | type=bool, |
344 | default=False, | 415 | default=False, |
345 | help="Amsgrad value for the Adam optimizer" | 416 | help="Amsgrad value for the Adam optimizer", |
346 | ) | 417 | ) |
347 | parser.add_argument( | 418 | parser.add_argument( |
348 | "--mixed_precision", | 419 | "--mixed_precision", |
@@ -356,12 +427,28 @@ def parse_args(): | |||
356 | ), | 427 | ), |
357 | ) | 428 | ) |
358 | parser.add_argument( | 429 | parser.add_argument( |
430 | "--compile_unet", | ||
431 | action="store_true", | ||
432 | help="Compile UNet with Torch Dynamo.", | ||
433 | ) | ||
434 | parser.add_argument( | ||
435 | "--use_xformers", | ||
436 | action="store_true", | ||
437 | help="Use xformers.", | ||
438 | ) | ||
439 | parser.add_argument( | ||
359 | "--sample_frequency", | 440 | "--sample_frequency", |
360 | type=int, | 441 | type=int, |
361 | default=1, | 442 | default=1, |
362 | help="How often to save a checkpoint and sample image", | 443 | help="How often to save a checkpoint and sample image", |
363 | ) | 444 | ) |
364 | parser.add_argument( | 445 | parser.add_argument( |
446 | "--sample_num", | ||
447 | type=int, | ||
448 | default=None, | ||
449 | help="How often to save a checkpoint and sample image (in number of samples)", | ||
450 | ) | ||
451 | parser.add_argument( | ||
365 | "--sample_image_size", | 452 | "--sample_image_size", |
366 | type=int, | 453 | type=int, |
367 | default=768, | 454 | default=768, |
@@ -383,19 +470,19 @@ def parse_args(): | |||
383 | "--valid_set_size", | 470 | "--valid_set_size", |
384 | type=int, | 471 | type=int, |
385 | default=None, | 472 | default=None, |
386 | help="Number of images in the validation dataset." | 473 | help="Number of images in the validation dataset.", |
387 | ) | 474 | ) |
388 | parser.add_argument( | 475 | parser.add_argument( |
389 | "--valid_set_repeat", | 476 | "--valid_set_repeat", |
390 | type=int, | 477 | type=int, |
391 | default=1, | 478 | default=1, |
392 | help="Times the images in the validation dataset are repeated." | 479 | help="Times the images in the validation dataset are repeated.", |
393 | ) | 480 | ) |
394 | parser.add_argument( | 481 | parser.add_argument( |
395 | "--train_batch_size", | 482 | "--train_batch_size", |
396 | type=int, | 483 | type=int, |
397 | default=1, | 484 | default=1, |
398 | help="Batch size (per device) for the training dataloader." | 485 | help="Batch size (per device) for the training dataloader.", |
399 | ) | 486 | ) |
400 | parser.add_argument( | 487 | parser.add_argument( |
401 | "--sample_steps", | 488 | "--sample_steps", |
@@ -407,13 +494,18 @@ def parse_args(): | |||
407 | "--prior_loss_weight", | 494 | "--prior_loss_weight", |
408 | type=float, | 495 | type=float, |
409 | default=1.0, | 496 | default=1.0, |
410 | help="The weight of prior preservation loss." | 497 | help="The weight of prior preservation loss.", |
411 | ) | 498 | ) |
499 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
500 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
412 | parser.add_argument( | 501 | parser.add_argument( |
413 | "--max_grad_norm", | 502 | "--emb_dropout", |
414 | default=1.0, | ||
415 | type=float, | 503 | type=float, |
416 | help="Max gradient norm." | 504 | default=0, |
505 | help="Embedding dropout probability.", | ||
506 | ) | ||
507 | parser.add_argument( | ||
508 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." | ||
417 | ) | 509 | ) |
418 | parser.add_argument( | 510 | parser.add_argument( |
419 | "--noise_timesteps", | 511 | "--noise_timesteps", |
@@ -424,7 +516,7 @@ def parse_args(): | |||
424 | "--config", | 516 | "--config", |
425 | type=str, | 517 | type=str, |
426 | default=None, | 518 | default=None, |
427 | help="Path to a JSON configuration file containing arguments for invoking this script." | 519 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
428 | ) | 520 | ) |
429 | 521 | ||
430 | args = parser.parse_args() | 522 | args = parser.parse_args() |
@@ -441,6 +533,67 @@ def parse_args(): | |||
441 | if args.project is None: | 533 | if args.project is None: |
442 | raise ValueError("You must specify --project") | 534 | raise ValueError("You must specify --project") |
443 | 535 | ||
536 | if args.initializer_tokens is None: | ||
537 | args.initializer_tokens = [] | ||
538 | |||
539 | if args.placeholder_tokens is None: | ||
540 | args.placeholder_tokens = [] | ||
541 | |||
542 | if isinstance(args.placeholder_tokens, str): | ||
543 | args.placeholder_tokens = [args.placeholder_tokens] | ||
544 | |||
545 | if isinstance(args.initializer_tokens, str): | ||
546 | args.initializer_tokens = [args.initializer_tokens] * len( | ||
547 | args.placeholder_tokens | ||
548 | ) | ||
549 | |||
550 | if len(args.placeholder_tokens) == 0: | ||
551 | args.placeholder_tokens = [ | ||
552 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
553 | ] | ||
554 | |||
555 | if len(args.initializer_tokens) == 0: | ||
556 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
557 | |||
558 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
559 | raise ValueError( | ||
560 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
561 | ) | ||
562 | |||
563 | if isinstance(args.inverted_initializer_tokens, str): | ||
564 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
565 | args.placeholder_tokens | ||
566 | ) | ||
567 | |||
568 | if ( | ||
569 | isinstance(args.inverted_initializer_tokens, list) | ||
570 | and len(args.inverted_initializer_tokens) != 0 | ||
571 | ): | ||
572 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
573 | args.initializer_tokens += args.inverted_initializer_tokens | ||
574 | |||
575 | if isinstance(args.num_vectors, int): | ||
576 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | ||
577 | |||
578 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( | ||
579 | args.num_vectors | ||
580 | ): | ||
581 | raise ValueError( | ||
582 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
583 | ) | ||
584 | |||
585 | if args.alias_tokens is None: | ||
586 | args.alias_tokens = [] | ||
587 | |||
588 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
589 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
590 | |||
591 | if args.filter_tokens is None: | ||
592 | args.filter_tokens = args.placeholder_tokens.copy() | ||
593 | |||
594 | if isinstance(args.filter_tokens, str): | ||
595 | args.filter_tokens = [args.filter_tokens] | ||
596 | |||
444 | if isinstance(args.collection, str): | 597 | if isinstance(args.collection, str): |
445 | args.collection = [args.collection] | 598 | args.collection = [args.collection] |
446 | 599 | ||
@@ -451,15 +604,15 @@ def parse_args(): | |||
451 | raise ValueError("You must specify --output_dir") | 604 | raise ValueError("You must specify --output_dir") |
452 | 605 | ||
453 | if args.adam_beta1 is None: | 606 | if args.adam_beta1 is None: |
454 | if args.optimizer in ('adam', 'adam8bit'): | 607 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
455 | args.adam_beta1 = 0.9 | 608 | args.adam_beta1 = 0.9 |
456 | elif args.optimizer == 'lion': | 609 | elif args.optimizer in ("lion", "dlion"): |
457 | args.adam_beta1 = 0.95 | 610 | args.adam_beta1 = 0.95 |
458 | 611 | ||
459 | if args.adam_beta2 is None: | 612 | if args.adam_beta2 is None: |
460 | if args.optimizer in ('adam', 'adam8bit'): | 613 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
461 | args.adam_beta2 = 0.999 | 614 | args.adam_beta2 = 0.999 |
462 | elif args.optimizer == 'lion': | 615 | elif args.optimizer in ("lion", "dlion"): |
463 | args.adam_beta2 = 0.98 | 616 | args.adam_beta2 = 0.98 |
464 | 617 | ||
465 | return args | 618 | return args |
@@ -475,7 +628,7 @@ def main(): | |||
475 | accelerator = Accelerator( | 628 | accelerator = Accelerator( |
476 | log_with=LoggerType.TENSORBOARD, | 629 | log_with=LoggerType.TENSORBOARD, |
477 | project_dir=f"{output_dir}", | 630 | project_dir=f"{output_dir}", |
478 | mixed_precision=args.mixed_precision | 631 | mixed_precision=args.mixed_precision, |
479 | ) | 632 | ) |
480 | 633 | ||
481 | weight_dtype = torch.float32 | 634 | weight_dtype = torch.float32 |
@@ -484,6 +637,8 @@ def main(): | |||
484 | elif args.mixed_precision == "bf16": | 637 | elif args.mixed_precision == "bf16": |
485 | weight_dtype = torch.bfloat16 | 638 | weight_dtype = torch.bfloat16 |
486 | 639 | ||
640 | patch_xformers(weight_dtype) | ||
641 | |||
487 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 642 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
488 | 643 | ||
489 | if args.seed is None: | 644 | if args.seed is None: |
@@ -493,44 +648,125 @@ def main(): | |||
493 | 648 | ||
494 | save_args(output_dir, args) | 649 | save_args(output_dir, args) |
495 | 650 | ||
496 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 651 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
497 | args.pretrained_model_name_or_path) | 652 | args.pretrained_model_name_or_path |
498 | 653 | ) | |
499 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 654 | embeddings = patch_managed_embeddings( |
500 | tokenizer.set_dropout(args.vector_dropout) | 655 | text_encoder, args.emb_alpha, args.emb_dropout |
656 | ) | ||
657 | schedule_sampler = create_named_schedule_sampler( | ||
658 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
659 | ) | ||
501 | 660 | ||
502 | vae.enable_slicing() | 661 | vae.enable_slicing() |
503 | vae.set_use_memory_efficient_attention_xformers(True) | 662 | |
504 | unet.enable_xformers_memory_efficient_attention() | 663 | if args.use_xformers: |
664 | vae.set_use_memory_efficient_attention_xformers(True) | ||
665 | unet.enable_xformers_memory_efficient_attention() | ||
666 | # elif args.compile_unet: | ||
667 | # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
668 | # | ||
669 | # proc = AttnProcessor() | ||
670 | # | ||
671 | # def fn_recursive_set_proc(module: torch.nn.Module): | ||
672 | # if hasattr(module, "processor"): | ||
673 | # module.processor = proc | ||
674 | # | ||
675 | # for child in module.children(): | ||
676 | # fn_recursive_set_proc(child) | ||
677 | # | ||
678 | # fn_recursive_set_proc(unet) | ||
505 | 679 | ||
506 | if args.gradient_checkpointing: | 680 | if args.gradient_checkpointing: |
507 | unet.enable_gradient_checkpointing() | 681 | unet.enable_gradient_checkpointing() |
508 | text_encoder.gradient_checkpointing_enable() | 682 | |
683 | if len(args.alias_tokens) != 0: | ||
684 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
685 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
686 | |||
687 | added_tokens, added_ids = add_placeholder_tokens( | ||
688 | tokenizer=tokenizer, | ||
689 | embeddings=embeddings, | ||
690 | placeholder_tokens=alias_placeholder_tokens, | ||
691 | initializer_tokens=alias_initializer_tokens, | ||
692 | ) | ||
693 | embeddings.persist() | ||
694 | print( | ||
695 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
696 | ) | ||
697 | |||
698 | placeholder_tokens = [] | ||
699 | placeholder_token_ids = [] | ||
509 | 700 | ||
510 | if args.embeddings_dir is not None: | 701 | if args.embeddings_dir is not None: |
511 | embeddings_dir = Path(args.embeddings_dir) | 702 | embeddings_dir = Path(args.embeddings_dir) |
512 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 703 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
513 | raise ValueError("--embeddings_dir must point to an existing directory") | 704 | raise ValueError("--embeddings_dir must point to an existing directory") |
514 | 705 | ||
515 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 706 | added_tokens, added_ids = load_embeddings_from_dir( |
516 | embeddings.persist() | 707 | tokenizer, embeddings, embeddings_dir |
517 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 708 | ) |
709 | |||
710 | placeholder_tokens = added_tokens | ||
711 | placeholder_token_ids = added_ids | ||
712 | |||
713 | print( | ||
714 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
715 | ) | ||
716 | |||
717 | if args.train_dir_embeddings: | ||
718 | print("Training embeddings from embeddings dir") | ||
719 | else: | ||
720 | embeddings.persist() | ||
721 | |||
722 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | ||
723 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
724 | tokenizer=tokenizer, | ||
725 | embeddings=embeddings, | ||
726 | placeholder_tokens=args.placeholder_tokens, | ||
727 | initializer_tokens=args.initializer_tokens, | ||
728 | num_vectors=args.num_vectors, | ||
729 | initializer_noise=args.initializer_noise, | ||
730 | ) | ||
731 | |||
732 | placeholder_tokens = args.placeholder_tokens | ||
733 | |||
734 | stats = list( | ||
735 | zip( | ||
736 | placeholder_tokens, | ||
737 | placeholder_token_ids, | ||
738 | args.initializer_tokens, | ||
739 | initializer_token_ids, | ||
740 | ) | ||
741 | ) | ||
742 | print(f"Training embeddings: {stats}") | ||
518 | 743 | ||
519 | if args.scale_lr: | 744 | if args.scale_lr: |
520 | args.learning_rate = ( | 745 | args.learning_rate_unet = ( |
521 | args.learning_rate * args.gradient_accumulation_steps * | 746 | args.learning_rate_unet |
522 | args.train_batch_size * accelerator.num_processes | 747 | * args.gradient_accumulation_steps |
748 | * args.train_batch_size | ||
749 | * accelerator.num_processes | ||
750 | ) | ||
751 | args.learning_rate_text = ( | ||
752 | args.learning_rate_text | ||
753 | * args.gradient_accumulation_steps | ||
754 | * args.train_batch_size | ||
755 | * accelerator.num_processes | ||
523 | ) | 756 | ) |
524 | 757 | ||
525 | if args.find_lr: | 758 | if args.find_lr: |
526 | args.learning_rate = 1e-6 | 759 | args.learning_rate_unet = 1e-6 |
760 | args.learning_rate_text = 1e-6 | ||
527 | args.lr_scheduler = "exponential_growth" | 761 | args.lr_scheduler = "exponential_growth" |
528 | 762 | ||
529 | if args.optimizer == 'adam8bit': | 763 | if args.optimizer == "adam8bit": |
530 | try: | 764 | try: |
531 | import bitsandbytes as bnb | 765 | import bitsandbytes as bnb |
532 | except ImportError: | 766 | except ImportError: |
533 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 767 | raise ImportError( |
768 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
769 | ) | ||
534 | 770 | ||
535 | create_optimizer = partial( | 771 | create_optimizer = partial( |
536 | bnb.optim.AdamW8bit, | 772 | bnb.optim.AdamW8bit, |
@@ -539,7 +775,7 @@ def main(): | |||
539 | eps=args.adam_epsilon, | 775 | eps=args.adam_epsilon, |
540 | amsgrad=args.adam_amsgrad, | 776 | amsgrad=args.adam_amsgrad, |
541 | ) | 777 | ) |
542 | elif args.optimizer == 'adam': | 778 | elif args.optimizer == "adam": |
543 | create_optimizer = partial( | 779 | create_optimizer = partial( |
544 | torch.optim.AdamW, | 780 | torch.optim.AdamW, |
545 | betas=(args.adam_beta1, args.adam_beta2), | 781 | betas=(args.adam_beta1, args.adam_beta2), |
@@ -547,22 +783,27 @@ def main(): | |||
547 | eps=args.adam_epsilon, | 783 | eps=args.adam_epsilon, |
548 | amsgrad=args.adam_amsgrad, | 784 | amsgrad=args.adam_amsgrad, |
549 | ) | 785 | ) |
550 | elif args.optimizer == 'adan': | 786 | elif args.optimizer == "adan": |
551 | try: | 787 | try: |
552 | import timm.optim | 788 | import timm.optim |
553 | except ImportError: | 789 | except ImportError: |
554 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 790 | raise ImportError( |
791 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
792 | ) | ||
555 | 793 | ||
556 | create_optimizer = partial( | 794 | create_optimizer = partial( |
557 | timm.optim.Adan, | 795 | timm.optim.Adan, |
558 | weight_decay=args.adam_weight_decay, | 796 | weight_decay=args.adam_weight_decay, |
559 | eps=args.adam_epsilon, | 797 | eps=args.adam_epsilon, |
798 | no_prox=True, | ||
560 | ) | 799 | ) |
561 | elif args.optimizer == 'lion': | 800 | elif args.optimizer == "lion": |
562 | try: | 801 | try: |
563 | import lion_pytorch | 802 | import lion_pytorch |
564 | except ImportError: | 803 | except ImportError: |
565 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 804 | raise ImportError( |
805 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
806 | ) | ||
566 | 807 | ||
567 | create_optimizer = partial( | 808 | create_optimizer = partial( |
568 | lion_pytorch.Lion, | 809 | lion_pytorch.Lion, |
@@ -570,7 +811,7 @@ def main(): | |||
570 | weight_decay=args.adam_weight_decay, | 811 | weight_decay=args.adam_weight_decay, |
571 | use_triton=True, | 812 | use_triton=True, |
572 | ) | 813 | ) |
573 | elif args.optimizer == 'adafactor': | 814 | elif args.optimizer == "adafactor": |
574 | create_optimizer = partial( | 815 | create_optimizer = partial( |
575 | transformers.optimization.Adafactor, | 816 | transformers.optimization.Adafactor, |
576 | weight_decay=args.adam_weight_decay, | 817 | weight_decay=args.adam_weight_decay, |
@@ -580,13 +821,16 @@ def main(): | |||
580 | ) | 821 | ) |
581 | 822 | ||
582 | args.lr_scheduler = "adafactor" | 823 | args.lr_scheduler = "adafactor" |
583 | args.lr_min_lr = args.learning_rate | 824 | args.lr_min_lr = args.learning_rate_unet |
584 | args.learning_rate = None | 825 | args.learning_rate_unet = None |
585 | elif args.optimizer == 'dadam': | 826 | args.learning_rate_text = None |
827 | elif args.optimizer == "dadam": | ||
586 | try: | 828 | try: |
587 | import dadaptation | 829 | import dadaptation |
588 | except ImportError: | 830 | except ImportError: |
589 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 831 | raise ImportError( |
832 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
833 | ) | ||
590 | 834 | ||
591 | create_optimizer = partial( | 835 | create_optimizer = partial( |
592 | dadaptation.DAdaptAdam, | 836 | dadaptation.DAdaptAdam, |
@@ -595,46 +839,65 @@ def main(): | |||
595 | eps=args.adam_epsilon, | 839 | eps=args.adam_epsilon, |
596 | decouple=True, | 840 | decouple=True, |
597 | d0=args.dadaptation_d0, | 841 | d0=args.dadaptation_d0, |
842 | growth_rate=args.dadaptation_growth_rate, | ||
598 | ) | 843 | ) |
599 | 844 | ||
600 | args.learning_rate = 1.0 | 845 | args.learning_rate_unet = 1.0 |
601 | elif args.optimizer == 'dadan': | 846 | args.learning_rate_text = 1.0 |
847 | elif args.optimizer == "dadan": | ||
602 | try: | 848 | try: |
603 | import dadaptation | 849 | import dadaptation |
604 | except ImportError: | 850 | except ImportError: |
605 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 851 | raise ImportError( |
852 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
853 | ) | ||
606 | 854 | ||
607 | create_optimizer = partial( | 855 | create_optimizer = partial( |
608 | dadaptation.DAdaptAdan, | 856 | dadaptation.DAdaptAdan, |
609 | weight_decay=args.adam_weight_decay, | 857 | weight_decay=args.adam_weight_decay, |
610 | eps=args.adam_epsilon, | 858 | eps=args.adam_epsilon, |
611 | d0=args.dadaptation_d0, | 859 | d0=args.dadaptation_d0, |
860 | growth_rate=args.dadaptation_growth_rate, | ||
612 | ) | 861 | ) |
613 | 862 | ||
614 | args.learning_rate = 1.0 | 863 | args.learning_rate_unet = 1.0 |
864 | args.learning_rate_text = 1.0 | ||
865 | elif args.optimizer == "dlion": | ||
866 | raise ImportError("DLion has not been merged into dadaptation yet") | ||
615 | else: | 867 | else: |
616 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 868 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
617 | 869 | ||
618 | trainer = partial( | 870 | trainer = partial( |
619 | train, | 871 | train, |
620 | accelerator=accelerator, | 872 | accelerator=accelerator, |
621 | unet=unet, | 873 | unet=unet, |
622 | text_encoder=text_encoder, | 874 | text_encoder=text_encoder, |
875 | tokenizer=tokenizer, | ||
623 | vae=vae, | 876 | vae=vae, |
624 | noise_scheduler=noise_scheduler, | 877 | noise_scheduler=noise_scheduler, |
878 | schedule_sampler=schedule_sampler, | ||
879 | min_snr_gamma=args.min_snr_gamma, | ||
625 | dtype=weight_dtype, | 880 | dtype=weight_dtype, |
881 | seed=args.seed, | ||
882 | compile_unet=args.compile_unet, | ||
626 | guidance_scale=args.guidance_scale, | 883 | guidance_scale=args.guidance_scale, |
627 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 884 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
628 | no_val=args.valid_set_size == 0, | 885 | sample_scheduler=sample_scheduler, |
886 | sample_batch_size=args.sample_batch_size, | ||
887 | sample_num_batches=args.sample_batches, | ||
888 | sample_num_steps=args.sample_steps, | ||
889 | sample_image_size=args.sample_image_size, | ||
890 | max_grad_norm=args.max_grad_norm, | ||
629 | ) | 891 | ) |
630 | 892 | ||
631 | checkpoint_output_dir = output_dir / "model" | 893 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
632 | sample_output_dir = output_dir / "samples" | 894 | data_npgenerator = np.random.default_rng(args.seed) |
633 | 895 | ||
634 | datamodule = VlpnDataModule( | 896 | create_datamodule = partial( |
897 | VlpnDataModule, | ||
635 | data_file=args.train_data_file, | 898 | data_file=args.train_data_file, |
636 | batch_size=args.train_batch_size, | ||
637 | tokenizer=tokenizer, | 899 | tokenizer=tokenizer, |
900 | constant_prompt_length=args.compile_unet, | ||
638 | class_subdir=args.class_image_dir, | 901 | class_subdir=args.class_image_dir, |
639 | with_guidance=args.guidance_scale != 0, | 902 | with_guidance=args.guidance_scale != 0, |
640 | num_class_images=args.num_class_images, | 903 | num_class_images=args.num_class_images, |
@@ -643,83 +906,186 @@ def main(): | |||
643 | progressive_buckets=args.progressive_buckets, | 906 | progressive_buckets=args.progressive_buckets, |
644 | bucket_step_size=args.bucket_step_size, | 907 | bucket_step_size=args.bucket_step_size, |
645 | bucket_max_pixels=args.bucket_max_pixels, | 908 | bucket_max_pixels=args.bucket_max_pixels, |
646 | dropout=args.tag_dropout, | ||
647 | shuffle=not args.no_tag_shuffle, | 909 | shuffle=not args.no_tag_shuffle, |
648 | template_key=args.train_data_template, | 910 | template_key=args.train_data_template, |
649 | valid_set_size=args.valid_set_size, | ||
650 | train_set_pad=args.train_set_pad, | 911 | train_set_pad=args.train_set_pad, |
651 | valid_set_pad=args.valid_set_pad, | 912 | valid_set_pad=args.valid_set_pad, |
652 | seed=args.seed, | 913 | dtype=weight_dtype, |
653 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 914 | generator=data_generator, |
654 | dtype=weight_dtype | 915 | npgenerator=data_npgenerator, |
655 | ) | ||
656 | datamodule.setup() | ||
657 | |||
658 | num_train_epochs = args.num_train_epochs | ||
659 | sample_frequency = args.sample_frequency | ||
660 | if num_train_epochs is None: | ||
661 | num_train_epochs = math.ceil( | ||
662 | args.num_train_steps / len(datamodule.train_dataset) | ||
663 | ) * args.gradient_accumulation_steps | ||
664 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | ||
665 | |||
666 | params_to_optimize = (unet.parameters(), ) | ||
667 | if args.train_text_encoder_epochs != 0: | ||
668 | params_to_optimize += ( | ||
669 | text_encoder.text_model.encoder.parameters(), | ||
670 | text_encoder.text_model.final_layer_norm.parameters(), | ||
671 | ) | ||
672 | |||
673 | optimizer = create_optimizer( | ||
674 | itertools.chain(*params_to_optimize), | ||
675 | lr=args.learning_rate, | ||
676 | ) | 916 | ) |
677 | 917 | ||
678 | lr_scheduler = get_scheduler( | 918 | create_lr_scheduler = partial( |
679 | args.lr_scheduler, | 919 | get_scheduler, |
680 | optimizer=optimizer, | ||
681 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
682 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
683 | min_lr=args.lr_min_lr, | 920 | min_lr=args.lr_min_lr, |
684 | warmup_func=args.lr_warmup_func, | 921 | warmup_func=args.lr_warmup_func, |
685 | annealing_func=args.lr_annealing_func, | 922 | annealing_func=args.lr_annealing_func, |
686 | warmup_exp=args.lr_warmup_exp, | 923 | warmup_exp=args.lr_warmup_exp, |
687 | annealing_exp=args.lr_annealing_exp, | 924 | annealing_exp=args.lr_annealing_exp, |
688 | cycles=args.lr_cycles, | ||
689 | end_lr=1e2, | 925 | end_lr=1e2, |
690 | train_epochs=num_train_epochs, | ||
691 | warmup_epochs=args.lr_warmup_epochs, | ||
692 | mid_point=args.lr_mid_point, | 926 | mid_point=args.lr_mid_point, |
693 | ) | 927 | ) |
694 | 928 | ||
695 | trainer( | 929 | # Dreambooth |
696 | strategy=dreambooth_strategy, | 930 | # -------------------------------------------------------------------------------- |
697 | project="dreambooth", | 931 | |
698 | train_dataloader=datamodule.train_dataloader, | 932 | dreambooth_datamodule = create_datamodule( |
699 | val_dataloader=datamodule.val_dataloader, | 933 | valid_set_size=args.valid_set_size, |
700 | seed=args.seed, | 934 | batch_size=args.train_batch_size, |
701 | optimizer=optimizer, | 935 | dropout=args.tag_dropout, |
702 | lr_scheduler=lr_scheduler, | 936 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
703 | num_train_epochs=num_train_epochs, | 937 | ) |
704 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 938 | dreambooth_datamodule.setup() |
705 | sample_frequency=sample_frequency, | 939 | |
706 | offset_noise_strength=args.offset_noise_strength, | 940 | num_train_epochs = args.num_train_epochs |
707 | # -- | 941 | dreambooth_sample_frequency = args.sample_frequency |
708 | tokenizer=tokenizer, | 942 | if num_train_epochs is None: |
709 | sample_scheduler=sample_scheduler, | 943 | num_train_epochs = ( |
710 | sample_output_dir=sample_output_dir, | 944 | math.ceil(args.num_train_steps / len(dreambooth_datamodule.train_dataset)) |
711 | checkpoint_output_dir=checkpoint_output_dir, | 945 | * args.gradient_accumulation_steps |
712 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 946 | ) |
713 | max_grad_norm=args.max_grad_norm, | 947 | dreambooth_sample_frequency = math.ceil( |
714 | use_ema=args.use_ema, | 948 | num_train_epochs * (dreambooth_sample_frequency / args.num_train_steps) |
715 | ema_inv_gamma=args.ema_inv_gamma, | 949 | ) |
716 | ema_power=args.ema_power, | 950 | num_training_steps_per_epoch = math.ceil( |
717 | ema_max_decay=args.ema_max_decay, | 951 | len(dreambooth_datamodule.train_dataset) / args.gradient_accumulation_steps |
718 | sample_batch_size=args.sample_batch_size, | ||
719 | sample_num_batches=args.sample_batches, | ||
720 | sample_num_steps=args.sample_steps, | ||
721 | sample_image_size=args.sample_image_size, | ||
722 | ) | 952 | ) |
953 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
954 | if args.sample_num is not None: | ||
955 | dreambooth_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | ||
956 | |||
957 | dreambooth_project = "dreambooth" | ||
958 | |||
959 | if accelerator.is_main_process: | ||
960 | accelerator.init_trackers(dreambooth_project) | ||
961 | |||
962 | dreambooth_sample_output_dir = output_dir / dreambooth_project / "samples" | ||
963 | |||
964 | training_iter = 0 | ||
965 | auto_cycles = list(args.auto_cycles) | ||
966 | learning_rate_unet = args.learning_rate_unet | ||
967 | learning_rate_text = args.learning_rate_text | ||
968 | lr_scheduler = args.lr_scheduler | ||
969 | lr_warmup_epochs = args.lr_warmup_epochs | ||
970 | lr_cycles = args.lr_cycles | ||
971 | |||
972 | avg_loss = AverageMeter() | ||
973 | avg_acc = AverageMeter() | ||
974 | avg_loss_val = AverageMeter() | ||
975 | avg_acc_val = AverageMeter() | ||
976 | |||
977 | params_to_optimize = [ | ||
978 | { | ||
979 | "params": (param for param in unet.parameters() if param.requires_grad), | ||
980 | "lr": learning_rate_unet, | ||
981 | }, | ||
982 | { | ||
983 | "params": ( | ||
984 | param for param in text_encoder.parameters() if param.requires_grad | ||
985 | ), | ||
986 | "lr": learning_rate_text, | ||
987 | }, | ||
988 | ] | ||
989 | group_labels = ["unet", "text"] | ||
990 | |||
991 | dreambooth_optimizer = create_optimizer(params_to_optimize) | ||
992 | |||
993 | while True: | ||
994 | if len(auto_cycles) != 0: | ||
995 | response = auto_cycles.pop(0) | ||
996 | else: | ||
997 | response = input( | ||
998 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " | ||
999 | ) | ||
1000 | |||
1001 | if response.lower().strip() == "o": | ||
1002 | if args.learning_rate_unet is not None: | ||
1003 | learning_rate_unet = ( | ||
1004 | args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) | ||
1005 | ) | ||
1006 | if args.learning_rate_text is not None: | ||
1007 | learning_rate_text = ( | ||
1008 | args.learning_rate_text * 2 * (args.cycle_decay**training_iter) | ||
1009 | ) | ||
1010 | else: | ||
1011 | learning_rate_unet = args.learning_rate_unet * ( | ||
1012 | args.cycle_decay**training_iter | ||
1013 | ) | ||
1014 | learning_rate_text = args.learning_rate_text * ( | ||
1015 | args.cycle_decay**training_iter | ||
1016 | ) | ||
1017 | |||
1018 | if response.lower().strip() == "o": | ||
1019 | lr_scheduler = "one_cycle" | ||
1020 | lr_warmup_epochs = args.lr_warmup_epochs | ||
1021 | lr_cycles = args.lr_cycles | ||
1022 | elif response.lower().strip() == "w": | ||
1023 | lr_scheduler = "constant_with_warmup" | ||
1024 | lr_warmup_epochs = num_train_epochs | ||
1025 | elif response.lower().strip() == "c": | ||
1026 | lr_scheduler = "constant" | ||
1027 | elif response.lower().strip() == "d": | ||
1028 | lr_scheduler = "cosine" | ||
1029 | lr_warmup_epochs = 0 | ||
1030 | lr_cycles = 1 | ||
1031 | elif response.lower().strip() == "s": | ||
1032 | break | ||
1033 | else: | ||
1034 | continue | ||
1035 | |||
1036 | print("") | ||
1037 | print( | ||
1038 | f"============ Dreambooth cycle {training_iter + 1}: {response} ============" | ||
1039 | ) | ||
1040 | print("") | ||
1041 | |||
1042 | for group, lr in zip( | ||
1043 | dreambooth_optimizer.param_groups, [learning_rate_unet, learning_rate_text] | ||
1044 | ): | ||
1045 | group["lr"] = lr | ||
1046 | |||
1047 | dreambooth_lr_scheduler = create_lr_scheduler( | ||
1048 | lr_scheduler, | ||
1049 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
1050 | optimizer=dreambooth_optimizer, | ||
1051 | num_training_steps_per_epoch=len(dreambooth_datamodule.train_dataloader), | ||
1052 | train_epochs=num_train_epochs, | ||
1053 | cycles=lr_cycles, | ||
1054 | warmup_epochs=lr_warmup_epochs, | ||
1055 | ) | ||
1056 | |||
1057 | dreambooth_checkpoint_output_dir = ( | ||
1058 | output_dir / dreambooth_project / f"model_{training_iter}" | ||
1059 | ) | ||
1060 | |||
1061 | trainer( | ||
1062 | strategy=dreambooth_strategy, | ||
1063 | train_dataloader=dreambooth_datamodule.train_dataloader, | ||
1064 | val_dataloader=dreambooth_datamodule.val_dataloader, | ||
1065 | optimizer=dreambooth_optimizer, | ||
1066 | lr_scheduler=dreambooth_lr_scheduler, | ||
1067 | num_train_epochs=num_train_epochs, | ||
1068 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
1069 | global_step_offset=training_iter * num_train_steps, | ||
1070 | cycle=training_iter, | ||
1071 | train_text_encoder_cycles=args.train_text_encoder_cycles, | ||
1072 | # -- | ||
1073 | group_labels=group_labels, | ||
1074 | sample_output_dir=dreambooth_sample_output_dir, | ||
1075 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | ||
1076 | sample_frequency=dreambooth_sample_frequency, | ||
1077 | offset_noise_strength=args.offset_noise_strength, | ||
1078 | input_pertubation=args.input_pertubation, | ||
1079 | no_val=args.valid_set_size == 0, | ||
1080 | avg_loss=avg_loss, | ||
1081 | avg_acc=avg_acc, | ||
1082 | avg_loss_val=avg_loss_val, | ||
1083 | avg_acc_val=avg_acc_val, | ||
1084 | ) | ||
1085 | |||
1086 | training_iter += 1 | ||
1087 | |||
1088 | accelerator.end_training() | ||
723 | 1089 | ||
724 | 1090 | ||
725 | if __name__ == "__main__": | 1091 | if __name__ == "__main__": |
diff --git a/train_lora.py b/train_lora.py index c74dd8f..fccf48d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -16,6 +16,7 @@ from accelerate import Accelerator | |||
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from peft import LoraConfig, get_peft_model | 18 | from peft import LoraConfig, get_peft_model |
19 | |||
19 | # from diffusers.models.attention_processor import AttnProcessor | 20 | # from diffusers.models.attention_processor import AttnProcessor |
20 | from diffusers.utils.import_utils import is_xformers_available | 21 | from diffusers.utils.import_utils import is_xformers_available |
21 | import transformers | 22 | import transformers |
@@ -34,15 +35,20 @@ from util.files import load_config, load_embeddings_from_dir | |||
34 | 35 | ||
35 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 36 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
36 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] | 37 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] |
37 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"] | 38 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0", "to_k", "key"] # [] |
38 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] | 39 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] |
39 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] | 40 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + [ |
40 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | 41 | "out_proj", |
42 | "k_proj", | ||
43 | ] # [] | ||
44 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + [ | ||
45 | "token_embedding" | ||
46 | ] | ||
41 | 47 | ||
42 | 48 | ||
43 | logger = get_logger(__name__) | 49 | logger = get_logger(__name__) |
44 | 50 | ||
45 | warnings.filterwarnings('ignore') | 51 | warnings.filterwarnings("ignore") |
46 | 52 | ||
47 | 53 | ||
48 | torch.backends.cuda.matmul.allow_tf32 = True | 54 | torch.backends.cuda.matmul.allow_tf32 = True |
@@ -55,20 +61,27 @@ hidet.torch.dynamo_config.use_tensor_core(True) | |||
55 | hidet.torch.dynamo_config.search_space(0) | 61 | hidet.torch.dynamo_config.search_space(0) |
56 | 62 | ||
57 | 63 | ||
58 | if is_xformers_available(): | 64 | def patch_xformers(dtype): |
59 | import xformers | 65 | if is_xformers_available(): |
60 | import xformers.ops | 66 | import xformers |
61 | 67 | import xformers.ops | |
62 | orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention | 68 | |
63 | def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs): | 69 | orig_xformers_memory_efficient_attention = ( |
64 | return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs) | 70 | xformers.ops.memory_efficient_attention |
65 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | 71 | ) |
72 | |||
73 | def xformers_memory_efficient_attention( | ||
74 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
75 | ): | ||
76 | return orig_xformers_memory_efficient_attention( | ||
77 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
78 | ) | ||
79 | |||
80 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
66 | 81 | ||
67 | 82 | ||
68 | def parse_args(): | 83 | def parse_args(): |
69 | parser = argparse.ArgumentParser( | 84 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
70 | description="Simple example of a training script." | ||
71 | ) | ||
72 | parser.add_argument( | 85 | parser.add_argument( |
73 | "--pretrained_model_name_or_path", | 86 | "--pretrained_model_name_or_path", |
74 | type=str, | 87 | type=str, |
@@ -85,7 +98,7 @@ def parse_args(): | |||
85 | "--train_data_file", | 98 | "--train_data_file", |
86 | type=str, | 99 | type=str, |
87 | default=None, | 100 | default=None, |
88 | help="A folder containing the training data." | 101 | help="A folder containing the training data.", |
89 | ) | 102 | ) |
90 | parser.add_argument( | 103 | parser.add_argument( |
91 | "--train_data_template", | 104 | "--train_data_template", |
@@ -96,13 +109,13 @@ def parse_args(): | |||
96 | "--train_set_pad", | 109 | "--train_set_pad", |
97 | type=int, | 110 | type=int, |
98 | default=None, | 111 | default=None, |
99 | help="The number to fill train dataset items up to." | 112 | help="The number to fill train dataset items up to.", |
100 | ) | 113 | ) |
101 | parser.add_argument( | 114 | parser.add_argument( |
102 | "--valid_set_pad", | 115 | "--valid_set_pad", |
103 | type=int, | 116 | type=int, |
104 | default=None, | 117 | default=None, |
105 | help="The number to fill validation dataset items up to." | 118 | help="The number to fill validation dataset items up to.", |
106 | ) | 119 | ) |
107 | parser.add_argument( | 120 | parser.add_argument( |
108 | "--project", | 121 | "--project", |
@@ -111,64 +124,52 @@ def parse_args(): | |||
111 | help="The name of the current project.", | 124 | help="The name of the current project.", |
112 | ) | 125 | ) |
113 | parser.add_argument( | 126 | parser.add_argument( |
114 | "--auto_cycles", | 127 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
115 | type=str, | ||
116 | default="o", | ||
117 | help="Cycles to run automatically." | ||
118 | ) | 128 | ) |
119 | parser.add_argument( | 129 | parser.add_argument( |
120 | "--cycle_decay", | 130 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." |
121 | type=float, | ||
122 | default=1.0, | ||
123 | help="Learning rate decay per cycle." | ||
124 | ) | 131 | ) |
125 | parser.add_argument( | 132 | parser.add_argument( |
126 | "--placeholder_tokens", | 133 | "--placeholder_tokens", |
127 | type=str, | 134 | type=str, |
128 | nargs='*', | 135 | nargs="*", |
129 | help="A token to use as a placeholder for the concept.", | 136 | help="A token to use as a placeholder for the concept.", |
130 | ) | 137 | ) |
131 | parser.add_argument( | 138 | parser.add_argument( |
132 | "--initializer_tokens", | 139 | "--initializer_tokens", |
133 | type=str, | 140 | type=str, |
134 | nargs='*', | 141 | nargs="*", |
135 | help="A token to use as initializer word." | 142 | help="A token to use as initializer word.", |
136 | ) | 143 | ) |
137 | parser.add_argument( | 144 | parser.add_argument( |
138 | "--filter_tokens", | 145 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." |
139 | type=str, | ||
140 | nargs='*', | ||
141 | help="Tokens to filter the dataset by." | ||
142 | ) | 146 | ) |
143 | parser.add_argument( | 147 | parser.add_argument( |
144 | "--initializer_noise", | 148 | "--initializer_noise", |
145 | type=float, | 149 | type=float, |
146 | default=0, | 150 | default=0, |
147 | help="Noise to apply to the initializer word" | 151 | help="Noise to apply to the initializer word", |
148 | ) | 152 | ) |
149 | parser.add_argument( | 153 | parser.add_argument( |
150 | "--alias_tokens", | 154 | "--alias_tokens", |
151 | type=str, | 155 | type=str, |
152 | nargs='*', | 156 | nargs="*", |
153 | default=[], | 157 | default=[], |
154 | help="Tokens to create an alias for." | 158 | help="Tokens to create an alias for.", |
155 | ) | 159 | ) |
156 | parser.add_argument( | 160 | parser.add_argument( |
157 | "--inverted_initializer_tokens", | 161 | "--inverted_initializer_tokens", |
158 | type=str, | 162 | type=str, |
159 | nargs='*', | 163 | nargs="*", |
160 | help="A token to use as initializer word." | 164 | help="A token to use as initializer word.", |
161 | ) | 165 | ) |
162 | parser.add_argument( | 166 | parser.add_argument( |
163 | "--num_vectors", | 167 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
164 | type=int, | ||
165 | nargs='*', | ||
166 | help="Number of vectors per embedding." | ||
167 | ) | 168 | ) |
168 | parser.add_argument( | 169 | parser.add_argument( |
169 | "--exclude_collections", | 170 | "--exclude_collections", |
170 | type=str, | 171 | type=str, |
171 | nargs='*', | 172 | nargs="*", |
172 | help="Exclude all items with a listed collection.", | 173 | help="Exclude all items with a listed collection.", |
173 | ) | 174 | ) |
174 | parser.add_argument( | 175 | parser.add_argument( |
@@ -214,7 +215,7 @@ def parse_args(): | |||
214 | "--num_class_images", | 215 | "--num_class_images", |
215 | type=int, | 216 | type=int, |
216 | default=0, | 217 | default=0, |
217 | help="How many class images to generate." | 218 | help="How many class images to generate.", |
218 | ) | 219 | ) |
219 | parser.add_argument( | 220 | parser.add_argument( |
220 | "--class_image_dir", | 221 | "--class_image_dir", |
@@ -242,14 +243,11 @@ def parse_args(): | |||
242 | parser.add_argument( | 243 | parser.add_argument( |
243 | "--collection", | 244 | "--collection", |
244 | type=str, | 245 | type=str, |
245 | nargs='*', | 246 | nargs="*", |
246 | help="A collection to filter the dataset.", | 247 | help="A collection to filter the dataset.", |
247 | ) | 248 | ) |
248 | parser.add_argument( | 249 | parser.add_argument( |
249 | "--seed", | 250 | "--seed", type=int, default=None, help="A seed for reproducible training." |
250 | type=int, | ||
251 | default=None, | ||
252 | help="A seed for reproducible training." | ||
253 | ) | 251 | ) |
254 | parser.add_argument( | 252 | parser.add_argument( |
255 | "--resolution", | 253 | "--resolution", |
@@ -270,18 +268,10 @@ def parse_args(): | |||
270 | "--input_pertubation", | 268 | "--input_pertubation", |
271 | type=float, | 269 | type=float, |
272 | default=0, | 270 | default=0, |
273 | help="The scale of input pretubation. Recommended 0.1." | 271 | help="The scale of input pretubation. Recommended 0.1.", |
274 | ) | ||
275 | parser.add_argument( | ||
276 | "--num_train_epochs", | ||
277 | type=int, | ||
278 | default=None | ||
279 | ) | ||
280 | parser.add_argument( | ||
281 | "--num_train_steps", | ||
282 | type=int, | ||
283 | default=2000 | ||
284 | ) | 272 | ) |
273 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
274 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
285 | parser.add_argument( | 275 | parser.add_argument( |
286 | "--gradient_accumulation_steps", | 276 | "--gradient_accumulation_steps", |
287 | type=int, | 277 | type=int, |
@@ -289,22 +279,19 @@ def parse_args(): | |||
289 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 279 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
290 | ) | 280 | ) |
291 | parser.add_argument( | 281 | parser.add_argument( |
292 | "--lora_r", | 282 | "--lora_r", type=int, default=8, help="Lora rank, only used if use_lora is True" |
293 | type=int, | ||
294 | default=8, | ||
295 | help="Lora rank, only used if use_lora is True" | ||
296 | ) | 283 | ) |
297 | parser.add_argument( | 284 | parser.add_argument( |
298 | "--lora_alpha", | 285 | "--lora_alpha", |
299 | type=int, | 286 | type=int, |
300 | default=32, | 287 | default=32, |
301 | help="Lora alpha, only used if use_lora is True" | 288 | help="Lora alpha, only used if use_lora is True", |
302 | ) | 289 | ) |
303 | parser.add_argument( | 290 | parser.add_argument( |
304 | "--lora_dropout", | 291 | "--lora_dropout", |
305 | type=float, | 292 | type=float, |
306 | default=0.0, | 293 | default=0.0, |
307 | help="Lora dropout, only used if use_lora is True" | 294 | help="Lora dropout, only used if use_lora is True", |
308 | ) | 295 | ) |
309 | parser.add_argument( | 296 | parser.add_argument( |
310 | "--lora_bias", | 297 | "--lora_bias", |
@@ -344,7 +331,7 @@ def parse_args(): | |||
344 | parser.add_argument( | 331 | parser.add_argument( |
345 | "--train_text_encoder_cycles", | 332 | "--train_text_encoder_cycles", |
346 | default=999999, | 333 | default=999999, |
347 | help="Number of epochs the text encoder will be trained." | 334 | help="Number of epochs the text encoder will be trained.", |
348 | ) | 335 | ) |
349 | parser.add_argument( | 336 | parser.add_argument( |
350 | "--find_lr", | 337 | "--find_lr", |
@@ -378,27 +365,31 @@ def parse_args(): | |||
378 | "--lr_scheduler", | 365 | "--lr_scheduler", |
379 | type=str, | 366 | type=str, |
380 | default="one_cycle", | 367 | default="one_cycle", |
381 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 368 | choices=[ |
382 | "constant", "constant_with_warmup", "one_cycle"], | 369 | "linear", |
383 | help='The scheduler type to use.', | 370 | "cosine", |
371 | "cosine_with_restarts", | ||
372 | "polynomial", | ||
373 | "constant", | ||
374 | "constant_with_warmup", | ||
375 | "one_cycle", | ||
376 | ], | ||
377 | help="The scheduler type to use.", | ||
384 | ) | 378 | ) |
385 | parser.add_argument( | 379 | parser.add_argument( |
386 | "--lr_warmup_epochs", | 380 | "--lr_warmup_epochs", |
387 | type=int, | 381 | type=int, |
388 | default=10, | 382 | default=10, |
389 | help="Number of steps for the warmup in the lr scheduler." | 383 | help="Number of steps for the warmup in the lr scheduler.", |
390 | ) | 384 | ) |
391 | parser.add_argument( | 385 | parser.add_argument( |
392 | "--lr_mid_point", | 386 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
393 | type=float, | ||
394 | default=0.3, | ||
395 | help="OneCycle schedule mid point." | ||
396 | ) | 387 | ) |
397 | parser.add_argument( | 388 | parser.add_argument( |
398 | "--lr_cycles", | 389 | "--lr_cycles", |
399 | type=int, | 390 | type=int, |
400 | default=None, | 391 | default=None, |
401 | help="Number of restart cycles in the lr scheduler (if supported)." | 392 | help="Number of restart cycles in the lr scheduler (if supported).", |
402 | ) | 393 | ) |
403 | parser.add_argument( | 394 | parser.add_argument( |
404 | "--lr_warmup_func", | 395 | "--lr_warmup_func", |
@@ -410,7 +401,7 @@ def parse_args(): | |||
410 | "--lr_warmup_exp", | 401 | "--lr_warmup_exp", |
411 | type=int, | 402 | type=int, |
412 | default=1, | 403 | default=1, |
413 | help='If lr_warmup_func is "cos", exponent to modify the function' | 404 | help='If lr_warmup_func is "cos", exponent to modify the function', |
414 | ) | 405 | ) |
415 | parser.add_argument( | 406 | parser.add_argument( |
416 | "--lr_annealing_func", | 407 | "--lr_annealing_func", |
@@ -422,69 +413,76 @@ def parse_args(): | |||
422 | "--lr_annealing_exp", | 413 | "--lr_annealing_exp", |
423 | type=int, | 414 | type=int, |
424 | default=3, | 415 | default=3, |
425 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 416 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
426 | ) | 417 | ) |
427 | parser.add_argument( | 418 | parser.add_argument( |
428 | "--lr_min_lr", | 419 | "--lr_min_lr", |
429 | type=float, | 420 | type=float, |
430 | default=0.04, | 421 | default=0.04, |
431 | help="Minimum learning rate in the lr scheduler." | 422 | help="Minimum learning rate in the lr scheduler.", |
432 | ) | ||
433 | parser.add_argument( | ||
434 | "--min_snr_gamma", | ||
435 | type=int, | ||
436 | default=5, | ||
437 | help="MinSNR gamma." | ||
438 | ) | 423 | ) |
424 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
439 | parser.add_argument( | 425 | parser.add_argument( |
440 | "--schedule_sampler", | 426 | "--schedule_sampler", |
441 | type=str, | 427 | type=str, |
442 | default="uniform", | 428 | default="uniform", |
443 | choices=["uniform", "loss-second-moment"], | 429 | choices=["uniform", "loss-second-moment"], |
444 | help="Noise schedule sampler." | 430 | help="Noise schedule sampler.", |
445 | ) | 431 | ) |
446 | parser.add_argument( | 432 | parser.add_argument( |
447 | "--optimizer", | 433 | "--optimizer", |
448 | type=str, | 434 | type=str, |
449 | default="adan", | 435 | default="adan", |
450 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 436 | choices=[ |
451 | help='Optimizer to use' | 437 | "adam", |
438 | "adam8bit", | ||
439 | "adan", | ||
440 | "lion", | ||
441 | "dadam", | ||
442 | "dadan", | ||
443 | "dlion", | ||
444 | "adafactor", | ||
445 | ], | ||
446 | help="Optimizer to use", | ||
452 | ) | 447 | ) |
453 | parser.add_argument( | 448 | parser.add_argument( |
454 | "--dadaptation_d0", | 449 | "--dadaptation_d0", |
455 | type=float, | 450 | type=float, |
456 | default=1e-6, | 451 | default=1e-6, |
457 | help="The d0 parameter for Dadaptation optimizers." | 452 | help="The d0 parameter for Dadaptation optimizers.", |
453 | ) | ||
454 | parser.add_argument( | ||
455 | "--dadaptation_growth_rate", | ||
456 | type=float, | ||
457 | default=math.inf, | ||
458 | help="The growth_rate parameter for Dadaptation optimizers.", | ||
458 | ) | 459 | ) |
459 | parser.add_argument( | 460 | parser.add_argument( |
460 | "--adam_beta1", | 461 | "--adam_beta1", |
461 | type=float, | 462 | type=float, |
462 | default=None, | 463 | default=None, |
463 | help="The beta1 parameter for the Adam optimizer." | 464 | help="The beta1 parameter for the Adam optimizer.", |
464 | ) | 465 | ) |
465 | parser.add_argument( | 466 | parser.add_argument( |
466 | "--adam_beta2", | 467 | "--adam_beta2", |
467 | type=float, | 468 | type=float, |
468 | default=None, | 469 | default=None, |
469 | help="The beta2 parameter for the Adam optimizer." | 470 | help="The beta2 parameter for the Adam optimizer.", |
470 | ) | 471 | ) |
471 | parser.add_argument( | 472 | parser.add_argument( |
472 | "--adam_weight_decay", | 473 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
473 | type=float, | ||
474 | default=2e-2, | ||
475 | help="Weight decay to use." | ||
476 | ) | 474 | ) |
477 | parser.add_argument( | 475 | parser.add_argument( |
478 | "--adam_epsilon", | 476 | "--adam_epsilon", |
479 | type=float, | 477 | type=float, |
480 | default=1e-08, | 478 | default=1e-08, |
481 | help="Epsilon value for the Adam optimizer" | 479 | help="Epsilon value for the Adam optimizer", |
482 | ) | 480 | ) |
483 | parser.add_argument( | 481 | parser.add_argument( |
484 | "--adam_amsgrad", | 482 | "--adam_amsgrad", |
485 | type=bool, | 483 | type=bool, |
486 | default=False, | 484 | default=False, |
487 | help="Amsgrad value for the Adam optimizer" | 485 | help="Amsgrad value for the Adam optimizer", |
488 | ) | 486 | ) |
489 | parser.add_argument( | 487 | parser.add_argument( |
490 | "--mixed_precision", | 488 | "--mixed_precision", |
@@ -547,19 +545,19 @@ def parse_args(): | |||
547 | "--valid_set_size", | 545 | "--valid_set_size", |
548 | type=int, | 546 | type=int, |
549 | default=None, | 547 | default=None, |
550 | help="Number of images in the validation dataset." | 548 | help="Number of images in the validation dataset.", |
551 | ) | 549 | ) |
552 | parser.add_argument( | 550 | parser.add_argument( |
553 | "--valid_set_repeat", | 551 | "--valid_set_repeat", |
554 | type=int, | 552 | type=int, |
555 | default=1, | 553 | default=1, |
556 | help="Times the images in the validation dataset are repeated." | 554 | help="Times the images in the validation dataset are repeated.", |
557 | ) | 555 | ) |
558 | parser.add_argument( | 556 | parser.add_argument( |
559 | "--train_batch_size", | 557 | "--train_batch_size", |
560 | type=int, | 558 | type=int, |
561 | default=1, | 559 | default=1, |
562 | help="Batch size (per device) for the training dataloader." | 560 | help="Batch size (per device) for the training dataloader.", |
563 | ) | 561 | ) |
564 | parser.add_argument( | 562 | parser.add_argument( |
565 | "--sample_steps", | 563 | "--sample_steps", |
@@ -571,19 +569,10 @@ def parse_args(): | |||
571 | "--prior_loss_weight", | 569 | "--prior_loss_weight", |
572 | type=float, | 570 | type=float, |
573 | default=1.0, | 571 | default=1.0, |
574 | help="The weight of prior preservation loss." | 572 | help="The weight of prior preservation loss.", |
575 | ) | ||
576 | parser.add_argument( | ||
577 | "--run_pti", | ||
578 | action="store_true", | ||
579 | help="Whether to run PTI." | ||
580 | ) | ||
581 | parser.add_argument( | ||
582 | "--emb_alpha", | ||
583 | type=float, | ||
584 | default=1.0, | ||
585 | help="Embedding alpha" | ||
586 | ) | 573 | ) |
574 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
575 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
587 | parser.add_argument( | 576 | parser.add_argument( |
588 | "--emb_dropout", | 577 | "--emb_dropout", |
589 | type=float, | 578 | type=float, |
@@ -591,27 +580,16 @@ def parse_args(): | |||
591 | help="Embedding dropout probability.", | 580 | help="Embedding dropout probability.", |
592 | ) | 581 | ) |
593 | parser.add_argument( | 582 | parser.add_argument( |
594 | "--use_emb_decay", | 583 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." |
595 | action="store_true", | ||
596 | help="Whether to use embedding decay." | ||
597 | ) | 584 | ) |
598 | parser.add_argument( | 585 | parser.add_argument( |
599 | "--emb_decay_target", | 586 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." |
600 | default=0.4, | ||
601 | type=float, | ||
602 | help="Embedding decay target." | ||
603 | ) | 587 | ) |
604 | parser.add_argument( | 588 | parser.add_argument( |
605 | "--emb_decay", | 589 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." |
606 | default=1e+2, | ||
607 | type=float, | ||
608 | help="Embedding decay factor." | ||
609 | ) | 590 | ) |
610 | parser.add_argument( | 591 | parser.add_argument( |
611 | "--max_grad_norm", | 592 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." |
612 | default=1.0, | ||
613 | type=float, | ||
614 | help="Max gradient norm." | ||
615 | ) | 593 | ) |
616 | parser.add_argument( | 594 | parser.add_argument( |
617 | "--noise_timesteps", | 595 | "--noise_timesteps", |
@@ -622,7 +600,7 @@ def parse_args(): | |||
622 | "--config", | 600 | "--config", |
623 | type=str, | 601 | type=str, |
624 | default=None, | 602 | default=None, |
625 | help="Path to a JSON configuration file containing arguments for invoking this script." | 603 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
626 | ) | 604 | ) |
627 | 605 | ||
628 | args = parser.parse_args() | 606 | args = parser.parse_args() |
@@ -649,29 +627,44 @@ def parse_args(): | |||
649 | args.placeholder_tokens = [args.placeholder_tokens] | 627 | args.placeholder_tokens = [args.placeholder_tokens] |
650 | 628 | ||
651 | if isinstance(args.initializer_tokens, str): | 629 | if isinstance(args.initializer_tokens, str): |
652 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 630 | args.initializer_tokens = [args.initializer_tokens] * len( |
631 | args.placeholder_tokens | ||
632 | ) | ||
653 | 633 | ||
654 | if len(args.placeholder_tokens) == 0: | 634 | if len(args.placeholder_tokens) == 0: |
655 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 635 | args.placeholder_tokens = [ |
636 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
637 | ] | ||
656 | 638 | ||
657 | if len(args.initializer_tokens) == 0: | 639 | if len(args.initializer_tokens) == 0: |
658 | args.initializer_tokens = args.placeholder_tokens.copy() | 640 | args.initializer_tokens = args.placeholder_tokens.copy() |
659 | 641 | ||
660 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 642 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
661 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 643 | raise ValueError( |
644 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
645 | ) | ||
662 | 646 | ||
663 | if isinstance(args.inverted_initializer_tokens, str): | 647 | if isinstance(args.inverted_initializer_tokens, str): |
664 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | 648 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( |
649 | args.placeholder_tokens | ||
650 | ) | ||
665 | 651 | ||
666 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | 652 | if ( |
653 | isinstance(args.inverted_initializer_tokens, list) | ||
654 | and len(args.inverted_initializer_tokens) != 0 | ||
655 | ): | ||
667 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | 656 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] |
668 | args.initializer_tokens += args.inverted_initializer_tokens | 657 | args.initializer_tokens += args.inverted_initializer_tokens |
669 | 658 | ||
670 | if isinstance(args.num_vectors, int): | 659 | if isinstance(args.num_vectors, int): |
671 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 660 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
672 | 661 | ||
673 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 662 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( |
674 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 663 | args.num_vectors |
664 | ): | ||
665 | raise ValueError( | ||
666 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
667 | ) | ||
675 | 668 | ||
676 | if args.alias_tokens is None: | 669 | if args.alias_tokens is None: |
677 | args.alias_tokens = [] | 670 | args.alias_tokens = [] |
@@ -695,15 +688,15 @@ def parse_args(): | |||
695 | raise ValueError("You must specify --output_dir") | 688 | raise ValueError("You must specify --output_dir") |
696 | 689 | ||
697 | if args.adam_beta1 is None: | 690 | if args.adam_beta1 is None: |
698 | if args.optimizer in ('adam', 'adam8bit'): | 691 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
699 | args.adam_beta1 = 0.9 | 692 | args.adam_beta1 = 0.9 |
700 | elif args.optimizer == 'lion': | 693 | elif args.optimizer in ("lion", "dlion"): |
701 | args.adam_beta1 = 0.95 | 694 | args.adam_beta1 = 0.95 |
702 | 695 | ||
703 | if args.adam_beta2 is None: | 696 | if args.adam_beta2 is None: |
704 | if args.optimizer in ('adam', 'adam8bit'): | 697 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
705 | args.adam_beta2 = 0.999 | 698 | args.adam_beta2 = 0.999 |
706 | elif args.optimizer == 'lion': | 699 | elif args.optimizer in ("lion", "dlion"): |
707 | args.adam_beta2 = 0.98 | 700 | args.adam_beta2 = 0.98 |
708 | 701 | ||
709 | return args | 702 | return args |
@@ -719,7 +712,7 @@ def main(): | |||
719 | accelerator = Accelerator( | 712 | accelerator = Accelerator( |
720 | log_with=LoggerType.TENSORBOARD, | 713 | log_with=LoggerType.TENSORBOARD, |
721 | project_dir=f"{output_dir}", | 714 | project_dir=f"{output_dir}", |
722 | mixed_precision=args.mixed_precision | 715 | mixed_precision=args.mixed_precision, |
723 | ) | 716 | ) |
724 | 717 | ||
725 | weight_dtype = torch.float32 | 718 | weight_dtype = torch.float32 |
@@ -728,6 +721,8 @@ def main(): | |||
728 | elif args.mixed_precision == "bf16": | 721 | elif args.mixed_precision == "bf16": |
729 | weight_dtype = torch.bfloat16 | 722 | weight_dtype = torch.bfloat16 |
730 | 723 | ||
724 | patch_xformers(weight_dtype) | ||
725 | |||
731 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 726 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
732 | 727 | ||
733 | if args.seed is None: | 728 | if args.seed is None: |
@@ -737,12 +732,18 @@ def main(): | |||
737 | 732 | ||
738 | save_args(output_dir, args) | 733 | save_args(output_dir, args) |
739 | 734 | ||
740 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) | 735 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
741 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 736 | args.pretrained_model_name_or_path |
742 | 737 | ) | |
738 | schedule_sampler = create_named_schedule_sampler( | ||
739 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
740 | ) | ||
741 | |||
743 | def ensure_embeddings(): | 742 | def ensure_embeddings(): |
744 | if args.lora_text_encoder_emb: | 743 | if args.lora_text_encoder_emb: |
745 | raise ValueError("Can't use TI options when training token embeddings with LoRA") | 744 | raise ValueError( |
745 | "Can't use TI options when training token embeddings with LoRA" | ||
746 | ) | ||
746 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | 747 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) |
747 | 748 | ||
748 | unet_config = LoraConfig( | 749 | unet_config = LoraConfig( |
@@ -757,7 +758,9 @@ def main(): | |||
757 | text_encoder_config = LoraConfig( | 758 | text_encoder_config = LoraConfig( |
758 | r=args.lora_text_encoder_r, | 759 | r=args.lora_text_encoder_r, |
759 | lora_alpha=args.lora_text_encoder_alpha, | 760 | lora_alpha=args.lora_text_encoder_alpha, |
760 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, | 761 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING |
762 | if args.lora_text_encoder_emb | ||
763 | else TEXT_ENCODER_TARGET_MODULES, | ||
761 | lora_dropout=args.lora_text_encoder_dropout, | 764 | lora_dropout=args.lora_text_encoder_dropout, |
762 | bias=args.lora_text_encoder_bias, | 765 | bias=args.lora_text_encoder_bias, |
763 | ) | 766 | ) |
@@ -787,7 +790,7 @@ def main(): | |||
787 | 790 | ||
788 | if len(args.alias_tokens) != 0: | 791 | if len(args.alias_tokens) != 0: |
789 | embeddings = ensure_embeddings() | 792 | embeddings = ensure_embeddings() |
790 | 793 | ||
791 | alias_placeholder_tokens = args.alias_tokens[::2] | 794 | alias_placeholder_tokens = args.alias_tokens[::2] |
792 | alias_initializer_tokens = args.alias_tokens[1::2] | 795 | alias_initializer_tokens = args.alias_tokens[1::2] |
793 | 796 | ||
@@ -795,27 +798,33 @@ def main(): | |||
795 | tokenizer=tokenizer, | 798 | tokenizer=tokenizer, |
796 | embeddings=embeddings, | 799 | embeddings=embeddings, |
797 | placeholder_tokens=alias_placeholder_tokens, | 800 | placeholder_tokens=alias_placeholder_tokens, |
798 | initializer_tokens=alias_initializer_tokens | 801 | initializer_tokens=alias_initializer_tokens, |
799 | ) | 802 | ) |
800 | embeddings.persist() | 803 | embeddings.persist() |
801 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 804 | print( |
805 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
806 | ) | ||
802 | 807 | ||
803 | placeholder_tokens = [] | 808 | placeholder_tokens = [] |
804 | placeholder_token_ids = [] | 809 | placeholder_token_ids = [] |
805 | 810 | ||
806 | if args.embeddings_dir is not None: | 811 | if args.embeddings_dir is not None: |
807 | embeddings = ensure_embeddings() | 812 | embeddings = ensure_embeddings() |
808 | 813 | ||
809 | embeddings_dir = Path(args.embeddings_dir) | 814 | embeddings_dir = Path(args.embeddings_dir) |
810 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 815 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
811 | raise ValueError("--embeddings_dir must point to an existing directory") | 816 | raise ValueError("--embeddings_dir must point to an existing directory") |
812 | 817 | ||
813 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 818 | added_tokens, added_ids = load_embeddings_from_dir( |
819 | tokenizer, embeddings, embeddings_dir | ||
820 | ) | ||
814 | 821 | ||
815 | placeholder_tokens = added_tokens | 822 | placeholder_tokens = added_tokens |
816 | placeholder_token_ids = added_ids | 823 | placeholder_token_ids = added_ids |
817 | 824 | ||
818 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 825 | print( |
826 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
827 | ) | ||
819 | 828 | ||
820 | if args.train_dir_embeddings: | 829 | if args.train_dir_embeddings: |
821 | print("Training embeddings from embeddings dir") | 830 | print("Training embeddings from embeddings dir") |
@@ -824,7 +833,7 @@ def main(): | |||
824 | 833 | ||
825 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 834 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
826 | embeddings = ensure_embeddings() | 835 | embeddings = ensure_embeddings() |
827 | 836 | ||
828 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 837 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
829 | tokenizer=tokenizer, | 838 | tokenizer=tokenizer, |
830 | embeddings=embeddings, | 839 | embeddings=embeddings, |
@@ -836,23 +845,34 @@ def main(): | |||
836 | 845 | ||
837 | placeholder_tokens = args.placeholder_tokens | 846 | placeholder_tokens = args.placeholder_tokens |
838 | 847 | ||
839 | stats = list(zip( | 848 | stats = list( |
840 | placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | 849 | zip( |
841 | )) | 850 | placeholder_tokens, |
851 | placeholder_token_ids, | ||
852 | args.initializer_tokens, | ||
853 | initializer_token_ids, | ||
854 | ) | ||
855 | ) | ||
842 | print(f"Training embeddings: {stats}") | 856 | print(f"Training embeddings: {stats}") |
843 | 857 | ||
844 | if args.scale_lr: | 858 | if args.scale_lr: |
845 | args.learning_rate_unet = ( | 859 | args.learning_rate_unet = ( |
846 | args.learning_rate_unet * args.gradient_accumulation_steps * | 860 | args.learning_rate_unet |
847 | args.train_batch_size * accelerator.num_processes | 861 | * args.gradient_accumulation_steps |
862 | * args.train_batch_size | ||
863 | * accelerator.num_processes | ||
848 | ) | 864 | ) |
849 | args.learning_rate_text = ( | 865 | args.learning_rate_text = ( |
850 | args.learning_rate_text * args.gradient_accumulation_steps * | 866 | args.learning_rate_text |
851 | args.train_batch_size * accelerator.num_processes | 867 | * args.gradient_accumulation_steps |
868 | * args.train_batch_size | ||
869 | * accelerator.num_processes | ||
852 | ) | 870 | ) |
853 | args.learning_rate_emb = ( | 871 | args.learning_rate_emb = ( |
854 | args.learning_rate_emb * args.gradient_accumulation_steps * | 872 | args.learning_rate_emb |
855 | args.train_batch_size * accelerator.num_processes | 873 | * args.gradient_accumulation_steps |
874 | * args.train_batch_size | ||
875 | * accelerator.num_processes | ||
856 | ) | 876 | ) |
857 | 877 | ||
858 | if args.find_lr: | 878 | if args.find_lr: |
@@ -861,11 +881,13 @@ def main(): | |||
861 | args.learning_rate_emb = 1e-6 | 881 | args.learning_rate_emb = 1e-6 |
862 | args.lr_scheduler = "exponential_growth" | 882 | args.lr_scheduler = "exponential_growth" |
863 | 883 | ||
864 | if args.optimizer == 'adam8bit': | 884 | if args.optimizer == "adam8bit": |
865 | try: | 885 | try: |
866 | import bitsandbytes as bnb | 886 | import bitsandbytes as bnb |
867 | except ImportError: | 887 | except ImportError: |
868 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 888 | raise ImportError( |
889 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
890 | ) | ||
869 | 891 | ||
870 | create_optimizer = partial( | 892 | create_optimizer = partial( |
871 | bnb.optim.AdamW8bit, | 893 | bnb.optim.AdamW8bit, |
@@ -874,7 +896,7 @@ def main(): | |||
874 | eps=args.adam_epsilon, | 896 | eps=args.adam_epsilon, |
875 | amsgrad=args.adam_amsgrad, | 897 | amsgrad=args.adam_amsgrad, |
876 | ) | 898 | ) |
877 | elif args.optimizer == 'adam': | 899 | elif args.optimizer == "adam": |
878 | create_optimizer = partial( | 900 | create_optimizer = partial( |
879 | torch.optim.AdamW, | 901 | torch.optim.AdamW, |
880 | betas=(args.adam_beta1, args.adam_beta2), | 902 | betas=(args.adam_beta1, args.adam_beta2), |
@@ -882,11 +904,13 @@ def main(): | |||
882 | eps=args.adam_epsilon, | 904 | eps=args.adam_epsilon, |
883 | amsgrad=args.adam_amsgrad, | 905 | amsgrad=args.adam_amsgrad, |
884 | ) | 906 | ) |
885 | elif args.optimizer == 'adan': | 907 | elif args.optimizer == "adan": |
886 | try: | 908 | try: |
887 | import timm.optim | 909 | import timm.optim |
888 | except ImportError: | 910 | except ImportError: |
889 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 911 | raise ImportError( |
912 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
913 | ) | ||
890 | 914 | ||
891 | create_optimizer = partial( | 915 | create_optimizer = partial( |
892 | timm.optim.Adan, | 916 | timm.optim.Adan, |
@@ -894,11 +918,13 @@ def main(): | |||
894 | eps=args.adam_epsilon, | 918 | eps=args.adam_epsilon, |
895 | no_prox=True, | 919 | no_prox=True, |
896 | ) | 920 | ) |
897 | elif args.optimizer == 'lion': | 921 | elif args.optimizer == "lion": |
898 | try: | 922 | try: |
899 | import lion_pytorch | 923 | import lion_pytorch |
900 | except ImportError: | 924 | except ImportError: |
901 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 925 | raise ImportError( |
926 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
927 | ) | ||
902 | 928 | ||
903 | create_optimizer = partial( | 929 | create_optimizer = partial( |
904 | lion_pytorch.Lion, | 930 | lion_pytorch.Lion, |
@@ -906,7 +932,7 @@ def main(): | |||
906 | weight_decay=args.adam_weight_decay, | 932 | weight_decay=args.adam_weight_decay, |
907 | use_triton=True, | 933 | use_triton=True, |
908 | ) | 934 | ) |
909 | elif args.optimizer == 'adafactor': | 935 | elif args.optimizer == "adafactor": |
910 | create_optimizer = partial( | 936 | create_optimizer = partial( |
911 | transformers.optimization.Adafactor, | 937 | transformers.optimization.Adafactor, |
912 | weight_decay=args.adam_weight_decay, | 938 | weight_decay=args.adam_weight_decay, |
@@ -920,11 +946,13 @@ def main(): | |||
920 | args.learning_rate_unet = None | 946 | args.learning_rate_unet = None |
921 | args.learning_rate_text = None | 947 | args.learning_rate_text = None |
922 | args.learning_rate_emb = None | 948 | args.learning_rate_emb = None |
923 | elif args.optimizer == 'dadam': | 949 | elif args.optimizer == "dadam": |
924 | try: | 950 | try: |
925 | import dadaptation | 951 | import dadaptation |
926 | except ImportError: | 952 | except ImportError: |
927 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 953 | raise ImportError( |
954 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
955 | ) | ||
928 | 956 | ||
929 | create_optimizer = partial( | 957 | create_optimizer = partial( |
930 | dadaptation.DAdaptAdam, | 958 | dadaptation.DAdaptAdam, |
@@ -933,29 +961,35 @@ def main(): | |||
933 | eps=args.adam_epsilon, | 961 | eps=args.adam_epsilon, |
934 | decouple=True, | 962 | decouple=True, |
935 | d0=args.dadaptation_d0, | 963 | d0=args.dadaptation_d0, |
964 | growth_rate=args.dadaptation_growth_rate, | ||
936 | ) | 965 | ) |
937 | 966 | ||
938 | args.learning_rate_unet = 1.0 | 967 | args.learning_rate_unet = 1.0 |
939 | args.learning_rate_text = 1.0 | 968 | args.learning_rate_text = 1.0 |
940 | args.learning_rate_emb = 1.0 | 969 | args.learning_rate_emb = 1.0 |
941 | elif args.optimizer == 'dadan': | 970 | elif args.optimizer == "dadan": |
942 | try: | 971 | try: |
943 | import dadaptation | 972 | import dadaptation |
944 | except ImportError: | 973 | except ImportError: |
945 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 974 | raise ImportError( |
975 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
976 | ) | ||
946 | 977 | ||
947 | create_optimizer = partial( | 978 | create_optimizer = partial( |
948 | dadaptation.DAdaptAdan, | 979 | dadaptation.DAdaptAdan, |
949 | weight_decay=args.adam_weight_decay, | 980 | weight_decay=args.adam_weight_decay, |
950 | eps=args.adam_epsilon, | 981 | eps=args.adam_epsilon, |
951 | d0=args.dadaptation_d0, | 982 | d0=args.dadaptation_d0, |
983 | growth_rate=args.dadaptation_growth_rate, | ||
952 | ) | 984 | ) |
953 | 985 | ||
954 | args.learning_rate_unet = 1.0 | 986 | args.learning_rate_unet = 1.0 |
955 | args.learning_rate_text = 1.0 | 987 | args.learning_rate_text = 1.0 |
956 | args.learning_rate_emb = 1.0 | 988 | args.learning_rate_emb = 1.0 |
989 | elif args.optimizer == "dlion": | ||
990 | raise ImportError("DLion has not been merged into dadaptation yet") | ||
957 | else: | 991 | else: |
958 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 992 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
959 | 993 | ||
960 | trainer = partial( | 994 | trainer = partial( |
961 | train, | 995 | train, |
@@ -1026,25 +1060,33 @@ def main(): | |||
1026 | 1060 | ||
1027 | if args.run_pti and len(placeholder_tokens) != 0: | 1061 | if args.run_pti and len(placeholder_tokens) != 0: |
1028 | embeddings = ensure_embeddings() | 1062 | embeddings = ensure_embeddings() |
1029 | 1063 | ||
1030 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 1064 | filter_tokens = [ |
1065 | token for token in args.filter_tokens if token in placeholder_tokens | ||
1066 | ] | ||
1031 | 1067 | ||
1032 | pti_datamodule = create_datamodule( | 1068 | pti_datamodule = create_datamodule( |
1033 | valid_set_size=0, | 1069 | valid_set_size=0, |
1034 | batch_size=args.train_batch_size, | 1070 | batch_size=args.train_batch_size, |
1035 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 1071 | filter=partial( |
1072 | keyword_filter, filter_tokens, args.collection, args.exclude_collections | ||
1073 | ), | ||
1036 | ) | 1074 | ) |
1037 | pti_datamodule.setup() | 1075 | pti_datamodule.setup() |
1038 | 1076 | ||
1039 | num_train_epochs = args.num_train_epochs | 1077 | num_train_epochs = args.num_train_epochs |
1040 | pti_sample_frequency = args.sample_frequency | 1078 | pti_sample_frequency = args.sample_frequency |
1041 | if num_train_epochs is None: | 1079 | if num_train_epochs is None: |
1042 | num_train_epochs = math.ceil( | 1080 | num_train_epochs = ( |
1043 | args.num_train_steps / len(pti_datamodule.train_dataset) | 1081 | math.ceil(args.num_train_steps / len(pti_datamodule.train_dataset)) |
1044 | ) * args.gradient_accumulation_steps | 1082 | * args.gradient_accumulation_steps |
1045 | pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) | 1083 | ) |
1084 | pti_sample_frequency = math.ceil( | ||
1085 | num_train_epochs * (pti_sample_frequency / args.num_train_steps) | ||
1086 | ) | ||
1046 | num_training_steps_per_epoch = math.ceil( | 1087 | num_training_steps_per_epoch = math.ceil( |
1047 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) | 1088 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps |
1089 | ) | ||
1048 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 1090 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
1049 | if args.sample_num is not None: | 1091 | if args.sample_num is not None: |
1050 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1092 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
@@ -1060,11 +1102,15 @@ def main(): | |||
1060 | print(f"============ PTI ============") | 1102 | print(f"============ PTI ============") |
1061 | print("") | 1103 | print("") |
1062 | 1104 | ||
1063 | pti_optimizer = create_optimizer([{ | 1105 | pti_optimizer = create_optimizer( |
1064 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | 1106 | [ |
1065 | "lr": args.learning_rate_emb, | 1107 | { |
1066 | "weight_decay": 0, | 1108 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), |
1067 | }]) | 1109 | "lr": args.learning_rate_emb, |
1110 | "weight_decay": 0, | ||
1111 | } | ||
1112 | ] | ||
1113 | ) | ||
1068 | 1114 | ||
1069 | pti_lr_scheduler = create_lr_scheduler( | 1115 | pti_lr_scheduler = create_lr_scheduler( |
1070 | "constant_with_warmup", | 1116 | "constant_with_warmup", |
@@ -1113,11 +1159,16 @@ def main(): | |||
1113 | num_train_epochs = args.num_train_epochs | 1159 | num_train_epochs = args.num_train_epochs |
1114 | lora_sample_frequency = args.sample_frequency | 1160 | lora_sample_frequency = args.sample_frequency |
1115 | if num_train_epochs is None: | 1161 | if num_train_epochs is None: |
1116 | num_train_epochs = math.ceil( | 1162 | num_train_epochs = ( |
1117 | args.num_train_steps / len(lora_datamodule.train_dataset) | 1163 | math.ceil(args.num_train_steps / len(lora_datamodule.train_dataset)) |
1118 | ) * args.gradient_accumulation_steps | 1164 | * args.gradient_accumulation_steps |
1119 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 1165 | ) |
1120 | num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) | 1166 | lora_sample_frequency = math.ceil( |
1167 | num_train_epochs * (lora_sample_frequency / args.num_train_steps) | ||
1168 | ) | ||
1169 | num_training_steps_per_epoch = math.ceil( | ||
1170 | len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps | ||
1171 | ) | ||
1121 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 1172 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
1122 | if args.sample_num is not None: | 1173 | if args.sample_num is not None: |
1123 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1174 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
@@ -1131,7 +1182,6 @@ def main(): | |||
1131 | 1182 | ||
1132 | training_iter = 0 | 1183 | training_iter = 0 |
1133 | auto_cycles = list(args.auto_cycles) | 1184 | auto_cycles = list(args.auto_cycles) |
1134 | learning_rate_emb = args.learning_rate_emb | ||
1135 | learning_rate_unet = args.learning_rate_unet | 1185 | learning_rate_unet = args.learning_rate_unet |
1136 | learning_rate_text = args.learning_rate_text | 1186 | learning_rate_text = args.learning_rate_text |
1137 | lr_scheduler = args.lr_scheduler | 1187 | lr_scheduler = args.lr_scheduler |
@@ -1145,21 +1195,15 @@ def main(): | |||
1145 | 1195 | ||
1146 | params_to_optimize = [ | 1196 | params_to_optimize = [ |
1147 | { | 1197 | { |
1148 | "params": ( | 1198 | "params": (param for param in unet.parameters() if param.requires_grad), |
1149 | param | ||
1150 | for param in unet.parameters() | ||
1151 | if param.requires_grad | ||
1152 | ), | ||
1153 | "lr": learning_rate_unet, | 1199 | "lr": learning_rate_unet, |
1154 | }, | 1200 | }, |
1155 | { | 1201 | { |
1156 | "params": ( | 1202 | "params": ( |
1157 | param | 1203 | param for param in text_encoder.parameters() if param.requires_grad |
1158 | for param in text_encoder.parameters() | ||
1159 | if param.requires_grad | ||
1160 | ), | 1204 | ), |
1161 | "lr": learning_rate_text, | 1205 | "lr": learning_rate_text, |
1162 | } | 1206 | }, |
1163 | ] | 1207 | ] |
1164 | group_labels = ["unet", "text"] | 1208 | group_labels = ["unet", "text"] |
1165 | 1209 | ||
@@ -1169,19 +1213,26 @@ def main(): | |||
1169 | if len(auto_cycles) != 0: | 1213 | if len(auto_cycles) != 0: |
1170 | response = auto_cycles.pop(0) | 1214 | response = auto_cycles.pop(0) |
1171 | else: | 1215 | else: |
1172 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 1216 | response = input( |
1217 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " | ||
1218 | ) | ||
1173 | 1219 | ||
1174 | if response.lower().strip() == "o": | 1220 | if response.lower().strip() == "o": |
1175 | if args.learning_rate_emb is not None: | ||
1176 | learning_rate_emb = args.learning_rate_emb * 2 | ||
1177 | if args.learning_rate_unet is not None: | 1221 | if args.learning_rate_unet is not None: |
1178 | learning_rate_unet = args.learning_rate_unet * 2 | 1222 | learning_rate_unet = ( |
1223 | args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) | ||
1224 | ) | ||
1179 | if args.learning_rate_text is not None: | 1225 | if args.learning_rate_text is not None: |
1180 | learning_rate_text = args.learning_rate_text * 2 | 1226 | learning_rate_text = ( |
1227 | args.learning_rate_text * 2 * (args.cycle_decay**training_iter) | ||
1228 | ) | ||
1181 | else: | 1229 | else: |
1182 | learning_rate_emb = args.learning_rate_emb | 1230 | learning_rate_unet = args.learning_rate_unet * ( |
1183 | learning_rate_unet = args.learning_rate_unet | 1231 | args.cycle_decay**training_iter |
1184 | learning_rate_text = args.learning_rate_text | 1232 | ) |
1233 | learning_rate_text = args.learning_rate_text * ( | ||
1234 | args.cycle_decay**training_iter | ||
1235 | ) | ||
1185 | 1236 | ||
1186 | if response.lower().strip() == "o": | 1237 | if response.lower().strip() == "o": |
1187 | lr_scheduler = "one_cycle" | 1238 | lr_scheduler = "one_cycle" |
@@ -1204,9 +1255,11 @@ def main(): | |||
1204 | print("") | 1255 | print("") |
1205 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") | 1256 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") |
1206 | print("") | 1257 | print("") |
1207 | 1258 | ||
1208 | for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): | 1259 | for group, lr in zip( |
1209 | group['lr'] = lr | 1260 | lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text] |
1261 | ): | ||
1262 | group["lr"] = lr | ||
1210 | 1263 | ||
1211 | lora_lr_scheduler = create_lr_scheduler( | 1264 | lora_lr_scheduler = create_lr_scheduler( |
1212 | lr_scheduler, | 1265 | lr_scheduler, |
@@ -1218,7 +1271,9 @@ def main(): | |||
1218 | warmup_epochs=lr_warmup_epochs, | 1271 | warmup_epochs=lr_warmup_epochs, |
1219 | ) | 1272 | ) |
1220 | 1273 | ||
1221 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" | 1274 | lora_checkpoint_output_dir = ( |
1275 | output_dir / lora_project / f"model_{training_iter}" | ||
1276 | ) | ||
1222 | 1277 | ||
1223 | trainer( | 1278 | trainer( |
1224 | strategy=lora_strategy, | 1279 | strategy=lora_strategy, |
@@ -1246,12 +1301,6 @@ def main(): | |||
1246 | ) | 1301 | ) |
1247 | 1302 | ||
1248 | training_iter += 1 | 1303 | training_iter += 1 |
1249 | if learning_rate_emb is not None: | ||
1250 | learning_rate_emb *= args.cycle_decay | ||
1251 | if learning_rate_unet is not None: | ||
1252 | learning_rate_unet *= args.cycle_decay | ||
1253 | if learning_rate_text is not None: | ||
1254 | learning_rate_text *= args.cycle_decay | ||
1255 | 1304 | ||
1256 | accelerator.end_training() | 1305 | accelerator.end_training() |
1257 | 1306 | ||
diff --git a/train_ti.py b/train_ti.py index f60e3e5..c6f0b3a 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -32,7 +32,7 @@ from util.files import load_config, load_embeddings_from_dir | |||
32 | 32 | ||
33 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
34 | 34 | ||
35 | warnings.filterwarnings('ignore') | 35 | warnings.filterwarnings("ignore") |
36 | 36 | ||
37 | 37 | ||
38 | torch.backends.cuda.matmul.allow_tf32 = True | 38 | torch.backends.cuda.matmul.allow_tf32 = True |
@@ -46,9 +46,7 @@ hidet.torch.dynamo_config.search_space(0) | |||
46 | 46 | ||
47 | 47 | ||
48 | def parse_args(): | 48 | def parse_args(): |
49 | parser = argparse.ArgumentParser( | 49 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
50 | description="Simple example of a training script." | ||
51 | ) | ||
52 | parser.add_argument( | 50 | parser.add_argument( |
53 | "--pretrained_model_name_or_path", | 51 | "--pretrained_model_name_or_path", |
54 | type=str, | 52 | type=str, |
@@ -65,12 +63,12 @@ def parse_args(): | |||
65 | "--train_data_file", | 63 | "--train_data_file", |
66 | type=str, | 64 | type=str, |
67 | default=None, | 65 | default=None, |
68 | help="A CSV file containing the training data." | 66 | help="A CSV file containing the training data.", |
69 | ) | 67 | ) |
70 | parser.add_argument( | 68 | parser.add_argument( |
71 | "--train_data_template", | 69 | "--train_data_template", |
72 | type=str, | 70 | type=str, |
73 | nargs='*', | 71 | nargs="*", |
74 | default="template", | 72 | default="template", |
75 | ) | 73 | ) |
76 | parser.add_argument( | 74 | parser.add_argument( |
@@ -80,59 +78,47 @@ def parse_args(): | |||
80 | help="The name of the current project.", | 78 | help="The name of the current project.", |
81 | ) | 79 | ) |
82 | parser.add_argument( | 80 | parser.add_argument( |
83 | "--auto_cycles", | 81 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
84 | type=str, | ||
85 | default="o", | ||
86 | help="Cycles to run automatically." | ||
87 | ) | 82 | ) |
88 | parser.add_argument( | 83 | parser.add_argument( |
89 | "--cycle_decay", | 84 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." |
90 | type=float, | ||
91 | default=1.0, | ||
92 | help="Learning rate decay per cycle." | ||
93 | ) | 85 | ) |
94 | parser.add_argument( | 86 | parser.add_argument( |
95 | "--placeholder_tokens", | 87 | "--placeholder_tokens", |
96 | type=str, | 88 | type=str, |
97 | nargs='*', | 89 | nargs="*", |
98 | help="A token to use as a placeholder for the concept.", | 90 | help="A token to use as a placeholder for the concept.", |
99 | ) | 91 | ) |
100 | parser.add_argument( | 92 | parser.add_argument( |
101 | "--initializer_tokens", | 93 | "--initializer_tokens", |
102 | type=str, | 94 | type=str, |
103 | nargs='*', | 95 | nargs="*", |
104 | help="A token to use as initializer word." | 96 | help="A token to use as initializer word.", |
105 | ) | 97 | ) |
106 | parser.add_argument( | 98 | parser.add_argument( |
107 | "--filter_tokens", | 99 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." |
108 | type=str, | ||
109 | nargs='*', | ||
110 | help="Tokens to filter the dataset by." | ||
111 | ) | 100 | ) |
112 | parser.add_argument( | 101 | parser.add_argument( |
113 | "--initializer_noise", | 102 | "--initializer_noise", |
114 | type=float, | 103 | type=float, |
115 | default=0, | 104 | default=0, |
116 | help="Noise to apply to the initializer word" | 105 | help="Noise to apply to the initializer word", |
117 | ) | 106 | ) |
118 | parser.add_argument( | 107 | parser.add_argument( |
119 | "--alias_tokens", | 108 | "--alias_tokens", |
120 | type=str, | 109 | type=str, |
121 | nargs='*', | 110 | nargs="*", |
122 | default=[], | 111 | default=[], |
123 | help="Tokens to create an alias for." | 112 | help="Tokens to create an alias for.", |
124 | ) | 113 | ) |
125 | parser.add_argument( | 114 | parser.add_argument( |
126 | "--inverted_initializer_tokens", | 115 | "--inverted_initializer_tokens", |
127 | type=str, | 116 | type=str, |
128 | nargs='*', | 117 | nargs="*", |
129 | help="A token to use as initializer word." | 118 | help="A token to use as initializer word.", |
130 | ) | 119 | ) |
131 | parser.add_argument( | 120 | parser.add_argument( |
132 | "--num_vectors", | 121 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
133 | type=int, | ||
134 | nargs='*', | ||
135 | help="Number of vectors per embedding." | ||
136 | ) | 122 | ) |
137 | parser.add_argument( | 123 | parser.add_argument( |
138 | "--sequential", | 124 | "--sequential", |
@@ -147,7 +133,7 @@ def parse_args(): | |||
147 | "--num_class_images", | 133 | "--num_class_images", |
148 | type=int, | 134 | type=int, |
149 | default=0, | 135 | default=0, |
150 | help="How many class images to generate." | 136 | help="How many class images to generate.", |
151 | ) | 137 | ) |
152 | parser.add_argument( | 138 | parser.add_argument( |
153 | "--class_image_dir", | 139 | "--class_image_dir", |
@@ -158,7 +144,7 @@ def parse_args(): | |||
158 | parser.add_argument( | 144 | parser.add_argument( |
159 | "--exclude_collections", | 145 | "--exclude_collections", |
160 | type=str, | 146 | type=str, |
161 | nargs='*', | 147 | nargs="*", |
162 | help="Exclude all items with a listed collection.", | 148 | help="Exclude all items with a listed collection.", |
163 | ) | 149 | ) |
164 | parser.add_argument( | 150 | parser.add_argument( |
@@ -181,14 +167,11 @@ def parse_args(): | |||
181 | parser.add_argument( | 167 | parser.add_argument( |
182 | "--collection", | 168 | "--collection", |
183 | type=str, | 169 | type=str, |
184 | nargs='*', | 170 | nargs="*", |
185 | help="A collection to filter the dataset.", | 171 | help="A collection to filter the dataset.", |
186 | ) | 172 | ) |
187 | parser.add_argument( | 173 | parser.add_argument( |
188 | "--seed", | 174 | "--seed", type=int, default=None, help="A seed for reproducible training." |
189 | type=int, | ||
190 | default=None, | ||
191 | help="A seed for reproducible training." | ||
192 | ) | 175 | ) |
193 | parser.add_argument( | 176 | parser.add_argument( |
194 | "--resolution", | 177 | "--resolution", |
@@ -244,7 +227,7 @@ def parse_args(): | |||
244 | type=str, | 227 | type=str, |
245 | default="auto", | 228 | default="auto", |
246 | choices=["all", "trailing", "leading", "between", "auto", "off"], | 229 | choices=["all", "trailing", "leading", "between", "auto", "off"], |
247 | help='Vector shuffling algorithm.', | 230 | help="Vector shuffling algorithm.", |
248 | ) | 231 | ) |
249 | parser.add_argument( | 232 | parser.add_argument( |
250 | "--offset_noise_strength", | 233 | "--offset_noise_strength", |
@@ -256,18 +239,10 @@ def parse_args(): | |||
256 | "--input_pertubation", | 239 | "--input_pertubation", |
257 | type=float, | 240 | type=float, |
258 | default=0, | 241 | default=0, |
259 | help="The scale of input pretubation. Recommended 0.1." | 242 | help="The scale of input pretubation. Recommended 0.1.", |
260 | ) | ||
261 | parser.add_argument( | ||
262 | "--num_train_epochs", | ||
263 | type=int, | ||
264 | default=None | ||
265 | ) | ||
266 | parser.add_argument( | ||
267 | "--num_train_steps", | ||
268 | type=int, | ||
269 | default=2000 | ||
270 | ) | 243 | ) |
244 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
245 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
271 | parser.add_argument( | 246 | parser.add_argument( |
272 | "--gradient_accumulation_steps", | 247 | "--gradient_accumulation_steps", |
273 | type=int, | 248 | type=int, |
@@ -299,27 +274,31 @@ def parse_args(): | |||
299 | "--lr_scheduler", | 274 | "--lr_scheduler", |
300 | type=str, | 275 | type=str, |
301 | default="one_cycle", | 276 | default="one_cycle", |
302 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 277 | choices=[ |
303 | "constant", "constant_with_warmup", "one_cycle"], | 278 | "linear", |
304 | help='The scheduler type to use.', | 279 | "cosine", |
280 | "cosine_with_restarts", | ||
281 | "polynomial", | ||
282 | "constant", | ||
283 | "constant_with_warmup", | ||
284 | "one_cycle", | ||
285 | ], | ||
286 | help="The scheduler type to use.", | ||
305 | ) | 287 | ) |
306 | parser.add_argument( | 288 | parser.add_argument( |
307 | "--lr_warmup_epochs", | 289 | "--lr_warmup_epochs", |
308 | type=int, | 290 | type=int, |
309 | default=10, | 291 | default=10, |
310 | help="Number of steps for the warmup in the lr scheduler." | 292 | help="Number of steps for the warmup in the lr scheduler.", |
311 | ) | 293 | ) |
312 | parser.add_argument( | 294 | parser.add_argument( |
313 | "--lr_mid_point", | 295 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
314 | type=float, | ||
315 | default=0.3, | ||
316 | help="OneCycle schedule mid point." | ||
317 | ) | 296 | ) |
318 | parser.add_argument( | 297 | parser.add_argument( |
319 | "--lr_cycles", | 298 | "--lr_cycles", |
320 | type=int, | 299 | type=int, |
321 | default=None, | 300 | default=None, |
322 | help="Number of restart cycles in the lr scheduler." | 301 | help="Number of restart cycles in the lr scheduler.", |
323 | ) | 302 | ) |
324 | parser.add_argument( | 303 | parser.add_argument( |
325 | "--lr_warmup_func", | 304 | "--lr_warmup_func", |
@@ -331,7 +310,7 @@ def parse_args(): | |||
331 | "--lr_warmup_exp", | 310 | "--lr_warmup_exp", |
332 | type=int, | 311 | type=int, |
333 | default=1, | 312 | default=1, |
334 | help='If lr_warmup_func is "cos", exponent to modify the function' | 313 | help='If lr_warmup_func is "cos", exponent to modify the function', |
335 | ) | 314 | ) |
336 | parser.add_argument( | 315 | parser.add_argument( |
337 | "--lr_annealing_func", | 316 | "--lr_annealing_func", |
@@ -343,89 +322,67 @@ def parse_args(): | |||
343 | "--lr_annealing_exp", | 322 | "--lr_annealing_exp", |
344 | type=int, | 323 | type=int, |
345 | default=1, | 324 | default=1, |
346 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 325 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
347 | ) | 326 | ) |
348 | parser.add_argument( | 327 | parser.add_argument( |
349 | "--lr_min_lr", | 328 | "--lr_min_lr", |
350 | type=float, | 329 | type=float, |
351 | default=0.04, | 330 | default=0.04, |
352 | help="Minimum learning rate in the lr scheduler." | 331 | help="Minimum learning rate in the lr scheduler.", |
353 | ) | 332 | ) |
354 | parser.add_argument( | 333 | parser.add_argument( |
355 | "--use_ema", | 334 | "--use_ema", action="store_true", help="Whether to use EMA model." |
356 | action="store_true", | ||
357 | help="Whether to use EMA model." | ||
358 | ) | ||
359 | parser.add_argument( | ||
360 | "--ema_inv_gamma", | ||
361 | type=float, | ||
362 | default=1.0 | ||
363 | ) | ||
364 | parser.add_argument( | ||
365 | "--ema_power", | ||
366 | type=float, | ||
367 | default=4/5 | ||
368 | ) | ||
369 | parser.add_argument( | ||
370 | "--ema_max_decay", | ||
371 | type=float, | ||
372 | default=0.9999 | ||
373 | ) | ||
374 | parser.add_argument( | ||
375 | "--min_snr_gamma", | ||
376 | type=int, | ||
377 | default=5, | ||
378 | help="MinSNR gamma." | ||
379 | ) | 335 | ) |
336 | parser.add_argument("--ema_inv_gamma", type=float, default=1.0) | ||
337 | parser.add_argument("--ema_power", type=float, default=4 / 5) | ||
338 | parser.add_argument("--ema_max_decay", type=float, default=0.9999) | ||
339 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
380 | parser.add_argument( | 340 | parser.add_argument( |
381 | "--schedule_sampler", | 341 | "--schedule_sampler", |
382 | type=str, | 342 | type=str, |
383 | default="uniform", | 343 | default="uniform", |
384 | choices=["uniform", "loss-second-moment"], | 344 | choices=["uniform", "loss-second-moment"], |
385 | help="Noise schedule sampler." | 345 | help="Noise schedule sampler.", |
386 | ) | 346 | ) |
387 | parser.add_argument( | 347 | parser.add_argument( |
388 | "--optimizer", | 348 | "--optimizer", |
389 | type=str, | 349 | type=str, |
390 | default="adan", | 350 | default="adan", |
391 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 351 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
392 | help='Optimizer to use' | 352 | help="Optimizer to use", |
393 | ) | 353 | ) |
394 | parser.add_argument( | 354 | parser.add_argument( |
395 | "--dadaptation_d0", | 355 | "--dadaptation_d0", |
396 | type=float, | 356 | type=float, |
397 | default=1e-6, | 357 | default=1e-6, |
398 | help="The d0 parameter for Dadaptation optimizers." | 358 | help="The d0 parameter for Dadaptation optimizers.", |
399 | ) | 359 | ) |
400 | parser.add_argument( | 360 | parser.add_argument( |
401 | "--adam_beta1", | 361 | "--adam_beta1", |
402 | type=float, | 362 | type=float, |
403 | default=None, | 363 | default=None, |
404 | help="The beta1 parameter for the Adam optimizer." | 364 | help="The beta1 parameter for the Adam optimizer.", |
405 | ) | 365 | ) |
406 | parser.add_argument( | 366 | parser.add_argument( |
407 | "--adam_beta2", | 367 | "--adam_beta2", |
408 | type=float, | 368 | type=float, |
409 | default=None, | 369 | default=None, |
410 | help="The beta2 parameter for the Adam optimizer." | 370 | help="The beta2 parameter for the Adam optimizer.", |
411 | ) | 371 | ) |
412 | parser.add_argument( | 372 | parser.add_argument( |
413 | "--adam_weight_decay", | 373 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
414 | type=float, | ||
415 | default=2e-2, | ||
416 | help="Weight decay to use." | ||
417 | ) | 374 | ) |
418 | parser.add_argument( | 375 | parser.add_argument( |
419 | "--adam_epsilon", | 376 | "--adam_epsilon", |
420 | type=float, | 377 | type=float, |
421 | default=1e-08, | 378 | default=1e-08, |
422 | help="Epsilon value for the Adam optimizer" | 379 | help="Epsilon value for the Adam optimizer", |
423 | ) | 380 | ) |
424 | parser.add_argument( | 381 | parser.add_argument( |
425 | "--adam_amsgrad", | 382 | "--adam_amsgrad", |
426 | type=bool, | 383 | type=bool, |
427 | default=False, | 384 | default=False, |
428 | help="Amsgrad value for the Adam optimizer" | 385 | help="Amsgrad value for the Adam optimizer", |
429 | ) | 386 | ) |
430 | parser.add_argument( | 387 | parser.add_argument( |
431 | "--mixed_precision", | 388 | "--mixed_precision", |
@@ -456,7 +413,7 @@ def parse_args(): | |||
456 | ) | 413 | ) |
457 | parser.add_argument( | 414 | parser.add_argument( |
458 | "--no_milestone_checkpoints", | 415 | "--no_milestone_checkpoints", |
459 | action='store_true', | 416 | action="store_true", |
460 | help="If checkpoints are saved on maximum accuracy", | 417 | help="If checkpoints are saved on maximum accuracy", |
461 | ) | 418 | ) |
462 | parser.add_argument( | 419 | parser.add_argument( |
@@ -493,25 +450,25 @@ def parse_args(): | |||
493 | "--valid_set_size", | 450 | "--valid_set_size", |
494 | type=int, | 451 | type=int, |
495 | default=None, | 452 | default=None, |
496 | help="Number of images in the validation dataset." | 453 | help="Number of images in the validation dataset.", |
497 | ) | 454 | ) |
498 | parser.add_argument( | 455 | parser.add_argument( |
499 | "--train_set_pad", | 456 | "--train_set_pad", |
500 | type=int, | 457 | type=int, |
501 | default=None, | 458 | default=None, |
502 | help="The number to fill train dataset items up to." | 459 | help="The number to fill train dataset items up to.", |
503 | ) | 460 | ) |
504 | parser.add_argument( | 461 | parser.add_argument( |
505 | "--valid_set_pad", | 462 | "--valid_set_pad", |
506 | type=int, | 463 | type=int, |
507 | default=None, | 464 | default=None, |
508 | help="The number to fill validation dataset items up to." | 465 | help="The number to fill validation dataset items up to.", |
509 | ) | 466 | ) |
510 | parser.add_argument( | 467 | parser.add_argument( |
511 | "--train_batch_size", | 468 | "--train_batch_size", |
512 | type=int, | 469 | type=int, |
513 | default=1, | 470 | default=1, |
514 | help="Batch size (per device) for the training dataloader." | 471 | help="Batch size (per device) for the training dataloader.", |
515 | ) | 472 | ) |
516 | parser.add_argument( | 473 | parser.add_argument( |
517 | "--sample_steps", | 474 | "--sample_steps", |
@@ -523,14 +480,9 @@ def parse_args(): | |||
523 | "--prior_loss_weight", | 480 | "--prior_loss_weight", |
524 | type=float, | 481 | type=float, |
525 | default=1.0, | 482 | default=1.0, |
526 | help="The weight of prior preservation loss." | 483 | help="The weight of prior preservation loss.", |
527 | ) | ||
528 | parser.add_argument( | ||
529 | "--emb_alpha", | ||
530 | type=float, | ||
531 | default=1.0, | ||
532 | help="Embedding alpha" | ||
533 | ) | 484 | ) |
485 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
534 | parser.add_argument( | 486 | parser.add_argument( |
535 | "--emb_dropout", | 487 | "--emb_dropout", |
536 | type=float, | 488 | type=float, |
@@ -538,21 +490,13 @@ def parse_args(): | |||
538 | help="Embedding dropout probability.", | 490 | help="Embedding dropout probability.", |
539 | ) | 491 | ) |
540 | parser.add_argument( | 492 | parser.add_argument( |
541 | "--use_emb_decay", | 493 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." |
542 | action="store_true", | ||
543 | help="Whether to use embedding decay." | ||
544 | ) | 494 | ) |
545 | parser.add_argument( | 495 | parser.add_argument( |
546 | "--emb_decay_target", | 496 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." |
547 | default=0.4, | ||
548 | type=float, | ||
549 | help="Embedding decay target." | ||
550 | ) | 497 | ) |
551 | parser.add_argument( | 498 | parser.add_argument( |
552 | "--emb_decay", | 499 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." |
553 | default=1e+2, | ||
554 | type=float, | ||
555 | help="Embedding decay factor." | ||
556 | ) | 500 | ) |
557 | parser.add_argument( | 501 | parser.add_argument( |
558 | "--noise_timesteps", | 502 | "--noise_timesteps", |
@@ -563,7 +507,7 @@ def parse_args(): | |||
563 | "--resume_from", | 507 | "--resume_from", |
564 | type=str, | 508 | type=str, |
565 | default=None, | 509 | default=None, |
566 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" | 510 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)", |
567 | ) | 511 | ) |
568 | parser.add_argument( | 512 | parser.add_argument( |
569 | "--global_step", | 513 | "--global_step", |
@@ -574,7 +518,7 @@ def parse_args(): | |||
574 | "--config", | 518 | "--config", |
575 | type=str, | 519 | type=str, |
576 | default=None, | 520 | default=None, |
577 | help="Path to a JSON configuration file containing arguments for invoking this script." | 521 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
578 | ) | 522 | ) |
579 | 523 | ||
580 | args = parser.parse_args() | 524 | args = parser.parse_args() |
@@ -595,29 +539,44 @@ def parse_args(): | |||
595 | args.placeholder_tokens = [args.placeholder_tokens] | 539 | args.placeholder_tokens = [args.placeholder_tokens] |
596 | 540 | ||
597 | if isinstance(args.initializer_tokens, str): | 541 | if isinstance(args.initializer_tokens, str): |
598 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 542 | args.initializer_tokens = [args.initializer_tokens] * len( |
543 | args.placeholder_tokens | ||
544 | ) | ||
599 | 545 | ||
600 | if len(args.placeholder_tokens) == 0: | 546 | if len(args.placeholder_tokens) == 0: |
601 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 547 | args.placeholder_tokens = [ |
548 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
549 | ] | ||
602 | 550 | ||
603 | if len(args.initializer_tokens) == 0: | 551 | if len(args.initializer_tokens) == 0: |
604 | args.initializer_tokens = args.placeholder_tokens.copy() | 552 | args.initializer_tokens = args.placeholder_tokens.copy() |
605 | 553 | ||
606 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 554 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
607 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 555 | raise ValueError( |
556 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
557 | ) | ||
608 | 558 | ||
609 | if isinstance(args.inverted_initializer_tokens, str): | 559 | if isinstance(args.inverted_initializer_tokens, str): |
610 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | 560 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( |
561 | args.placeholder_tokens | ||
562 | ) | ||
611 | 563 | ||
612 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | 564 | if ( |
565 | isinstance(args.inverted_initializer_tokens, list) | ||
566 | and len(args.inverted_initializer_tokens) != 0 | ||
567 | ): | ||
613 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | 568 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] |
614 | args.initializer_tokens += args.inverted_initializer_tokens | 569 | args.initializer_tokens += args.inverted_initializer_tokens |
615 | 570 | ||
616 | if isinstance(args.num_vectors, int): | 571 | if isinstance(args.num_vectors, int): |
617 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 572 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
618 | 573 | ||
619 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 574 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( |
620 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 575 | args.num_vectors |
576 | ): | ||
577 | raise ValueError( | ||
578 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
579 | ) | ||
621 | 580 | ||
622 | if args.alias_tokens is None: | 581 | if args.alias_tokens is None: |
623 | args.alias_tokens = [] | 582 | args.alias_tokens = [] |
@@ -639,16 +598,22 @@ def parse_args(): | |||
639 | ] | 598 | ] |
640 | 599 | ||
641 | if isinstance(args.train_data_template, str): | 600 | if isinstance(args.train_data_template, str): |
642 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 601 | args.train_data_template = [args.train_data_template] * len( |
602 | args.placeholder_tokens | ||
603 | ) | ||
643 | 604 | ||
644 | if len(args.placeholder_tokens) != len(args.train_data_template): | 605 | if len(args.placeholder_tokens) != len(args.train_data_template): |
645 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 606 | raise ValueError( |
607 | "--placeholder_tokens and --train_data_template must have the same number of items" | ||
608 | ) | ||
646 | 609 | ||
647 | if args.num_vectors is None: | 610 | if args.num_vectors is None: |
648 | args.num_vectors = [None] * len(args.placeholder_tokens) | 611 | args.num_vectors = [None] * len(args.placeholder_tokens) |
649 | else: | 612 | else: |
650 | if isinstance(args.train_data_template, list): | 613 | if isinstance(args.train_data_template, list): |
651 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | 614 | raise ValueError( |
615 | "--train_data_template can't be a list in simultaneous mode" | ||
616 | ) | ||
652 | 617 | ||
653 | if isinstance(args.collection, str): | 618 | if isinstance(args.collection, str): |
654 | args.collection = [args.collection] | 619 | args.collection = [args.collection] |
@@ -660,13 +625,13 @@ def parse_args(): | |||
660 | raise ValueError("You must specify --output_dir") | 625 | raise ValueError("You must specify --output_dir") |
661 | 626 | ||
662 | if args.adam_beta1 is None: | 627 | if args.adam_beta1 is None: |
663 | if args.optimizer == 'lion': | 628 | if args.optimizer == "lion": |
664 | args.adam_beta1 = 0.95 | 629 | args.adam_beta1 = 0.95 |
665 | else: | 630 | else: |
666 | args.adam_beta1 = 0.9 | 631 | args.adam_beta1 = 0.9 |
667 | 632 | ||
668 | if args.adam_beta2 is None: | 633 | if args.adam_beta2 is None: |
669 | if args.optimizer == 'lion': | 634 | if args.optimizer == "lion": |
670 | args.adam_beta2 = 0.98 | 635 | args.adam_beta2 = 0.98 |
671 | else: | 636 | else: |
672 | args.adam_beta2 = 0.999 | 637 | args.adam_beta2 = 0.999 |
@@ -679,13 +644,13 @@ def main(): | |||
679 | 644 | ||
680 | global_step_offset = args.global_step | 645 | global_step_offset = args.global_step |
681 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 646 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
682 | output_dir = Path(args.output_dir)/slugify(args.project)/now | 647 | output_dir = Path(args.output_dir) / slugify(args.project) / now |
683 | output_dir.mkdir(parents=True, exist_ok=True) | 648 | output_dir.mkdir(parents=True, exist_ok=True) |
684 | 649 | ||
685 | accelerator = Accelerator( | 650 | accelerator = Accelerator( |
686 | log_with=LoggerType.TENSORBOARD, | 651 | log_with=LoggerType.TENSORBOARD, |
687 | project_dir=f"{output_dir}", | 652 | project_dir=f"{output_dir}", |
688 | mixed_precision=args.mixed_precision | 653 | mixed_precision=args.mixed_precision, |
689 | ) | 654 | ) |
690 | 655 | ||
691 | weight_dtype = torch.float32 | 656 | weight_dtype = torch.float32 |
@@ -703,9 +668,15 @@ def main(): | |||
703 | 668 | ||
704 | save_args(output_dir, args) | 669 | save_args(output_dir, args) |
705 | 670 | ||
706 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) | 671 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
707 | embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | 672 | args.pretrained_model_name_or_path |
708 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 673 | ) |
674 | embeddings = patch_managed_embeddings( | ||
675 | text_encoder, args.emb_alpha, args.emb_dropout | ||
676 | ) | ||
677 | schedule_sampler = create_named_schedule_sampler( | ||
678 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
679 | ) | ||
709 | 680 | ||
710 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 681 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
711 | tokenizer.set_dropout(args.vector_dropout) | 682 | tokenizer.set_dropout(args.vector_dropout) |
@@ -717,16 +688,16 @@ def main(): | |||
717 | unet.enable_xformers_memory_efficient_attention() | 688 | unet.enable_xformers_memory_efficient_attention() |
718 | elif args.compile_unet: | 689 | elif args.compile_unet: |
719 | unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | 690 | unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False |
720 | 691 | ||
721 | proc = AttnProcessor() | 692 | proc = AttnProcessor() |
722 | 693 | ||
723 | def fn_recursive_set_proc(module: torch.nn.Module): | 694 | def fn_recursive_set_proc(module: torch.nn.Module): |
724 | if hasattr(module, "processor"): | 695 | if hasattr(module, "processor"): |
725 | module.processor = proc | 696 | module.processor = proc |
726 | 697 | ||
727 | for child in module.children(): | 698 | for child in module.children(): |
728 | fn_recursive_set_proc(child) | 699 | fn_recursive_set_proc(child) |
729 | 700 | ||
730 | fn_recursive_set_proc(unet) | 701 | fn_recursive_set_proc(unet) |
731 | 702 | ||
732 | if args.gradient_checkpointing: | 703 | if args.gradient_checkpointing: |
@@ -751,18 +722,24 @@ def main(): | |||
751 | tokenizer=tokenizer, | 722 | tokenizer=tokenizer, |
752 | embeddings=embeddings, | 723 | embeddings=embeddings, |
753 | placeholder_tokens=alias_placeholder_tokens, | 724 | placeholder_tokens=alias_placeholder_tokens, |
754 | initializer_tokens=alias_initializer_tokens | 725 | initializer_tokens=alias_initializer_tokens, |
755 | ) | 726 | ) |
756 | embeddings.persist() | 727 | embeddings.persist() |
757 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 728 | print( |
729 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
730 | ) | ||
758 | 731 | ||
759 | if args.embeddings_dir is not None: | 732 | if args.embeddings_dir is not None: |
760 | embeddings_dir = Path(args.embeddings_dir) | 733 | embeddings_dir = Path(args.embeddings_dir) |
761 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 734 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
762 | raise ValueError("--embeddings_dir must point to an existing directory") | 735 | raise ValueError("--embeddings_dir must point to an existing directory") |
763 | 736 | ||
764 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 737 | added_tokens, added_ids = load_embeddings_from_dir( |
765 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 738 | tokenizer, embeddings, embeddings_dir |
739 | ) | ||
740 | print( | ||
741 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
742 | ) | ||
766 | 743 | ||
767 | if args.train_dir_embeddings: | 744 | if args.train_dir_embeddings: |
768 | args.placeholder_tokens = added_tokens | 745 | args.placeholder_tokens = added_tokens |
@@ -772,19 +749,23 @@ def main(): | |||
772 | 749 | ||
773 | if args.scale_lr: | 750 | if args.scale_lr: |
774 | args.learning_rate = ( | 751 | args.learning_rate = ( |
775 | args.learning_rate * args.gradient_accumulation_steps * | 752 | args.learning_rate |
776 | args.train_batch_size * accelerator.num_processes | 753 | * args.gradient_accumulation_steps |
754 | * args.train_batch_size | ||
755 | * accelerator.num_processes | ||
777 | ) | 756 | ) |
778 | 757 | ||
779 | if args.find_lr: | 758 | if args.find_lr: |
780 | args.learning_rate = 1e-5 | 759 | args.learning_rate = 1e-5 |
781 | args.lr_scheduler = "exponential_growth" | 760 | args.lr_scheduler = "exponential_growth" |
782 | 761 | ||
783 | if args.optimizer == 'adam8bit': | 762 | if args.optimizer == "adam8bit": |
784 | try: | 763 | try: |
785 | import bitsandbytes as bnb | 764 | import bitsandbytes as bnb |
786 | except ImportError: | 765 | except ImportError: |
787 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 766 | raise ImportError( |
767 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
768 | ) | ||
788 | 769 | ||
789 | create_optimizer = partial( | 770 | create_optimizer = partial( |
790 | bnb.optim.AdamW8bit, | 771 | bnb.optim.AdamW8bit, |
@@ -793,7 +774,7 @@ def main(): | |||
793 | eps=args.adam_epsilon, | 774 | eps=args.adam_epsilon, |
794 | amsgrad=args.adam_amsgrad, | 775 | amsgrad=args.adam_amsgrad, |
795 | ) | 776 | ) |
796 | elif args.optimizer == 'adam': | 777 | elif args.optimizer == "adam": |
797 | create_optimizer = partial( | 778 | create_optimizer = partial( |
798 | torch.optim.AdamW, | 779 | torch.optim.AdamW, |
799 | betas=(args.adam_beta1, args.adam_beta2), | 780 | betas=(args.adam_beta1, args.adam_beta2), |
@@ -801,11 +782,13 @@ def main(): | |||
801 | eps=args.adam_epsilon, | 782 | eps=args.adam_epsilon, |
802 | amsgrad=args.adam_amsgrad, | 783 | amsgrad=args.adam_amsgrad, |
803 | ) | 784 | ) |
804 | elif args.optimizer == 'adan': | 785 | elif args.optimizer == "adan": |
805 | try: | 786 | try: |
806 | import timm.optim | 787 | import timm.optim |
807 | except ImportError: | 788 | except ImportError: |
808 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 789 | raise ImportError( |
790 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
791 | ) | ||
809 | 792 | ||
810 | create_optimizer = partial( | 793 | create_optimizer = partial( |
811 | timm.optim.Adan, | 794 | timm.optim.Adan, |
@@ -813,11 +796,13 @@ def main(): | |||
813 | eps=args.adam_epsilon, | 796 | eps=args.adam_epsilon, |
814 | no_prox=True, | 797 | no_prox=True, |
815 | ) | 798 | ) |
816 | elif args.optimizer == 'lion': | 799 | elif args.optimizer == "lion": |
817 | try: | 800 | try: |
818 | import lion_pytorch | 801 | import lion_pytorch |
819 | except ImportError: | 802 | except ImportError: |
820 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 803 | raise ImportError( |
804 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
805 | ) | ||
821 | 806 | ||
822 | create_optimizer = partial( | 807 | create_optimizer = partial( |
823 | lion_pytorch.Lion, | 808 | lion_pytorch.Lion, |
@@ -825,7 +810,7 @@ def main(): | |||
825 | weight_decay=args.adam_weight_decay, | 810 | weight_decay=args.adam_weight_decay, |
826 | use_triton=True, | 811 | use_triton=True, |
827 | ) | 812 | ) |
828 | elif args.optimizer == 'adafactor': | 813 | elif args.optimizer == "adafactor": |
829 | create_optimizer = partial( | 814 | create_optimizer = partial( |
830 | transformers.optimization.Adafactor, | 815 | transformers.optimization.Adafactor, |
831 | weight_decay=args.adam_weight_decay, | 816 | weight_decay=args.adam_weight_decay, |
@@ -837,11 +822,13 @@ def main(): | |||
837 | args.lr_scheduler = "adafactor" | 822 | args.lr_scheduler = "adafactor" |
838 | args.lr_min_lr = args.learning_rate | 823 | args.lr_min_lr = args.learning_rate |
839 | args.learning_rate = None | 824 | args.learning_rate = None |
840 | elif args.optimizer == 'dadam': | 825 | elif args.optimizer == "dadam": |
841 | try: | 826 | try: |
842 | import dadaptation | 827 | import dadaptation |
843 | except ImportError: | 828 | except ImportError: |
844 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 829 | raise ImportError( |
830 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
831 | ) | ||
845 | 832 | ||
846 | create_optimizer = partial( | 833 | create_optimizer = partial( |
847 | dadaptation.DAdaptAdam, | 834 | dadaptation.DAdaptAdam, |
@@ -851,11 +838,13 @@ def main(): | |||
851 | decouple=True, | 838 | decouple=True, |
852 | d0=args.dadaptation_d0, | 839 | d0=args.dadaptation_d0, |
853 | ) | 840 | ) |
854 | elif args.optimizer == 'dadan': | 841 | elif args.optimizer == "dadan": |
855 | try: | 842 | try: |
856 | import dadaptation | 843 | import dadaptation |
857 | except ImportError: | 844 | except ImportError: |
858 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 845 | raise ImportError( |
846 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
847 | ) | ||
859 | 848 | ||
860 | create_optimizer = partial( | 849 | create_optimizer = partial( |
861 | dadaptation.DAdaptAdan, | 850 | dadaptation.DAdaptAdan, |
@@ -864,7 +853,7 @@ def main(): | |||
864 | d0=args.dadaptation_d0, | 853 | d0=args.dadaptation_d0, |
865 | ) | 854 | ) |
866 | else: | 855 | else: |
867 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 856 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
868 | 857 | ||
869 | trainer = partial( | 858 | trainer = partial( |
870 | train, | 859 | train, |
@@ -904,10 +893,21 @@ def main(): | |||
904 | sample_image_size=args.sample_image_size, | 893 | sample_image_size=args.sample_image_size, |
905 | ) | 894 | ) |
906 | 895 | ||
896 | optimizer = create_optimizer( | ||
897 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
898 | lr=learning_rate, | ||
899 | ) | ||
900 | |||
907 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 901 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
908 | data_npgenerator = np.random.default_rng(args.seed) | 902 | data_npgenerator = np.random.default_rng(args.seed) |
909 | 903 | ||
910 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 904 | def run( |
905 | i: int, | ||
906 | placeholder_tokens: list[str], | ||
907 | initializer_tokens: list[str], | ||
908 | num_vectors: Union[int, list[int]], | ||
909 | data_template: str, | ||
910 | ): | ||
911 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 911 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
912 | tokenizer=tokenizer, | 912 | tokenizer=tokenizer, |
913 | embeddings=embeddings, | 913 | embeddings=embeddings, |
@@ -917,14 +917,23 @@ def main(): | |||
917 | initializer_noise=args.initializer_noise, | 917 | initializer_noise=args.initializer_noise, |
918 | ) | 918 | ) |
919 | 919 | ||
920 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) | 920 | stats = list( |
921 | zip( | ||
922 | placeholder_tokens, | ||
923 | placeholder_token_ids, | ||
924 | initializer_tokens, | ||
925 | initializer_token_ids, | ||
926 | ) | ||
927 | ) | ||
921 | 928 | ||
922 | print("") | 929 | print("") |
923 | print(f"============ TI batch {i + 1} ============") | 930 | print(f"============ TI batch {i + 1} ============") |
924 | print("") | 931 | print("") |
925 | print(stats) | 932 | print(stats) |
926 | 933 | ||
927 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 934 | filter_tokens = [ |
935 | token for token in args.filter_tokens if token in placeholder_tokens | ||
936 | ] | ||
928 | 937 | ||
929 | datamodule = VlpnDataModule( | 938 | datamodule = VlpnDataModule( |
930 | data_file=args.train_data_file, | 939 | data_file=args.train_data_file, |
@@ -945,7 +954,9 @@ def main(): | |||
945 | valid_set_size=args.valid_set_size, | 954 | valid_set_size=args.valid_set_size, |
946 | train_set_pad=args.train_set_pad, | 955 | train_set_pad=args.train_set_pad, |
947 | valid_set_pad=args.valid_set_pad, | 956 | valid_set_pad=args.valid_set_pad, |
948 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 957 | filter=partial( |
958 | keyword_filter, filter_tokens, args.collection, args.exclude_collections | ||
959 | ), | ||
949 | dtype=weight_dtype, | 960 | dtype=weight_dtype, |
950 | generator=data_generator, | 961 | generator=data_generator, |
951 | npgenerator=data_npgenerator, | 962 | npgenerator=data_npgenerator, |
@@ -955,11 +966,16 @@ def main(): | |||
955 | num_train_epochs = args.num_train_epochs | 966 | num_train_epochs = args.num_train_epochs |
956 | sample_frequency = args.sample_frequency | 967 | sample_frequency = args.sample_frequency |
957 | if num_train_epochs is None: | 968 | if num_train_epochs is None: |
958 | num_train_epochs = math.ceil( | 969 | num_train_epochs = ( |
959 | args.num_train_steps / len(datamodule.train_dataset) | 970 | math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
960 | ) * args.gradient_accumulation_steps | 971 | * args.gradient_accumulation_steps |
961 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 972 | ) |
962 | num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) | 973 | sample_frequency = math.ceil( |
974 | num_train_epochs * (sample_frequency / args.num_train_steps) | ||
975 | ) | ||
976 | num_training_steps_per_epoch = math.ceil( | ||
977 | len(datamodule.train_dataset) / args.gradient_accumulation_steps | ||
978 | ) | ||
963 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 979 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
964 | if args.sample_num is not None: | 980 | if args.sample_num is not None: |
965 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 981 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
@@ -988,7 +1004,8 @@ def main(): | |||
988 | response = auto_cycles.pop(0) | 1004 | response = auto_cycles.pop(0) |
989 | else: | 1005 | else: |
990 | response = input( | 1006 | response = input( |
991 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 1007 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " |
1008 | ) | ||
992 | 1009 | ||
993 | if response.lower().strip() == "o": | 1010 | if response.lower().strip() == "o": |
994 | if args.learning_rate is not None: | 1011 | if args.learning_rate is not None: |
@@ -1018,10 +1035,8 @@ def main(): | |||
1018 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") | 1035 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") |
1019 | print("") | 1036 | print("") |
1020 | 1037 | ||
1021 | optimizer = create_optimizer( | 1038 | for group, lr in zip(optimizer.param_groups, [learning_rate]): |
1022 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 1039 | group["lr"] = lr |
1023 | lr=learning_rate, | ||
1024 | ) | ||
1025 | 1040 | ||
1026 | lr_scheduler = get_scheduler( | 1041 | lr_scheduler = get_scheduler( |
1027 | lr_scheduler, | 1042 | lr_scheduler, |
@@ -1040,7 +1055,9 @@ def main(): | |||
1040 | mid_point=args.lr_mid_point, | 1055 | mid_point=args.lr_mid_point, |
1041 | ) | 1056 | ) |
1042 | 1057 | ||
1043 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" | 1058 | checkpoint_output_dir = ( |
1059 | output_dir / project / f"checkpoints_{training_iter}" | ||
1060 | ) | ||
1044 | 1061 | ||
1045 | trainer( | 1062 | trainer( |
1046 | train_dataloader=datamodule.train_dataloader, | 1063 | train_dataloader=datamodule.train_dataloader, |
@@ -1070,14 +1087,20 @@ def main(): | |||
1070 | accelerator.end_training() | 1087 | accelerator.end_training() |
1071 | 1088 | ||
1072 | if not args.sequential: | 1089 | if not args.sequential: |
1073 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 1090 | run( |
1091 | 0, | ||
1092 | args.placeholder_tokens, | ||
1093 | args.initializer_tokens, | ||
1094 | args.num_vectors, | ||
1095 | args.train_data_template, | ||
1096 | ) | ||
1074 | else: | 1097 | else: |
1075 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 1098 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
1076 | range(len(args.placeholder_tokens)), | 1099 | range(len(args.placeholder_tokens)), |
1077 | args.placeholder_tokens, | 1100 | args.placeholder_tokens, |
1078 | args.initializer_tokens, | 1101 | args.initializer_tokens, |
1079 | args.num_vectors, | 1102 | args.num_vectors, |
1080 | args.train_data_template | 1103 | args.train_data_template, |
1081 | ): | 1104 | ): |
1082 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) | 1105 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) |
1083 | embeddings.persist() | 1106 | embeddings.persist() |
diff --git a/training/functional.py b/training/functional.py index fd3f9f4..f68faf9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -14,7 +14,13 @@ import numpy as np | |||
14 | 14 | ||
15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
16 | from transformers import CLIPTextModel | 16 | from transformers import CLIPTextModel |
17 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin | 17 | from diffusers import ( |
18 | AutoencoderKL, | ||
19 | UNet2DConditionModel, | ||
20 | DDPMScheduler, | ||
21 | UniPCMultistepScheduler, | ||
22 | SchedulerMixin, | ||
23 | ) | ||
18 | 24 | ||
19 | from tqdm.auto import tqdm | 25 | from tqdm.auto import tqdm |
20 | 26 | ||
@@ -33,11 +39,12 @@ from util.noise import perlin_noise | |||
33 | def const(result=None): | 39 | def const(result=None): |
34 | def fn(*args, **kwargs): | 40 | def fn(*args, **kwargs): |
35 | return result | 41 | return result |
42 | |||
36 | return fn | 43 | return fn |
37 | 44 | ||
38 | 45 | ||
39 | @dataclass | 46 | @dataclass |
40 | class TrainingCallbacks(): | 47 | class TrainingCallbacks: |
41 | on_log: Callable[[], dict[str, Any]] = const({}) | 48 | on_log: Callable[[], dict[str, Any]] = const({}) |
42 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 49 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
43 | on_before_optimize: Callable[[int], Any] = const() | 50 | on_before_optimize: Callable[[int], Any] = const() |
@@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol): | |||
58 | train_dataloader: DataLoader, | 65 | train_dataloader: DataLoader, |
59 | val_dataloader: Optional[DataLoader], | 66 | val_dataloader: Optional[DataLoader], |
60 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 67 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
61 | **kwargs | 68 | **kwargs, |
62 | ) -> Tuple: ... | 69 | ) -> Tuple: |
70 | ... | ||
63 | 71 | ||
64 | 72 | ||
65 | @dataclass | 73 | @dataclass |
66 | class TrainingStrategy(): | 74 | class TrainingStrategy: |
67 | callbacks: Callable[..., TrainingCallbacks] | 75 | callbacks: Callable[..., TrainingCallbacks] |
68 | prepare: TrainingStrategyPrepareCallable | 76 | prepare: TrainingStrategyPrepareCallable |
69 | 77 | ||
70 | 78 | ||
71 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): | 79 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): |
72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 80 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) | 81 | pretrained_model_name_or_path, subfolder="tokenizer" |
74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) | 82 | ) |
75 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) | 83 | text_encoder = CLIPTextModel.from_pretrained( |
76 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 84 | pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype |
77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 85 | ) |
86 | vae = AutoencoderKL.from_pretrained( | ||
87 | pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype | ||
88 | ) | ||
89 | unet = UNet2DConditionModel.from_pretrained( | ||
90 | pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype | ||
91 | ) | ||
92 | noise_scheduler = DDPMScheduler.from_pretrained( | ||
93 | pretrained_model_name_or_path, subfolder="scheduler" | ||
94 | ) | ||
95 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | ||
96 | pretrained_model_name_or_path, subfolder="scheduler" | ||
97 | ) | ||
78 | 98 | ||
79 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 99 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
80 | 100 | ||
@@ -113,7 +133,9 @@ def save_samples( | |||
113 | 133 | ||
114 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) | 134 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
115 | 135 | ||
116 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] | 136 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [ |
137 | ("train", train_dataloader, None) | ||
138 | ] | ||
117 | 139 | ||
118 | if val_dataloader is not None: | 140 | if val_dataloader is not None: |
119 | datasets.append(("stable", val_dataloader, generator)) | 141 | datasets.append(("stable", val_dataloader, generator)) |
@@ -124,17 +146,11 @@ def save_samples( | |||
124 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" | 146 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" |
125 | file_path.parent.mkdir(parents=True, exist_ok=True) | 147 | file_path.parent.mkdir(parents=True, exist_ok=True) |
126 | 148 | ||
127 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 149 | batches = list( |
128 | prompt_ids = [ | 150 | itertools.islice(itertools.cycle(data), batch_size * num_batches) |
129 | prompt | 151 | ) |
130 | for batch in batches | 152 | prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]] |
131 | for prompt in batch["prompt_ids"] | 153 | nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]] |
132 | ] | ||
133 | nprompt_ids = [ | ||
134 | prompt | ||
135 | for batch in batches | ||
136 | for prompt in batch["nprompt_ids"] | ||
137 | ] | ||
138 | 154 | ||
139 | with torch.inference_mode(): | 155 | with torch.inference_mode(): |
140 | for i in range(num_batches): | 156 | for i in range(num_batches): |
@@ -165,7 +181,9 @@ def save_samples( | |||
165 | pass | 181 | pass |
166 | 182 | ||
167 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 183 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
168 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | 184 | image_grid = pipeline.numpy_to_pil( |
185 | image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() | ||
186 | )[0] | ||
169 | image_grid.save(file_path, quality=85) | 187 | image_grid.save(file_path, quality=85) |
170 | 188 | ||
171 | del generator, pipeline | 189 | del generator, pipeline |
@@ -184,15 +202,17 @@ def generate_class_images( | |||
184 | train_dataset: VlpnDataset, | 202 | train_dataset: VlpnDataset, |
185 | sample_batch_size: int, | 203 | sample_batch_size: int, |
186 | sample_image_size: int, | 204 | sample_image_size: int, |
187 | sample_steps: int | 205 | sample_steps: int, |
188 | ): | 206 | ): |
189 | missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] | 207 | missing_data = [ |
208 | item for item in train_dataset.items if not item.class_image_path.exists() | ||
209 | ] | ||
190 | 210 | ||
191 | if len(missing_data) == 0: | 211 | if len(missing_data) == 0: |
192 | return | 212 | return |
193 | 213 | ||
194 | batched_data = [ | 214 | batched_data = [ |
195 | missing_data[i:i+sample_batch_size] | 215 | missing_data[i : i + sample_batch_size] |
196 | for i in range(0, len(missing_data), sample_batch_size) | 216 | for i in range(0, len(missing_data), sample_batch_size) |
197 | ] | 217 | ] |
198 | 218 | ||
@@ -216,7 +236,7 @@ def generate_class_images( | |||
216 | negative_prompt=nprompt, | 236 | negative_prompt=nprompt, |
217 | height=sample_image_size, | 237 | height=sample_image_size, |
218 | width=sample_image_size, | 238 | width=sample_image_size, |
219 | num_inference_steps=sample_steps | 239 | num_inference_steps=sample_steps, |
220 | ).images | 240 | ).images |
221 | 241 | ||
222 | for i, image in enumerate(images): | 242 | for i, image in enumerate(images): |
@@ -245,8 +265,12 @@ def add_placeholder_tokens( | |||
245 | 265 | ||
246 | embeddings.resize(len(tokenizer)) | 266 | embeddings.resize(len(tokenizer)) |
247 | 267 | ||
248 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 268 | for placeholder_token_id, initializer_token_id in zip( |
249 | embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) | 269 | placeholder_token_ids, initializer_token_ids |
270 | ): | ||
271 | embeddings.add_embed( | ||
272 | placeholder_token_id, initializer_token_id, initializer_noise | ||
273 | ) | ||
250 | 274 | ||
251 | return placeholder_token_ids, initializer_token_ids | 275 | return placeholder_token_ids, initializer_token_ids |
252 | 276 | ||
@@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler): | |||
261 | 285 | ||
262 | # Expand the tensors. | 286 | # Expand the tensors. |
263 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 | 287 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 |
264 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 288 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
289 | timesteps | ||
290 | ].float() | ||
265 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): | 291 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
266 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 292 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
267 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) | 293 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
268 | 294 | ||
269 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 295 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
296 | device=timesteps.device | ||
297 | )[timesteps].float() | ||
270 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): | 298 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
271 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 299 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
272 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | 300 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
@@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler): | |||
277 | 305 | ||
278 | 306 | ||
279 | def get_original( | 307 | def get_original( |
280 | noise_scheduler, | 308 | noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor |
281 | model_output, | ||
282 | sample: torch.FloatTensor, | ||
283 | timesteps: torch.IntTensor | ||
284 | ): | 309 | ): |
285 | alphas_cumprod = noise_scheduler.alphas_cumprod | 310 | alphas_cumprod = noise_scheduler.alphas_cumprod |
286 | sqrt_alphas_cumprod = alphas_cumprod**0.5 | 311 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
287 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | 312 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
288 | 313 | ||
289 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 314 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
315 | timesteps | ||
316 | ].float() | ||
290 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | 317 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): |
291 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 318 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
292 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | 319 | alpha = sqrt_alphas_cumprod.expand(sample.shape) |
293 | 320 | ||
294 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 321 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
322 | device=timesteps.device | ||
323 | )[timesteps].float() | ||
295 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | 324 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): |
296 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 325 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
297 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | 326 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) |
@@ -329,7 +358,9 @@ def loss_step( | |||
329 | eval: bool = False, | 358 | eval: bool = False, |
330 | ): | 359 | ): |
331 | images = batch["pixel_values"] | 360 | images = batch["pixel_values"] |
332 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | 361 | generator = ( |
362 | torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | ||
363 | ) | ||
333 | bsz = images.shape[0] | 364 | bsz = images.shape[0] |
334 | 365 | ||
335 | # Convert images to latent space | 366 | # Convert images to latent space |
@@ -342,7 +373,7 @@ def loss_step( | |||
342 | dtype=latents.dtype, | 373 | dtype=latents.dtype, |
343 | layout=latents.layout, | 374 | layout=latents.layout, |
344 | device=latents.device, | 375 | device=latents.device, |
345 | generator=generator | 376 | generator=generator, |
346 | ) | 377 | ) |
347 | applied_noise = noise | 378 | applied_noise = noise |
348 | 379 | ||
@@ -353,7 +384,7 @@ def loss_step( | |||
353 | octaves=4, | 384 | octaves=4, |
354 | dtype=latents.dtype, | 385 | dtype=latents.dtype, |
355 | device=latents.device, | 386 | device=latents.device, |
356 | generator=generator | 387 | generator=generator, |
357 | ) | 388 | ) |
358 | 389 | ||
359 | if input_pertubation != 0: | 390 | if input_pertubation != 0: |
@@ -362,7 +393,7 @@ def loss_step( | |||
362 | dtype=latents.dtype, | 393 | dtype=latents.dtype, |
363 | layout=latents.layout, | 394 | layout=latents.layout, |
364 | device=latents.device, | 395 | device=latents.device, |
365 | generator=generator | 396 | generator=generator, |
366 | ) | 397 | ) |
367 | 398 | ||
368 | # Sample a random timestep for each image | 399 | # Sample a random timestep for each image |
@@ -375,25 +406,27 @@ def loss_step( | |||
375 | 406 | ||
376 | # Get the text embedding for conditioning | 407 | # Get the text embedding for conditioning |
377 | encoder_hidden_states = get_extended_embeddings( | 408 | encoder_hidden_states = get_extended_embeddings( |
378 | text_encoder, | 409 | text_encoder, batch["input_ids"], batch["attention_mask"] |
379 | batch["input_ids"], | ||
380 | batch["attention_mask"] | ||
381 | ) | 410 | ) |
382 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 411 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
383 | 412 | ||
384 | # Predict the noise residual | 413 | # Predict the noise residual |
385 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] | 414 | model_pred = unet( |
415 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False | ||
416 | )[0] | ||
386 | 417 | ||
387 | if guidance_scale != 0: | 418 | if guidance_scale != 0: |
388 | uncond_encoder_hidden_states = get_extended_embeddings( | 419 | uncond_encoder_hidden_states = get_extended_embeddings( |
389 | text_encoder, | 420 | text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] |
390 | batch["negative_input_ids"], | ||
391 | batch["negative_attention_mask"] | ||
392 | ) | 421 | ) |
393 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | 422 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) |
394 | 423 | ||
395 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] | 424 | model_pred_uncond = unet( |
396 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 425 | noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False |
426 | )[0] | ||
427 | model_pred = model_pred_uncond + guidance_scale * ( | ||
428 | model_pred - model_pred_uncond | ||
429 | ) | ||
397 | 430 | ||
398 | # Get the target for loss depending on the prediction type | 431 | # Get the target for loss depending on the prediction type |
399 | if noise_scheduler.config.prediction_type == "epsilon": | 432 | if noise_scheduler.config.prediction_type == "epsilon": |
@@ -401,7 +434,9 @@ def loss_step( | |||
401 | elif noise_scheduler.config.prediction_type == "v_prediction": | 434 | elif noise_scheduler.config.prediction_type == "v_prediction": |
402 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 435 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
403 | else: | 436 | else: |
404 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 437 | raise ValueError( |
438 | f"Unknown prediction type {noise_scheduler.config.prediction_type}" | ||
439 | ) | ||
405 | 440 | ||
406 | acc = (model_pred == target).float().mean() | 441 | acc = (model_pred == target).float().mean() |
407 | 442 | ||
@@ -414,7 +449,9 @@ def loss_step( | |||
414 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 449 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
415 | 450 | ||
416 | # Compute prior loss | 451 | # Compute prior loss |
417 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") | 452 | prior_loss = F.mse_loss( |
453 | model_pred_prior.float(), target_prior.float(), reduction="none" | ||
454 | ) | ||
418 | 455 | ||
419 | # Add the prior loss to the instance loss. | 456 | # Add the prior loss to the instance loss. |
420 | loss = loss + prior_loss_weight * prior_loss | 457 | loss = loss + prior_loss_weight * prior_loss |
@@ -433,7 +470,10 @@ def loss_step( | |||
433 | if min_snr_gamma != 0: | 470 | if min_snr_gamma != 0: |
434 | snr = compute_snr(timesteps, noise_scheduler) | 471 | snr = compute_snr(timesteps, noise_scheduler) |
435 | mse_loss_weights = ( | 472 | mse_loss_weights = ( |
436 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | 473 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( |
474 | dim=1 | ||
475 | )[0] | ||
476 | / snr | ||
437 | ) | 477 | ) |
438 | loss = loss * mse_loss_weights | 478 | loss = loss * mse_loss_weights |
439 | 479 | ||
@@ -447,8 +487,14 @@ def loss_step( | |||
447 | 487 | ||
448 | 488 | ||
449 | class LossCallable(Protocol): | 489 | class LossCallable(Protocol): |
450 | def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], | 490 | def __call__( |
451 | eval: bool = False) -> Tuple[Any, Any, int]: ... | 491 | self, |
492 | step: int, | ||
493 | batch: dict[Any, Any], | ||
494 | cache: dict[str, Any], | ||
495 | eval: bool = False, | ||
496 | ) -> Tuple[Any, Any, int]: | ||
497 | ... | ||
452 | 498 | ||
453 | 499 | ||
454 | def train_loop( | 500 | def train_loop( |
@@ -472,9 +518,14 @@ def train_loop( | |||
472 | avg_acc_val: AverageMeter = AverageMeter(), | 518 | avg_acc_val: AverageMeter = AverageMeter(), |
473 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 519 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
474 | ): | 520 | ): |
475 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 521 | num_training_steps_per_epoch = math.ceil( |
476 | num_val_steps_per_epoch = math.ceil( | 522 | len(train_dataloader) / gradient_accumulation_steps |
477 | len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 | 523 | ) |
524 | num_val_steps_per_epoch = ( | ||
525 | math.ceil(len(val_dataloader) / gradient_accumulation_steps) | ||
526 | if val_dataloader is not None | ||
527 | else 0 | ||
528 | ) | ||
478 | 529 | ||
479 | num_training_steps = num_training_steps_per_epoch * num_epochs | 530 | num_training_steps = num_training_steps_per_epoch * num_epochs |
480 | num_val_steps = num_val_steps_per_epoch * num_epochs | 531 | num_val_steps = num_val_steps_per_epoch * num_epochs |
@@ -488,14 +539,14 @@ def train_loop( | |||
488 | local_progress_bar = tqdm( | 539 | local_progress_bar = tqdm( |
489 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 540 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
490 | disable=not accelerator.is_local_main_process, | 541 | disable=not accelerator.is_local_main_process, |
491 | dynamic_ncols=True | 542 | dynamic_ncols=True, |
492 | ) | 543 | ) |
493 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | 544 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
494 | 545 | ||
495 | global_progress_bar = tqdm( | 546 | global_progress_bar = tqdm( |
496 | range(num_training_steps + num_val_steps), | 547 | range(num_training_steps + num_val_steps), |
497 | disable=not accelerator.is_local_main_process, | 548 | disable=not accelerator.is_local_main_process, |
498 | dynamic_ncols=True | 549 | dynamic_ncols=True, |
499 | ) | 550 | ) |
500 | global_progress_bar.set_description("Total progress") | 551 | global_progress_bar.set_description("Total progress") |
501 | 552 | ||
@@ -513,7 +564,9 @@ def train_loop( | |||
513 | try: | 564 | try: |
514 | import dadaptation | 565 | import dadaptation |
515 | 566 | ||
516 | isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) | 567 | isDadaptation = isinstance( |
568 | optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan) | ||
569 | ) | ||
517 | except ImportError: | 570 | except ImportError: |
518 | pass | 571 | pass |
519 | 572 | ||
@@ -565,7 +618,10 @@ def train_loop( | |||
565 | label = group_labels[i] if i < len(group_labels) else f"{i}" | 618 | label = group_labels[i] if i < len(group_labels) else f"{i}" |
566 | logs[f"lr/{label}"] = lr | 619 | logs[f"lr/{label}"] = lr |
567 | if isDadaptation: | 620 | if isDadaptation: |
568 | lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] | 621 | lr = ( |
622 | optimizer.param_groups[i]["d"] | ||
623 | * optimizer.param_groups[i]["lr"] | ||
624 | ) | ||
569 | logs[f"d*lr/{label}"] = lr | 625 | logs[f"d*lr/{label}"] = lr |
570 | lrs[label] = lr | 626 | lrs[label] = lr |
571 | 627 | ||
@@ -573,8 +629,10 @@ def train_loop( | |||
573 | 629 | ||
574 | local_progress_bar.set_postfix(**logs) | 630 | local_progress_bar.set_postfix(**logs) |
575 | 631 | ||
576 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 632 | if ((step + 1) % gradient_accumulation_steps == 0) or ( |
577 | before_optimize_result = on_before_optimize(epoch) | 633 | (step + 1) == len(train_dataloader) |
634 | ): | ||
635 | before_optimize_result = on_before_optimize(cycle) | ||
578 | 636 | ||
579 | optimizer.step() | 637 | optimizer.step() |
580 | lr_scheduler.step() | 638 | lr_scheduler.step() |
@@ -614,7 +672,9 @@ def train_loop( | |||
614 | } | 672 | } |
615 | local_progress_bar.set_postfix(**logs) | 673 | local_progress_bar.set_postfix(**logs) |
616 | 674 | ||
617 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): | 675 | if ((step + 1) % gradient_accumulation_steps == 0) or ( |
676 | (step + 1) == len(val_dataloader) | ||
677 | ): | ||
618 | local_progress_bar.update(1) | 678 | local_progress_bar.update(1) |
619 | global_progress_bar.update(1) | 679 | global_progress_bar.update(1) |
620 | 680 | ||
@@ -634,7 +694,8 @@ def train_loop( | |||
634 | global_progress_bar.clear() | 694 | global_progress_bar.clear() |
635 | 695 | ||
636 | accelerator.print( | 696 | accelerator.print( |
637 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") | 697 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}" |
698 | ) | ||
638 | on_checkpoint(global_step, "milestone") | 699 | on_checkpoint(global_step, "milestone") |
639 | best_acc_val = avg_acc_val.max | 700 | best_acc_val = avg_acc_val.max |
640 | else: | 701 | else: |
@@ -644,7 +705,8 @@ def train_loop( | |||
644 | global_progress_bar.clear() | 705 | global_progress_bar.clear() |
645 | 706 | ||
646 | accelerator.print( | 707 | accelerator.print( |
647 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") | 708 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}" |
709 | ) | ||
648 | on_checkpoint(global_step, "milestone") | 710 | on_checkpoint(global_step, "milestone") |
649 | best_acc = avg_acc.max | 711 | best_acc = avg_acc.max |
650 | 712 | ||
@@ -700,17 +762,32 @@ def train( | |||
700 | avg_acc_val: AverageMeter = AverageMeter(), | 762 | avg_acc_val: AverageMeter = AverageMeter(), |
701 | **kwargs, | 763 | **kwargs, |
702 | ): | 764 | ): |
703 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 765 | ( |
704 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) | 766 | text_encoder, |
767 | unet, | ||
768 | optimizer, | ||
769 | train_dataloader, | ||
770 | val_dataloader, | ||
771 | lr_scheduler, | ||
772 | ) = strategy.prepare( | ||
773 | accelerator, | ||
774 | text_encoder, | ||
775 | unet, | ||
776 | optimizer, | ||
777 | train_dataloader, | ||
778 | val_dataloader, | ||
779 | lr_scheduler, | ||
780 | **kwargs, | ||
781 | ) | ||
705 | 782 | ||
706 | vae.to(accelerator.device, dtype=dtype) | 783 | vae.to(accelerator.device, dtype=dtype) |
707 | vae.requires_grad_(False) | 784 | vae.requires_grad_(False) |
708 | vae.eval() | 785 | vae.eval() |
709 | 786 | ||
710 | vae = torch.compile(vae, backend='hidet') | 787 | vae = torch.compile(vae, backend="hidet") |
711 | 788 | ||
712 | if compile_unet: | 789 | if compile_unet: |
713 | unet = torch.compile(unet, backend='hidet') | 790 | unet = torch.compile(unet, backend="hidet") |
714 | # unet = torch.compile(unet, mode="reduce-overhead") | 791 | # unet = torch.compile(unet, mode="reduce-overhead") |
715 | 792 | ||
716 | callbacks = strategy.callbacks( | 793 | callbacks = strategy.callbacks( |
diff --git a/training/lr.py b/training/lr.py index f5b362f..a75078f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -23,12 +23,12 @@ def plot_metrics( | |||
23 | fig, ax_loss = plt.subplots() | 23 | fig, ax_loss = plt.subplots() |
24 | ax_acc = ax_loss.twinx() | 24 | ax_acc = ax_loss.twinx() |
25 | 25 | ||
26 | ax_loss.plot(lrs, losses, color='red') | 26 | ax_loss.plot(lrs, losses, color="red") |
27 | ax_loss.set_xscale("log") | 27 | ax_loss.set_xscale("log") |
28 | ax_loss.set_xlabel(f"Learning rate") | 28 | ax_loss.set_xlabel(f"Learning rate") |
29 | ax_loss.set_ylabel("Loss") | 29 | ax_loss.set_ylabel("Loss") |
30 | 30 | ||
31 | ax_acc.plot(lrs, accs, color='blue') | 31 | ax_acc.plot(lrs, accs, color="blue") |
32 | ax_acc.set_xscale("log") | 32 | ax_acc.set_xscale("log") |
33 | ax_acc.set_ylabel("Accuracy") | 33 | ax_acc.set_ylabel("Accuracy") |
34 | 34 | ||
diff --git a/training/optimization.py b/training/optimization.py index d22a900..55531bf 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -5,7 +5,10 @@ from functools import partial | |||
5 | import torch | 5 | import torch |
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import ( |
9 | get_scheduler as get_scheduler_, | ||
10 | get_cosine_with_hard_restarts_schedule_with_warmup, | ||
11 | ) | ||
9 | from transformers.optimization import get_adafactor_schedule | 12 | from transformers.optimization import get_adafactor_schedule |
10 | 13 | ||
11 | 14 | ||
@@ -52,7 +55,7 @@ def get_one_cycle_schedule( | |||
52 | annealing_exp: int = 1, | 55 | annealing_exp: int = 1, |
53 | min_lr: float = 0.04, | 56 | min_lr: float = 0.04, |
54 | mid_point: float = 0.3, | 57 | mid_point: float = 0.3, |
55 | last_epoch: int = -1 | 58 | last_epoch: int = -1, |
56 | ): | 59 | ): |
57 | if warmup == "linear": | 60 | if warmup == "linear": |
58 | warmup_func = warmup_linear | 61 | warmup_func = warmup_linear |
@@ -83,12 +86,16 @@ def get_one_cycle_schedule( | |||
83 | 86 | ||
84 | def lr_lambda(current_step: int): | 87 | def lr_lambda(current_step: int): |
85 | phase = [p for p in phases if current_step >= p.step_min][-1] | 88 | phase = [p for p in phases if current_step >= p.step_min][-1] |
86 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | 89 | return phase.min + phase.func( |
90 | (current_step - phase.step_min) / (phase.step_max - phase.step_min) | ||
91 | ) * (phase.max - phase.min) | ||
87 | 92 | ||
88 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 93 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
89 | 94 | ||
90 | 95 | ||
91 | def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): | 96 | def get_exponential_growing_schedule( |
97 | optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 | ||
98 | ): | ||
92 | def lr_lambda(base_lr: float, current_step: int): | 99 | def lr_lambda(base_lr: float, current_step: int): |
93 | return (end_lr / base_lr) ** (current_step / num_training_steps) | 100 | return (end_lr / base_lr) ** (current_step / num_training_steps) |
94 | 101 | ||
@@ -132,7 +139,14 @@ def get_scheduler( | |||
132 | ) | 139 | ) |
133 | elif id == "exponential_growth": | 140 | elif id == "exponential_growth": |
134 | if cycles is None: | 141 | if cycles is None: |
135 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 142 | cycles = math.ceil( |
143 | math.sqrt( | ||
144 | ( | ||
145 | (num_training_steps - num_warmup_steps) | ||
146 | / num_training_steps_per_epoch | ||
147 | ) | ||
148 | ) | ||
149 | ) | ||
136 | 150 | ||
137 | lr_scheduler = get_exponential_growing_schedule( | 151 | lr_scheduler = get_exponential_growing_schedule( |
138 | optimizer=optimizer, | 152 | optimizer=optimizer, |
@@ -141,7 +155,14 @@ def get_scheduler( | |||
141 | ) | 155 | ) |
142 | elif id == "cosine_with_restarts": | 156 | elif id == "cosine_with_restarts": |
143 | if cycles is None: | 157 | if cycles is None: |
144 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 158 | cycles = math.ceil( |
159 | math.sqrt( | ||
160 | ( | ||
161 | (num_training_steps - num_warmup_steps) | ||
162 | / num_training_steps_per_epoch | ||
163 | ) | ||
164 | ) | ||
165 | ) | ||
145 | 166 | ||
146 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 167 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
147 | optimizer=optimizer, | 168 | optimizer=optimizer, |
@@ -150,10 +171,7 @@ def get_scheduler( | |||
150 | num_cycles=cycles, | 171 | num_cycles=cycles, |
151 | ) | 172 | ) |
152 | elif id == "adafactor": | 173 | elif id == "adafactor": |
153 | lr_scheduler = get_adafactor_schedule( | 174 | lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) |
154 | optimizer, | ||
155 | initial_lr=min_lr | ||
156 | ) | ||
157 | else: | 175 | else: |
158 | lr_scheduler = get_scheduler_( | 176 | lr_scheduler = get_scheduler_( |
159 | id, | 177 | id, |
diff --git a/training/sampler.py b/training/sampler.py index bdb3e90..0487d66 100644 --- a/training/sampler.py +++ b/training/sampler.py | |||
@@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler): | |||
134 | def weights(self): | 134 | def weights(self): |
135 | if not self._warmed_up(): | 135 | if not self._warmed_up(): |
136 | return np.ones([self.num_timesteps], dtype=np.float64) | 136 | return np.ones([self.num_timesteps], dtype=np.float64) |
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) | 137 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) |
138 | weights /= np.sum(weights) | 138 | weights /= np.sum(weights) |
139 | weights *= 1 - self.uniform_prob | 139 | weights *= 1 - self.uniform_prob |
140 | weights += self.uniform_prob / len(weights) | 140 | weights += self.uniform_prob / len(weights) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( | |||
29 | sample_output_dir: Path, | 29 | sample_output_dir: Path, |
30 | checkpoint_output_dir: Path, | 30 | checkpoint_output_dir: Path, |
31 | seed: int, | 31 | seed: int, |
32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_cycles: int, |
33 | max_grad_norm: float = 1.0, | 33 | max_grad_norm: float = 1.0, |
34 | use_ema: bool = False, | 34 | use_ema: bool = False, |
35 | ema_inv_gamma: float = 1.0, | 35 | ema_inv_gamma: float = 1.0, |
@@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( | |||
85 | return nullcontext() | 85 | return nullcontext() |
86 | 86 | ||
87 | @contextmanager | 87 | @contextmanager |
88 | def on_train(epoch: int): | 88 | def on_train(cycle: int): |
89 | unet.train() | 89 | unet.train() |
90 | tokenizer.train() | 90 | tokenizer.train() |
91 | 91 | ||
92 | if epoch < train_text_encoder_epochs: | 92 | if cycle < train_text_encoder_cycles: |
93 | text_encoder.train() | 93 | text_encoder.train() |
94 | elif epoch == train_text_encoder_epochs: | 94 | tokenizer.train() |
95 | text_encoder.requires_grad_(False) | ||
96 | text_encoder.eval() | ||
97 | 95 | ||
98 | yield | 96 | yield |
99 | 97 | ||
@@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( | |||
106 | with ema_context(): | 104 | with ema_context(): |
107 | yield | 105 | yield |
108 | 106 | ||
109 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
110 | params_to_clip = [unet.parameters()] | 108 | params_to_clip = [unet.parameters()] |
111 | if epoch < train_text_encoder_epochs: | 109 | if cycle < train_text_encoder_cycles: |
112 | params_to_clip.append(text_encoder.parameters()) | 110 | params_to_clip.append(text_encoder.parameters()) |
113 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 111 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
114 | 112 | ||
@@ -189,8 +187,16 @@ def dreambooth_prepare( | |||
189 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
190 | **kwargs | 188 | **kwargs |
191 | ): | 189 | ): |
192 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 190 | ( |
193 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 191 | text_encoder, |
192 | unet, | ||
193 | optimizer, | ||
194 | train_dataloader, | ||
195 | val_dataloader, | ||
196 | lr_scheduler, | ||
197 | ) = accelerator.prepare( | ||
198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
199 | ) | ||
194 | 200 | ||
195 | text_encoder.text_model.embeddings.requires_grad_(False) | 201 | text_encoder.text_model.embeddings.requires_grad_(False) |
196 | 202 | ||
@@ -198,6 +204,5 @@ def dreambooth_prepare( | |||
198 | 204 | ||
199 | 205 | ||
200 | dreambooth_strategy = TrainingStrategy( | 206 | dreambooth_strategy = TrainingStrategy( |
201 | callbacks=dreambooth_strategy_callbacks, | 207 | callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare |
202 | prepare=dreambooth_prepare | ||
203 | ) | 208 | ) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -81,7 +81,7 @@ def lora_strategy_callbacks( | |||
81 | tokenizer.eval() | 81 | tokenizer.eval() |
82 | yield | 82 | yield |
83 | 83 | ||
84 | def on_before_optimize(epoch: int): | 84 | def on_before_optimize(cycle: int): |
85 | if not pti_mode: | 85 | if not pti_mode: |
86 | accelerator.clip_grad_norm_( | 86 | accelerator.clip_grad_norm_( |
87 | itertools.chain( | 87 | itertools.chain( |
@@ -89,7 +89,7 @@ def lora_strategy_callbacks( | |||
89 | text_encoder.text_model.encoder.parameters(), | 89 | text_encoder.text_model.encoder.parameters(), |
90 | text_encoder.text_model.final_layer_norm.parameters(), | 90 | text_encoder.text_model.final_layer_norm.parameters(), |
91 | ), | 91 | ), |
92 | max_grad_norm | 92 | max_grad_norm, |
93 | ) | 93 | ) |
94 | 94 | ||
95 | if len(placeholder_tokens) != 0 and use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
@@ -108,7 +108,9 @@ def lora_strategy_callbacks( | |||
108 | 108 | ||
109 | if lambda_ != 0: | 109 | if lambda_ != 0: |
110 | norm = w[:, :].norm(dim=-1, keepdim=True) | 110 | norm = w[:, :].norm(dim=-1, keepdim=True) |
111 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 111 | w[:].add_( |
112 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
113 | ) | ||
112 | 114 | ||
113 | @torch.no_grad() | 115 | @torch.no_grad() |
114 | def on_checkpoint(step, postfix): | 116 | def on_checkpoint(step, postfix): |
@@ -128,25 +130,32 @@ def lora_strategy_callbacks( | |||
128 | 130 | ||
129 | if not pti_mode: | 131 | if not pti_mode: |
130 | lora_config = {} | 132 | lora_config = {} |
131 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 133 | state_dict = get_peft_model_state_dict( |
134 | unet_, state_dict=accelerator.get_state_dict(unet_) | ||
135 | ) | ||
132 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 136 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
133 | 137 | ||
134 | text_encoder_state_dict = get_peft_model_state_dict( | 138 | text_encoder_state_dict = get_peft_model_state_dict( |
135 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) | 139 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
136 | ) | 140 | ) |
137 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 141 | text_encoder_state_dict = { |
142 | f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() | ||
143 | } | ||
138 | state_dict.update(text_encoder_state_dict) | 144 | state_dict.update(text_encoder_state_dict) |
139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 145 | lora_config[ |
146 | "text_encoder_peft_config" | ||
147 | ] = text_encoder_.get_peft_config_as_dict(inference=True) | ||
140 | 148 | ||
141 | if len(placeholder_tokens) != 0: | 149 | if len(placeholder_tokens) != 0: |
142 | ti_state_dict = { | 150 | ti_state_dict = { |
143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | 151 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) |
144 | for (token, ids) | 152 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) |
145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
146 | } | 153 | } |
147 | state_dict.update(ti_state_dict) | 154 | state_dict.update(ti_state_dict) |
148 | 155 | ||
149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 156 | save_file( |
157 | state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" | ||
158 | ) | ||
150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 159 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
151 | json.dump(lora_config, f) | 160 | json.dump(lora_config, f) |
152 | 161 | ||
@@ -185,10 +194,18 @@ def lora_prepare( | |||
185 | train_dataloader: DataLoader, | 194 | train_dataloader: DataLoader, |
186 | val_dataloader: Optional[DataLoader], | 195 | val_dataloader: Optional[DataLoader], |
187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 196 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
188 | **kwargs | 197 | **kwargs, |
189 | ): | 198 | ): |
190 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 199 | ( |
191 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 200 | text_encoder, |
201 | unet, | ||
202 | optimizer, | ||
203 | train_dataloader, | ||
204 | val_dataloader, | ||
205 | lr_scheduler, | ||
206 | ) = accelerator.prepare( | ||
207 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
208 | ) | ||
192 | 209 | ||
193 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) | 210 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
194 | 211 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
104 | yield | 104 | yield |
105 | 105 | ||
106 | @torch.no_grad() | 106 | @torch.no_grad() |
107 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
108 | if use_emb_decay: | 108 | if use_emb_decay: |
109 | params = [ | 109 | params = [ |
110 | p | 110 | p |
@@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step( |
120 | text_encoder.text_model.embeddings.token_embedding.parameters() | ||
121 | ) | ||
120 | 122 | ||
121 | if use_emb_decay and w is not None: | 123 | if use_emb_decay and w is not None: |
122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 124 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
@@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( | |||
124 | 126 | ||
125 | if lambda_ != 0: | 127 | if lambda_ != 0: |
126 | norm = w[:, :].norm(dim=-1, keepdim=True) | 128 | norm = w[:, :].norm(dim=-1, keepdim=True) |
127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 129 | w[:].add_( |
130 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
131 | ) | ||
128 | 132 | ||
129 | def on_log(): | 133 | def on_log(): |
130 | if ema_embeddings is not None: | 134 | if ema_embeddings is not None: |
@@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( | |||
136 | print(f"Saving checkpoint for step {step}...") | 140 | print(f"Saving checkpoint for step {step}...") |
137 | 141 | ||
138 | with ema_context(): | 142 | with ema_context(): |
139 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
140 | text_encoder.text_model.embeddings.save_embed( | 144 | text_encoder.text_model.embeddings.save_embed( |
141 | ids, | 145 | ids, |
142 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 146 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", |
143 | ) | 147 | ) |
144 | 148 | ||
145 | @torch.no_grad() | 149 | @torch.no_grad() |
@@ -183,7 +187,7 @@ def textual_inversion_prepare( | |||
183 | val_dataloader: Optional[DataLoader], | 187 | val_dataloader: Optional[DataLoader], |
184 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 188 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
185 | gradient_checkpointing: bool = False, | 189 | gradient_checkpointing: bool = False, |
186 | **kwargs | 190 | **kwargs, |
187 | ): | 191 | ): |
188 | weight_dtype = torch.float32 | 192 | weight_dtype = torch.float32 |
189 | if accelerator.state.mixed_precision == "fp16": | 193 | if accelerator.state.mixed_precision == "fp16": |
@@ -191,8 +195,15 @@ def textual_inversion_prepare( | |||
191 | elif accelerator.state.mixed_precision == "bf16": | 195 | elif accelerator.state.mixed_precision == "bf16": |
192 | weight_dtype = torch.bfloat16 | 196 | weight_dtype = torch.bfloat16 |
193 | 197 | ||
194 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 198 | ( |
195 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 199 | text_encoder, |
200 | optimizer, | ||
201 | train_dataloader, | ||
202 | val_dataloader, | ||
203 | lr_scheduler, | ||
204 | ) = accelerator.prepare( | ||
205 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
206 | ) | ||
196 | 207 | ||
197 | unet.to(accelerator.device, dtype=weight_dtype) | 208 | unet.to(accelerator.device, dtype=weight_dtype) |
198 | unet.requires_grad_(False) | 209 | unet.requires_grad_(False) |