diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
| commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
| tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 | |
| parent | Fix LoRA training with DAdan (diff) | |
| download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip | |
Update
| -rw-r--r-- | .vscode/settings.json | 6 | ||||
| -rw-r--r-- | data/csv.py | 183 | ||||
| -rw-r--r-- | data/keywords.py | 8 | ||||
| -rw-r--r-- | environment.yaml | 11 | ||||
| -rw-r--r-- | infer.py | 124 | ||||
| -rw-r--r-- | models/attention/control.py | 104 | ||||
| -rw-r--r-- | models/attention/hook.py | 5 | ||||
| -rw-r--r-- | models/attention/structured.py | 65 | ||||
| -rw-r--r-- | models/clip/embeddings.py | 29 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 23 | ||||
| -rw-r--r-- | models/clip/util.py | 17 | ||||
| -rw-r--r-- | models/convnext/discriminator.py | 11 | ||||
| -rw-r--r-- | models/sparse.py | 12 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 262 | ||||
| -rw-r--r-- | train_dreambooth.py | 770 | ||||
| -rw-r--r-- | train_lora.py | 489 | ||||
| -rw-r--r-- | train_ti.py | 379 | ||||
| -rw-r--r-- | training/functional.py | 221 | ||||
| -rw-r--r-- | training/lr.py | 4 | ||||
| -rw-r--r-- | training/optimization.py | 38 | ||||
| -rw-r--r-- | training/sampler.py | 2 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 29 | ||||
| -rw-r--r-- | training/strategy/lora.py | 41 | ||||
| -rw-r--r-- | training/strategy/ti.py | 27 |
24 files changed, 1872 insertions, 988 deletions
diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..cee7b74 --- /dev/null +++ b/.vscode/settings.json | |||
| @@ -0,0 +1,6 @@ | |||
| 1 | { | ||
| 2 | "[python]": { | ||
| 3 | "editor.defaultFormatter": "ms-python.black-formatter" | ||
| 4 | }, | ||
| 5 | "python.formatting.provider": "none" | ||
| 6 | } | ||
diff --git a/data/csv.py b/data/csv.py index 14380e8..d726033 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -49,7 +49,7 @@ def generate_buckets( | |||
| 49 | max_pixels: Optional[int] = None, | 49 | max_pixels: Optional[int] = None, |
| 50 | num_buckets: int = 4, | 50 | num_buckets: int = 4, |
| 51 | progressive_buckets: bool = False, | 51 | progressive_buckets: bool = False, |
| 52 | return_tensor: bool = True | 52 | return_tensor: bool = True, |
| 53 | ): | 53 | ): |
| 54 | if max_pixels is None: | 54 | if max_pixels is None: |
| 55 | max_pixels = (base_size + step_size) ** 2 | 55 | max_pixels = (base_size + step_size) ** 2 |
| @@ -62,7 +62,11 @@ def generate_buckets( | |||
| 62 | 62 | ||
| 63 | for i in range(1, num_buckets + 1): | 63 | for i in range(1, num_buckets + 1): |
| 64 | long_side = base_size + i * step_size | 64 | long_side = base_size + i * step_size |
| 65 | short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) | 65 | short_side = min( |
| 66 | base_size | ||
| 67 | - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, | ||
| 68 | base_size, | ||
| 69 | ) | ||
| 66 | buckets.append(long_side / short_side) | 70 | buckets.append(long_side / short_side) |
| 67 | buckets.append(short_side / long_side) | 71 | buckets.append(short_side / long_side) |
| 68 | 72 | ||
| @@ -106,7 +110,7 @@ def collate_fn( | |||
| 106 | max_token_id_length: Optional[int], | 110 | max_token_id_length: Optional[int], |
| 107 | with_guidance: bool, | 111 | with_guidance: bool, |
| 108 | with_prior_preservation: bool, | 112 | with_prior_preservation: bool, |
| 109 | examples | 113 | examples, |
| 110 | ): | 114 | ): |
| 111 | prompt_ids = [example["prompt_ids"] for example in examples] | 115 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 112 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 116 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| @@ -125,7 +129,9 @@ def collate_fn( | |||
| 125 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) | 129 | prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) |
| 126 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) | 130 | nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) |
| 127 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) | 131 | inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) |
| 128 | negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) | 132 | negative_inputs = unify_input_ids( |
| 133 | tokenizer, negative_input_ids, max_token_id_length | ||
| 134 | ) | ||
| 129 | 135 | ||
| 130 | batch = { | 136 | batch = { |
| 131 | "prompt_ids": prompts.input_ids, | 137 | "prompt_ids": prompts.input_ids, |
| @@ -149,35 +155,39 @@ class VlpnDataItem(NamedTuple): | |||
| 149 | nprompt: str | 155 | nprompt: str |
| 150 | collection: list[str] | 156 | collection: list[str] |
| 151 | 157 | ||
| 152 | def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): | 158 | def full_prompt( |
| 153 | return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) | 159 | self, |
| 160 | dropout: float = 0, | ||
| 161 | shuffle: bool = False, | ||
| 162 | npgenerator: Optional[np.random.Generator] = None, | ||
| 163 | ): | ||
| 164 | return keywords_to_str( | ||
| 165 | self.keywords, [self.prompt], dropout, shuffle, npgenerator | ||
| 166 | ) | ||
| 154 | 167 | ||
| 155 | 168 | ||
| 156 | def keyword_filter( | 169 | def keyword_filter( |
| 157 | placeholder_tokens: Optional[list[str]], | 170 | placeholder_tokens: Optional[list[str]], |
| 158 | collections: Optional[list[str]], | 171 | collections: Optional[list[str]], |
| 159 | exclude_collections: Optional[list[str]], | 172 | exclude_collections: Optional[list[str]], |
| 160 | item: VlpnDataItem | 173 | item: VlpnDataItem, |
| 161 | ): | 174 | ): |
| 162 | full_prompt = item.full_prompt() | 175 | full_prompt = item.full_prompt() |
| 163 | 176 | ||
| 164 | cond1 = placeholder_tokens is None or any( | 177 | cond1 = placeholder_tokens is None or any( |
| 165 | token in full_prompt | 178 | token in full_prompt for token in placeholder_tokens |
| 166 | for token in placeholder_tokens | ||
| 167 | ) | 179 | ) |
| 168 | cond2 = collections is None or any( | 180 | cond2 = collections is None or any( |
| 169 | collection in item.collection | 181 | collection in item.collection for collection in collections |
| 170 | for collection in collections | ||
| 171 | ) | 182 | ) |
| 172 | cond3 = exclude_collections is None or not any( | 183 | cond3 = exclude_collections is None or not any( |
| 173 | collection in item.collection | 184 | collection in item.collection for collection in exclude_collections |
| 174 | for collection in exclude_collections | ||
| 175 | ) | 185 | ) |
| 176 | 186 | ||
| 177 | return cond1 and cond2 and cond3 | 187 | return cond1 and cond2 and cond3 |
| 178 | 188 | ||
| 179 | 189 | ||
| 180 | class VlpnDataModule(): | 190 | class VlpnDataModule: |
| 181 | def __init__( | 191 | def __init__( |
| 182 | self, | 192 | self, |
| 183 | batch_size: int, | 193 | batch_size: int, |
| @@ -222,7 +232,7 @@ class VlpnDataModule(): | |||
| 222 | 232 | ||
| 223 | self.constant_prompt_length = constant_prompt_length | 233 | self.constant_prompt_length = constant_prompt_length |
| 224 | self.max_token_id_length = None | 234 | self.max_token_id_length = None |
| 225 | 235 | ||
| 226 | self.tokenizer = tokenizer | 236 | self.tokenizer = tokenizer |
| 227 | self.size = size | 237 | self.size = size |
| 228 | self.num_buckets = num_buckets | 238 | self.num_buckets = num_buckets |
| @@ -259,23 +269,29 @@ class VlpnDataModule(): | |||
| 259 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") | 269 | nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") |
| 260 | collection = item["collection"].split(", ") if "collection" in item else [] | 270 | collection = item["collection"].split(", ") if "collection" in item else [] |
| 261 | 271 | ||
| 262 | saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) | 272 | saturated_keywords = str_to_keywords( |
| 273 | tpl_keywords.format(**keywords), expansions | ||
| 274 | ) | ||
| 263 | 275 | ||
| 264 | inverted_tokens = keywords_to_str([ | 276 | inverted_tokens = keywords_to_str( |
| 265 | f"inv_{token}" | 277 | [ |
| 266 | for token in self.placeholder_tokens | 278 | f"inv_{token}" |
| 267 | if token in saturated_keywords | 279 | for token in self.placeholder_tokens |
| 268 | ]) | 280 | if token in saturated_keywords |
| 281 | ] | ||
| 282 | ) | ||
| 269 | 283 | ||
| 270 | items.append(VlpnDataItem( | 284 | items.append( |
| 271 | self.data_root / image, | 285 | VlpnDataItem( |
| 272 | None, | 286 | self.data_root / image, |
| 273 | saturated_keywords, | 287 | None, |
| 274 | tpl_prompt.format(**prompt), | 288 | saturated_keywords, |
| 275 | tpl_cprompt.format(**prompt), | 289 | tpl_prompt.format(**prompt), |
| 276 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), | 290 | tpl_cprompt.format(**prompt), |
| 277 | collection | 291 | tpl_nprompt.format(_inv=inverted_tokens, **nprompt), |
| 278 | )) | 292 | collection, |
| 293 | ) | ||
| 294 | ) | ||
| 279 | 295 | ||
| 280 | return items | 296 | return items |
| 281 | 297 | ||
| @@ -285,13 +301,16 @@ class VlpnDataModule(): | |||
| 285 | 301 | ||
| 286 | return [item for item in items if self.filter(item)] | 302 | return [item for item in items if self.filter(item)] |
| 287 | 303 | ||
| 288 | def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: | 304 | def pad_items( |
| 305 | self, items: list[VlpnDataItem], num_class_images: int = 1 | ||
| 306 | ) -> list[VlpnDataItem]: | ||
| 289 | image_multiplier = max(num_class_images, 1) | 307 | image_multiplier = max(num_class_images, 1) |
| 290 | 308 | ||
| 291 | return [ | 309 | return [ |
| 292 | VlpnDataItem( | 310 | VlpnDataItem( |
| 293 | item.instance_image_path, | 311 | item.instance_image_path, |
| 294 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | 312 | self.class_root |
| 313 | / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", | ||
| 295 | item.keywords, | 314 | item.keywords, |
| 296 | item.prompt, | 315 | item.prompt, |
| 297 | item.cprompt, | 316 | item.cprompt, |
| @@ -303,7 +322,7 @@ class VlpnDataModule(): | |||
| 303 | ] | 322 | ] |
| 304 | 323 | ||
| 305 | def setup(self): | 324 | def setup(self): |
| 306 | with open(self.data_file, 'rt') as f: | 325 | with open(self.data_file, "rt") as f: |
| 307 | metadata = json.load(f) | 326 | metadata = json.load(f) |
| 308 | template = metadata[self.template_key] if self.template_key in metadata else {} | 327 | template = metadata[self.template_key] if self.template_key in metadata else {} |
| 309 | expansions = metadata["expansions"] if "expansions" in metadata else {} | 328 | expansions = metadata["expansions"] if "expansions" in metadata else {} |
| @@ -312,25 +331,41 @@ class VlpnDataModule(): | |||
| 312 | items = self.prepare_items(template, expansions, items) | 331 | items = self.prepare_items(template, expansions, items) |
| 313 | items = self.filter_items(items) | 332 | items = self.filter_items(items) |
| 314 | self.npgenerator.shuffle(items) | 333 | self.npgenerator.shuffle(items) |
| 315 | 334 | ||
| 316 | if self.constant_prompt_length: | 335 | if self.constant_prompt_length: |
| 317 | all_input_ids = unify_input_ids( | 336 | all_input_ids = unify_input_ids( |
| 318 | self.tokenizer, | 337 | self.tokenizer, |
| 319 | [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] | 338 | [ |
| 339 | self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids | ||
| 340 | for item in items | ||
| 341 | ], | ||
| 320 | ).input_ids | 342 | ).input_ids |
| 321 | self.max_token_id_length = all_input_ids.shape[1] | 343 | self.max_token_id_length = all_input_ids.shape[1] |
| 322 | 344 | ||
| 323 | num_images = len(items) | 345 | num_images = len(items) |
| 324 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 | 346 | valid_set_size = ( |
| 347 | min(self.valid_set_size, num_images) | ||
| 348 | if self.valid_set_size is not None | ||
| 349 | else num_images // 10 | ||
| 350 | ) | ||
| 325 | train_set_size = max(num_images - valid_set_size, 1) | 351 | train_set_size = max(num_images - valid_set_size, 1) |
| 326 | valid_set_size = num_images - train_set_size | 352 | valid_set_size = num_images - train_set_size |
| 327 | 353 | ||
| 328 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) | 354 | collate_fn_ = partial( |
| 355 | collate_fn, | ||
| 356 | self.dtype, | ||
| 357 | self.tokenizer, | ||
| 358 | self.max_token_id_length, | ||
| 359 | self.with_guidance, | ||
| 360 | self.num_class_images != 0, | ||
| 361 | ) | ||
| 329 | 362 | ||
| 330 | if valid_set_size == 0: | 363 | if valid_set_size == 0: |
| 331 | data_train, data_val = items, items | 364 | data_train, data_val = items, items |
| 332 | else: | 365 | else: |
| 333 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) | 366 | data_train, data_val = random_split( |
| 367 | items, [train_set_size, valid_set_size], generator=self.generator | ||
| 368 | ) | ||
| 334 | 369 | ||
| 335 | data_train = self.pad_items(data_train, self.num_class_images) | 370 | data_train = self.pad_items(data_train, self.num_class_images) |
| 336 | 371 | ||
| @@ -338,17 +373,25 @@ class VlpnDataModule(): | |||
| 338 | data_train *= math.ceil(self.train_set_pad / len(data_train)) | 373 | data_train *= math.ceil(self.train_set_pad / len(data_train)) |
| 339 | 374 | ||
| 340 | self.train_dataset = VlpnDataset( | 375 | self.train_dataset = VlpnDataset( |
| 341 | data_train, self.tokenizer, | 376 | data_train, |
| 342 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 377 | self.tokenizer, |
| 343 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 378 | num_buckets=self.num_buckets, |
| 344 | batch_size=self.batch_size, fill_batch=True, generator=self.generator, | 379 | progressive_buckets=self.progressive_buckets, |
| 345 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 380 | bucket_step_size=self.bucket_step_size, |
| 346 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 381 | bucket_max_pixels=self.bucket_max_pixels, |
| 382 | batch_size=self.batch_size, | ||
| 383 | fill_batch=True, | ||
| 384 | generator=self.generator, | ||
| 385 | size=self.size, | ||
| 386 | interpolation=self.interpolation, | ||
| 387 | color_jitter=self.color_jitter, | ||
| 388 | num_class_images=self.num_class_images, | ||
| 389 | dropout=self.dropout, | ||
| 390 | shuffle=self.shuffle, | ||
| 347 | ) | 391 | ) |
| 348 | 392 | ||
| 349 | self.train_dataloader = DataLoader( | 393 | self.train_dataloader = DataLoader( |
| 350 | self.train_dataset, | 394 | self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 351 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | ||
| 352 | ) | 395 | ) |
| 353 | 396 | ||
| 354 | if len(data_val) != 0: | 397 | if len(data_val) != 0: |
| @@ -358,16 +401,24 @@ class VlpnDataModule(): | |||
| 358 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) | 401 | data_val *= math.ceil(self.valid_set_pad / len(data_val)) |
| 359 | 402 | ||
| 360 | self.val_dataset = VlpnDataset( | 403 | self.val_dataset = VlpnDataset( |
| 361 | data_val, self.tokenizer, | 404 | data_val, |
| 362 | num_buckets=self.num_buckets, progressive_buckets=True, | 405 | self.tokenizer, |
| 363 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 406 | num_buckets=self.num_buckets, |
| 364 | batch_size=self.batch_size, generator=self.generator, | 407 | progressive_buckets=True, |
| 365 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, | 408 | bucket_step_size=self.bucket_step_size, |
| 409 | bucket_max_pixels=self.bucket_max_pixels, | ||
| 410 | batch_size=self.batch_size, | ||
| 411 | generator=self.generator, | ||
| 412 | size=self.size, | ||
| 413 | interpolation=self.interpolation, | ||
| 414 | color_jitter=self.color_jitter, | ||
| 366 | ) | 415 | ) |
| 367 | 416 | ||
| 368 | self.val_dataloader = DataLoader( | 417 | self.val_dataloader = DataLoader( |
| 369 | self.val_dataset, | 418 | self.val_dataset, |
| 370 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 419 | batch_size=None, |
| 420 | pin_memory=True, | ||
| 421 | collate_fn=collate_fn_, | ||
| 371 | ) | 422 | ) |
| 372 | else: | 423 | else: |
| 373 | self.val_dataloader = None | 424 | self.val_dataloader = None |
| @@ -418,7 +469,13 @@ class VlpnDataset(IterableDataset): | |||
| 418 | 469 | ||
| 419 | self.bucket_item_range = torch.arange(len(self.bucket_items)) | 470 | self.bucket_item_range = torch.arange(len(self.bucket_items)) |
| 420 | 471 | ||
| 421 | self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() | 472 | self.length_ = ( |
| 473 | (self.bucket_assignments.bincount() / self.batch_size) | ||
| 474 | .ceil() | ||
| 475 | .long() | ||
| 476 | .sum() | ||
| 477 | .item() | ||
| 478 | ) | ||
| 422 | 479 | ||
| 423 | def get_input_ids(self, text: str): | 480 | def get_input_ids(self, text: str): |
| 424 | return self.tokenizer(text, padding="do_not_pad").input_ids | 481 | return self.tokenizer(text, padding="do_not_pad").input_ids |
| @@ -430,7 +487,9 @@ class VlpnDataset(IterableDataset): | |||
| 430 | worker_info = torch.utils.data.get_worker_info() | 487 | worker_info = torch.utils.data.get_worker_info() |
| 431 | 488 | ||
| 432 | if self.shuffle: | 489 | if self.shuffle: |
| 433 | perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) | 490 | perm = torch.randperm( |
| 491 | len(self.bucket_assignments), generator=self.generator | ||
| 492 | ) | ||
| 434 | self.bucket_items = self.bucket_items[perm] | 493 | self.bucket_items = self.bucket_items[perm] |
| 435 | self.bucket_assignments = self.bucket_assignments[perm] | 494 | self.bucket_assignments = self.bucket_assignments[perm] |
| 436 | 495 | ||
| @@ -488,7 +547,9 @@ class VlpnDataset(IterableDataset): | |||
| 488 | 547 | ||
| 489 | if len(bucket_items) == 0: | 548 | if len(bucket_items) == 0: |
| 490 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] | 549 | bucket_items = self.bucket_items[self.bucket_assignments == bucket] |
| 491 | item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] | 550 | item_index = bucket_items[ |
| 551 | torch.randint(len(bucket_items), (1,), generator=self.generator) | ||
| 552 | ] | ||
| 492 | else: | 553 | else: |
| 493 | item_index = bucket_items[0] | 554 | item_index = bucket_items[0] |
| 494 | mask[self.bucket_item_range[bucket_mask][0]] = False | 555 | mask[self.bucket_item_range[bucket_mask][0]] = False |
| @@ -500,12 +561,18 @@ class VlpnDataset(IterableDataset): | |||
| 500 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) | 561 | example["prompt_ids"] = self.get_input_ids(item.full_prompt()) |
| 501 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) | 562 | example["nprompt_ids"] = self.get_input_ids(item.nprompt) |
| 502 | 563 | ||
| 503 | example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) | 564 | example["instance_prompt_ids"] = self.get_input_ids( |
| 565 | item.full_prompt(self.dropout, True, self.npgenerator) | ||
| 566 | ) | ||
| 504 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) | 567 | example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) |
| 505 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | 568 | example["instance_images"] = image_transforms( |
| 569 | get_image(item.instance_image_path) | ||
| 570 | ) | ||
| 506 | 571 | ||
| 507 | if self.num_class_images != 0: | 572 | if self.num_class_images != 0: |
| 508 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) | 573 | example["class_prompt_ids"] = self.get_input_ids(item.cprompt) |
| 509 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | 574 | example["class_images"] = image_transforms( |
| 575 | get_image(item.class_image_path) | ||
| 576 | ) | ||
| 510 | 577 | ||
| 511 | batch.append(example) | 578 | batch.append(example) |
diff --git a/data/keywords.py b/data/keywords.py index 8632d67..83fe9ff 100644 --- a/data/keywords.py +++ b/data/keywords.py | |||
| @@ -8,7 +8,7 @@ def keywords_to_str( | |||
| 8 | undroppable_keywords: list[str] = [], | 8 | undroppable_keywords: list[str] = [], |
| 9 | dropout: float = 0, | 9 | dropout: float = 0, |
| 10 | shuffle: bool = False, | 10 | shuffle: bool = False, |
| 11 | npgenerator: Optional[np.random.Generator] = None | 11 | npgenerator: Optional[np.random.Generator] = None, |
| 12 | ) -> str: | 12 | ) -> str: |
| 13 | if dropout != 0: | 13 | if dropout != 0: |
| 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] | 14 | keywords = [keyword for keyword in keywords if np.random.random() > dropout] |
| @@ -23,7 +23,11 @@ def keywords_to_str( | |||
| 23 | 23 | ||
| 24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: | 24 | def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: |
| 25 | def expand_keyword(keyword: str) -> list[str]: | 25 | def expand_keyword(keyword: str) -> list[str]: |
| 26 | return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] | 26 | return ( |
| 27 | [keyword] + expansions[keyword].split(", ") | ||
| 28 | if keyword in expansions | ||
| 29 | else [keyword] | ||
| 30 | ) | ||
| 27 | 31 | ||
| 28 | return [ | 32 | return [ |
| 29 | kw | 33 | kw |
diff --git a/environment.yaml b/environment.yaml index 1a55967..2c81a90 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -14,16 +14,17 @@ dependencies: | |||
| 14 | - numpy=1.24.3 | 14 | - numpy=1.24.3 |
| 15 | - pip=22.3.1 | 15 | - pip=22.3.1 |
| 16 | - python=3.10.8 | 16 | - python=3.10.8 |
| 17 | - pytorch=2.0.0=*cuda11.8* | 17 | - pytorch=2.0.1=*cuda11.8* |
| 18 | - torchvision=0.15.0 | 18 | - scipy=1.10.1 |
| 19 | - xformers=0.0.20.dev528 | 19 | - torchvision=0.15.2 |
| 20 | - xformers=0.0.21.dev542+git.a205b24 | ||
| 20 | - pip: | 21 | - pip: |
| 21 | - -e . | 22 | - -e . |
| 22 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate | 23 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate |
| 23 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 24 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 24 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | 25 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation |
| 25 | - --pre --extra-index-url https://download.hidet.org/whl hidet | 26 | - --pre --extra-index-url https://download.hidet.org/whl hidet |
| 26 | - bitsandbytes==0.38.1 | 27 | - bitsandbytes==0.39.1 |
| 27 | - lion-pytorch==0.0.7 | 28 | - lion-pytorch==0.0.7 |
| 28 | - peft==0.3.0 | 29 | - peft==0.3.0 |
| 29 | - python-slugify>=6.1.2 | 30 | - python-slugify>=6.1.2 |
| @@ -31,4 +32,4 @@ dependencies: | |||
| 31 | - setuptools==65.6.3 | 32 | - setuptools==65.6.3 |
| 32 | - test-tube>=0.7.5 | 33 | - test-tube>=0.7.5 |
| 33 | - timm==0.9.2 | 34 | - timm==0.9.2 |
| 34 | - transformers==4.29.1 | 35 | - transformers==4.30.1 |
| @@ -24,7 +24,7 @@ from diffusers import ( | |||
| 24 | KDPM2DiscreteScheduler, | 24 | KDPM2DiscreteScheduler, |
| 25 | KDPM2AncestralDiscreteScheduler, | 25 | KDPM2AncestralDiscreteScheduler, |
| 26 | DEISMultistepScheduler, | 26 | DEISMultistepScheduler, |
| 27 | UniPCMultistepScheduler | 27 | UniPCMultistepScheduler, |
| 28 | ) | 28 | ) |
| 29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict | 29 | from peft import LoraConfig, LoraModel, set_peft_model_state_dict |
| 30 | from safetensors.torch import load_file | 30 | from safetensors.torch import load_file |
| @@ -61,7 +61,7 @@ default_cmds = { | |||
| 61 | "negative_prompt": None, | 61 | "negative_prompt": None, |
| 62 | "shuffle": False, | 62 | "shuffle": False, |
| 63 | "image": None, | 63 | "image": None, |
| 64 | "image_noise": .7, | 64 | "image_noise": 0.7, |
| 65 | "width": 768, | 65 | "width": 768, |
| 66 | "height": 768, | 66 | "height": 768, |
| 67 | "batch_size": 1, | 67 | "batch_size": 1, |
| @@ -69,7 +69,6 @@ default_cmds = { | |||
| 69 | "steps": 30, | 69 | "steps": 30, |
| 70 | "guidance_scale": 7.0, | 70 | "guidance_scale": 7.0, |
| 71 | "sag_scale": 0, | 71 | "sag_scale": 0, |
| 72 | "brightness_offset": 0, | ||
| 73 | "seed": None, | 72 | "seed": None, |
| 74 | "config": None, | 73 | "config": None, |
| 75 | } | 74 | } |
| @@ -85,9 +84,7 @@ def merge_dicts(d1, *args): | |||
| 85 | 84 | ||
| 86 | 85 | ||
| 87 | def create_args_parser(): | 86 | def create_args_parser(): |
| 88 | parser = argparse.ArgumentParser( | 87 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 89 | description="Simple example of a training script." | ||
| 90 | ) | ||
| 91 | parser.add_argument( | 88 | parser.add_argument( |
| 92 | "--model", | 89 | "--model", |
| 93 | type=str, | 90 | type=str, |
| @@ -118,9 +115,7 @@ def create_args_parser(): | |||
| 118 | 115 | ||
| 119 | 116 | ||
| 120 | def create_cmd_parser(): | 117 | def create_cmd_parser(): |
| 121 | parser = argparse.ArgumentParser( | 118 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 122 | description="Simple example of a training script." | ||
| 123 | ) | ||
| 124 | parser.add_argument( | 119 | parser.add_argument( |
| 125 | "--project", | 120 | "--project", |
| 126 | type=str, | 121 | type=str, |
| @@ -130,13 +125,34 @@ def create_cmd_parser(): | |||
| 130 | parser.add_argument( | 125 | parser.add_argument( |
| 131 | "--scheduler", | 126 | "--scheduler", |
| 132 | type=str, | 127 | type=str, |
| 133 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc"], | 128 | choices=[ |
| 129 | "plms", | ||
| 130 | "ddim", | ||
| 131 | "klms", | ||
| 132 | "dpmsm", | ||
| 133 | "dpmss", | ||
| 134 | "euler_a", | ||
| 135 | "kdpm2", | ||
| 136 | "kdpm2_a", | ||
| 137 | "deis", | ||
| 138 | "unipc", | ||
| 139 | ], | ||
| 134 | ) | 140 | ) |
| 135 | parser.add_argument( | 141 | parser.add_argument( |
| 136 | "--subscheduler", | 142 | "--subscheduler", |
| 137 | type=str, | 143 | type=str, |
| 138 | default=None, | 144 | default=None, |
| 139 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis"], | 145 | choices=[ |
| 146 | "plms", | ||
| 147 | "ddim", | ||
| 148 | "klms", | ||
| 149 | "dpmsm", | ||
| 150 | "dpmss", | ||
| 151 | "euler_a", | ||
| 152 | "kdpm2", | ||
| 153 | "kdpm2_a", | ||
| 154 | "deis", | ||
| 155 | ], | ||
| 140 | ) | 156 | ) |
| 141 | parser.add_argument( | 157 | parser.add_argument( |
| 142 | "--template", | 158 | "--template", |
| @@ -193,10 +209,6 @@ def create_cmd_parser(): | |||
| 193 | type=float, | 209 | type=float, |
| 194 | ) | 210 | ) |
| 195 | parser.add_argument( | 211 | parser.add_argument( |
| 196 | "--brightness_offset", | ||
| 197 | type=float, | ||
| 198 | ) | ||
| 199 | parser.add_argument( | ||
| 200 | "--seed", | 212 | "--seed", |
| 201 | type=int, | 213 | type=int, |
| 202 | ) | 214 | ) |
| @@ -214,7 +226,9 @@ def run_parser(parser, defaults, input=None): | |||
| 214 | 226 | ||
| 215 | if args.config is not None: | 227 | if args.config is not None: |
| 216 | conf_args = load_config(args.config) | 228 | conf_args = load_config(args.config) |
| 217 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] | 229 | conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[ |
| 230 | 0 | ||
| 231 | ] | ||
| 218 | 232 | ||
| 219 | res = defaults.copy() | 233 | res = defaults.copy() |
| 220 | for dict in [vars(conf_args), vars(args)]: | 234 | for dict in [vars(conf_args), vars(args)]: |
| @@ -234,10 +248,12 @@ def load_embeddings_dir(pipeline, embeddings_dir): | |||
| 234 | added_tokens, added_ids = load_embeddings_from_dir( | 248 | added_tokens, added_ids = load_embeddings_from_dir( |
| 235 | pipeline.tokenizer, | 249 | pipeline.tokenizer, |
| 236 | pipeline.text_encoder.text_model.embeddings, | 250 | pipeline.text_encoder.text_model.embeddings, |
| 237 | Path(embeddings_dir) | 251 | Path(embeddings_dir), |
| 238 | ) | 252 | ) |
| 239 | pipeline.text_encoder.text_model.embeddings.persist() | 253 | pipeline.text_encoder.text_model.embeddings.persist() |
| 240 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 254 | print( |
| 255 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
| 256 | ) | ||
| 241 | 257 | ||
| 242 | 258 | ||
| 243 | def load_lora(pipeline, path): | 259 | def load_lora(pipeline, path): |
| @@ -255,9 +271,13 @@ def load_lora(pipeline, path): | |||
| 255 | return | 271 | return |
| 256 | 272 | ||
| 257 | lora_checkpoint_sd = load_file(path / tensor_files[0]) | 273 | lora_checkpoint_sd = load_file(path / tensor_files[0]) |
| 258 | unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} | 274 | unet_lora_ds = { |
| 275 | k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k | ||
| 276 | } | ||
| 259 | text_encoder_lora_ds = { | 277 | text_encoder_lora_ds = { |
| 260 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k | 278 | k.replace("text_encoder_", ""): v |
| 279 | for k, v in lora_checkpoint_sd.items() | ||
| 280 | if "text_encoder_" in k | ||
| 261 | } | 281 | } |
| 262 | ti_lora_ds = { | 282 | ti_lora_ds = { |
| 263 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k | 283 | k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k |
| @@ -282,7 +302,9 @@ def load_lora(pipeline, path): | |||
| 282 | token_embeddings=token_embeddings, | 302 | token_embeddings=token_embeddings, |
| 283 | ) | 303 | ) |
| 284 | pipeline.text_encoder.text_model.embeddings.persist() | 304 | pipeline.text_encoder.text_model.embeddings.persist() |
| 285 | print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") | 305 | print( |
| 306 | f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}" | ||
| 307 | ) | ||
| 286 | 308 | ||
| 287 | return | 309 | return |
| 288 | 310 | ||
| @@ -315,17 +337,25 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) | |||
| 315 | solver_p=create_scheduler(config, subscheduler), | 337 | solver_p=create_scheduler(config, subscheduler), |
| 316 | ) | 338 | ) |
| 317 | else: | 339 | else: |
| 318 | raise ValueError(f"Unknown scheduler \"{scheduler}\"") | 340 | raise ValueError(f'Unknown scheduler "{scheduler}"') |
| 319 | 341 | ||
| 320 | 342 | ||
| 321 | def create_pipeline(model, dtype): | 343 | def create_pipeline(model, dtype): |
| 322 | print("Loading Stable Diffusion pipeline...") | 344 | print("Loading Stable Diffusion pipeline...") |
| 323 | 345 | ||
| 324 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 346 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
| 325 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 347 | model, subfolder="tokenizer", torch_dtype=dtype |
| 326 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 348 | ) |
| 327 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 349 | text_encoder = CLIPTextModel.from_pretrained( |
| 328 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 350 | model, subfolder="text_encoder", torch_dtype=dtype |
| 351 | ) | ||
| 352 | vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype) | ||
| 353 | unet = UNet2DConditionModel.from_pretrained( | ||
| 354 | model, subfolder="unet", torch_dtype=dtype | ||
| 355 | ) | ||
| 356 | scheduler = DDIMScheduler.from_pretrained( | ||
| 357 | model, subfolder="scheduler", torch_dtype=dtype | ||
| 358 | ) | ||
| 329 | 359 | ||
| 330 | patch_managed_embeddings(text_encoder) | 360 | patch_managed_embeddings(text_encoder) |
| 331 | 361 | ||
| @@ -347,7 +377,9 @@ def create_pipeline(model, dtype): | |||
| 347 | 377 | ||
| 348 | 378 | ||
| 349 | def shuffle_prompts(prompts: list[str]) -> list[str]: | 379 | def shuffle_prompts(prompts: list[str]) -> list[str]: |
| 350 | return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts] | 380 | return [ |
| 381 | keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts | ||
| 382 | ] | ||
| 351 | 383 | ||
| 352 | 384 | ||
| 353 | @torch.inference_mode() | 385 | @torch.inference_mode() |
| @@ -386,12 +418,13 @@ def generate(output_dir: Path, pipeline, args): | |||
| 386 | else: | 418 | else: |
| 387 | init_image = None | 419 | init_image = None |
| 388 | 420 | ||
| 389 | pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) | 421 | pipeline.scheduler = create_scheduler( |
| 422 | pipeline.scheduler.config, args.scheduler, args.subscheduler | ||
| 423 | ) | ||
| 390 | 424 | ||
| 391 | for i in range(args.batch_num): | 425 | for i in range(args.batch_num): |
| 392 | pipeline.set_progress_bar_config( | 426 | pipeline.set_progress_bar_config( |
| 393 | desc=f"Batch {i + 1} of {args.batch_num}", | 427 | desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True |
| 394 | dynamic_ncols=True | ||
| 395 | ) | 428 | ) |
| 396 | 429 | ||
| 397 | seed = args.seed + i | 430 | seed = args.seed + i |
| @@ -409,7 +442,6 @@ def generate(output_dir: Path, pipeline, args): | |||
| 409 | generator=generator, | 442 | generator=generator, |
| 410 | image=init_image, | 443 | image=init_image, |
| 411 | strength=args.image_noise, | 444 | strength=args.image_noise, |
| 412 | brightness_offset=args.brightness_offset, | ||
| 413 | ).images | 445 | ).images |
| 414 | 446 | ||
| 415 | for j, image in enumerate(images): | 447 | for j, image in enumerate(images): |
| @@ -418,7 +450,7 @@ def generate(output_dir: Path, pipeline, args): | |||
| 418 | 450 | ||
| 419 | image.save(dir / f"{basename}.png") | 451 | image.save(dir / f"{basename}.png") |
| 420 | image.save(dir / f"{basename}.jpg", quality=85) | 452 | image.save(dir / f"{basename}.jpg", quality=85) |
| 421 | with open(dir / f"{basename}.txt", 'w') as f: | 453 | with open(dir / f"{basename}.txt", "w") as f: |
| 422 | f.write(prompt[j % len(args.prompt)]) | 454 | f.write(prompt[j % len(args.prompt)]) |
| 423 | 455 | ||
| 424 | if torch.cuda.is_available(): | 456 | if torch.cuda.is_available(): |
| @@ -426,10 +458,12 @@ def generate(output_dir: Path, pipeline, args): | |||
| 426 | 458 | ||
| 427 | 459 | ||
| 428 | class CmdParse(cmd.Cmd): | 460 | class CmdParse(cmd.Cmd): |
| 429 | prompt = 'dream> ' | 461 | prompt = "dream> " |
| 430 | commands = [] | 462 | commands = [] |
| 431 | 463 | ||
| 432 | def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): | 464 | def __init__( |
| 465 | self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser | ||
| 466 | ): | ||
| 433 | super().__init__() | 467 | super().__init__() |
| 434 | 468 | ||
| 435 | self.output_dir = output_dir | 469 | self.output_dir = output_dir |
| @@ -447,10 +481,10 @@ class CmdParse(cmd.Cmd): | |||
| 447 | print(str(e)) | 481 | print(str(e)) |
| 448 | return | 482 | return |
| 449 | 483 | ||
| 450 | if elements[0] == 'q': | 484 | if elements[0] == "q": |
| 451 | return True | 485 | return True |
| 452 | 486 | ||
| 453 | if elements[0] == 'reload_embeddings': | 487 | if elements[0] == "reload_embeddings": |
| 454 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) | 488 | load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) |
| 455 | return | 489 | return |
| 456 | 490 | ||
| @@ -458,7 +492,7 @@ class CmdParse(cmd.Cmd): | |||
| 458 | args = run_parser(self.parser, default_cmds, elements) | 492 | args = run_parser(self.parser, default_cmds, elements) |
| 459 | 493 | ||
| 460 | if len(args.prompt) == 0: | 494 | if len(args.prompt) == 0: |
| 461 | print('Try again with a prompt!') | 495 | print("Try again with a prompt!") |
| 462 | return | 496 | return |
| 463 | except SystemExit: | 497 | except SystemExit: |
| 464 | traceback.print_exc() | 498 | traceback.print_exc() |
| @@ -471,7 +505,7 @@ class CmdParse(cmd.Cmd): | |||
| 471 | try: | 505 | try: |
| 472 | generate(self.output_dir, self.pipeline, args) | 506 | generate(self.output_dir, self.pipeline, args) |
| 473 | except KeyboardInterrupt: | 507 | except KeyboardInterrupt: |
| 474 | print('Generation cancelled.') | 508 | print("Generation cancelled.") |
| 475 | except Exception as e: | 509 | except Exception as e: |
| 476 | traceback.print_exc() | 510 | traceback.print_exc() |
| 477 | return | 511 | return |
| @@ -487,7 +521,9 @@ def main(): | |||
| 487 | args = run_parser(args_parser, default_args) | 521 | args = run_parser(args_parser, default_args) |
| 488 | 522 | ||
| 489 | output_dir = Path(args.output_dir) | 523 | output_dir = Path(args.output_dir) |
| 490 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 524 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[ |
| 525 | args.precision | ||
| 526 | ] | ||
| 491 | 527 | ||
| 492 | pipeline = create_pipeline(args.model, dtype) | 528 | pipeline = create_pipeline(args.model, dtype) |
| 493 | 529 | ||
| @@ -496,7 +532,13 @@ def main(): | |||
| 496 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
| 497 | 533 | ||
| 498 | cmd_parser = create_cmd_parser() | 534 | cmd_parser = create_cmd_parser() |
| 499 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) | 535 | cmd_prompt = CmdParse( |
| 536 | output_dir, | ||
| 537 | args.ti_embeddings_dir, | ||
| 538 | args.lora_embeddings_dir, | ||
| 539 | pipeline, | ||
| 540 | cmd_parser, | ||
| 541 | ) | ||
| 500 | cmd_prompt.cmdloop() | 542 | cmd_prompt.cmdloop() |
| 501 | 543 | ||
| 502 | 544 | ||
diff --git a/models/attention/control.py b/models/attention/control.py index 248bd9f..ec378c4 100644 --- a/models/attention/control.py +++ b/models/attention/control.py | |||
| @@ -23,7 +23,7 @@ class AttentionControl(abc.ABC): | |||
| 23 | attn = self.forward(attn, is_cross, place_in_unet) | 23 | attn = self.forward(attn, is_cross, place_in_unet) |
| 24 | else: | 24 | else: |
| 25 | h = attn.shape[0] | 25 | h = attn.shape[0] |
| 26 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) | 26 | attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) |
| 27 | self.cur_att_layer += 1 | 27 | self.cur_att_layer += 1 |
| 28 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | 28 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: |
| 29 | self.cur_att_layer = 0 | 29 | self.cur_att_layer = 0 |
| @@ -49,12 +49,18 @@ class EmptyControl(AttentionControl): | |||
| 49 | class AttentionStore(AttentionControl): | 49 | class AttentionStore(AttentionControl): |
| 50 | @staticmethod | 50 | @staticmethod |
| 51 | def get_empty_store(): | 51 | def get_empty_store(): |
| 52 | return {"down_cross": [], "mid_cross": [], "up_cross": [], | 52 | return { |
| 53 | "down_self": [], "mid_self": [], "up_self": []} | 53 | "down_cross": [], |
| 54 | "mid_cross": [], | ||
| 55 | "up_cross": [], | ||
| 56 | "down_self": [], | ||
| 57 | "mid_self": [], | ||
| 58 | "up_self": [], | ||
| 59 | } | ||
| 54 | 60 | ||
| 55 | def forward(self, attn, is_cross: bool, place_in_unet: str): | 61 | def forward(self, attn, is_cross: bool, place_in_unet: str): |
| 56 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | 62 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" |
| 57 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead | 63 | if attn.shape[1] <= 32**2: # avoid memory overhead |
| 58 | self.step_store[key].append(attn) | 64 | self.step_store[key].append(attn) |
| 59 | return attn | 65 | return attn |
| 60 | 66 | ||
| @@ -68,8 +74,10 @@ class AttentionStore(AttentionControl): | |||
| 68 | self.step_store = self.get_empty_store() | 74 | self.step_store = self.get_empty_store() |
| 69 | 75 | ||
| 70 | def get_average_attention(self): | 76 | def get_average_attention(self): |
| 71 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] | 77 | average_attention = { |
| 72 | for key in self.attention_store} | 78 | key: [item / self.cur_step for item in self.attention_store[key]] |
| 79 | for key in self.attention_store | ||
| 80 | } | ||
| 73 | return average_attention | 81 | return average_attention |
| 74 | 82 | ||
| 75 | def reset(self): | 83 | def reset(self): |
| @@ -90,7 +98,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC): | |||
| 90 | return x_t | 98 | return x_t |
| 91 | 99 | ||
| 92 | def replace_self_attention(self, attn_base, att_replace): | 100 | def replace_self_attention(self, attn_base, att_replace): |
| 93 | if att_replace.shape[2] <= 16 ** 2: | 101 | if att_replace.shape[2] <= 16**2: |
| 94 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) | 102 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) |
| 95 | else: | 103 | else: |
| 96 | return att_replace | 104 | return att_replace |
| @@ -101,41 +109,62 @@ class AttentionControlEdit(AttentionStore, abc.ABC): | |||
| 101 | 109 | ||
| 102 | def forward(self, attn, is_cross: bool, place_in_unet: str): | 110 | def forward(self, attn, is_cross: bool, place_in_unet: str): |
| 103 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) | 111 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) |
| 104 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): | 112 | if is_cross or ( |
| 113 | self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1] | ||
| 114 | ): | ||
| 105 | h = attn.shape[0] // (self.batch_size) | 115 | h = attn.shape[0] // (self.batch_size) |
| 106 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) | 116 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) |
| 107 | attn_base, attn_repalce = attn[0], attn[1:] | 117 | attn_base, attn_repalce = attn[0], attn[1:] |
| 108 | if is_cross: | 118 | if is_cross: |
| 109 | alpha_words = self.cross_replace_alpha[self.cur_step] | 119 | alpha_words = self.cross_replace_alpha[self.cur_step] |
| 110 | attn_repalce_new = self.replace_cross_attention( | 120 | attn_repalce_new = ( |
| 111 | attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce | 121 | self.replace_cross_attention(attn_base, attn_repalce) * alpha_words |
| 122 | + (1 - alpha_words) * attn_repalce | ||
| 123 | ) | ||
| 112 | attn[1:] = attn_repalce_new | 124 | attn[1:] = attn_repalce_new |
| 113 | else: | 125 | else: |
| 114 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) | 126 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) |
| 115 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) | 127 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) |
| 116 | return attn | 128 | return attn |
| 117 | 129 | ||
| 118 | def __init__(self, prompts, num_steps: int, | 130 | def __init__( |
| 119 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], | 131 | self, |
| 120 | self_replace_steps: Union[float, Tuple[float, float]], | 132 | prompts, |
| 121 | local_blend: Optional[LocalBlend]): | 133 | num_steps: int, |
| 134 | cross_replace_steps: Union[ | ||
| 135 | float, Tuple[float, float], Dict[str, Tuple[float, float]] | ||
| 136 | ], | ||
| 137 | self_replace_steps: Union[float, Tuple[float, float]], | ||
| 138 | local_blend: Optional[LocalBlend], | ||
| 139 | ): | ||
| 122 | super(AttentionControlEdit, self).__init__() | 140 | super(AttentionControlEdit, self).__init__() |
| 123 | self.batch_size = len(prompts) | 141 | self.batch_size = len(prompts) |
| 124 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( | 142 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( |
| 125 | prompts, num_steps, cross_replace_steps, tokenizer).to(device) | 143 | prompts, num_steps, cross_replace_steps, tokenizer |
| 144 | ).to(device) | ||
| 126 | if type(self_replace_steps) is float: | 145 | if type(self_replace_steps) is float: |
| 127 | self_replace_steps = 0, self_replace_steps | 146 | self_replace_steps = 0, self_replace_steps |
| 128 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) | 147 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int( |
| 148 | num_steps * self_replace_steps[1] | ||
| 149 | ) | ||
| 129 | self.local_blend = local_blend | 150 | self.local_blend = local_blend |
| 130 | 151 | ||
| 131 | 152 | ||
| 132 | class AttentionReplace(AttentionControlEdit): | 153 | class AttentionReplace(AttentionControlEdit): |
| 133 | def replace_cross_attention(self, attn_base, att_replace): | 154 | def replace_cross_attention(self, attn_base, att_replace): |
| 134 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) | 155 | return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) |
| 135 | 156 | ||
| 136 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, | 157 | def __init__( |
| 137 | local_blend: Optional[LocalBlend] = None): | 158 | self, |
| 138 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) | 159 | prompts, |
| 160 | num_steps: int, | ||
| 161 | cross_replace_steps: float, | ||
| 162 | self_replace_steps: float, | ||
| 163 | local_blend: Optional[LocalBlend] = None, | ||
| 164 | ): | ||
| 165 | super(AttentionReplace, self).__init__( | ||
| 166 | prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend | ||
| 167 | ) | ||
| 139 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) | 168 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) |
| 140 | 169 | ||
| 141 | 170 | ||
| @@ -145,9 +174,17 @@ class AttentionRefine(AttentionControlEdit): | |||
| 145 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) | 174 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) |
| 146 | return attn_replace | 175 | return attn_replace |
| 147 | 176 | ||
| 148 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, | 177 | def __init__( |
| 149 | local_blend: Optional[LocalBlend] = None): | 178 | self, |
| 150 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) | 179 | prompts, |
| 180 | num_steps: int, | ||
| 181 | cross_replace_steps: float, | ||
| 182 | self_replace_steps: float, | ||
| 183 | local_blend: Optional[LocalBlend] = None, | ||
| 184 | ): | ||
| 185 | super(AttentionRefine, self).__init__( | ||
| 186 | prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend | ||
| 187 | ) | ||
| 151 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) | 188 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) |
| 152 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) | 189 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) |
| 153 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) | 190 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) |
| @@ -156,13 +193,24 @@ class AttentionRefine(AttentionControlEdit): | |||
| 156 | class AttentionReweight(AttentionControlEdit): | 193 | class AttentionReweight(AttentionControlEdit): |
| 157 | def replace_cross_attention(self, attn_base, att_replace): | 194 | def replace_cross_attention(self, attn_base, att_replace): |
| 158 | if self.prev_controller is not None: | 195 | if self.prev_controller is not None: |
| 159 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) | 196 | attn_base = self.prev_controller.replace_cross_attention( |
| 197 | attn_base, att_replace | ||
| 198 | ) | ||
| 160 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] | 199 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] |
| 161 | return attn_replace | 200 | return attn_replace |
| 162 | 201 | ||
| 163 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, | 202 | def __init__( |
| 164 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): | 203 | self, |
| 165 | super(AttentionReweight, self).__init__(prompts, num_steps, | 204 | prompts, |
| 166 | cross_replace_steps, self_replace_steps, local_blend) | 205 | num_steps: int, |
| 206 | cross_replace_steps: float, | ||
| 207 | self_replace_steps: float, | ||
| 208 | equalizer, | ||
| 209 | local_blend: Optional[LocalBlend] = None, | ||
| 210 | controller: Optional[AttentionControlEdit] = None, | ||
| 211 | ): | ||
| 212 | super(AttentionReweight, self).__init__( | ||
| 213 | prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend | ||
| 214 | ) | ||
| 167 | self.equalizer = equalizer.to(device) | 215 | self.equalizer = equalizer.to(device) |
| 168 | self.prev_controller = controller | 216 | self.prev_controller = controller |
diff --git a/models/attention/hook.py b/models/attention/hook.py index 903de02..6b5fb68 100644 --- a/models/attention/hook.py +++ b/models/attention/hook.py | |||
| @@ -3,6 +3,7 @@ import torch | |||
| 3 | 3 | ||
| 4 | try: | 4 | try: |
| 5 | import xformers.ops | 5 | import xformers.ops |
| 6 | |||
| 6 | xformers._is_functorch_available = True | 7 | xformers._is_functorch_available = True |
| 7 | MEM_EFFICIENT_ATTN = True | 8 | MEM_EFFICIENT_ATTN = True |
| 8 | except ImportError: | 9 | except ImportError: |
| @@ -42,10 +43,10 @@ def register_attention_control(model, controller): | |||
| 42 | return forward | 43 | return forward |
| 43 | 44 | ||
| 44 | def register_recr(net_, count, place_in_unet): | 45 | def register_recr(net_, count, place_in_unet): |
| 45 | if net_.__class__.__name__ == 'CrossAttention': | 46 | if net_.__class__.__name__ == "CrossAttention": |
| 46 | net_.forward = ca_forward(net_, place_in_unet) | 47 | net_.forward = ca_forward(net_, place_in_unet) |
| 47 | return count + 1 | 48 | return count + 1 |
| 48 | elif hasattr(net_, 'children'): | 49 | elif hasattr(net_, "children"): |
| 49 | for net__ in net_.children(): | 50 | for net__ in net_.children(): |
| 50 | count = register_recr(net__, count, place_in_unet) | 51 | count = register_recr(net__, count, place_in_unet) |
| 51 | return count | 52 | return count |
diff --git a/models/attention/structured.py b/models/attention/structured.py index 24d889f..5bbbc06 100644 --- a/models/attention/structured.py +++ b/models/attention/structured.py | |||
| @@ -16,7 +16,9 @@ class StructuredAttentionControl(AttentionControl): | |||
| 16 | if self.struct_attn: | 16 | if self.struct_attn: |
| 17 | out = self.struct_qkv(q, context, mask) | 17 | out = self.struct_qkv(q, context, mask) |
| 18 | else: | 18 | else: |
| 19 | context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context | 19 | context = torch.cat( |
| 20 | [context[0], context[1]["k"][0]], dim=0 | ||
| 21 | ) # use key tensor for context | ||
| 20 | out = self.normal_qkv(q, context, mask) | 22 | out = self.normal_qkv(q, context, mask) |
| 21 | else: | 23 | else: |
| 22 | context = default(context, x) | 24 | context = default(context, x) |
| @@ -29,11 +31,13 @@ class StructuredAttentionControl(AttentionControl): | |||
| 29 | context: list of [uc, list of conditional context] | 31 | context: list of [uc, list of conditional context] |
| 30 | """ | 32 | """ |
| 31 | uc_context = context[0] | 33 | uc_context = context[0] |
| 32 | context_k, context_v = context[1]['k'], context[1]['v'] | 34 | context_k, context_v = context[1]["k"], context[1]["v"] |
| 33 | 35 | ||
| 34 | if isinstance(context_k, list) and isinstance(context_v, list): | 36 | if isinstance(context_k, list) and isinstance(context_v, list): |
| 35 | out = self.multi_qkv(q, uc_context, context_k, context_v, mask) | 37 | out = self.multi_qkv(q, uc_context, context_k, context_v, mask) |
| 36 | elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor): | 38 | elif isinstance(context_k, torch.Tensor) and isinstance( |
| 39 | context_v, torch.Tensor | ||
| 40 | ): | ||
| 37 | out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) | 41 | out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) |
| 38 | else: | 42 | else: |
| 39 | raise NotImplementedError | 43 | raise NotImplementedError |
| @@ -50,36 +54,45 @@ class StructuredAttentionControl(AttentionControl): | |||
| 50 | k_c = [self.to_k(c_k) for c_k in context_k] | 54 | k_c = [self.to_k(c_k) for c_k in context_k] |
| 51 | v_c = [self.to_v(c_v) for c_v in context_v] | 55 | v_c = [self.to_v(c_v) for c_v in context_v] |
| 52 | 56 | ||
| 53 | q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) | 57 | q = rearrange(q, "b n (h d) -> (b h) n d", h=h) |
| 54 | 58 | ||
| 55 | k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h) | 59 | k_uc = rearrange(k_uc, "b n (h d) -> (b h) n d", h=h) |
| 56 | v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h) | 60 | v_uc = rearrange(v_uc, "b n (h d) -> (b h) n d", h=h) |
| 57 | 61 | ||
| 58 | k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point | 62 | k_c = [ |
| 59 | v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c] | 63 | rearrange(k, "b n (h d) -> (b h) n d", h=h) for k in k_c |
| 64 | ] # NOTE: modification point | ||
| 65 | v_c = [rearrange(v, "b n (h d) -> (b h) n d", h=h) for v in v_c] | ||
| 60 | 66 | ||
| 61 | # get composition | 67 | # get composition |
| 62 | sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale | 68 | sim_uc = einsum("b i d, b j d -> b i j", q[:true_bs], k_uc) * self.scale |
| 63 | sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c] | 69 | sim_c = [ |
| 70 | einsum("b i d, b j d -> b i j", q[true_bs:], k) * self.scale for k in k_c | ||
| 71 | ] | ||
| 64 | 72 | ||
| 65 | attn_uc = sim_uc.softmax(dim=-1) | 73 | attn_uc = sim_uc.softmax(dim=-1) |
| 66 | attn_c = [sim.softmax(dim=-1) for sim in sim_c] | 74 | attn_c = [sim.softmax(dim=-1) for sim in sim_c] |
| 67 | 75 | ||
| 68 | # get uc output | 76 | # get uc output |
| 69 | out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc) | 77 | out_uc = einsum("b i j, b j d -> b i d", attn_uc, v_uc) |
| 70 | 78 | ||
| 71 | # get c output | 79 | # get c output |
| 72 | if len(v_c) == 1: | 80 | if len(v_c) == 1: |
| 73 | out_c_collect = [] | 81 | out_c_collect = [] |
| 74 | for attn in attn_c: | 82 | for attn in attn_c: |
| 75 | for v in v_c: | 83 | for v in v_c: |
| 76 | out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v)) | 84 | out_c_collect.append(einsum("b i j, b j d -> b i d", attn, v)) |
| 77 | out_c = sum(out_c_collect) / len(out_c_collect) | 85 | out_c = sum(out_c_collect) / len(out_c_collect) |
| 78 | else: | 86 | else: |
| 79 | out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c) | 87 | out_c = sum( |
| 88 | [ | ||
| 89 | einsum("b i j, b j d -> b i d", attn, v) | ||
| 90 | for attn, v in zip(attn_c, v_c) | ||
| 91 | ] | ||
| 92 | ) / len(v_c) | ||
| 80 | 93 | ||
| 81 | out = torch.cat([out_uc, out_c], dim=0) | 94 | out = torch.cat([out_uc, out_c], dim=0) |
| 82 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | 95 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) |
| 83 | 96 | ||
| 84 | return out | 97 | return out |
| 85 | 98 | ||
| @@ -88,21 +101,21 @@ class StructuredAttentionControl(AttentionControl): | |||
| 88 | k = self.to_k(context) | 101 | k = self.to_k(context) |
| 89 | v = self.to_v(context) | 102 | v = self.to_v(context) |
| 90 | 103 | ||
| 91 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | 104 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) |
| 92 | 105 | ||
| 93 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | 106 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale |
| 94 | 107 | ||
| 95 | if exists(mask): | 108 | if exists(mask): |
| 96 | mask = rearrange(mask, 'b ... -> b (...)') | 109 | mask = rearrange(mask, "b ... -> b (...)") |
| 97 | max_neg_value = -torch.finfo(sim.dtype).max | 110 | max_neg_value = -torch.finfo(sim.dtype).max |
| 98 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | 111 | mask = repeat(mask, "b j -> (b h) () j", h=h) |
| 99 | sim.masked_fill_(~mask, max_neg_value) | 112 | sim.masked_fill_(~mask, max_neg_value) |
| 100 | 113 | ||
| 101 | # attention, what we cannot get enough of | 114 | # attention, what we cannot get enough of |
| 102 | attn = sim.softmax(dim=-1) | 115 | attn = sim.softmax(dim=-1) |
| 103 | 116 | ||
| 104 | out = einsum('b i j, b j d -> b i d', attn, v) | 117 | out = einsum("b i j, b j d -> b i d", attn, v) |
| 105 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | 118 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) |
| 106 | 119 | ||
| 107 | return out | 120 | return out |
| 108 | 121 | ||
| @@ -111,21 +124,21 @@ class StructuredAttentionControl(AttentionControl): | |||
| 111 | k = self.to_k(torch.cat([uc_context, context_k], dim=0)) | 124 | k = self.to_k(torch.cat([uc_context, context_k], dim=0)) |
| 112 | v = self.to_v(torch.cat([uc_context, context_v], dim=0)) | 125 | v = self.to_v(torch.cat([uc_context, context_v], dim=0)) |
| 113 | 126 | ||
| 114 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) | 127 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) |
| 115 | 128 | ||
| 116 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale | 129 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale |
| 117 | 130 | ||
| 118 | if exists(mask): | 131 | if exists(mask): |
| 119 | mask = rearrange(mask, 'b ... -> b (...)') | 132 | mask = rearrange(mask, "b ... -> b (...)") |
| 120 | max_neg_value = -torch.finfo(sim.dtype).max | 133 | max_neg_value = -torch.finfo(sim.dtype).max |
| 121 | mask = repeat(mask, 'b j -> (b h) () j', h=h) | 134 | mask = repeat(mask, "b j -> (b h) () j", h=h) |
| 122 | sim.masked_fill_(~mask, max_neg_value) | 135 | sim.masked_fill_(~mask, max_neg_value) |
| 123 | 136 | ||
| 124 | # attention, what we cannot get enough of | 137 | # attention, what we cannot get enough of |
| 125 | attn = sim.softmax(dim=-1) | 138 | attn = sim.softmax(dim=-1) |
| 126 | 139 | ||
| 127 | out = einsum('b i j, b j d -> b i d', attn, v) | 140 | out = einsum("b i j, b j d -> b i d", attn, v) |
| 128 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | 141 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) |
| 129 | return out | 142 | return out |
| 130 | 143 | ||
| 131 | def get_kv(self, context): | 144 | def get_kv(self, context): |
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 7c7f2ac..8c3c6d4 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
| @@ -14,7 +14,13 @@ from models.sparse import SparseEmbedding | |||
| 14 | 14 | ||
| 15 | 15 | ||
| 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | 16 | class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): |
| 17 | def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): | 17 | def __init__( |
| 18 | self, | ||
| 19 | config: CLIPTextConfig, | ||
| 20 | embeddings: CLIPTextEmbeddings, | ||
| 21 | alpha: int = 8, | ||
| 22 | dropout: float = 0.0, | ||
| 23 | ): | ||
| 18 | super().__init__(config) | 24 | super().__init__(config) |
| 19 | 25 | ||
| 20 | self.position_embedding = embeddings.position_embedding | 26 | self.position_embedding = embeddings.position_embedding |
| @@ -28,7 +34,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 28 | self.token_embedding.weight = embeddings.token_embedding.weight | 34 | self.token_embedding.weight = embeddings.token_embedding.weight |
| 29 | 35 | ||
| 30 | def resize(self, size: int): | 36 | def resize(self, size: int): |
| 31 | self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) | 37 | self.token_embedding = self.token_embedding.new_resized( |
| 38 | size, self.initializer_factor | ||
| 39 | ) | ||
| 32 | 40 | ||
| 33 | def add_embed( | 41 | def add_embed( |
| 34 | self, | 42 | self, |
| @@ -46,7 +54,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 46 | initializer = [initializer] | 54 | initializer = [initializer] |
| 47 | 55 | ||
| 48 | if isinstance(initializer, list): | 56 | if isinstance(initializer, list): |
| 49 | initializer = (initializer * len(token_ids))[:len(token_ids)] | 57 | initializer = (initializer * len(token_ids))[: len(token_ids)] |
| 50 | 58 | ||
| 51 | with torch.no_grad(): | 59 | with torch.no_grad(): |
| 52 | initializer = self.get_embed(initializer) | 60 | initializer = self.get_embed(initializer) |
| @@ -76,24 +84,21 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
| 76 | 84 | ||
| 77 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): | 85 | def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): |
| 78 | if isinstance(input_ids, list): | 86 | if isinstance(input_ids, list): |
| 79 | input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) | 87 | input_ids = torch.tensor( |
| 88 | input_ids, device=self.token_embedding.weight.device, dtype=torch.long | ||
| 89 | ) | ||
| 80 | 90 | ||
| 81 | return self.token_embedding(input_ids) | 91 | return self.token_embedding(input_ids) |
| 82 | 92 | ||
| 83 | 93 | ||
| 84 | def patch_managed_embeddings( | 94 | def patch_managed_embeddings( |
| 85 | text_encoder: CLIPTextModel, | 95 | text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0 |
| 86 | alpha: int = 8, | ||
| 87 | dropout: float = 0.0 | ||
| 88 | ) -> ManagedCLIPTextEmbeddings: | 96 | ) -> ManagedCLIPTextEmbeddings: |
| 89 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): | 97 | if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): |
| 90 | return text_encoder.text_model.embeddings | 98 | return text_encoder.text_model.embeddings |
| 91 | 99 | ||
| 92 | text_embeddings = ManagedCLIPTextEmbeddings( | 100 | text_embeddings = ManagedCLIPTextEmbeddings( |
| 93 | text_encoder.config, | 101 | text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout |
| 94 | text_encoder.text_model.embeddings, | ||
| 95 | alpha, | ||
| 96 | dropout | ||
| 97 | ) | 102 | ) |
| 98 | text_encoder.text_model.embeddings = text_embeddings | 103 | text_encoder.text_model.embeddings = text_embeddings |
| 99 | return text_embeddings | 104 | return text_embeddings |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 789b525..a866641 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -91,18 +91,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 91 | self.vector_shuffle = shuffle_none | 91 | self.vector_shuffle = shuffle_none |
| 92 | 92 | ||
| 93 | def add_multi_tokens( | 93 | def add_multi_tokens( |
| 94 | self, | 94 | self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1 |
| 95 | new_tokens: Union[str, list[str]], | ||
| 96 | num_vectors: Union[int, list[int]] = 1 | ||
| 97 | ) -> Union[list[int], list[list[int]]]: | 95 | ) -> Union[list[int], list[list[int]]]: |
| 98 | if isinstance(new_tokens, list): | 96 | if isinstance(new_tokens, list): |
| 99 | if isinstance(num_vectors, int): | 97 | if isinstance(num_vectors, int): |
| 100 | num_vectors = [num_vectors] * len(new_tokens) | 98 | num_vectors = [num_vectors] * len(new_tokens) |
| 101 | 99 | ||
| 102 | if len(num_vectors) != len(new_tokens): | 100 | if len(num_vectors) != len(new_tokens): |
| 103 | raise ValueError("Expected new_tokens and num_vectors to have the same len") | 101 | raise ValueError( |
| 102 | "Expected new_tokens and num_vectors to have the same len" | ||
| 103 | ) | ||
| 104 | 104 | ||
| 105 | return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] | 105 | return [ |
| 106 | self.add_multi_tokens(new_token, vecs) | ||
| 107 | for new_token, vecs in zip(new_tokens, num_vectors) | ||
| 108 | ] | ||
| 106 | 109 | ||
| 107 | if isinstance(num_vectors, list): | 110 | if isinstance(num_vectors, list): |
| 108 | raise ValueError("Expected num_vectors to be int for single token") | 111 | raise ValueError("Expected num_vectors to be int for single token") |
| @@ -129,13 +132,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 129 | return [id] | 132 | return [id] |
| 130 | 133 | ||
| 131 | def expand_ids(self, ids: list[int]): | 134 | def expand_ids(self, ids: list[int]): |
| 132 | return [ | 135 | return [new_id for id in ids for new_id in self.expand_id(id)] |
| 133 | new_id | ||
| 134 | for id in ids | ||
| 135 | for new_id in self.expand_id(id) | ||
| 136 | ] | ||
| 137 | 136 | ||
| 138 | def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): | 137 | def expand_batched_ids( |
| 138 | self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]] | ||
| 139 | ): | ||
| 139 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): | 140 | if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): |
| 140 | return [self.expand_ids(batch) for batch in input_ids] | 141 | return [self.expand_ids(batch) for batch in input_ids] |
| 141 | else: | 142 | else: |
diff --git a/models/clip/util.py b/models/clip/util.py index f94fbc7..7196bb6 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
| @@ -5,27 +5,32 @@ import torch | |||
| 5 | from transformers import CLIPTokenizer, CLIPTextModel | 5 | from transformers import CLIPTokenizer, CLIPTextModel |
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): | 8 | def unify_input_ids( |
| 9 | tokenizer: CLIPTokenizer, | ||
| 10 | input_ids: list[list[int]], | ||
| 11 | max_length: Optional[int] = None, | ||
| 12 | ): | ||
| 9 | if max_length is None: | 13 | if max_length is None: |
| 10 | return tokenizer.pad( | 14 | return tokenizer.pad( |
| 11 | {"input_ids": input_ids}, | 15 | {"input_ids": input_ids}, |
| 12 | padding=True, | 16 | padding=True, |
| 13 | pad_to_multiple_of=tokenizer.model_max_length, | 17 | pad_to_multiple_of=tokenizer.model_max_length, |
| 14 | return_tensors="pt" | 18 | return_tensors="pt", |
| 15 | ) | 19 | ) |
| 16 | else: | 20 | else: |
| 17 | return tokenizer.pad( | 21 | return tokenizer.pad( |
| 18 | {"input_ids": input_ids}, | 22 | {"input_ids": input_ids}, |
| 19 | padding="max_length", | 23 | padding="max_length", |
| 20 | max_length=max_length, | 24 | max_length=max_length, |
| 21 | return_tensors="pt" | 25 | return_tensors="pt", |
| 22 | ) | 26 | ) |
| 23 | 27 | ||
| 28 | |||
| 24 | def get_extended_embeddings( | 29 | def get_extended_embeddings( |
| 25 | text_encoder: CLIPTextModel, | 30 | text_encoder: CLIPTextModel, |
| 26 | input_ids: torch.LongTensor, | 31 | input_ids: torch.LongTensor, |
| 27 | position_ids: Optional[torch.LongTensor] = None, | 32 | position_ids: Optional[torch.LongTensor] = None, |
| 28 | attention_mask=None | 33 | attention_mask=None, |
| 29 | ): | 34 | ): |
| 30 | model_max_length = text_encoder.config.max_position_embeddings | 35 | model_max_length = text_encoder.config.max_position_embeddings |
| 31 | prompts = input_ids.shape[0] | 36 | prompts = input_ids.shape[0] |
| @@ -36,6 +41,8 @@ def get_extended_embeddings( | |||
| 36 | if attention_mask is not None: | 41 | if attention_mask is not None: |
| 37 | attention_mask = attention_mask.view((-1, model_max_length)) | 42 | attention_mask = attention_mask.view((-1, model_max_length)) |
| 38 | 43 | ||
| 39 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | 44 | text_embeddings = text_encoder( |
| 45 | input_ids, position_ids=position_ids, attention_mask=attention_mask | ||
| 46 | )[0] | ||
| 40 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | 47 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
| 41 | return text_embeddings | 48 | return text_embeddings |
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py index 571b915..5798bcf 100644 --- a/models/convnext/discriminator.py +++ b/models/convnext/discriminator.py | |||
| @@ -5,7 +5,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |||
| 5 | from torch.nn import functional as F | 5 | from torch.nn import functional as F |
| 6 | 6 | ||
| 7 | 7 | ||
| 8 | class ConvNeXtDiscriminator(): | 8 | class ConvNeXtDiscriminator: |
| 9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: | 9 | def __init__(self, model: ConvNeXt, input_size: int) -> None: |
| 10 | self.net = model | 10 | self.net = model |
| 11 | 11 | ||
| @@ -22,8 +22,13 @@ class ConvNeXtDiscriminator(): | |||
| 22 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) | 22 | img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) |
| 23 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) | 23 | img_std = self.img_std.to(device=img.device, dtype=img.dtype) |
| 24 | 24 | ||
| 25 | img = ((img + 1.) / 2.).sub(img_mean).div(img_std) | 25 | img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std) |
| 26 | 26 | ||
| 27 | img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) | 27 | img = F.interpolate( |
| 28 | img, | ||
| 29 | size=(self.input_size, self.input_size), | ||
| 30 | mode="bicubic", | ||
| 31 | align_corners=True, | ||
| 32 | ) | ||
| 28 | pred = self.net(img) | 33 | pred = self.net(img) |
| 29 | return pred | 34 | return pred |
diff --git a/models/sparse.py b/models/sparse.py index bd45696..e5897c9 100644 --- a/models/sparse.py +++ b/models/sparse.py | |||
| @@ -15,21 +15,25 @@ class SparseEmbedding(nn.Embedding): | |||
| 15 | ): | 15 | ): |
| 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | 16 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) |
| 17 | 17 | ||
| 18 | self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) | 18 | self.register_buffer( |
| 19 | "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1 | ||
| 20 | ) | ||
| 19 | 21 | ||
| 20 | self.trainable = nn.ParameterList() | 22 | self.trainable = nn.ParameterList() |
| 21 | self.scaling = alpha | 23 | self.scaling = alpha |
| 22 | self.dropout_p = dropout | 24 | self.dropout_p = dropout |
| 23 | self.weight.requires_grad = False | 25 | self.weight.requires_grad = False |
| 24 | 26 | ||
| 25 | if dropout > 0.: | 27 | if dropout > 0.0: |
| 26 | self.dropout = nn.Dropout(p=dropout) | 28 | self.dropout = nn.Dropout(p=dropout) |
| 27 | else: | 29 | else: |
| 28 | self.dropout = nn.Identity() | 30 | self.dropout = nn.Identity() |
| 29 | 31 | ||
| 30 | self.reset_parameters() | 32 | self.reset_parameters() |
| 31 | 33 | ||
| 32 | def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): | 34 | def new_resized( |
| 35 | self, new_num_embeddings: int, initializer_factor: Optional[float] = None | ||
| 36 | ): | ||
| 33 | n = min(self.num_embeddings, new_num_embeddings) | 37 | n = min(self.num_embeddings, new_num_embeddings) |
| 34 | 38 | ||
| 35 | new_emb = SparseEmbedding( | 39 | new_emb = SparseEmbedding( |
| @@ -38,7 +42,7 @@ class SparseEmbedding(nn.Embedding): | |||
| 38 | self.scaling, | 42 | self.scaling, |
| 39 | self.dropout_p, | 43 | self.dropout_p, |
| 40 | device=self.weight.device, | 44 | device=self.weight.device, |
| 41 | dtype=self.weight.dtype | 45 | dtype=self.weight.dtype, |
| 42 | ) | 46 | ) |
| 43 | if initializer_factor is not None: | 47 | if initializer_factor is not None: |
| 44 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) | 48 | new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index aa446ec..16b8456 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -21,7 +21,9 @@ from diffusers import ( | |||
| 21 | LMSDiscreteScheduler, | 21 | LMSDiscreteScheduler, |
| 22 | PNDMScheduler, | 22 | PNDMScheduler, |
| 23 | ) | 23 | ) |
| 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 24 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( |
| 25 | StableDiffusionPipelineOutput, | ||
| 26 | ) | ||
| 25 | from diffusers.utils import logging, randn_tensor | 27 | from diffusers.utils import logging, randn_tensor |
| 26 | from transformers import CLIPTextModel, CLIPTokenizer | 28 | from transformers import CLIPTextModel, CLIPTokenizer |
| 27 | 29 | ||
| @@ -62,13 +64,35 @@ def gaussian_blur_2d(img, kernel_size, sigma): | |||
| 62 | return img | 64 | return img |
| 63 | 65 | ||
| 64 | 66 | ||
| 67 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | ||
| 68 | """ | ||
| 69 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | ||
| 70 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | ||
| 71 | """ | ||
| 72 | std_text = noise_pred_text.std( | ||
| 73 | dim=list(range(1, noise_pred_text.ndim)), keepdim=True | ||
| 74 | ) | ||
| 75 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | ||
| 76 | # rescale the results from guidance (fixes overexposure) | ||
| 77 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | ||
| 78 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | ||
| 79 | noise_cfg = ( | ||
| 80 | guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | ||
| 81 | ) | ||
| 82 | return noise_cfg | ||
| 83 | |||
| 84 | |||
| 65 | class CrossAttnStoreProcessor: | 85 | class CrossAttnStoreProcessor: |
| 66 | def __init__(self): | 86 | def __init__(self): |
| 67 | self.attention_probs = None | 87 | self.attention_probs = None |
| 68 | 88 | ||
| 69 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): | 89 | def __call__( |
| 90 | self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None | ||
| 91 | ): | ||
| 70 | batch_size, sequence_length, _ = hidden_states.shape | 92 | batch_size, sequence_length, _ = hidden_states.shape |
| 71 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | 93 | attention_mask = attn.prepare_attention_mask( |
| 94 | attention_mask, sequence_length, batch_size | ||
| 95 | ) | ||
| 72 | query = attn.to_q(hidden_states) | 96 | query = attn.to_q(hidden_states) |
| 73 | 97 | ||
| 74 | if encoder_hidden_states is None: | 98 | if encoder_hidden_states is None: |
| @@ -113,7 +137,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 113 | ): | 137 | ): |
| 114 | super().__init__() | 138 | super().__init__() |
| 115 | 139 | ||
| 116 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: | 140 | if ( |
| 141 | hasattr(scheduler.config, "steps_offset") | ||
| 142 | and scheduler.config.steps_offset != 1 | ||
| 143 | ): | ||
| 117 | warnings.warn( | 144 | warnings.warn( |
| 118 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" | 145 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" |
| 119 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " | 146 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " |
| @@ -179,7 +206,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 179 | 206 | ||
| 180 | device = torch.device("cuda") | 207 | device = torch.device("cuda") |
| 181 | 208 | ||
| 182 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: | 209 | for cpu_offloaded_model in [ |
| 210 | self.unet, | ||
| 211 | self.text_encoder, | ||
| 212 | self.vae, | ||
| 213 | self.safety_checker, | ||
| 214 | ]: | ||
| 183 | if cpu_offloaded_model is not None: | 215 | if cpu_offloaded_model is not None: |
| 184 | cpu_offload(cpu_offloaded_model, device) | 216 | cpu_offload(cpu_offloaded_model, device) |
| 185 | 217 | ||
| @@ -223,35 +255,47 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 223 | width: int, | 255 | width: int, |
| 224 | height: int, | 256 | height: int, |
| 225 | strength: float, | 257 | strength: float, |
| 226 | callback_steps: Optional[int] | 258 | callback_steps: Optional[int], |
| 227 | ): | 259 | ): |
| 228 | if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): | 260 | if isinstance(prompt, str) or ( |
| 261 | isinstance(prompt, list) and isinstance(prompt[0], int) | ||
| 262 | ): | ||
| 229 | prompt = [prompt] | 263 | prompt = [prompt] |
| 230 | 264 | ||
| 231 | if negative_prompt is None: | 265 | if negative_prompt is None: |
| 232 | negative_prompt = "" | 266 | negative_prompt = "" |
| 233 | 267 | ||
| 234 | if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): | 268 | if isinstance(negative_prompt, str) or ( |
| 269 | isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int) | ||
| 270 | ): | ||
| 235 | negative_prompt = [negative_prompt] * len(prompt) | 271 | negative_prompt = [negative_prompt] * len(prompt) |
| 236 | 272 | ||
| 237 | if not isinstance(prompt, list): | 273 | if not isinstance(prompt, list): |
| 238 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | 274 | raise ValueError( |
| 275 | f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" | ||
| 276 | ) | ||
| 239 | 277 | ||
| 240 | if not isinstance(negative_prompt, list): | 278 | if not isinstance(negative_prompt, list): |
| 241 | raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") | 279 | raise ValueError( |
| 280 | f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}" | ||
| 281 | ) | ||
| 242 | 282 | ||
| 243 | if len(negative_prompt) != len(prompt): | 283 | if len(negative_prompt) != len(prompt): |
| 244 | raise ValueError( | 284 | raise ValueError( |
| 245 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") | 285 | f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}" |
| 286 | ) | ||
| 246 | 287 | ||
| 247 | if strength < 0 or strength > 1: | 288 | if strength < 0 or strength > 1: |
| 248 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") | 289 | raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") |
| 249 | 290 | ||
| 250 | if height % 8 != 0 or width % 8 != 0: | 291 | if height % 8 != 0 or width % 8 != 0: |
| 251 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 292 | raise ValueError( |
| 293 | f"`height` and `width` have to be divisible by 8 but are {height} and {width}." | ||
| 294 | ) | ||
| 252 | 295 | ||
| 253 | if (callback_steps is None) or ( | 296 | if (callback_steps is None) or ( |
| 254 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) | 297 | callback_steps is not None |
| 298 | and (not isinstance(callback_steps, int) or callback_steps <= 0) | ||
| 255 | ): | 299 | ): |
| 256 | raise ValueError( | 300 | raise ValueError( |
| 257 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" | 301 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
| @@ -266,7 +310,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 266 | negative_prompt: Union[List[str], List[List[int]]], | 310 | negative_prompt: Union[List[str], List[List[int]]], |
| 267 | num_images_per_prompt: int, | 311 | num_images_per_prompt: int, |
| 268 | do_classifier_free_guidance: bool, | 312 | do_classifier_free_guidance: bool, |
| 269 | device | 313 | device, |
| 270 | ): | 314 | ): |
| 271 | if isinstance(prompt[0], str): | 315 | if isinstance(prompt[0], str): |
| 272 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids | 316 | text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids |
| @@ -277,7 +321,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 277 | 321 | ||
| 278 | if do_classifier_free_guidance: | 322 | if do_classifier_free_guidance: |
| 279 | if isinstance(prompt[0], str): | 323 | if isinstance(prompt[0], str): |
| 280 | unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids | 324 | unconditional_input_ids = self.tokenizer( |
| 325 | negative_prompt, padding="do_not_pad" | ||
| 326 | ).input_ids | ||
| 281 | else: | 327 | else: |
| 282 | unconditional_input_ids = negative_prompt | 328 | unconditional_input_ids = negative_prompt |
| 283 | unconditional_input_ids *= num_images_per_prompt | 329 | unconditional_input_ids *= num_images_per_prompt |
| @@ -286,12 +332,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 286 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) | 332 | text_inputs = unify_input_ids(self.tokenizer, text_input_ids) |
| 287 | text_input_ids = text_inputs.input_ids | 333 | text_input_ids = text_inputs.input_ids |
| 288 | 334 | ||
| 289 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: | 335 | if ( |
| 336 | hasattr(self.text_encoder.config, "use_attention_mask") | ||
| 337 | and self.text_encoder.config.use_attention_mask | ||
| 338 | ): | ||
| 290 | attention_mask = text_inputs.attention_mask.to(device) | 339 | attention_mask = text_inputs.attention_mask.to(device) |
| 291 | else: | 340 | else: |
| 292 | attention_mask = None | 341 | attention_mask = None |
| 293 | 342 | ||
| 294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) | 343 | prompt_embeds = get_extended_embeddings( |
| 344 | self.text_encoder, text_input_ids.to(device), attention_mask | ||
| 345 | ) | ||
| 295 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | 346 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) |
| 296 | 347 | ||
| 297 | return prompt_embeds | 348 | return prompt_embeds |
| @@ -301,25 +352,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 301 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | 352 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
| 302 | 353 | ||
| 303 | t_start = max(num_inference_steps - init_timestep, 0) | 354 | t_start = max(num_inference_steps - init_timestep, 0) |
| 304 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] | 355 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] |
| 305 | 356 | ||
| 306 | timesteps = timesteps.to(device) | 357 | timesteps = timesteps.to(device) |
| 307 | 358 | ||
| 308 | return timesteps, num_inference_steps - t_start | 359 | return timesteps, num_inference_steps - t_start |
| 309 | 360 | ||
| 310 | def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): | 361 | def prepare_latents_from_image( |
| 311 | offset_image = perlin_noise( | 362 | self, |
| 312 | (batch_size, 1, width, height), | 363 | init_image, |
| 313 | res=1, | 364 | timestep, |
| 314 | generator=generator, | 365 | batch_size, |
| 315 | dtype=dtype, | 366 | dtype, |
| 316 | device=device | 367 | device, |
| 317 | ) | 368 | generator=None, |
| 318 | offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) | 369 | ): |
| 319 | offset_latents = self.vae.config.scaling_factor * offset_latents | ||
| 320 | return offset_latents | ||
| 321 | |||
| 322 | def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): | ||
| 323 | init_image = init_image.to(device=device, dtype=dtype) | 370 | init_image = init_image.to(device=device, dtype=dtype) |
| 324 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) | 371 | latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) |
| 325 | latents = self.vae.config.scaling_factor * latents | 372 | latents = self.vae.config.scaling_factor * latents |
| @@ -333,20 +380,32 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 333 | latents = torch.cat([latents] * batch_multiplier, dim=0) | 380 | latents = torch.cat([latents] * batch_multiplier, dim=0) |
| 334 | 381 | ||
| 335 | # add noise to latents using the timesteps | 382 | # add noise to latents using the timesteps |
| 336 | noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) | 383 | noise = torch.randn( |
| 337 | 384 | latents.shape, generator=generator, device=device, dtype=dtype | |
| 338 | if brightness_offset != 0: | 385 | ) |
| 339 | noise += brightness_offset * self.prepare_brightness_offset( | ||
| 340 | batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator | ||
| 341 | ) | ||
| 342 | 386 | ||
| 343 | # get latents | 387 | # get latents |
| 344 | latents = self.scheduler.add_noise(latents, noise, timestep) | 388 | latents = self.scheduler.add_noise(latents, noise, timestep) |
| 345 | 389 | ||
| 346 | return latents | 390 | return latents |
| 347 | 391 | ||
| 348 | def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): | 392 | def prepare_latents( |
| 349 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | 393 | self, |
| 394 | batch_size, | ||
| 395 | num_channels_latents, | ||
| 396 | height, | ||
| 397 | width, | ||
| 398 | dtype, | ||
| 399 | device, | ||
| 400 | generator, | ||
| 401 | latents=None, | ||
| 402 | ): | ||
| 403 | shape = ( | ||
| 404 | batch_size, | ||
| 405 | num_channels_latents, | ||
| 406 | height // self.vae_scale_factor, | ||
| 407 | width // self.vae_scale_factor, | ||
| 408 | ) | ||
| 350 | if isinstance(generator, list) and len(generator) != batch_size: | 409 | if isinstance(generator, list) and len(generator) != batch_size: |
| 351 | raise ValueError( | 410 | raise ValueError( |
| 352 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | 411 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| @@ -354,15 +413,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 354 | ) | 413 | ) |
| 355 | 414 | ||
| 356 | if latents is None: | 415 | if latents is None: |
| 357 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | 416 | latents = randn_tensor( |
| 417 | shape, generator=generator, device=device, dtype=dtype | ||
| 418 | ) | ||
| 358 | else: | 419 | else: |
| 359 | latents = latents.to(device) | 420 | latents = latents.to(device) |
| 360 | 421 | ||
| 361 | if brightness_offset != 0: | ||
| 362 | latents += brightness_offset * self.prepare_brightness_offset( | ||
| 363 | batch_size, height, width, dtype, device, generator | ||
| 364 | ) | ||
| 365 | |||
| 366 | # scale the initial noise by the standard deviation required by the scheduler | 422 | # scale the initial noise by the standard deviation required by the scheduler |
| 367 | latents = latents * self.scheduler.init_noise_sigma | 423 | latents = latents * self.scheduler.init_noise_sigma |
| 368 | return latents | 424 | return latents |
| @@ -373,13 +429,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 373 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | 429 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
| 374 | # and should be between [0, 1] | 430 | # and should be between [0, 1] |
| 375 | 431 | ||
| 376 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 432 | accepts_eta = "eta" in set( |
| 433 | inspect.signature(self.scheduler.step).parameters.keys() | ||
| 434 | ) | ||
| 377 | extra_step_kwargs = {} | 435 | extra_step_kwargs = {} |
| 378 | if accepts_eta: | 436 | if accepts_eta: |
| 379 | extra_step_kwargs["eta"] = eta | 437 | extra_step_kwargs["eta"] = eta |
| 380 | 438 | ||
| 381 | # check if the scheduler accepts generator | 439 | # check if the scheduler accepts generator |
| 382 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 440 | accepts_generator = "generator" in set( |
| 441 | inspect.signature(self.scheduler.step).parameters.keys() | ||
| 442 | ) | ||
| 383 | if accepts_generator: | 443 | if accepts_generator: |
| 384 | extra_step_kwargs["generator"] = generator | 444 | extra_step_kwargs["generator"] = generator |
| 385 | return extra_step_kwargs | 445 | return extra_step_kwargs |
| @@ -396,7 +456,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 396 | def __call__( | 456 | def __call__( |
| 397 | self, | 457 | self, |
| 398 | prompt: Union[str, List[str], List[int], List[List[int]]], | 458 | prompt: Union[str, List[str], List[int], List[List[int]]], |
| 399 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, | 459 | negative_prompt: Optional[ |
| 460 | Union[str, List[str], List[int], List[List[int]]] | ||
| 461 | ] = None, | ||
| 400 | num_images_per_prompt: int = 1, | 462 | num_images_per_prompt: int = 1, |
| 401 | strength: float = 1.0, | 463 | strength: float = 1.0, |
| 402 | height: Optional[int] = None, | 464 | height: Optional[int] = None, |
| @@ -407,12 +469,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 407 | eta: float = 0.0, | 469 | eta: float = 0.0, |
| 408 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 470 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 409 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 471 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| 410 | brightness_offset: Union[float, torch.FloatTensor] = 0, | ||
| 411 | output_type: str = "pil", | 472 | output_type: str = "pil", |
| 412 | return_dict: bool = True, | 473 | return_dict: bool = True, |
| 413 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 474 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| 414 | callback_steps: int = 1, | 475 | callback_steps: int = 1, |
| 415 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, | 476 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| 477 | guidance_rescale: float = 0.0, | ||
| 416 | ): | 478 | ): |
| 417 | r""" | 479 | r""" |
| 418 | Function invoked when calling the pipeline for generation. | 480 | Function invoked when calling the pipeline for generation. |
| @@ -472,7 +534,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 472 | width = width or self.unet.config.sample_size * self.vae_scale_factor | 534 | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| 473 | 535 | ||
| 474 | # 1. Check inputs. Raise error if not correct | 536 | # 1. Check inputs. Raise error if not correct |
| 475 | prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) | 537 | prompt, negative_prompt = self.check_inputs( |
| 538 | prompt, negative_prompt, width, height, strength, callback_steps | ||
| 539 | ) | ||
| 476 | 540 | ||
| 477 | # 2. Define call parameters | 541 | # 2. Define call parameters |
| 478 | batch_size = len(prompt) | 542 | batch_size = len(prompt) |
| @@ -488,7 +552,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 488 | negative_prompt, | 552 | negative_prompt, |
| 489 | num_images_per_prompt, | 553 | num_images_per_prompt, |
| 490 | do_classifier_free_guidance, | 554 | do_classifier_free_guidance, |
| 491 | device | 555 | device, |
| 492 | ) | 556 | ) |
| 493 | 557 | ||
| 494 | # 4. Prepare latent variables | 558 | # 4. Prepare latent variables |
| @@ -497,7 +561,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 497 | 561 | ||
| 498 | # 5. Prepare timesteps | 562 | # 5. Prepare timesteps |
| 499 | self.scheduler.set_timesteps(num_inference_steps, device=device) | 563 | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| 500 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) | 564 | timesteps, num_inference_steps = self.get_timesteps( |
| 565 | num_inference_steps, strength, device | ||
| 566 | ) | ||
| 501 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | 567 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
| 502 | 568 | ||
| 503 | # 6. Prepare latent variables | 569 | # 6. Prepare latent variables |
| @@ -506,7 +572,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 506 | image, | 572 | image, |
| 507 | latent_timestep, | 573 | latent_timestep, |
| 508 | batch_size * num_images_per_prompt, | 574 | batch_size * num_images_per_prompt, |
| 509 | brightness_offset, | ||
| 510 | prompt_embeds.dtype, | 575 | prompt_embeds.dtype, |
| 511 | device, | 576 | device, |
| 512 | generator, | 577 | generator, |
| @@ -517,7 +582,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 517 | num_channels_latents, | 582 | num_channels_latents, |
| 518 | height, | 583 | height, |
| 519 | width, | 584 | width, |
| 520 | brightness_offset, | ||
| 521 | prompt_embeds.dtype, | 585 | prompt_embeds.dtype, |
| 522 | device, | 586 | device, |
| 523 | generator, | 587 | generator, |
| @@ -530,14 +594,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 530 | # 8. Denoising loo | 594 | # 8. Denoising loo |
| 531 | if do_self_attention_guidance: | 595 | if do_self_attention_guidance: |
| 532 | store_processor = CrossAttnStoreProcessor() | 596 | store_processor = CrossAttnStoreProcessor() |
| 533 | self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor | 597 | self.unet.mid_block.attentions[0].transformer_blocks[ |
| 598 | 0 | ||
| 599 | ].attn1.processor = store_processor | ||
| 534 | 600 | ||
| 535 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | 601 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| 536 | with self.progress_bar(total=num_inference_steps) as progress_bar: | 602 | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| 537 | for i, t in enumerate(timesteps): | 603 | for i, t in enumerate(timesteps): |
| 538 | # expand the latents if we are doing classifier free guidance | 604 | # expand the latents if we are doing classifier free guidance |
| 539 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 605 | latent_model_input = ( |
| 540 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 606 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| 607 | ) | ||
| 608 | latent_model_input = self.scheduler.scale_model_input( | ||
| 609 | latent_model_input, t | ||
| 610 | ) | ||
| 541 | 611 | ||
| 542 | # predict the noise residual | 612 | # predict the noise residual |
| 543 | noise_pred = self.unet( | 613 | noise_pred = self.unet( |
| @@ -551,7 +621,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 551 | # perform guidance | 621 | # perform guidance |
| 552 | if do_classifier_free_guidance: | 622 | if do_classifier_free_guidance: |
| 553 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 623 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| 554 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 624 | noise_pred = noise_pred_uncond + guidance_scale * ( |
| 625 | noise_pred_text - noise_pred_uncond | ||
| 626 | ) | ||
| 627 | noise_pred = rescale_noise_cfg( | ||
| 628 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale | ||
| 629 | ) | ||
| 555 | 630 | ||
| 556 | if do_self_attention_guidance: | 631 | if do_self_attention_guidance: |
| 557 | # classifier-free guidance produces two chunks of attention map | 632 | # classifier-free guidance produces two chunks of attention map |
| @@ -561,15 +636,24 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 561 | # DDIM-like prediction of x0 | 636 | # DDIM-like prediction of x0 |
| 562 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) | 637 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) |
| 563 | # get the stored attention maps | 638 | # get the stored attention maps |
| 564 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) | 639 | uncond_attn, cond_attn = store_processor.attention_probs.chunk( |
| 640 | 2 | ||
| 641 | ) | ||
| 565 | # self-attention-based degrading of latents | 642 | # self-attention-based degrading of latents |
| 566 | degraded_latents = self.sag_masking( | 643 | degraded_latents = self.sag_masking( |
| 567 | pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) | 644 | pred_x0, |
| 645 | uncond_attn, | ||
| 646 | t, | ||
| 647 | self.pred_epsilon(latents, noise_pred_uncond, t), | ||
| 568 | ) | 648 | ) |
| 569 | uncond_emb, _ = prompt_embeds.chunk(2) | 649 | uncond_emb, _ = prompt_embeds.chunk(2) |
| 570 | # forward and give guidance | 650 | # forward and give guidance |
| 571 | degraded_pred = self.unet( | 651 | degraded_pred = self.unet( |
| 572 | degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] | 652 | degraded_latents, |
| 653 | t, | ||
| 654 | encoder_hidden_states=uncond_emb, | ||
| 655 | return_dict=False, | ||
| 656 | )[0] | ||
| 573 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | 657 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) |
| 574 | else: | 658 | else: |
| 575 | # DDIM-like prediction of x0 | 659 | # DDIM-like prediction of x0 |
| @@ -578,18 +662,29 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 578 | cond_attn = store_processor.attention_probs | 662 | cond_attn = store_processor.attention_probs |
| 579 | # self-attention-based degrading of latents | 663 | # self-attention-based degrading of latents |
| 580 | degraded_latents = self.sag_masking( | 664 | degraded_latents = self.sag_masking( |
| 581 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) | 665 | pred_x0, |
| 666 | cond_attn, | ||
| 667 | t, | ||
| 668 | self.pred_epsilon(latents, noise_pred, t), | ||
| 582 | ) | 669 | ) |
| 583 | # forward and give guidance | 670 | # forward and give guidance |
| 584 | degraded_pred = self.unet( | 671 | degraded_pred = self.unet( |
| 585 | degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] | 672 | degraded_latents, |
| 673 | t, | ||
| 674 | encoder_hidden_states=prompt_embeds, | ||
| 675 | return_dict=False, | ||
| 676 | )[0] | ||
| 586 | noise_pred += sag_scale * (noise_pred - degraded_pred) | 677 | noise_pred += sag_scale * (noise_pred - degraded_pred) |
| 587 | 678 | ||
| 588 | # compute the previous noisy sample x_t -> x_t-1 | 679 | # compute the previous noisy sample x_t -> x_t-1 |
| 589 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | 680 | latents = self.scheduler.step( |
| 681 | noise_pred, t, latents, **extra_step_kwargs, return_dict=False | ||
| 682 | )[0] | ||
| 590 | 683 | ||
| 591 | # call the callback, if provided | 684 | # call the callback, if provided |
| 592 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | 685 | if i == len(timesteps) - 1 or ( |
| 686 | (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | ||
| 687 | ): | ||
| 593 | progress_bar.update() | 688 | progress_bar.update() |
| 594 | if callback is not None and i % callback_steps == 0: | 689 | if callback is not None and i % callback_steps == 0: |
| 595 | callback(i, t, latents) | 690 | callback(i, t, latents) |
| @@ -615,7 +710,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 615 | if not return_dict: | 710 | if not return_dict: |
| 616 | return (image, has_nsfw_concept) | 711 | return (image, has_nsfw_concept) |
| 617 | 712 | ||
| 618 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | 713 | return StableDiffusionPipelineOutput( |
| 714 | images=image, nsfw_content_detected=has_nsfw_concept | ||
| 715 | ) | ||
| 619 | 716 | ||
| 620 | # Self-Attention-Guided (SAG) Stable Diffusion | 717 | # Self-Attention-Guided (SAG) Stable Diffusion |
| 621 | 718 | ||
| @@ -632,16 +729,23 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 632 | attn_map = attn_map.reshape(b, h, hw1, hw2) | 729 | attn_map = attn_map.reshape(b, h, hw1, hw2) |
| 633 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 | 730 | attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 |
| 634 | attn_mask = ( | 731 | attn_mask = ( |
| 635 | attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) | 732 | attn_mask.reshape(b, map_size, map_size) |
| 733 | .unsqueeze(1) | ||
| 734 | .repeat(1, latent_channel, 1, 1) | ||
| 735 | .type(attn_map.dtype) | ||
| 636 | ) | 736 | ) |
| 637 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) | 737 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) |
| 638 | 738 | ||
| 639 | # Blur according to the self-attention mask | 739 | # Blur according to the self-attention mask |
| 640 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) | 740 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) |
| 641 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) | 741 | degraded_latents = degraded_latents * attn_mask + original_latents * ( |
| 742 | 1 - attn_mask | ||
| 743 | ) | ||
| 642 | 744 | ||
| 643 | # Noise it again to match the noise level | 745 | # Noise it again to match the noise level |
| 644 | degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) | 746 | degraded_latents = self.scheduler.add_noise( |
| 747 | degraded_latents, noise=eps, timesteps=t | ||
| 748 | ) | ||
| 645 | 749 | ||
| 646 | return degraded_latents | 750 | return degraded_latents |
| 647 | 751 | ||
| @@ -652,13 +756,19 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 652 | 756 | ||
| 653 | beta_prod_t = 1 - alpha_prod_t | 757 | beta_prod_t = 1 - alpha_prod_t |
| 654 | if self.scheduler.config.prediction_type == "epsilon": | 758 | if self.scheduler.config.prediction_type == "epsilon": |
| 655 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | 759 | pred_original_sample = ( |
| 760 | sample - beta_prod_t ** (0.5) * model_output | ||
| 761 | ) / alpha_prod_t ** (0.5) | ||
| 656 | elif self.scheduler.config.prediction_type == "sample": | 762 | elif self.scheduler.config.prediction_type == "sample": |
| 657 | pred_original_sample = model_output | 763 | pred_original_sample = model_output |
| 658 | elif self.scheduler.config.prediction_type == "v_prediction": | 764 | elif self.scheduler.config.prediction_type == "v_prediction": |
| 659 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | 765 | pred_original_sample = (alpha_prod_t**0.5) * sample - ( |
| 766 | beta_prod_t**0.5 | ||
| 767 | ) * model_output | ||
| 660 | # predict V | 768 | # predict V |
| 661 | model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | 769 | model_output = (alpha_prod_t**0.5) * model_output + ( |
| 770 | beta_prod_t**0.5 | ||
| 771 | ) * sample | ||
| 662 | else: | 772 | else: |
| 663 | raise ValueError( | 773 | raise ValueError( |
| 664 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 774 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |
| @@ -674,9 +784,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 674 | if self.scheduler.config.prediction_type == "epsilon": | 784 | if self.scheduler.config.prediction_type == "epsilon": |
| 675 | pred_eps = model_output | 785 | pred_eps = model_output |
| 676 | elif self.scheduler.config.prediction_type == "sample": | 786 | elif self.scheduler.config.prediction_type == "sample": |
| 677 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) | 787 | pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / ( |
| 788 | beta_prod_t**0.5 | ||
| 789 | ) | ||
| 678 | elif self.scheduler.config.prediction_type == "v_prediction": | 790 | elif self.scheduler.config.prediction_type == "v_prediction": |
| 679 | pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output | 791 | pred_eps = (beta_prod_t**0.5) * sample + ( |
| 792 | alpha_prod_t**0.5 | ||
| 793 | ) * model_output | ||
| 680 | else: | 794 | else: |
| 681 | raise ValueError( | 795 | raise ValueError( |
| 682 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 796 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2aca1e7..659b84c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -5,34 +5,70 @@ import itertools | |||
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from functools import partial | 6 | from functools import partial |
| 7 | import math | 7 | import math |
| 8 | import warnings | ||
| 8 | 9 | ||
| 9 | import torch | 10 | import torch |
| 11 | import torch._dynamo | ||
| 10 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
| 13 | import hidet | ||
| 11 | 14 | ||
| 12 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 13 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 15 | from slugify import slugify | 18 | |
| 19 | # from diffusers.models.attention_processor import AttnProcessor | ||
| 20 | from diffusers.utils.import_utils import is_xformers_available | ||
| 16 | import transformers | 21 | import transformers |
| 17 | 22 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 23 | import numpy as np |
| 24 | from slugify import slugify | ||
| 25 | |||
| 19 | from data.csv import VlpnDataModule, keyword_filter | 26 | from data.csv import VlpnDataModule, keyword_filter |
| 20 | from training.functional import train, get_models | 27 | from models.clip.embeddings import patch_managed_embeddings |
| 28 | from training.functional import train, add_placeholder_tokens, get_models | ||
| 21 | from training.strategy.dreambooth import dreambooth_strategy | 29 | from training.strategy.dreambooth import dreambooth_strategy |
| 22 | from training.optimization import get_scheduler | 30 | from training.optimization import get_scheduler |
| 23 | from training.util import save_args | 31 | from training.sampler import create_named_schedule_sampler |
| 32 | from training.util import AverageMeter, save_args | ||
| 33 | from util.files import load_config, load_embeddings_from_dir | ||
| 34 | |||
| 24 | 35 | ||
| 25 | logger = get_logger(__name__) | 36 | logger = get_logger(__name__) |
| 26 | 37 | ||
| 38 | warnings.filterwarnings("ignore") | ||
| 39 | |||
| 27 | 40 | ||
| 28 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
| 29 | torch.backends.cudnn.benchmark = True | 42 | torch.backends.cudnn.benchmark = True |
| 30 | 43 | ||
| 44 | # torch._dynamo.config.log_level = logging.WARNING | ||
| 45 | torch._dynamo.config.suppress_errors = True | ||
| 46 | |||
| 47 | hidet.torch.dynamo_config.use_tensor_core(True) | ||
| 48 | hidet.torch.dynamo_config.search_space(0) | ||
| 49 | |||
| 50 | |||
| 51 | def patch_xformers(dtype): | ||
| 52 | if is_xformers_available(): | ||
| 53 | import xformers | ||
| 54 | import xformers.ops | ||
| 55 | |||
| 56 | orig_xformers_memory_efficient_attention = ( | ||
| 57 | xformers.ops.memory_efficient_attention | ||
| 58 | ) | ||
| 59 | |||
| 60 | def xformers_memory_efficient_attention( | ||
| 61 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
| 62 | ): | ||
| 63 | return orig_xformers_memory_efficient_attention( | ||
| 64 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
| 65 | ) | ||
| 66 | |||
| 67 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
| 68 | |||
| 31 | 69 | ||
| 32 | def parse_args(): | 70 | def parse_args(): |
| 33 | parser = argparse.ArgumentParser( | 71 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 34 | description="Simple example of a training script." | ||
| 35 | ) | ||
| 36 | parser.add_argument( | 72 | parser.add_argument( |
| 37 | "--pretrained_model_name_or_path", | 73 | "--pretrained_model_name_or_path", |
| 38 | type=str, | 74 | type=str, |
| @@ -49,7 +85,7 @@ def parse_args(): | |||
| 49 | "--train_data_file", | 85 | "--train_data_file", |
| 50 | type=str, | 86 | type=str, |
| 51 | default=None, | 87 | default=None, |
| 52 | help="A folder containing the training data." | 88 | help="A folder containing the training data.", |
| 53 | ) | 89 | ) |
| 54 | parser.add_argument( | 90 | parser.add_argument( |
| 55 | "--train_data_template", | 91 | "--train_data_template", |
| @@ -60,13 +96,13 @@ def parse_args(): | |||
| 60 | "--train_set_pad", | 96 | "--train_set_pad", |
| 61 | type=int, | 97 | type=int, |
| 62 | default=None, | 98 | default=None, |
| 63 | help="The number to fill train dataset items up to." | 99 | help="The number to fill train dataset items up to.", |
| 64 | ) | 100 | ) |
| 65 | parser.add_argument( | 101 | parser.add_argument( |
| 66 | "--valid_set_pad", | 102 | "--valid_set_pad", |
| 67 | type=int, | 103 | type=int, |
| 68 | default=None, | 104 | default=None, |
| 69 | help="The number to fill validation dataset items up to." | 105 | help="The number to fill validation dataset items up to.", |
| 70 | ) | 106 | ) |
| 71 | parser.add_argument( | 107 | parser.add_argument( |
| 72 | "--project", | 108 | "--project", |
| @@ -75,20 +111,58 @@ def parse_args(): | |||
| 75 | help="The name of the current project.", | 111 | help="The name of the current project.", |
| 76 | ) | 112 | ) |
| 77 | parser.add_argument( | 113 | parser.add_argument( |
| 78 | "--exclude_collections", | 114 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
| 115 | ) | ||
| 116 | parser.add_argument( | ||
| 117 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." | ||
| 118 | ) | ||
| 119 | parser.add_argument( | ||
| 120 | "--placeholder_tokens", | ||
| 79 | type=str, | 121 | type=str, |
| 80 | nargs='*', | 122 | nargs="*", |
| 81 | help="Exclude all items with a listed collection.", | 123 | help="A token to use as a placeholder for the concept.", |
| 82 | ) | 124 | ) |
| 83 | parser.add_argument( | 125 | parser.add_argument( |
| 84 | "--train_text_encoder_epochs", | 126 | "--initializer_tokens", |
| 85 | default=999999, | 127 | type=str, |
| 86 | help="Number of epochs the text encoder will be trained." | 128 | nargs="*", |
| 129 | help="A token to use as initializer word.", | ||
| 130 | ) | ||
| 131 | parser.add_argument( | ||
| 132 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." | ||
| 133 | ) | ||
| 134 | parser.add_argument( | ||
| 135 | "--initializer_noise", | ||
| 136 | type=float, | ||
| 137 | default=0, | ||
| 138 | help="Noise to apply to the initializer word", | ||
| 139 | ) | ||
| 140 | parser.add_argument( | ||
| 141 | "--alias_tokens", | ||
| 142 | type=str, | ||
| 143 | nargs="*", | ||
| 144 | default=[], | ||
| 145 | help="Tokens to create an alias for.", | ||
| 146 | ) | ||
| 147 | parser.add_argument( | ||
| 148 | "--inverted_initializer_tokens", | ||
| 149 | type=str, | ||
| 150 | nargs="*", | ||
| 151 | help="A token to use as initializer word.", | ||
| 152 | ) | ||
| 153 | parser.add_argument( | ||
| 154 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | ||
| 155 | ) | ||
| 156 | parser.add_argument( | ||
| 157 | "--exclude_collections", | ||
| 158 | type=str, | ||
| 159 | nargs="*", | ||
| 160 | help="Exclude all items with a listed collection.", | ||
| 87 | ) | 161 | ) |
| 88 | parser.add_argument( | 162 | parser.add_argument( |
| 89 | "--num_buckets", | 163 | "--num_buckets", |
| 90 | type=int, | 164 | type=int, |
| 91 | default=0, | 165 | default=2, |
| 92 | help="Number of aspect ratio buckets in either direction.", | 166 | help="Number of aspect ratio buckets in either direction.", |
| 93 | ) | 167 | ) |
| 94 | parser.add_argument( | 168 | parser.add_argument( |
| @@ -120,19 +194,6 @@ def parse_args(): | |||
| 120 | help="Shuffle tags.", | 194 | help="Shuffle tags.", |
| 121 | ) | 195 | ) |
| 122 | parser.add_argument( | 196 | parser.add_argument( |
| 123 | "--vector_dropout", | ||
| 124 | type=int, | ||
| 125 | default=0, | ||
| 126 | help="Vector dropout probability.", | ||
| 127 | ) | ||
| 128 | parser.add_argument( | ||
| 129 | "--vector_shuffle", | ||
| 130 | type=str, | ||
| 131 | default="auto", | ||
| 132 | choices=["all", "trailing", "leading", "between", "auto", "off"], | ||
| 133 | help='Vector shuffling algorithm.', | ||
| 134 | ) | ||
| 135 | parser.add_argument( | ||
| 136 | "--guidance_scale", | 197 | "--guidance_scale", |
| 137 | type=float, | 198 | type=float, |
| 138 | default=0, | 199 | default=0, |
| @@ -141,7 +202,7 @@ def parse_args(): | |||
| 141 | "--num_class_images", | 202 | "--num_class_images", |
| 142 | type=int, | 203 | type=int, |
| 143 | default=0, | 204 | default=0, |
| 144 | help="How many class images to generate." | 205 | help="How many class images to generate.", |
| 145 | ) | 206 | ) |
| 146 | parser.add_argument( | 207 | parser.add_argument( |
| 147 | "--class_image_dir", | 208 | "--class_image_dir", |
| @@ -162,16 +223,18 @@ def parse_args(): | |||
| 162 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 223 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 163 | ) | 224 | ) |
| 164 | parser.add_argument( | 225 | parser.add_argument( |
| 226 | "--train_dir_embeddings", | ||
| 227 | action="store_true", | ||
| 228 | help="Train embeddings loaded from embeddings directory.", | ||
| 229 | ) | ||
| 230 | parser.add_argument( | ||
| 165 | "--collection", | 231 | "--collection", |
| 166 | type=str, | 232 | type=str, |
| 167 | nargs='*', | 233 | nargs="*", |
| 168 | help="A collection to filter the dataset.", | 234 | help="A collection to filter the dataset.", |
| 169 | ) | 235 | ) |
| 170 | parser.add_argument( | 236 | parser.add_argument( |
| 171 | "--seed", | 237 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 172 | type=int, | ||
| 173 | default=None, | ||
| 174 | help="A seed for reproducible training." | ||
| 175 | ) | 238 | ) |
| 176 | parser.add_argument( | 239 | parser.add_argument( |
| 177 | "--resolution", | 240 | "--resolution", |
| @@ -189,15 +252,13 @@ def parse_args(): | |||
| 189 | help="Perlin offset noise strength.", | 252 | help="Perlin offset noise strength.", |
| 190 | ) | 253 | ) |
| 191 | parser.add_argument( | 254 | parser.add_argument( |
| 192 | "--num_train_epochs", | 255 | "--input_pertubation", |
| 193 | type=int, | 256 | type=float, |
| 194 | default=None | 257 | default=0, |
| 195 | ) | 258 | help="The scale of input pretubation. Recommended 0.1.", |
| 196 | parser.add_argument( | ||
| 197 | "--num_train_steps", | ||
| 198 | type=int, | ||
| 199 | default=2000 | ||
| 200 | ) | 259 | ) |
| 260 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
| 261 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
| 201 | parser.add_argument( | 262 | parser.add_argument( |
| 202 | "--gradient_accumulation_steps", | 263 | "--gradient_accumulation_steps", |
| 203 | type=int, | 264 | type=int, |
| @@ -205,9 +266,9 @@ def parse_args(): | |||
| 205 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 266 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 206 | ) | 267 | ) |
| 207 | parser.add_argument( | 268 | parser.add_argument( |
| 208 | "--gradient_checkpointing", | 269 | "--train_text_encoder_cycles", |
| 209 | action="store_true", | 270 | default=999999, |
| 210 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 271 | help="Number of epochs the text encoder will be trained.", |
| 211 | ) | 272 | ) |
| 212 | parser.add_argument( | 273 | parser.add_argument( |
| 213 | "--find_lr", | 274 | "--find_lr", |
| @@ -215,9 +276,15 @@ def parse_args(): | |||
| 215 | help="Automatically find a learning rate (no training).", | 276 | help="Automatically find a learning rate (no training).", |
| 216 | ) | 277 | ) |
| 217 | parser.add_argument( | 278 | parser.add_argument( |
| 218 | "--learning_rate", | 279 | "--learning_rate_unet", |
| 219 | type=float, | 280 | type=float, |
| 220 | default=2e-6, | 281 | default=1e-4, |
| 282 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 283 | ) | ||
| 284 | parser.add_argument( | ||
| 285 | "--learning_rate_text", | ||
| 286 | type=float, | ||
| 287 | default=5e-5, | ||
| 221 | help="Initial learning rate (after the potential warmup period) to use.", | 288 | help="Initial learning rate (after the potential warmup period) to use.", |
| 222 | ) | 289 | ) |
| 223 | parser.add_argument( | 290 | parser.add_argument( |
| @@ -229,27 +296,31 @@ def parse_args(): | |||
| 229 | "--lr_scheduler", | 296 | "--lr_scheduler", |
| 230 | type=str, | 297 | type=str, |
| 231 | default="one_cycle", | 298 | default="one_cycle", |
| 232 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 299 | choices=[ |
| 233 | "constant", "constant_with_warmup", "one_cycle"], | 300 | "linear", |
| 234 | help='The scheduler type to use.', | 301 | "cosine", |
| 302 | "cosine_with_restarts", | ||
| 303 | "polynomial", | ||
| 304 | "constant", | ||
| 305 | "constant_with_warmup", | ||
| 306 | "one_cycle", | ||
| 307 | ], | ||
| 308 | help="The scheduler type to use.", | ||
| 235 | ) | 309 | ) |
| 236 | parser.add_argument( | 310 | parser.add_argument( |
| 237 | "--lr_warmup_epochs", | 311 | "--lr_warmup_epochs", |
| 238 | type=int, | 312 | type=int, |
| 239 | default=10, | 313 | default=10, |
| 240 | help="Number of steps for the warmup in the lr scheduler." | 314 | help="Number of steps for the warmup in the lr scheduler.", |
| 241 | ) | 315 | ) |
| 242 | parser.add_argument( | 316 | parser.add_argument( |
| 243 | "--lr_mid_point", | 317 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
| 244 | type=float, | ||
| 245 | default=0.3, | ||
| 246 | help="OneCycle schedule mid point." | ||
| 247 | ) | 318 | ) |
| 248 | parser.add_argument( | 319 | parser.add_argument( |
| 249 | "--lr_cycles", | 320 | "--lr_cycles", |
| 250 | type=int, | 321 | type=int, |
| 251 | default=None, | 322 | default=None, |
| 252 | help="Number of restart cycles in the lr scheduler (if supported)." | 323 | help="Number of restart cycles in the lr scheduler (if supported).", |
| 253 | ) | 324 | ) |
| 254 | parser.add_argument( | 325 | parser.add_argument( |
| 255 | "--lr_warmup_func", | 326 | "--lr_warmup_func", |
| @@ -261,7 +332,7 @@ def parse_args(): | |||
| 261 | "--lr_warmup_exp", | 332 | "--lr_warmup_exp", |
| 262 | type=int, | 333 | type=int, |
| 263 | default=1, | 334 | default=1, |
| 264 | help='If lr_warmup_func is "cos", exponent to modify the function' | 335 | help='If lr_warmup_func is "cos", exponent to modify the function', |
| 265 | ) | 336 | ) |
| 266 | parser.add_argument( | 337 | parser.add_argument( |
| 267 | "--lr_annealing_func", | 338 | "--lr_annealing_func", |
| @@ -273,76 +344,76 @@ def parse_args(): | |||
| 273 | "--lr_annealing_exp", | 344 | "--lr_annealing_exp", |
| 274 | type=int, | 345 | type=int, |
| 275 | default=3, | 346 | default=3, |
| 276 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 347 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
| 277 | ) | 348 | ) |
| 278 | parser.add_argument( | 349 | parser.add_argument( |
| 279 | "--lr_min_lr", | 350 | "--lr_min_lr", |
| 280 | type=float, | 351 | type=float, |
| 281 | default=0.04, | 352 | default=0.04, |
| 282 | help="Minimum learning rate in the lr scheduler." | 353 | help="Minimum learning rate in the lr scheduler.", |
| 283 | ) | 354 | ) |
| 355 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
| 284 | parser.add_argument( | 356 | parser.add_argument( |
| 285 | "--use_ema", | 357 | "--schedule_sampler", |
| 286 | action="store_true", | 358 | type=str, |
| 287 | help="Whether to use EMA model." | 359 | default="uniform", |
| 288 | ) | 360 | choices=["uniform", "loss-second-moment"], |
| 289 | parser.add_argument( | 361 | help="Noise schedule sampler.", |
| 290 | "--ema_inv_gamma", | ||
| 291 | type=float, | ||
| 292 | default=1.0 | ||
| 293 | ) | ||
| 294 | parser.add_argument( | ||
| 295 | "--ema_power", | ||
| 296 | type=float, | ||
| 297 | default=6/7 | ||
| 298 | ) | ||
| 299 | parser.add_argument( | ||
| 300 | "--ema_max_decay", | ||
| 301 | type=float, | ||
| 302 | default=0.9999 | ||
| 303 | ) | 362 | ) |
| 304 | parser.add_argument( | 363 | parser.add_argument( |
| 305 | "--optimizer", | 364 | "--optimizer", |
| 306 | type=str, | 365 | type=str, |
| 307 | default="dadan", | 366 | default="adan", |
| 308 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 367 | choices=[ |
| 309 | help='Optimizer to use' | 368 | "adam", |
| 369 | "adam8bit", | ||
| 370 | "adan", | ||
| 371 | "lion", | ||
| 372 | "dadam", | ||
| 373 | "dadan", | ||
| 374 | "dlion", | ||
| 375 | "adafactor", | ||
| 376 | ], | ||
| 377 | help="Optimizer to use", | ||
| 310 | ) | 378 | ) |
| 311 | parser.add_argument( | 379 | parser.add_argument( |
| 312 | "--dadaptation_d0", | 380 | "--dadaptation_d0", |
| 313 | type=float, | 381 | type=float, |
| 314 | default=1e-6, | 382 | default=1e-6, |
| 315 | help="The d0 parameter for Dadaptation optimizers." | 383 | help="The d0 parameter for Dadaptation optimizers.", |
| 384 | ) | ||
| 385 | parser.add_argument( | ||
| 386 | "--dadaptation_growth_rate", | ||
| 387 | type=float, | ||
| 388 | default=math.inf, | ||
| 389 | help="The growth_rate parameter for Dadaptation optimizers.", | ||
| 316 | ) | 390 | ) |
| 317 | parser.add_argument( | 391 | parser.add_argument( |
| 318 | "--adam_beta1", | 392 | "--adam_beta1", |
| 319 | type=float, | 393 | type=float, |
| 320 | default=None, | 394 | default=None, |
| 321 | help="The beta1 parameter for the Adam optimizer." | 395 | help="The beta1 parameter for the Adam optimizer.", |
| 322 | ) | 396 | ) |
| 323 | parser.add_argument( | 397 | parser.add_argument( |
| 324 | "--adam_beta2", | 398 | "--adam_beta2", |
| 325 | type=float, | 399 | type=float, |
| 326 | default=None, | 400 | default=None, |
| 327 | help="The beta2 parameter for the Adam optimizer." | 401 | help="The beta2 parameter for the Adam optimizer.", |
| 328 | ) | 402 | ) |
| 329 | parser.add_argument( | 403 | parser.add_argument( |
| 330 | "--adam_weight_decay", | 404 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
| 331 | type=float, | ||
| 332 | default=1e-2, | ||
| 333 | help="Weight decay to use." | ||
| 334 | ) | 405 | ) |
| 335 | parser.add_argument( | 406 | parser.add_argument( |
| 336 | "--adam_epsilon", | 407 | "--adam_epsilon", |
| 337 | type=float, | 408 | type=float, |
| 338 | default=1e-08, | 409 | default=1e-08, |
| 339 | help="Epsilon value for the Adam optimizer" | 410 | help="Epsilon value for the Adam optimizer", |
| 340 | ) | 411 | ) |
| 341 | parser.add_argument( | 412 | parser.add_argument( |
| 342 | "--adam_amsgrad", | 413 | "--adam_amsgrad", |
| 343 | type=bool, | 414 | type=bool, |
| 344 | default=False, | 415 | default=False, |
| 345 | help="Amsgrad value for the Adam optimizer" | 416 | help="Amsgrad value for the Adam optimizer", |
| 346 | ) | 417 | ) |
| 347 | parser.add_argument( | 418 | parser.add_argument( |
| 348 | "--mixed_precision", | 419 | "--mixed_precision", |
| @@ -356,12 +427,28 @@ def parse_args(): | |||
| 356 | ), | 427 | ), |
| 357 | ) | 428 | ) |
| 358 | parser.add_argument( | 429 | parser.add_argument( |
| 430 | "--compile_unet", | ||
| 431 | action="store_true", | ||
| 432 | help="Compile UNet with Torch Dynamo.", | ||
| 433 | ) | ||
| 434 | parser.add_argument( | ||
| 435 | "--use_xformers", | ||
| 436 | action="store_true", | ||
| 437 | help="Use xformers.", | ||
| 438 | ) | ||
| 439 | parser.add_argument( | ||
| 359 | "--sample_frequency", | 440 | "--sample_frequency", |
| 360 | type=int, | 441 | type=int, |
| 361 | default=1, | 442 | default=1, |
| 362 | help="How often to save a checkpoint and sample image", | 443 | help="How often to save a checkpoint and sample image", |
| 363 | ) | 444 | ) |
| 364 | parser.add_argument( | 445 | parser.add_argument( |
| 446 | "--sample_num", | ||
| 447 | type=int, | ||
| 448 | default=None, | ||
| 449 | help="How often to save a checkpoint and sample image (in number of samples)", | ||
| 450 | ) | ||
| 451 | parser.add_argument( | ||
| 365 | "--sample_image_size", | 452 | "--sample_image_size", |
| 366 | type=int, | 453 | type=int, |
| 367 | default=768, | 454 | default=768, |
| @@ -383,19 +470,19 @@ def parse_args(): | |||
| 383 | "--valid_set_size", | 470 | "--valid_set_size", |
| 384 | type=int, | 471 | type=int, |
| 385 | default=None, | 472 | default=None, |
| 386 | help="Number of images in the validation dataset." | 473 | help="Number of images in the validation dataset.", |
| 387 | ) | 474 | ) |
| 388 | parser.add_argument( | 475 | parser.add_argument( |
| 389 | "--valid_set_repeat", | 476 | "--valid_set_repeat", |
| 390 | type=int, | 477 | type=int, |
| 391 | default=1, | 478 | default=1, |
| 392 | help="Times the images in the validation dataset are repeated." | 479 | help="Times the images in the validation dataset are repeated.", |
| 393 | ) | 480 | ) |
| 394 | parser.add_argument( | 481 | parser.add_argument( |
| 395 | "--train_batch_size", | 482 | "--train_batch_size", |
| 396 | type=int, | 483 | type=int, |
| 397 | default=1, | 484 | default=1, |
| 398 | help="Batch size (per device) for the training dataloader." | 485 | help="Batch size (per device) for the training dataloader.", |
| 399 | ) | 486 | ) |
| 400 | parser.add_argument( | 487 | parser.add_argument( |
| 401 | "--sample_steps", | 488 | "--sample_steps", |
| @@ -407,13 +494,18 @@ def parse_args(): | |||
| 407 | "--prior_loss_weight", | 494 | "--prior_loss_weight", |
| 408 | type=float, | 495 | type=float, |
| 409 | default=1.0, | 496 | default=1.0, |
| 410 | help="The weight of prior preservation loss." | 497 | help="The weight of prior preservation loss.", |
| 411 | ) | 498 | ) |
| 499 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
| 500 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
| 412 | parser.add_argument( | 501 | parser.add_argument( |
| 413 | "--max_grad_norm", | 502 | "--emb_dropout", |
| 414 | default=1.0, | ||
| 415 | type=float, | 503 | type=float, |
| 416 | help="Max gradient norm." | 504 | default=0, |
| 505 | help="Embedding dropout probability.", | ||
| 506 | ) | ||
| 507 | parser.add_argument( | ||
| 508 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." | ||
| 417 | ) | 509 | ) |
| 418 | parser.add_argument( | 510 | parser.add_argument( |
| 419 | "--noise_timesteps", | 511 | "--noise_timesteps", |
| @@ -424,7 +516,7 @@ def parse_args(): | |||
| 424 | "--config", | 516 | "--config", |
| 425 | type=str, | 517 | type=str, |
| 426 | default=None, | 518 | default=None, |
| 427 | help="Path to a JSON configuration file containing arguments for invoking this script." | 519 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
| 428 | ) | 520 | ) |
| 429 | 521 | ||
| 430 | args = parser.parse_args() | 522 | args = parser.parse_args() |
| @@ -441,6 +533,67 @@ def parse_args(): | |||
| 441 | if args.project is None: | 533 | if args.project is None: |
| 442 | raise ValueError("You must specify --project") | 534 | raise ValueError("You must specify --project") |
| 443 | 535 | ||
| 536 | if args.initializer_tokens is None: | ||
| 537 | args.initializer_tokens = [] | ||
| 538 | |||
| 539 | if args.placeholder_tokens is None: | ||
| 540 | args.placeholder_tokens = [] | ||
| 541 | |||
| 542 | if isinstance(args.placeholder_tokens, str): | ||
| 543 | args.placeholder_tokens = [args.placeholder_tokens] | ||
| 544 | |||
| 545 | if isinstance(args.initializer_tokens, str): | ||
| 546 | args.initializer_tokens = [args.initializer_tokens] * len( | ||
| 547 | args.placeholder_tokens | ||
| 548 | ) | ||
| 549 | |||
| 550 | if len(args.placeholder_tokens) == 0: | ||
| 551 | args.placeholder_tokens = [ | ||
| 552 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
| 553 | ] | ||
| 554 | |||
| 555 | if len(args.initializer_tokens) == 0: | ||
| 556 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
| 557 | |||
| 558 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
| 559 | raise ValueError( | ||
| 560 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
| 561 | ) | ||
| 562 | |||
| 563 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 564 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
| 565 | args.placeholder_tokens | ||
| 566 | ) | ||
| 567 | |||
| 568 | if ( | ||
| 569 | isinstance(args.inverted_initializer_tokens, list) | ||
| 570 | and len(args.inverted_initializer_tokens) != 0 | ||
| 571 | ): | ||
| 572 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 573 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 574 | |||
| 575 | if isinstance(args.num_vectors, int): | ||
| 576 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | ||
| 577 | |||
| 578 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( | ||
| 579 | args.num_vectors | ||
| 580 | ): | ||
| 581 | raise ValueError( | ||
| 582 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
| 583 | ) | ||
| 584 | |||
| 585 | if args.alias_tokens is None: | ||
| 586 | args.alias_tokens = [] | ||
| 587 | |||
| 588 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
| 589 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
| 590 | |||
| 591 | if args.filter_tokens is None: | ||
| 592 | args.filter_tokens = args.placeholder_tokens.copy() | ||
| 593 | |||
| 594 | if isinstance(args.filter_tokens, str): | ||
| 595 | args.filter_tokens = [args.filter_tokens] | ||
| 596 | |||
| 444 | if isinstance(args.collection, str): | 597 | if isinstance(args.collection, str): |
| 445 | args.collection = [args.collection] | 598 | args.collection = [args.collection] |
| 446 | 599 | ||
| @@ -451,15 +604,15 @@ def parse_args(): | |||
| 451 | raise ValueError("You must specify --output_dir") | 604 | raise ValueError("You must specify --output_dir") |
| 452 | 605 | ||
| 453 | if args.adam_beta1 is None: | 606 | if args.adam_beta1 is None: |
| 454 | if args.optimizer in ('adam', 'adam8bit'): | 607 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
| 455 | args.adam_beta1 = 0.9 | 608 | args.adam_beta1 = 0.9 |
| 456 | elif args.optimizer == 'lion': | 609 | elif args.optimizer in ("lion", "dlion"): |
| 457 | args.adam_beta1 = 0.95 | 610 | args.adam_beta1 = 0.95 |
| 458 | 611 | ||
| 459 | if args.adam_beta2 is None: | 612 | if args.adam_beta2 is None: |
| 460 | if args.optimizer in ('adam', 'adam8bit'): | 613 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
| 461 | args.adam_beta2 = 0.999 | 614 | args.adam_beta2 = 0.999 |
| 462 | elif args.optimizer == 'lion': | 615 | elif args.optimizer in ("lion", "dlion"): |
| 463 | args.adam_beta2 = 0.98 | 616 | args.adam_beta2 = 0.98 |
| 464 | 617 | ||
| 465 | return args | 618 | return args |
| @@ -475,7 +628,7 @@ def main(): | |||
| 475 | accelerator = Accelerator( | 628 | accelerator = Accelerator( |
| 476 | log_with=LoggerType.TENSORBOARD, | 629 | log_with=LoggerType.TENSORBOARD, |
| 477 | project_dir=f"{output_dir}", | 630 | project_dir=f"{output_dir}", |
| 478 | mixed_precision=args.mixed_precision | 631 | mixed_precision=args.mixed_precision, |
| 479 | ) | 632 | ) |
| 480 | 633 | ||
| 481 | weight_dtype = torch.float32 | 634 | weight_dtype = torch.float32 |
| @@ -484,6 +637,8 @@ def main(): | |||
| 484 | elif args.mixed_precision == "bf16": | 637 | elif args.mixed_precision == "bf16": |
| 485 | weight_dtype = torch.bfloat16 | 638 | weight_dtype = torch.bfloat16 |
| 486 | 639 | ||
| 640 | patch_xformers(weight_dtype) | ||
| 641 | |||
| 487 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 642 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
| 488 | 643 | ||
| 489 | if args.seed is None: | 644 | if args.seed is None: |
| @@ -493,44 +648,125 @@ def main(): | |||
| 493 | 648 | ||
| 494 | save_args(output_dir, args) | 649 | save_args(output_dir, args) |
| 495 | 650 | ||
| 496 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 651 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
| 497 | args.pretrained_model_name_or_path) | 652 | args.pretrained_model_name_or_path |
| 498 | 653 | ) | |
| 499 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 654 | embeddings = patch_managed_embeddings( |
| 500 | tokenizer.set_dropout(args.vector_dropout) | 655 | text_encoder, args.emb_alpha, args.emb_dropout |
| 656 | ) | ||
| 657 | schedule_sampler = create_named_schedule_sampler( | ||
| 658 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
| 659 | ) | ||
| 501 | 660 | ||
| 502 | vae.enable_slicing() | 661 | vae.enable_slicing() |
| 503 | vae.set_use_memory_efficient_attention_xformers(True) | 662 | |
| 504 | unet.enable_xformers_memory_efficient_attention() | 663 | if args.use_xformers: |
| 664 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 665 | unet.enable_xformers_memory_efficient_attention() | ||
| 666 | # elif args.compile_unet: | ||
| 667 | # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
| 668 | # | ||
| 669 | # proc = AttnProcessor() | ||
| 670 | # | ||
| 671 | # def fn_recursive_set_proc(module: torch.nn.Module): | ||
| 672 | # if hasattr(module, "processor"): | ||
| 673 | # module.processor = proc | ||
| 674 | # | ||
| 675 | # for child in module.children(): | ||
| 676 | # fn_recursive_set_proc(child) | ||
| 677 | # | ||
| 678 | # fn_recursive_set_proc(unet) | ||
| 505 | 679 | ||
| 506 | if args.gradient_checkpointing: | 680 | if args.gradient_checkpointing: |
| 507 | unet.enable_gradient_checkpointing() | 681 | unet.enable_gradient_checkpointing() |
| 508 | text_encoder.gradient_checkpointing_enable() | 682 | |
| 683 | if len(args.alias_tokens) != 0: | ||
| 684 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
| 685 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
| 686 | |||
| 687 | added_tokens, added_ids = add_placeholder_tokens( | ||
| 688 | tokenizer=tokenizer, | ||
| 689 | embeddings=embeddings, | ||
| 690 | placeholder_tokens=alias_placeholder_tokens, | ||
| 691 | initializer_tokens=alias_initializer_tokens, | ||
| 692 | ) | ||
| 693 | embeddings.persist() | ||
| 694 | print( | ||
| 695 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
| 696 | ) | ||
| 697 | |||
| 698 | placeholder_tokens = [] | ||
| 699 | placeholder_token_ids = [] | ||
| 509 | 700 | ||
| 510 | if args.embeddings_dir is not None: | 701 | if args.embeddings_dir is not None: |
| 511 | embeddings_dir = Path(args.embeddings_dir) | 702 | embeddings_dir = Path(args.embeddings_dir) |
| 512 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 703 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 513 | raise ValueError("--embeddings_dir must point to an existing directory") | 704 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 514 | 705 | ||
| 515 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 706 | added_tokens, added_ids = load_embeddings_from_dir( |
| 516 | embeddings.persist() | 707 | tokenizer, embeddings, embeddings_dir |
| 517 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 708 | ) |
| 709 | |||
| 710 | placeholder_tokens = added_tokens | ||
| 711 | placeholder_token_ids = added_ids | ||
| 712 | |||
| 713 | print( | ||
| 714 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
| 715 | ) | ||
| 716 | |||
| 717 | if args.train_dir_embeddings: | ||
| 718 | print("Training embeddings from embeddings dir") | ||
| 719 | else: | ||
| 720 | embeddings.persist() | ||
| 721 | |||
| 722 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | ||
| 723 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
| 724 | tokenizer=tokenizer, | ||
| 725 | embeddings=embeddings, | ||
| 726 | placeholder_tokens=args.placeholder_tokens, | ||
| 727 | initializer_tokens=args.initializer_tokens, | ||
| 728 | num_vectors=args.num_vectors, | ||
| 729 | initializer_noise=args.initializer_noise, | ||
| 730 | ) | ||
| 731 | |||
| 732 | placeholder_tokens = args.placeholder_tokens | ||
| 733 | |||
| 734 | stats = list( | ||
| 735 | zip( | ||
| 736 | placeholder_tokens, | ||
| 737 | placeholder_token_ids, | ||
| 738 | args.initializer_tokens, | ||
| 739 | initializer_token_ids, | ||
| 740 | ) | ||
| 741 | ) | ||
| 742 | print(f"Training embeddings: {stats}") | ||
| 518 | 743 | ||
| 519 | if args.scale_lr: | 744 | if args.scale_lr: |
| 520 | args.learning_rate = ( | 745 | args.learning_rate_unet = ( |
| 521 | args.learning_rate * args.gradient_accumulation_steps * | 746 | args.learning_rate_unet |
| 522 | args.train_batch_size * accelerator.num_processes | 747 | * args.gradient_accumulation_steps |
| 748 | * args.train_batch_size | ||
| 749 | * accelerator.num_processes | ||
| 750 | ) | ||
| 751 | args.learning_rate_text = ( | ||
| 752 | args.learning_rate_text | ||
| 753 | * args.gradient_accumulation_steps | ||
| 754 | * args.train_batch_size | ||
| 755 | * accelerator.num_processes | ||
| 523 | ) | 756 | ) |
| 524 | 757 | ||
| 525 | if args.find_lr: | 758 | if args.find_lr: |
| 526 | args.learning_rate = 1e-6 | 759 | args.learning_rate_unet = 1e-6 |
| 760 | args.learning_rate_text = 1e-6 | ||
| 527 | args.lr_scheduler = "exponential_growth" | 761 | args.lr_scheduler = "exponential_growth" |
| 528 | 762 | ||
| 529 | if args.optimizer == 'adam8bit': | 763 | if args.optimizer == "adam8bit": |
| 530 | try: | 764 | try: |
| 531 | import bitsandbytes as bnb | 765 | import bitsandbytes as bnb |
| 532 | except ImportError: | 766 | except ImportError: |
| 533 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 767 | raise ImportError( |
| 768 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
| 769 | ) | ||
| 534 | 770 | ||
| 535 | create_optimizer = partial( | 771 | create_optimizer = partial( |
| 536 | bnb.optim.AdamW8bit, | 772 | bnb.optim.AdamW8bit, |
| @@ -539,7 +775,7 @@ def main(): | |||
| 539 | eps=args.adam_epsilon, | 775 | eps=args.adam_epsilon, |
| 540 | amsgrad=args.adam_amsgrad, | 776 | amsgrad=args.adam_amsgrad, |
| 541 | ) | 777 | ) |
| 542 | elif args.optimizer == 'adam': | 778 | elif args.optimizer == "adam": |
| 543 | create_optimizer = partial( | 779 | create_optimizer = partial( |
| 544 | torch.optim.AdamW, | 780 | torch.optim.AdamW, |
| 545 | betas=(args.adam_beta1, args.adam_beta2), | 781 | betas=(args.adam_beta1, args.adam_beta2), |
| @@ -547,22 +783,27 @@ def main(): | |||
| 547 | eps=args.adam_epsilon, | 783 | eps=args.adam_epsilon, |
| 548 | amsgrad=args.adam_amsgrad, | 784 | amsgrad=args.adam_amsgrad, |
| 549 | ) | 785 | ) |
| 550 | elif args.optimizer == 'adan': | 786 | elif args.optimizer == "adan": |
| 551 | try: | 787 | try: |
| 552 | import timm.optim | 788 | import timm.optim |
| 553 | except ImportError: | 789 | except ImportError: |
| 554 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 790 | raise ImportError( |
| 791 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
| 792 | ) | ||
| 555 | 793 | ||
| 556 | create_optimizer = partial( | 794 | create_optimizer = partial( |
| 557 | timm.optim.Adan, | 795 | timm.optim.Adan, |
| 558 | weight_decay=args.adam_weight_decay, | 796 | weight_decay=args.adam_weight_decay, |
| 559 | eps=args.adam_epsilon, | 797 | eps=args.adam_epsilon, |
| 798 | no_prox=True, | ||
| 560 | ) | 799 | ) |
| 561 | elif args.optimizer == 'lion': | 800 | elif args.optimizer == "lion": |
| 562 | try: | 801 | try: |
| 563 | import lion_pytorch | 802 | import lion_pytorch |
| 564 | except ImportError: | 803 | except ImportError: |
| 565 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 804 | raise ImportError( |
| 805 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
| 806 | ) | ||
| 566 | 807 | ||
| 567 | create_optimizer = partial( | 808 | create_optimizer = partial( |
| 568 | lion_pytorch.Lion, | 809 | lion_pytorch.Lion, |
| @@ -570,7 +811,7 @@ def main(): | |||
| 570 | weight_decay=args.adam_weight_decay, | 811 | weight_decay=args.adam_weight_decay, |
| 571 | use_triton=True, | 812 | use_triton=True, |
| 572 | ) | 813 | ) |
| 573 | elif args.optimizer == 'adafactor': | 814 | elif args.optimizer == "adafactor": |
| 574 | create_optimizer = partial( | 815 | create_optimizer = partial( |
| 575 | transformers.optimization.Adafactor, | 816 | transformers.optimization.Adafactor, |
| 576 | weight_decay=args.adam_weight_decay, | 817 | weight_decay=args.adam_weight_decay, |
| @@ -580,13 +821,16 @@ def main(): | |||
| 580 | ) | 821 | ) |
| 581 | 822 | ||
| 582 | args.lr_scheduler = "adafactor" | 823 | args.lr_scheduler = "adafactor" |
| 583 | args.lr_min_lr = args.learning_rate | 824 | args.lr_min_lr = args.learning_rate_unet |
| 584 | args.learning_rate = None | 825 | args.learning_rate_unet = None |
| 585 | elif args.optimizer == 'dadam': | 826 | args.learning_rate_text = None |
| 827 | elif args.optimizer == "dadam": | ||
| 586 | try: | 828 | try: |
| 587 | import dadaptation | 829 | import dadaptation |
| 588 | except ImportError: | 830 | except ImportError: |
| 589 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 831 | raise ImportError( |
| 832 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
| 833 | ) | ||
| 590 | 834 | ||
| 591 | create_optimizer = partial( | 835 | create_optimizer = partial( |
| 592 | dadaptation.DAdaptAdam, | 836 | dadaptation.DAdaptAdam, |
| @@ -595,46 +839,65 @@ def main(): | |||
| 595 | eps=args.adam_epsilon, | 839 | eps=args.adam_epsilon, |
| 596 | decouple=True, | 840 | decouple=True, |
| 597 | d0=args.dadaptation_d0, | 841 | d0=args.dadaptation_d0, |
| 842 | growth_rate=args.dadaptation_growth_rate, | ||
| 598 | ) | 843 | ) |
| 599 | 844 | ||
| 600 | args.learning_rate = 1.0 | 845 | args.learning_rate_unet = 1.0 |
| 601 | elif args.optimizer == 'dadan': | 846 | args.learning_rate_text = 1.0 |
| 847 | elif args.optimizer == "dadan": | ||
| 602 | try: | 848 | try: |
| 603 | import dadaptation | 849 | import dadaptation |
| 604 | except ImportError: | 850 | except ImportError: |
| 605 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 851 | raise ImportError( |
| 852 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 853 | ) | ||
| 606 | 854 | ||
| 607 | create_optimizer = partial( | 855 | create_optimizer = partial( |
| 608 | dadaptation.DAdaptAdan, | 856 | dadaptation.DAdaptAdan, |
| 609 | weight_decay=args.adam_weight_decay, | 857 | weight_decay=args.adam_weight_decay, |
| 610 | eps=args.adam_epsilon, | 858 | eps=args.adam_epsilon, |
| 611 | d0=args.dadaptation_d0, | 859 | d0=args.dadaptation_d0, |
| 860 | growth_rate=args.dadaptation_growth_rate, | ||
| 612 | ) | 861 | ) |
| 613 | 862 | ||
| 614 | args.learning_rate = 1.0 | 863 | args.learning_rate_unet = 1.0 |
| 864 | args.learning_rate_text = 1.0 | ||
| 865 | elif args.optimizer == "dlion": | ||
| 866 | raise ImportError("DLion has not been merged into dadaptation yet") | ||
| 615 | else: | 867 | else: |
| 616 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 868 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 617 | 869 | ||
| 618 | trainer = partial( | 870 | trainer = partial( |
| 619 | train, | 871 | train, |
| 620 | accelerator=accelerator, | 872 | accelerator=accelerator, |
| 621 | unet=unet, | 873 | unet=unet, |
| 622 | text_encoder=text_encoder, | 874 | text_encoder=text_encoder, |
| 875 | tokenizer=tokenizer, | ||
| 623 | vae=vae, | 876 | vae=vae, |
| 624 | noise_scheduler=noise_scheduler, | 877 | noise_scheduler=noise_scheduler, |
| 878 | schedule_sampler=schedule_sampler, | ||
| 879 | min_snr_gamma=args.min_snr_gamma, | ||
| 625 | dtype=weight_dtype, | 880 | dtype=weight_dtype, |
| 881 | seed=args.seed, | ||
| 882 | compile_unet=args.compile_unet, | ||
| 626 | guidance_scale=args.guidance_scale, | 883 | guidance_scale=args.guidance_scale, |
| 627 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 884 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 628 | no_val=args.valid_set_size == 0, | 885 | sample_scheduler=sample_scheduler, |
| 886 | sample_batch_size=args.sample_batch_size, | ||
| 887 | sample_num_batches=args.sample_batches, | ||
| 888 | sample_num_steps=args.sample_steps, | ||
| 889 | sample_image_size=args.sample_image_size, | ||
| 890 | max_grad_norm=args.max_grad_norm, | ||
| 629 | ) | 891 | ) |
| 630 | 892 | ||
| 631 | checkpoint_output_dir = output_dir / "model" | 893 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
| 632 | sample_output_dir = output_dir / "samples" | 894 | data_npgenerator = np.random.default_rng(args.seed) |
| 633 | 895 | ||
| 634 | datamodule = VlpnDataModule( | 896 | create_datamodule = partial( |
| 897 | VlpnDataModule, | ||
| 635 | data_file=args.train_data_file, | 898 | data_file=args.train_data_file, |
| 636 | batch_size=args.train_batch_size, | ||
| 637 | tokenizer=tokenizer, | 899 | tokenizer=tokenizer, |
| 900 | constant_prompt_length=args.compile_unet, | ||
| 638 | class_subdir=args.class_image_dir, | 901 | class_subdir=args.class_image_dir, |
| 639 | with_guidance=args.guidance_scale != 0, | 902 | with_guidance=args.guidance_scale != 0, |
| 640 | num_class_images=args.num_class_images, | 903 | num_class_images=args.num_class_images, |
| @@ -643,83 +906,186 @@ def main(): | |||
| 643 | progressive_buckets=args.progressive_buckets, | 906 | progressive_buckets=args.progressive_buckets, |
| 644 | bucket_step_size=args.bucket_step_size, | 907 | bucket_step_size=args.bucket_step_size, |
| 645 | bucket_max_pixels=args.bucket_max_pixels, | 908 | bucket_max_pixels=args.bucket_max_pixels, |
| 646 | dropout=args.tag_dropout, | ||
| 647 | shuffle=not args.no_tag_shuffle, | 909 | shuffle=not args.no_tag_shuffle, |
| 648 | template_key=args.train_data_template, | 910 | template_key=args.train_data_template, |
| 649 | valid_set_size=args.valid_set_size, | ||
| 650 | train_set_pad=args.train_set_pad, | 911 | train_set_pad=args.train_set_pad, |
| 651 | valid_set_pad=args.valid_set_pad, | 912 | valid_set_pad=args.valid_set_pad, |
| 652 | seed=args.seed, | 913 | dtype=weight_dtype, |
| 653 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 914 | generator=data_generator, |
| 654 | dtype=weight_dtype | 915 | npgenerator=data_npgenerator, |
| 655 | ) | ||
| 656 | datamodule.setup() | ||
| 657 | |||
| 658 | num_train_epochs = args.num_train_epochs | ||
| 659 | sample_frequency = args.sample_frequency | ||
| 660 | if num_train_epochs is None: | ||
| 661 | num_train_epochs = math.ceil( | ||
| 662 | args.num_train_steps / len(datamodule.train_dataset) | ||
| 663 | ) * args.gradient_accumulation_steps | ||
| 664 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | ||
| 665 | |||
| 666 | params_to_optimize = (unet.parameters(), ) | ||
| 667 | if args.train_text_encoder_epochs != 0: | ||
| 668 | params_to_optimize += ( | ||
| 669 | text_encoder.text_model.encoder.parameters(), | ||
| 670 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 671 | ) | ||
| 672 | |||
| 673 | optimizer = create_optimizer( | ||
| 674 | itertools.chain(*params_to_optimize), | ||
| 675 | lr=args.learning_rate, | ||
| 676 | ) | 916 | ) |
| 677 | 917 | ||
| 678 | lr_scheduler = get_scheduler( | 918 | create_lr_scheduler = partial( |
| 679 | args.lr_scheduler, | 919 | get_scheduler, |
| 680 | optimizer=optimizer, | ||
| 681 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 682 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 683 | min_lr=args.lr_min_lr, | 920 | min_lr=args.lr_min_lr, |
| 684 | warmup_func=args.lr_warmup_func, | 921 | warmup_func=args.lr_warmup_func, |
| 685 | annealing_func=args.lr_annealing_func, | 922 | annealing_func=args.lr_annealing_func, |
| 686 | warmup_exp=args.lr_warmup_exp, | 923 | warmup_exp=args.lr_warmup_exp, |
| 687 | annealing_exp=args.lr_annealing_exp, | 924 | annealing_exp=args.lr_annealing_exp, |
| 688 | cycles=args.lr_cycles, | ||
| 689 | end_lr=1e2, | 925 | end_lr=1e2, |
| 690 | train_epochs=num_train_epochs, | ||
| 691 | warmup_epochs=args.lr_warmup_epochs, | ||
| 692 | mid_point=args.lr_mid_point, | 926 | mid_point=args.lr_mid_point, |
| 693 | ) | 927 | ) |
| 694 | 928 | ||
| 695 | trainer( | 929 | # Dreambooth |
| 696 | strategy=dreambooth_strategy, | 930 | # -------------------------------------------------------------------------------- |
| 697 | project="dreambooth", | 931 | |
| 698 | train_dataloader=datamodule.train_dataloader, | 932 | dreambooth_datamodule = create_datamodule( |
| 699 | val_dataloader=datamodule.val_dataloader, | 933 | valid_set_size=args.valid_set_size, |
| 700 | seed=args.seed, | 934 | batch_size=args.train_batch_size, |
| 701 | optimizer=optimizer, | 935 | dropout=args.tag_dropout, |
| 702 | lr_scheduler=lr_scheduler, | 936 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
| 703 | num_train_epochs=num_train_epochs, | 937 | ) |
| 704 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 938 | dreambooth_datamodule.setup() |
| 705 | sample_frequency=sample_frequency, | 939 | |
| 706 | offset_noise_strength=args.offset_noise_strength, | 940 | num_train_epochs = args.num_train_epochs |
| 707 | # -- | 941 | dreambooth_sample_frequency = args.sample_frequency |
| 708 | tokenizer=tokenizer, | 942 | if num_train_epochs is None: |
| 709 | sample_scheduler=sample_scheduler, | 943 | num_train_epochs = ( |
| 710 | sample_output_dir=sample_output_dir, | 944 | math.ceil(args.num_train_steps / len(dreambooth_datamodule.train_dataset)) |
| 711 | checkpoint_output_dir=checkpoint_output_dir, | 945 | * args.gradient_accumulation_steps |
| 712 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 946 | ) |
| 713 | max_grad_norm=args.max_grad_norm, | 947 | dreambooth_sample_frequency = math.ceil( |
| 714 | use_ema=args.use_ema, | 948 | num_train_epochs * (dreambooth_sample_frequency / args.num_train_steps) |
| 715 | ema_inv_gamma=args.ema_inv_gamma, | 949 | ) |
| 716 | ema_power=args.ema_power, | 950 | num_training_steps_per_epoch = math.ceil( |
| 717 | ema_max_decay=args.ema_max_decay, | 951 | len(dreambooth_datamodule.train_dataset) / args.gradient_accumulation_steps |
| 718 | sample_batch_size=args.sample_batch_size, | ||
| 719 | sample_num_batches=args.sample_batches, | ||
| 720 | sample_num_steps=args.sample_steps, | ||
| 721 | sample_image_size=args.sample_image_size, | ||
| 722 | ) | 952 | ) |
| 953 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
| 954 | if args.sample_num is not None: | ||
| 955 | dreambooth_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | ||
| 956 | |||
| 957 | dreambooth_project = "dreambooth" | ||
| 958 | |||
| 959 | if accelerator.is_main_process: | ||
| 960 | accelerator.init_trackers(dreambooth_project) | ||
| 961 | |||
| 962 | dreambooth_sample_output_dir = output_dir / dreambooth_project / "samples" | ||
| 963 | |||
| 964 | training_iter = 0 | ||
| 965 | auto_cycles = list(args.auto_cycles) | ||
| 966 | learning_rate_unet = args.learning_rate_unet | ||
| 967 | learning_rate_text = args.learning_rate_text | ||
| 968 | lr_scheduler = args.lr_scheduler | ||
| 969 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 970 | lr_cycles = args.lr_cycles | ||
| 971 | |||
| 972 | avg_loss = AverageMeter() | ||
| 973 | avg_acc = AverageMeter() | ||
| 974 | avg_loss_val = AverageMeter() | ||
| 975 | avg_acc_val = AverageMeter() | ||
| 976 | |||
| 977 | params_to_optimize = [ | ||
| 978 | { | ||
| 979 | "params": (param for param in unet.parameters() if param.requires_grad), | ||
| 980 | "lr": learning_rate_unet, | ||
| 981 | }, | ||
| 982 | { | ||
| 983 | "params": ( | ||
| 984 | param for param in text_encoder.parameters() if param.requires_grad | ||
| 985 | ), | ||
| 986 | "lr": learning_rate_text, | ||
| 987 | }, | ||
| 988 | ] | ||
| 989 | group_labels = ["unet", "text"] | ||
| 990 | |||
| 991 | dreambooth_optimizer = create_optimizer(params_to_optimize) | ||
| 992 | |||
| 993 | while True: | ||
| 994 | if len(auto_cycles) != 0: | ||
| 995 | response = auto_cycles.pop(0) | ||
| 996 | else: | ||
| 997 | response = input( | ||
| 998 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " | ||
| 999 | ) | ||
| 1000 | |||
| 1001 | if response.lower().strip() == "o": | ||
| 1002 | if args.learning_rate_unet is not None: | ||
| 1003 | learning_rate_unet = ( | ||
| 1004 | args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) | ||
| 1005 | ) | ||
| 1006 | if args.learning_rate_text is not None: | ||
| 1007 | learning_rate_text = ( | ||
| 1008 | args.learning_rate_text * 2 * (args.cycle_decay**training_iter) | ||
| 1009 | ) | ||
| 1010 | else: | ||
| 1011 | learning_rate_unet = args.learning_rate_unet * ( | ||
| 1012 | args.cycle_decay**training_iter | ||
| 1013 | ) | ||
| 1014 | learning_rate_text = args.learning_rate_text * ( | ||
| 1015 | args.cycle_decay**training_iter | ||
| 1016 | ) | ||
| 1017 | |||
| 1018 | if response.lower().strip() == "o": | ||
| 1019 | lr_scheduler = "one_cycle" | ||
| 1020 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 1021 | lr_cycles = args.lr_cycles | ||
| 1022 | elif response.lower().strip() == "w": | ||
| 1023 | lr_scheduler = "constant_with_warmup" | ||
| 1024 | lr_warmup_epochs = num_train_epochs | ||
| 1025 | elif response.lower().strip() == "c": | ||
| 1026 | lr_scheduler = "constant" | ||
| 1027 | elif response.lower().strip() == "d": | ||
| 1028 | lr_scheduler = "cosine" | ||
| 1029 | lr_warmup_epochs = 0 | ||
| 1030 | lr_cycles = 1 | ||
| 1031 | elif response.lower().strip() == "s": | ||
| 1032 | break | ||
| 1033 | else: | ||
| 1034 | continue | ||
| 1035 | |||
| 1036 | print("") | ||
| 1037 | print( | ||
| 1038 | f"============ Dreambooth cycle {training_iter + 1}: {response} ============" | ||
| 1039 | ) | ||
| 1040 | print("") | ||
| 1041 | |||
| 1042 | for group, lr in zip( | ||
| 1043 | dreambooth_optimizer.param_groups, [learning_rate_unet, learning_rate_text] | ||
| 1044 | ): | ||
| 1045 | group["lr"] = lr | ||
| 1046 | |||
| 1047 | dreambooth_lr_scheduler = create_lr_scheduler( | ||
| 1048 | lr_scheduler, | ||
| 1049 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 1050 | optimizer=dreambooth_optimizer, | ||
| 1051 | num_training_steps_per_epoch=len(dreambooth_datamodule.train_dataloader), | ||
| 1052 | train_epochs=num_train_epochs, | ||
| 1053 | cycles=lr_cycles, | ||
| 1054 | warmup_epochs=lr_warmup_epochs, | ||
| 1055 | ) | ||
| 1056 | |||
| 1057 | dreambooth_checkpoint_output_dir = ( | ||
| 1058 | output_dir / dreambooth_project / f"model_{training_iter}" | ||
| 1059 | ) | ||
| 1060 | |||
| 1061 | trainer( | ||
| 1062 | strategy=dreambooth_strategy, | ||
| 1063 | train_dataloader=dreambooth_datamodule.train_dataloader, | ||
| 1064 | val_dataloader=dreambooth_datamodule.val_dataloader, | ||
| 1065 | optimizer=dreambooth_optimizer, | ||
| 1066 | lr_scheduler=dreambooth_lr_scheduler, | ||
| 1067 | num_train_epochs=num_train_epochs, | ||
| 1068 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 1069 | global_step_offset=training_iter * num_train_steps, | ||
| 1070 | cycle=training_iter, | ||
| 1071 | train_text_encoder_cycles=args.train_text_encoder_cycles, | ||
| 1072 | # -- | ||
| 1073 | group_labels=group_labels, | ||
| 1074 | sample_output_dir=dreambooth_sample_output_dir, | ||
| 1075 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | ||
| 1076 | sample_frequency=dreambooth_sample_frequency, | ||
| 1077 | offset_noise_strength=args.offset_noise_strength, | ||
| 1078 | input_pertubation=args.input_pertubation, | ||
| 1079 | no_val=args.valid_set_size == 0, | ||
| 1080 | avg_loss=avg_loss, | ||
| 1081 | avg_acc=avg_acc, | ||
| 1082 | avg_loss_val=avg_loss_val, | ||
| 1083 | avg_acc_val=avg_acc_val, | ||
| 1084 | ) | ||
| 1085 | |||
| 1086 | training_iter += 1 | ||
| 1087 | |||
| 1088 | accelerator.end_training() | ||
| 723 | 1089 | ||
| 724 | 1090 | ||
| 725 | if __name__ == "__main__": | 1091 | if __name__ == "__main__": |
diff --git a/train_lora.py b/train_lora.py index c74dd8f..fccf48d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -16,6 +16,7 @@ from accelerate import Accelerator | |||
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from peft import LoraConfig, get_peft_model | 18 | from peft import LoraConfig, get_peft_model |
| 19 | |||
| 19 | # from diffusers.models.attention_processor import AttnProcessor | 20 | # from diffusers.models.attention_processor import AttnProcessor |
| 20 | from diffusers.utils.import_utils import is_xformers_available | 21 | from diffusers.utils.import_utils import is_xformers_available |
| 21 | import transformers | 22 | import transformers |
| @@ -34,15 +35,20 @@ from util.files import load_config, load_embeddings_from_dir | |||
| 34 | 35 | ||
| 35 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 36 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
| 36 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] | 37 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] |
| 37 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"] | 38 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0", "to_k", "key"] # [] |
| 38 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] | 39 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] |
| 39 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] | 40 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + [ |
| 40 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | 41 | "out_proj", |
| 42 | "k_proj", | ||
| 43 | ] # [] | ||
| 44 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + [ | ||
| 45 | "token_embedding" | ||
| 46 | ] | ||
| 41 | 47 | ||
| 42 | 48 | ||
| 43 | logger = get_logger(__name__) | 49 | logger = get_logger(__name__) |
| 44 | 50 | ||
| 45 | warnings.filterwarnings('ignore') | 51 | warnings.filterwarnings("ignore") |
| 46 | 52 | ||
| 47 | 53 | ||
| 48 | torch.backends.cuda.matmul.allow_tf32 = True | 54 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -55,20 +61,27 @@ hidet.torch.dynamo_config.use_tensor_core(True) | |||
| 55 | hidet.torch.dynamo_config.search_space(0) | 61 | hidet.torch.dynamo_config.search_space(0) |
| 56 | 62 | ||
| 57 | 63 | ||
| 58 | if is_xformers_available(): | 64 | def patch_xformers(dtype): |
| 59 | import xformers | 65 | if is_xformers_available(): |
| 60 | import xformers.ops | 66 | import xformers |
| 61 | 67 | import xformers.ops | |
| 62 | orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention | 68 | |
| 63 | def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs): | 69 | orig_xformers_memory_efficient_attention = ( |
| 64 | return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs) | 70 | xformers.ops.memory_efficient_attention |
| 65 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | 71 | ) |
| 72 | |||
| 73 | def xformers_memory_efficient_attention( | ||
| 74 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
| 75 | ): | ||
| 76 | return orig_xformers_memory_efficient_attention( | ||
| 77 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
| 78 | ) | ||
| 79 | |||
| 80 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
| 66 | 81 | ||
| 67 | 82 | ||
| 68 | def parse_args(): | 83 | def parse_args(): |
| 69 | parser = argparse.ArgumentParser( | 84 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 70 | description="Simple example of a training script." | ||
| 71 | ) | ||
| 72 | parser.add_argument( | 85 | parser.add_argument( |
| 73 | "--pretrained_model_name_or_path", | 86 | "--pretrained_model_name_or_path", |
| 74 | type=str, | 87 | type=str, |
| @@ -85,7 +98,7 @@ def parse_args(): | |||
| 85 | "--train_data_file", | 98 | "--train_data_file", |
| 86 | type=str, | 99 | type=str, |
| 87 | default=None, | 100 | default=None, |
| 88 | help="A folder containing the training data." | 101 | help="A folder containing the training data.", |
| 89 | ) | 102 | ) |
| 90 | parser.add_argument( | 103 | parser.add_argument( |
| 91 | "--train_data_template", | 104 | "--train_data_template", |
| @@ -96,13 +109,13 @@ def parse_args(): | |||
| 96 | "--train_set_pad", | 109 | "--train_set_pad", |
| 97 | type=int, | 110 | type=int, |
| 98 | default=None, | 111 | default=None, |
| 99 | help="The number to fill train dataset items up to." | 112 | help="The number to fill train dataset items up to.", |
| 100 | ) | 113 | ) |
| 101 | parser.add_argument( | 114 | parser.add_argument( |
| 102 | "--valid_set_pad", | 115 | "--valid_set_pad", |
| 103 | type=int, | 116 | type=int, |
| 104 | default=None, | 117 | default=None, |
| 105 | help="The number to fill validation dataset items up to." | 118 | help="The number to fill validation dataset items up to.", |
| 106 | ) | 119 | ) |
| 107 | parser.add_argument( | 120 | parser.add_argument( |
| 108 | "--project", | 121 | "--project", |
| @@ -111,64 +124,52 @@ def parse_args(): | |||
| 111 | help="The name of the current project.", | 124 | help="The name of the current project.", |
| 112 | ) | 125 | ) |
| 113 | parser.add_argument( | 126 | parser.add_argument( |
| 114 | "--auto_cycles", | 127 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
| 115 | type=str, | ||
| 116 | default="o", | ||
| 117 | help="Cycles to run automatically." | ||
| 118 | ) | 128 | ) |
| 119 | parser.add_argument( | 129 | parser.add_argument( |
| 120 | "--cycle_decay", | 130 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." |
| 121 | type=float, | ||
| 122 | default=1.0, | ||
| 123 | help="Learning rate decay per cycle." | ||
| 124 | ) | 131 | ) |
| 125 | parser.add_argument( | 132 | parser.add_argument( |
| 126 | "--placeholder_tokens", | 133 | "--placeholder_tokens", |
| 127 | type=str, | 134 | type=str, |
| 128 | nargs='*', | 135 | nargs="*", |
| 129 | help="A token to use as a placeholder for the concept.", | 136 | help="A token to use as a placeholder for the concept.", |
| 130 | ) | 137 | ) |
| 131 | parser.add_argument( | 138 | parser.add_argument( |
| 132 | "--initializer_tokens", | 139 | "--initializer_tokens", |
| 133 | type=str, | 140 | type=str, |
| 134 | nargs='*', | 141 | nargs="*", |
| 135 | help="A token to use as initializer word." | 142 | help="A token to use as initializer word.", |
| 136 | ) | 143 | ) |
| 137 | parser.add_argument( | 144 | parser.add_argument( |
| 138 | "--filter_tokens", | 145 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." |
| 139 | type=str, | ||
| 140 | nargs='*', | ||
| 141 | help="Tokens to filter the dataset by." | ||
| 142 | ) | 146 | ) |
| 143 | parser.add_argument( | 147 | parser.add_argument( |
| 144 | "--initializer_noise", | 148 | "--initializer_noise", |
| 145 | type=float, | 149 | type=float, |
| 146 | default=0, | 150 | default=0, |
| 147 | help="Noise to apply to the initializer word" | 151 | help="Noise to apply to the initializer word", |
| 148 | ) | 152 | ) |
| 149 | parser.add_argument( | 153 | parser.add_argument( |
| 150 | "--alias_tokens", | 154 | "--alias_tokens", |
| 151 | type=str, | 155 | type=str, |
| 152 | nargs='*', | 156 | nargs="*", |
| 153 | default=[], | 157 | default=[], |
| 154 | help="Tokens to create an alias for." | 158 | help="Tokens to create an alias for.", |
| 155 | ) | 159 | ) |
| 156 | parser.add_argument( | 160 | parser.add_argument( |
| 157 | "--inverted_initializer_tokens", | 161 | "--inverted_initializer_tokens", |
| 158 | type=str, | 162 | type=str, |
| 159 | nargs='*', | 163 | nargs="*", |
| 160 | help="A token to use as initializer word." | 164 | help="A token to use as initializer word.", |
| 161 | ) | 165 | ) |
| 162 | parser.add_argument( | 166 | parser.add_argument( |
| 163 | "--num_vectors", | 167 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
| 164 | type=int, | ||
| 165 | nargs='*', | ||
| 166 | help="Number of vectors per embedding." | ||
| 167 | ) | 168 | ) |
| 168 | parser.add_argument( | 169 | parser.add_argument( |
| 169 | "--exclude_collections", | 170 | "--exclude_collections", |
| 170 | type=str, | 171 | type=str, |
| 171 | nargs='*', | 172 | nargs="*", |
| 172 | help="Exclude all items with a listed collection.", | 173 | help="Exclude all items with a listed collection.", |
| 173 | ) | 174 | ) |
| 174 | parser.add_argument( | 175 | parser.add_argument( |
| @@ -214,7 +215,7 @@ def parse_args(): | |||
| 214 | "--num_class_images", | 215 | "--num_class_images", |
| 215 | type=int, | 216 | type=int, |
| 216 | default=0, | 217 | default=0, |
| 217 | help="How many class images to generate." | 218 | help="How many class images to generate.", |
| 218 | ) | 219 | ) |
| 219 | parser.add_argument( | 220 | parser.add_argument( |
| 220 | "--class_image_dir", | 221 | "--class_image_dir", |
| @@ -242,14 +243,11 @@ def parse_args(): | |||
| 242 | parser.add_argument( | 243 | parser.add_argument( |
| 243 | "--collection", | 244 | "--collection", |
| 244 | type=str, | 245 | type=str, |
| 245 | nargs='*', | 246 | nargs="*", |
| 246 | help="A collection to filter the dataset.", | 247 | help="A collection to filter the dataset.", |
| 247 | ) | 248 | ) |
| 248 | parser.add_argument( | 249 | parser.add_argument( |
| 249 | "--seed", | 250 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 250 | type=int, | ||
| 251 | default=None, | ||
| 252 | help="A seed for reproducible training." | ||
| 253 | ) | 251 | ) |
| 254 | parser.add_argument( | 252 | parser.add_argument( |
| 255 | "--resolution", | 253 | "--resolution", |
| @@ -270,18 +268,10 @@ def parse_args(): | |||
| 270 | "--input_pertubation", | 268 | "--input_pertubation", |
| 271 | type=float, | 269 | type=float, |
| 272 | default=0, | 270 | default=0, |
| 273 | help="The scale of input pretubation. Recommended 0.1." | 271 | help="The scale of input pretubation. Recommended 0.1.", |
| 274 | ) | ||
| 275 | parser.add_argument( | ||
| 276 | "--num_train_epochs", | ||
| 277 | type=int, | ||
| 278 | default=None | ||
| 279 | ) | ||
| 280 | parser.add_argument( | ||
| 281 | "--num_train_steps", | ||
| 282 | type=int, | ||
| 283 | default=2000 | ||
| 284 | ) | 272 | ) |
| 273 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
| 274 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
| 285 | parser.add_argument( | 275 | parser.add_argument( |
| 286 | "--gradient_accumulation_steps", | 276 | "--gradient_accumulation_steps", |
| 287 | type=int, | 277 | type=int, |
| @@ -289,22 +279,19 @@ def parse_args(): | |||
| 289 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 279 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 290 | ) | 280 | ) |
| 291 | parser.add_argument( | 281 | parser.add_argument( |
| 292 | "--lora_r", | 282 | "--lora_r", type=int, default=8, help="Lora rank, only used if use_lora is True" |
| 293 | type=int, | ||
| 294 | default=8, | ||
| 295 | help="Lora rank, only used if use_lora is True" | ||
| 296 | ) | 283 | ) |
| 297 | parser.add_argument( | 284 | parser.add_argument( |
| 298 | "--lora_alpha", | 285 | "--lora_alpha", |
| 299 | type=int, | 286 | type=int, |
| 300 | default=32, | 287 | default=32, |
| 301 | help="Lora alpha, only used if use_lora is True" | 288 | help="Lora alpha, only used if use_lora is True", |
| 302 | ) | 289 | ) |
| 303 | parser.add_argument( | 290 | parser.add_argument( |
| 304 | "--lora_dropout", | 291 | "--lora_dropout", |
| 305 | type=float, | 292 | type=float, |
| 306 | default=0.0, | 293 | default=0.0, |
| 307 | help="Lora dropout, only used if use_lora is True" | 294 | help="Lora dropout, only used if use_lora is True", |
| 308 | ) | 295 | ) |
| 309 | parser.add_argument( | 296 | parser.add_argument( |
| 310 | "--lora_bias", | 297 | "--lora_bias", |
| @@ -344,7 +331,7 @@ def parse_args(): | |||
| 344 | parser.add_argument( | 331 | parser.add_argument( |
| 345 | "--train_text_encoder_cycles", | 332 | "--train_text_encoder_cycles", |
| 346 | default=999999, | 333 | default=999999, |
| 347 | help="Number of epochs the text encoder will be trained." | 334 | help="Number of epochs the text encoder will be trained.", |
| 348 | ) | 335 | ) |
| 349 | parser.add_argument( | 336 | parser.add_argument( |
| 350 | "--find_lr", | 337 | "--find_lr", |
| @@ -378,27 +365,31 @@ def parse_args(): | |||
| 378 | "--lr_scheduler", | 365 | "--lr_scheduler", |
| 379 | type=str, | 366 | type=str, |
| 380 | default="one_cycle", | 367 | default="one_cycle", |
| 381 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 368 | choices=[ |
| 382 | "constant", "constant_with_warmup", "one_cycle"], | 369 | "linear", |
| 383 | help='The scheduler type to use.', | 370 | "cosine", |
| 371 | "cosine_with_restarts", | ||
| 372 | "polynomial", | ||
| 373 | "constant", | ||
| 374 | "constant_with_warmup", | ||
| 375 | "one_cycle", | ||
| 376 | ], | ||
| 377 | help="The scheduler type to use.", | ||
| 384 | ) | 378 | ) |
| 385 | parser.add_argument( | 379 | parser.add_argument( |
| 386 | "--lr_warmup_epochs", | 380 | "--lr_warmup_epochs", |
| 387 | type=int, | 381 | type=int, |
| 388 | default=10, | 382 | default=10, |
| 389 | help="Number of steps for the warmup in the lr scheduler." | 383 | help="Number of steps for the warmup in the lr scheduler.", |
| 390 | ) | 384 | ) |
| 391 | parser.add_argument( | 385 | parser.add_argument( |
| 392 | "--lr_mid_point", | 386 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
| 393 | type=float, | ||
| 394 | default=0.3, | ||
| 395 | help="OneCycle schedule mid point." | ||
| 396 | ) | 387 | ) |
| 397 | parser.add_argument( | 388 | parser.add_argument( |
| 398 | "--lr_cycles", | 389 | "--lr_cycles", |
| 399 | type=int, | 390 | type=int, |
| 400 | default=None, | 391 | default=None, |
| 401 | help="Number of restart cycles in the lr scheduler (if supported)." | 392 | help="Number of restart cycles in the lr scheduler (if supported).", |
| 402 | ) | 393 | ) |
| 403 | parser.add_argument( | 394 | parser.add_argument( |
| 404 | "--lr_warmup_func", | 395 | "--lr_warmup_func", |
| @@ -410,7 +401,7 @@ def parse_args(): | |||
| 410 | "--lr_warmup_exp", | 401 | "--lr_warmup_exp", |
| 411 | type=int, | 402 | type=int, |
| 412 | default=1, | 403 | default=1, |
| 413 | help='If lr_warmup_func is "cos", exponent to modify the function' | 404 | help='If lr_warmup_func is "cos", exponent to modify the function', |
| 414 | ) | 405 | ) |
| 415 | parser.add_argument( | 406 | parser.add_argument( |
| 416 | "--lr_annealing_func", | 407 | "--lr_annealing_func", |
| @@ -422,69 +413,76 @@ def parse_args(): | |||
| 422 | "--lr_annealing_exp", | 413 | "--lr_annealing_exp", |
| 423 | type=int, | 414 | type=int, |
| 424 | default=3, | 415 | default=3, |
| 425 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 416 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
| 426 | ) | 417 | ) |
| 427 | parser.add_argument( | 418 | parser.add_argument( |
| 428 | "--lr_min_lr", | 419 | "--lr_min_lr", |
| 429 | type=float, | 420 | type=float, |
| 430 | default=0.04, | 421 | default=0.04, |
| 431 | help="Minimum learning rate in the lr scheduler." | 422 | help="Minimum learning rate in the lr scheduler.", |
| 432 | ) | ||
| 433 | parser.add_argument( | ||
| 434 | "--min_snr_gamma", | ||
| 435 | type=int, | ||
| 436 | default=5, | ||
| 437 | help="MinSNR gamma." | ||
| 438 | ) | 423 | ) |
| 424 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
| 439 | parser.add_argument( | 425 | parser.add_argument( |
| 440 | "--schedule_sampler", | 426 | "--schedule_sampler", |
| 441 | type=str, | 427 | type=str, |
| 442 | default="uniform", | 428 | default="uniform", |
| 443 | choices=["uniform", "loss-second-moment"], | 429 | choices=["uniform", "loss-second-moment"], |
| 444 | help="Noise schedule sampler." | 430 | help="Noise schedule sampler.", |
| 445 | ) | 431 | ) |
| 446 | parser.add_argument( | 432 | parser.add_argument( |
| 447 | "--optimizer", | 433 | "--optimizer", |
| 448 | type=str, | 434 | type=str, |
| 449 | default="adan", | 435 | default="adan", |
| 450 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 436 | choices=[ |
| 451 | help='Optimizer to use' | 437 | "adam", |
| 438 | "adam8bit", | ||
| 439 | "adan", | ||
| 440 | "lion", | ||
| 441 | "dadam", | ||
| 442 | "dadan", | ||
| 443 | "dlion", | ||
| 444 | "adafactor", | ||
| 445 | ], | ||
| 446 | help="Optimizer to use", | ||
| 452 | ) | 447 | ) |
| 453 | parser.add_argument( | 448 | parser.add_argument( |
| 454 | "--dadaptation_d0", | 449 | "--dadaptation_d0", |
| 455 | type=float, | 450 | type=float, |
| 456 | default=1e-6, | 451 | default=1e-6, |
| 457 | help="The d0 parameter for Dadaptation optimizers." | 452 | help="The d0 parameter for Dadaptation optimizers.", |
| 453 | ) | ||
| 454 | parser.add_argument( | ||
| 455 | "--dadaptation_growth_rate", | ||
| 456 | type=float, | ||
| 457 | default=math.inf, | ||
| 458 | help="The growth_rate parameter for Dadaptation optimizers.", | ||
| 458 | ) | 459 | ) |
| 459 | parser.add_argument( | 460 | parser.add_argument( |
| 460 | "--adam_beta1", | 461 | "--adam_beta1", |
| 461 | type=float, | 462 | type=float, |
| 462 | default=None, | 463 | default=None, |
| 463 | help="The beta1 parameter for the Adam optimizer." | 464 | help="The beta1 parameter for the Adam optimizer.", |
| 464 | ) | 465 | ) |
| 465 | parser.add_argument( | 466 | parser.add_argument( |
| 466 | "--adam_beta2", | 467 | "--adam_beta2", |
| 467 | type=float, | 468 | type=float, |
| 468 | default=None, | 469 | default=None, |
| 469 | help="The beta2 parameter for the Adam optimizer." | 470 | help="The beta2 parameter for the Adam optimizer.", |
| 470 | ) | 471 | ) |
| 471 | parser.add_argument( | 472 | parser.add_argument( |
| 472 | "--adam_weight_decay", | 473 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
| 473 | type=float, | ||
| 474 | default=2e-2, | ||
| 475 | help="Weight decay to use." | ||
| 476 | ) | 474 | ) |
| 477 | parser.add_argument( | 475 | parser.add_argument( |
| 478 | "--adam_epsilon", | 476 | "--adam_epsilon", |
| 479 | type=float, | 477 | type=float, |
| 480 | default=1e-08, | 478 | default=1e-08, |
| 481 | help="Epsilon value for the Adam optimizer" | 479 | help="Epsilon value for the Adam optimizer", |
| 482 | ) | 480 | ) |
| 483 | parser.add_argument( | 481 | parser.add_argument( |
| 484 | "--adam_amsgrad", | 482 | "--adam_amsgrad", |
| 485 | type=bool, | 483 | type=bool, |
| 486 | default=False, | 484 | default=False, |
| 487 | help="Amsgrad value for the Adam optimizer" | 485 | help="Amsgrad value for the Adam optimizer", |
| 488 | ) | 486 | ) |
| 489 | parser.add_argument( | 487 | parser.add_argument( |
| 490 | "--mixed_precision", | 488 | "--mixed_precision", |
| @@ -547,19 +545,19 @@ def parse_args(): | |||
| 547 | "--valid_set_size", | 545 | "--valid_set_size", |
| 548 | type=int, | 546 | type=int, |
| 549 | default=None, | 547 | default=None, |
| 550 | help="Number of images in the validation dataset." | 548 | help="Number of images in the validation dataset.", |
| 551 | ) | 549 | ) |
| 552 | parser.add_argument( | 550 | parser.add_argument( |
| 553 | "--valid_set_repeat", | 551 | "--valid_set_repeat", |
| 554 | type=int, | 552 | type=int, |
| 555 | default=1, | 553 | default=1, |
| 556 | help="Times the images in the validation dataset are repeated." | 554 | help="Times the images in the validation dataset are repeated.", |
| 557 | ) | 555 | ) |
| 558 | parser.add_argument( | 556 | parser.add_argument( |
| 559 | "--train_batch_size", | 557 | "--train_batch_size", |
| 560 | type=int, | 558 | type=int, |
| 561 | default=1, | 559 | default=1, |
| 562 | help="Batch size (per device) for the training dataloader." | 560 | help="Batch size (per device) for the training dataloader.", |
| 563 | ) | 561 | ) |
| 564 | parser.add_argument( | 562 | parser.add_argument( |
| 565 | "--sample_steps", | 563 | "--sample_steps", |
| @@ -571,19 +569,10 @@ def parse_args(): | |||
| 571 | "--prior_loss_weight", | 569 | "--prior_loss_weight", |
| 572 | type=float, | 570 | type=float, |
| 573 | default=1.0, | 571 | default=1.0, |
| 574 | help="The weight of prior preservation loss." | 572 | help="The weight of prior preservation loss.", |
| 575 | ) | ||
| 576 | parser.add_argument( | ||
| 577 | "--run_pti", | ||
| 578 | action="store_true", | ||
| 579 | help="Whether to run PTI." | ||
| 580 | ) | ||
| 581 | parser.add_argument( | ||
| 582 | "--emb_alpha", | ||
| 583 | type=float, | ||
| 584 | default=1.0, | ||
| 585 | help="Embedding alpha" | ||
| 586 | ) | 573 | ) |
| 574 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
| 575 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
| 587 | parser.add_argument( | 576 | parser.add_argument( |
| 588 | "--emb_dropout", | 577 | "--emb_dropout", |
| 589 | type=float, | 578 | type=float, |
| @@ -591,27 +580,16 @@ def parse_args(): | |||
| 591 | help="Embedding dropout probability.", | 580 | help="Embedding dropout probability.", |
| 592 | ) | 581 | ) |
| 593 | parser.add_argument( | 582 | parser.add_argument( |
| 594 | "--use_emb_decay", | 583 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." |
| 595 | action="store_true", | ||
| 596 | help="Whether to use embedding decay." | ||
| 597 | ) | 584 | ) |
| 598 | parser.add_argument( | 585 | parser.add_argument( |
| 599 | "--emb_decay_target", | 586 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." |
| 600 | default=0.4, | ||
| 601 | type=float, | ||
| 602 | help="Embedding decay target." | ||
| 603 | ) | 587 | ) |
| 604 | parser.add_argument( | 588 | parser.add_argument( |
| 605 | "--emb_decay", | 589 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." |
| 606 | default=1e+2, | ||
| 607 | type=float, | ||
| 608 | help="Embedding decay factor." | ||
| 609 | ) | 590 | ) |
| 610 | parser.add_argument( | 591 | parser.add_argument( |
| 611 | "--max_grad_norm", | 592 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." |
| 612 | default=1.0, | ||
| 613 | type=float, | ||
| 614 | help="Max gradient norm." | ||
| 615 | ) | 593 | ) |
| 616 | parser.add_argument( | 594 | parser.add_argument( |
| 617 | "--noise_timesteps", | 595 | "--noise_timesteps", |
| @@ -622,7 +600,7 @@ def parse_args(): | |||
| 622 | "--config", | 600 | "--config", |
| 623 | type=str, | 601 | type=str, |
| 624 | default=None, | 602 | default=None, |
| 625 | help="Path to a JSON configuration file containing arguments for invoking this script." | 603 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
| 626 | ) | 604 | ) |
| 627 | 605 | ||
| 628 | args = parser.parse_args() | 606 | args = parser.parse_args() |
| @@ -649,29 +627,44 @@ def parse_args(): | |||
| 649 | args.placeholder_tokens = [args.placeholder_tokens] | 627 | args.placeholder_tokens = [args.placeholder_tokens] |
| 650 | 628 | ||
| 651 | if isinstance(args.initializer_tokens, str): | 629 | if isinstance(args.initializer_tokens, str): |
| 652 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 630 | args.initializer_tokens = [args.initializer_tokens] * len( |
| 631 | args.placeholder_tokens | ||
| 632 | ) | ||
| 653 | 633 | ||
| 654 | if len(args.placeholder_tokens) == 0: | 634 | if len(args.placeholder_tokens) == 0: |
| 655 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 635 | args.placeholder_tokens = [ |
| 636 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
| 637 | ] | ||
| 656 | 638 | ||
| 657 | if len(args.initializer_tokens) == 0: | 639 | if len(args.initializer_tokens) == 0: |
| 658 | args.initializer_tokens = args.placeholder_tokens.copy() | 640 | args.initializer_tokens = args.placeholder_tokens.copy() |
| 659 | 641 | ||
| 660 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 642 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 661 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 643 | raise ValueError( |
| 644 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
| 645 | ) | ||
| 662 | 646 | ||
| 663 | if isinstance(args.inverted_initializer_tokens, str): | 647 | if isinstance(args.inverted_initializer_tokens, str): |
| 664 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | 648 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( |
| 649 | args.placeholder_tokens | ||
| 650 | ) | ||
| 665 | 651 | ||
| 666 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | 652 | if ( |
| 653 | isinstance(args.inverted_initializer_tokens, list) | ||
| 654 | and len(args.inverted_initializer_tokens) != 0 | ||
| 655 | ): | ||
| 667 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | 656 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] |
| 668 | args.initializer_tokens += args.inverted_initializer_tokens | 657 | args.initializer_tokens += args.inverted_initializer_tokens |
| 669 | 658 | ||
| 670 | if isinstance(args.num_vectors, int): | 659 | if isinstance(args.num_vectors, int): |
| 671 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 660 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 672 | 661 | ||
| 673 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 662 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( |
| 674 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 663 | args.num_vectors |
| 664 | ): | ||
| 665 | raise ValueError( | ||
| 666 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
| 667 | ) | ||
| 675 | 668 | ||
| 676 | if args.alias_tokens is None: | 669 | if args.alias_tokens is None: |
| 677 | args.alias_tokens = [] | 670 | args.alias_tokens = [] |
| @@ -695,15 +688,15 @@ def parse_args(): | |||
| 695 | raise ValueError("You must specify --output_dir") | 688 | raise ValueError("You must specify --output_dir") |
| 696 | 689 | ||
| 697 | if args.adam_beta1 is None: | 690 | if args.adam_beta1 is None: |
| 698 | if args.optimizer in ('adam', 'adam8bit'): | 691 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
| 699 | args.adam_beta1 = 0.9 | 692 | args.adam_beta1 = 0.9 |
| 700 | elif args.optimizer == 'lion': | 693 | elif args.optimizer in ("lion", "dlion"): |
| 701 | args.adam_beta1 = 0.95 | 694 | args.adam_beta1 = 0.95 |
| 702 | 695 | ||
| 703 | if args.adam_beta2 is None: | 696 | if args.adam_beta2 is None: |
| 704 | if args.optimizer in ('adam', 'adam8bit'): | 697 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
| 705 | args.adam_beta2 = 0.999 | 698 | args.adam_beta2 = 0.999 |
| 706 | elif args.optimizer == 'lion': | 699 | elif args.optimizer in ("lion", "dlion"): |
| 707 | args.adam_beta2 = 0.98 | 700 | args.adam_beta2 = 0.98 |
| 708 | 701 | ||
| 709 | return args | 702 | return args |
| @@ -719,7 +712,7 @@ def main(): | |||
| 719 | accelerator = Accelerator( | 712 | accelerator = Accelerator( |
| 720 | log_with=LoggerType.TENSORBOARD, | 713 | log_with=LoggerType.TENSORBOARD, |
| 721 | project_dir=f"{output_dir}", | 714 | project_dir=f"{output_dir}", |
| 722 | mixed_precision=args.mixed_precision | 715 | mixed_precision=args.mixed_precision, |
| 723 | ) | 716 | ) |
| 724 | 717 | ||
| 725 | weight_dtype = torch.float32 | 718 | weight_dtype = torch.float32 |
| @@ -728,6 +721,8 @@ def main(): | |||
| 728 | elif args.mixed_precision == "bf16": | 721 | elif args.mixed_precision == "bf16": |
| 729 | weight_dtype = torch.bfloat16 | 722 | weight_dtype = torch.bfloat16 |
| 730 | 723 | ||
| 724 | patch_xformers(weight_dtype) | ||
| 725 | |||
| 731 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 726 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
| 732 | 727 | ||
| 733 | if args.seed is None: | 728 | if args.seed is None: |
| @@ -737,12 +732,18 @@ def main(): | |||
| 737 | 732 | ||
| 738 | save_args(output_dir, args) | 733 | save_args(output_dir, args) |
| 739 | 734 | ||
| 740 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) | 735 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
| 741 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 736 | args.pretrained_model_name_or_path |
| 742 | 737 | ) | |
| 738 | schedule_sampler = create_named_schedule_sampler( | ||
| 739 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
| 740 | ) | ||
| 741 | |||
| 743 | def ensure_embeddings(): | 742 | def ensure_embeddings(): |
| 744 | if args.lora_text_encoder_emb: | 743 | if args.lora_text_encoder_emb: |
| 745 | raise ValueError("Can't use TI options when training token embeddings with LoRA") | 744 | raise ValueError( |
| 745 | "Can't use TI options when training token embeddings with LoRA" | ||
| 746 | ) | ||
| 746 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | 747 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) |
| 747 | 748 | ||
| 748 | unet_config = LoraConfig( | 749 | unet_config = LoraConfig( |
| @@ -757,7 +758,9 @@ def main(): | |||
| 757 | text_encoder_config = LoraConfig( | 758 | text_encoder_config = LoraConfig( |
| 758 | r=args.lora_text_encoder_r, | 759 | r=args.lora_text_encoder_r, |
| 759 | lora_alpha=args.lora_text_encoder_alpha, | 760 | lora_alpha=args.lora_text_encoder_alpha, |
| 760 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, | 761 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING |
| 762 | if args.lora_text_encoder_emb | ||
| 763 | else TEXT_ENCODER_TARGET_MODULES, | ||
| 761 | lora_dropout=args.lora_text_encoder_dropout, | 764 | lora_dropout=args.lora_text_encoder_dropout, |
| 762 | bias=args.lora_text_encoder_bias, | 765 | bias=args.lora_text_encoder_bias, |
| 763 | ) | 766 | ) |
| @@ -787,7 +790,7 @@ def main(): | |||
| 787 | 790 | ||
| 788 | if len(args.alias_tokens) != 0: | 791 | if len(args.alias_tokens) != 0: |
| 789 | embeddings = ensure_embeddings() | 792 | embeddings = ensure_embeddings() |
| 790 | 793 | ||
| 791 | alias_placeholder_tokens = args.alias_tokens[::2] | 794 | alias_placeholder_tokens = args.alias_tokens[::2] |
| 792 | alias_initializer_tokens = args.alias_tokens[1::2] | 795 | alias_initializer_tokens = args.alias_tokens[1::2] |
| 793 | 796 | ||
| @@ -795,27 +798,33 @@ def main(): | |||
| 795 | tokenizer=tokenizer, | 798 | tokenizer=tokenizer, |
| 796 | embeddings=embeddings, | 799 | embeddings=embeddings, |
| 797 | placeholder_tokens=alias_placeholder_tokens, | 800 | placeholder_tokens=alias_placeholder_tokens, |
| 798 | initializer_tokens=alias_initializer_tokens | 801 | initializer_tokens=alias_initializer_tokens, |
| 799 | ) | 802 | ) |
| 800 | embeddings.persist() | 803 | embeddings.persist() |
| 801 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 804 | print( |
| 805 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
| 806 | ) | ||
| 802 | 807 | ||
| 803 | placeholder_tokens = [] | 808 | placeholder_tokens = [] |
| 804 | placeholder_token_ids = [] | 809 | placeholder_token_ids = [] |
| 805 | 810 | ||
| 806 | if args.embeddings_dir is not None: | 811 | if args.embeddings_dir is not None: |
| 807 | embeddings = ensure_embeddings() | 812 | embeddings = ensure_embeddings() |
| 808 | 813 | ||
| 809 | embeddings_dir = Path(args.embeddings_dir) | 814 | embeddings_dir = Path(args.embeddings_dir) |
| 810 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 815 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 811 | raise ValueError("--embeddings_dir must point to an existing directory") | 816 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 812 | 817 | ||
| 813 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 818 | added_tokens, added_ids = load_embeddings_from_dir( |
| 819 | tokenizer, embeddings, embeddings_dir | ||
| 820 | ) | ||
| 814 | 821 | ||
| 815 | placeholder_tokens = added_tokens | 822 | placeholder_tokens = added_tokens |
| 816 | placeholder_token_ids = added_ids | 823 | placeholder_token_ids = added_ids |
| 817 | 824 | ||
| 818 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 825 | print( |
| 826 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
| 827 | ) | ||
| 819 | 828 | ||
| 820 | if args.train_dir_embeddings: | 829 | if args.train_dir_embeddings: |
| 821 | print("Training embeddings from embeddings dir") | 830 | print("Training embeddings from embeddings dir") |
| @@ -824,7 +833,7 @@ def main(): | |||
| 824 | 833 | ||
| 825 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 834 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
| 826 | embeddings = ensure_embeddings() | 835 | embeddings = ensure_embeddings() |
| 827 | 836 | ||
| 828 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 837 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 829 | tokenizer=tokenizer, | 838 | tokenizer=tokenizer, |
| 830 | embeddings=embeddings, | 839 | embeddings=embeddings, |
| @@ -836,23 +845,34 @@ def main(): | |||
| 836 | 845 | ||
| 837 | placeholder_tokens = args.placeholder_tokens | 846 | placeholder_tokens = args.placeholder_tokens |
| 838 | 847 | ||
| 839 | stats = list(zip( | 848 | stats = list( |
| 840 | placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | 849 | zip( |
| 841 | )) | 850 | placeholder_tokens, |
| 851 | placeholder_token_ids, | ||
| 852 | args.initializer_tokens, | ||
| 853 | initializer_token_ids, | ||
| 854 | ) | ||
| 855 | ) | ||
| 842 | print(f"Training embeddings: {stats}") | 856 | print(f"Training embeddings: {stats}") |
| 843 | 857 | ||
| 844 | if args.scale_lr: | 858 | if args.scale_lr: |
| 845 | args.learning_rate_unet = ( | 859 | args.learning_rate_unet = ( |
| 846 | args.learning_rate_unet * args.gradient_accumulation_steps * | 860 | args.learning_rate_unet |
| 847 | args.train_batch_size * accelerator.num_processes | 861 | * args.gradient_accumulation_steps |
| 862 | * args.train_batch_size | ||
| 863 | * accelerator.num_processes | ||
| 848 | ) | 864 | ) |
| 849 | args.learning_rate_text = ( | 865 | args.learning_rate_text = ( |
| 850 | args.learning_rate_text * args.gradient_accumulation_steps * | 866 | args.learning_rate_text |
| 851 | args.train_batch_size * accelerator.num_processes | 867 | * args.gradient_accumulation_steps |
| 868 | * args.train_batch_size | ||
| 869 | * accelerator.num_processes | ||
| 852 | ) | 870 | ) |
| 853 | args.learning_rate_emb = ( | 871 | args.learning_rate_emb = ( |
| 854 | args.learning_rate_emb * args.gradient_accumulation_steps * | 872 | args.learning_rate_emb |
| 855 | args.train_batch_size * accelerator.num_processes | 873 | * args.gradient_accumulation_steps |
| 874 | * args.train_batch_size | ||
| 875 | * accelerator.num_processes | ||
| 856 | ) | 876 | ) |
| 857 | 877 | ||
| 858 | if args.find_lr: | 878 | if args.find_lr: |
| @@ -861,11 +881,13 @@ def main(): | |||
| 861 | args.learning_rate_emb = 1e-6 | 881 | args.learning_rate_emb = 1e-6 |
| 862 | args.lr_scheduler = "exponential_growth" | 882 | args.lr_scheduler = "exponential_growth" |
| 863 | 883 | ||
| 864 | if args.optimizer == 'adam8bit': | 884 | if args.optimizer == "adam8bit": |
| 865 | try: | 885 | try: |
| 866 | import bitsandbytes as bnb | 886 | import bitsandbytes as bnb |
| 867 | except ImportError: | 887 | except ImportError: |
| 868 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 888 | raise ImportError( |
| 889 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
| 890 | ) | ||
| 869 | 891 | ||
| 870 | create_optimizer = partial( | 892 | create_optimizer = partial( |
| 871 | bnb.optim.AdamW8bit, | 893 | bnb.optim.AdamW8bit, |
| @@ -874,7 +896,7 @@ def main(): | |||
| 874 | eps=args.adam_epsilon, | 896 | eps=args.adam_epsilon, |
| 875 | amsgrad=args.adam_amsgrad, | 897 | amsgrad=args.adam_amsgrad, |
| 876 | ) | 898 | ) |
| 877 | elif args.optimizer == 'adam': | 899 | elif args.optimizer == "adam": |
| 878 | create_optimizer = partial( | 900 | create_optimizer = partial( |
| 879 | torch.optim.AdamW, | 901 | torch.optim.AdamW, |
| 880 | betas=(args.adam_beta1, args.adam_beta2), | 902 | betas=(args.adam_beta1, args.adam_beta2), |
| @@ -882,11 +904,13 @@ def main(): | |||
| 882 | eps=args.adam_epsilon, | 904 | eps=args.adam_epsilon, |
| 883 | amsgrad=args.adam_amsgrad, | 905 | amsgrad=args.adam_amsgrad, |
| 884 | ) | 906 | ) |
| 885 | elif args.optimizer == 'adan': | 907 | elif args.optimizer == "adan": |
| 886 | try: | 908 | try: |
| 887 | import timm.optim | 909 | import timm.optim |
| 888 | except ImportError: | 910 | except ImportError: |
| 889 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 911 | raise ImportError( |
| 912 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
| 913 | ) | ||
| 890 | 914 | ||
| 891 | create_optimizer = partial( | 915 | create_optimizer = partial( |
| 892 | timm.optim.Adan, | 916 | timm.optim.Adan, |
| @@ -894,11 +918,13 @@ def main(): | |||
| 894 | eps=args.adam_epsilon, | 918 | eps=args.adam_epsilon, |
| 895 | no_prox=True, | 919 | no_prox=True, |
| 896 | ) | 920 | ) |
| 897 | elif args.optimizer == 'lion': | 921 | elif args.optimizer == "lion": |
| 898 | try: | 922 | try: |
| 899 | import lion_pytorch | 923 | import lion_pytorch |
| 900 | except ImportError: | 924 | except ImportError: |
| 901 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 925 | raise ImportError( |
| 926 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
| 927 | ) | ||
| 902 | 928 | ||
| 903 | create_optimizer = partial( | 929 | create_optimizer = partial( |
| 904 | lion_pytorch.Lion, | 930 | lion_pytorch.Lion, |
| @@ -906,7 +932,7 @@ def main(): | |||
| 906 | weight_decay=args.adam_weight_decay, | 932 | weight_decay=args.adam_weight_decay, |
| 907 | use_triton=True, | 933 | use_triton=True, |
| 908 | ) | 934 | ) |
| 909 | elif args.optimizer == 'adafactor': | 935 | elif args.optimizer == "adafactor": |
| 910 | create_optimizer = partial( | 936 | create_optimizer = partial( |
| 911 | transformers.optimization.Adafactor, | 937 | transformers.optimization.Adafactor, |
| 912 | weight_decay=args.adam_weight_decay, | 938 | weight_decay=args.adam_weight_decay, |
| @@ -920,11 +946,13 @@ def main(): | |||
| 920 | args.learning_rate_unet = None | 946 | args.learning_rate_unet = None |
| 921 | args.learning_rate_text = None | 947 | args.learning_rate_text = None |
| 922 | args.learning_rate_emb = None | 948 | args.learning_rate_emb = None |
| 923 | elif args.optimizer == 'dadam': | 949 | elif args.optimizer == "dadam": |
| 924 | try: | 950 | try: |
| 925 | import dadaptation | 951 | import dadaptation |
| 926 | except ImportError: | 952 | except ImportError: |
| 927 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 953 | raise ImportError( |
| 954 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
| 955 | ) | ||
| 928 | 956 | ||
| 929 | create_optimizer = partial( | 957 | create_optimizer = partial( |
| 930 | dadaptation.DAdaptAdam, | 958 | dadaptation.DAdaptAdam, |
| @@ -933,29 +961,35 @@ def main(): | |||
| 933 | eps=args.adam_epsilon, | 961 | eps=args.adam_epsilon, |
| 934 | decouple=True, | 962 | decouple=True, |
| 935 | d0=args.dadaptation_d0, | 963 | d0=args.dadaptation_d0, |
| 964 | growth_rate=args.dadaptation_growth_rate, | ||
| 936 | ) | 965 | ) |
| 937 | 966 | ||
| 938 | args.learning_rate_unet = 1.0 | 967 | args.learning_rate_unet = 1.0 |
| 939 | args.learning_rate_text = 1.0 | 968 | args.learning_rate_text = 1.0 |
| 940 | args.learning_rate_emb = 1.0 | 969 | args.learning_rate_emb = 1.0 |
| 941 | elif args.optimizer == 'dadan': | 970 | elif args.optimizer == "dadan": |
| 942 | try: | 971 | try: |
| 943 | import dadaptation | 972 | import dadaptation |
| 944 | except ImportError: | 973 | except ImportError: |
| 945 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 974 | raise ImportError( |
| 975 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 976 | ) | ||
| 946 | 977 | ||
| 947 | create_optimizer = partial( | 978 | create_optimizer = partial( |
| 948 | dadaptation.DAdaptAdan, | 979 | dadaptation.DAdaptAdan, |
| 949 | weight_decay=args.adam_weight_decay, | 980 | weight_decay=args.adam_weight_decay, |
| 950 | eps=args.adam_epsilon, | 981 | eps=args.adam_epsilon, |
| 951 | d0=args.dadaptation_d0, | 982 | d0=args.dadaptation_d0, |
| 983 | growth_rate=args.dadaptation_growth_rate, | ||
| 952 | ) | 984 | ) |
| 953 | 985 | ||
| 954 | args.learning_rate_unet = 1.0 | 986 | args.learning_rate_unet = 1.0 |
| 955 | args.learning_rate_text = 1.0 | 987 | args.learning_rate_text = 1.0 |
| 956 | args.learning_rate_emb = 1.0 | 988 | args.learning_rate_emb = 1.0 |
| 989 | elif args.optimizer == "dlion": | ||
| 990 | raise ImportError("DLion has not been merged into dadaptation yet") | ||
| 957 | else: | 991 | else: |
| 958 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 992 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 959 | 993 | ||
| 960 | trainer = partial( | 994 | trainer = partial( |
| 961 | train, | 995 | train, |
| @@ -1026,25 +1060,33 @@ def main(): | |||
| 1026 | 1060 | ||
| 1027 | if args.run_pti and len(placeholder_tokens) != 0: | 1061 | if args.run_pti and len(placeholder_tokens) != 0: |
| 1028 | embeddings = ensure_embeddings() | 1062 | embeddings = ensure_embeddings() |
| 1029 | 1063 | ||
| 1030 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 1064 | filter_tokens = [ |
| 1065 | token for token in args.filter_tokens if token in placeholder_tokens | ||
| 1066 | ] | ||
| 1031 | 1067 | ||
| 1032 | pti_datamodule = create_datamodule( | 1068 | pti_datamodule = create_datamodule( |
| 1033 | valid_set_size=0, | 1069 | valid_set_size=0, |
| 1034 | batch_size=args.train_batch_size, | 1070 | batch_size=args.train_batch_size, |
| 1035 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 1071 | filter=partial( |
| 1072 | keyword_filter, filter_tokens, args.collection, args.exclude_collections | ||
| 1073 | ), | ||
| 1036 | ) | 1074 | ) |
| 1037 | pti_datamodule.setup() | 1075 | pti_datamodule.setup() |
| 1038 | 1076 | ||
| 1039 | num_train_epochs = args.num_train_epochs | 1077 | num_train_epochs = args.num_train_epochs |
| 1040 | pti_sample_frequency = args.sample_frequency | 1078 | pti_sample_frequency = args.sample_frequency |
| 1041 | if num_train_epochs is None: | 1079 | if num_train_epochs is None: |
| 1042 | num_train_epochs = math.ceil( | 1080 | num_train_epochs = ( |
| 1043 | args.num_train_steps / len(pti_datamodule.train_dataset) | 1081 | math.ceil(args.num_train_steps / len(pti_datamodule.train_dataset)) |
| 1044 | ) * args.gradient_accumulation_steps | 1082 | * args.gradient_accumulation_steps |
| 1045 | pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) | 1083 | ) |
| 1084 | pti_sample_frequency = math.ceil( | ||
| 1085 | num_train_epochs * (pti_sample_frequency / args.num_train_steps) | ||
| 1086 | ) | ||
| 1046 | num_training_steps_per_epoch = math.ceil( | 1087 | num_training_steps_per_epoch = math.ceil( |
| 1047 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) | 1088 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps |
| 1089 | ) | ||
| 1048 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 1090 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
| 1049 | if args.sample_num is not None: | 1091 | if args.sample_num is not None: |
| 1050 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1092 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| @@ -1060,11 +1102,15 @@ def main(): | |||
| 1060 | print(f"============ PTI ============") | 1102 | print(f"============ PTI ============") |
| 1061 | print("") | 1103 | print("") |
| 1062 | 1104 | ||
| 1063 | pti_optimizer = create_optimizer([{ | 1105 | pti_optimizer = create_optimizer( |
| 1064 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | 1106 | [ |
| 1065 | "lr": args.learning_rate_emb, | 1107 | { |
| 1066 | "weight_decay": 0, | 1108 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 1067 | }]) | 1109 | "lr": args.learning_rate_emb, |
| 1110 | "weight_decay": 0, | ||
| 1111 | } | ||
| 1112 | ] | ||
| 1113 | ) | ||
| 1068 | 1114 | ||
| 1069 | pti_lr_scheduler = create_lr_scheduler( | 1115 | pti_lr_scheduler = create_lr_scheduler( |
| 1070 | "constant_with_warmup", | 1116 | "constant_with_warmup", |
| @@ -1113,11 +1159,16 @@ def main(): | |||
| 1113 | num_train_epochs = args.num_train_epochs | 1159 | num_train_epochs = args.num_train_epochs |
| 1114 | lora_sample_frequency = args.sample_frequency | 1160 | lora_sample_frequency = args.sample_frequency |
| 1115 | if num_train_epochs is None: | 1161 | if num_train_epochs is None: |
| 1116 | num_train_epochs = math.ceil( | 1162 | num_train_epochs = ( |
| 1117 | args.num_train_steps / len(lora_datamodule.train_dataset) | 1163 | math.ceil(args.num_train_steps / len(lora_datamodule.train_dataset)) |
| 1118 | ) * args.gradient_accumulation_steps | 1164 | * args.gradient_accumulation_steps |
| 1119 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 1165 | ) |
| 1120 | num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) | 1166 | lora_sample_frequency = math.ceil( |
| 1167 | num_train_epochs * (lora_sample_frequency / args.num_train_steps) | ||
| 1168 | ) | ||
| 1169 | num_training_steps_per_epoch = math.ceil( | ||
| 1170 | len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps | ||
| 1171 | ) | ||
| 1121 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 1172 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
| 1122 | if args.sample_num is not None: | 1173 | if args.sample_num is not None: |
| 1123 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1174 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| @@ -1131,7 +1182,6 @@ def main(): | |||
| 1131 | 1182 | ||
| 1132 | training_iter = 0 | 1183 | training_iter = 0 |
| 1133 | auto_cycles = list(args.auto_cycles) | 1184 | auto_cycles = list(args.auto_cycles) |
| 1134 | learning_rate_emb = args.learning_rate_emb | ||
| 1135 | learning_rate_unet = args.learning_rate_unet | 1185 | learning_rate_unet = args.learning_rate_unet |
| 1136 | learning_rate_text = args.learning_rate_text | 1186 | learning_rate_text = args.learning_rate_text |
| 1137 | lr_scheduler = args.lr_scheduler | 1187 | lr_scheduler = args.lr_scheduler |
| @@ -1145,21 +1195,15 @@ def main(): | |||
| 1145 | 1195 | ||
| 1146 | params_to_optimize = [ | 1196 | params_to_optimize = [ |
| 1147 | { | 1197 | { |
| 1148 | "params": ( | 1198 | "params": (param for param in unet.parameters() if param.requires_grad), |
| 1149 | param | ||
| 1150 | for param in unet.parameters() | ||
| 1151 | if param.requires_grad | ||
| 1152 | ), | ||
| 1153 | "lr": learning_rate_unet, | 1199 | "lr": learning_rate_unet, |
| 1154 | }, | 1200 | }, |
| 1155 | { | 1201 | { |
| 1156 | "params": ( | 1202 | "params": ( |
| 1157 | param | 1203 | param for param in text_encoder.parameters() if param.requires_grad |
| 1158 | for param in text_encoder.parameters() | ||
| 1159 | if param.requires_grad | ||
| 1160 | ), | 1204 | ), |
| 1161 | "lr": learning_rate_text, | 1205 | "lr": learning_rate_text, |
| 1162 | } | 1206 | }, |
| 1163 | ] | 1207 | ] |
| 1164 | group_labels = ["unet", "text"] | 1208 | group_labels = ["unet", "text"] |
| 1165 | 1209 | ||
| @@ -1169,19 +1213,26 @@ def main(): | |||
| 1169 | if len(auto_cycles) != 0: | 1213 | if len(auto_cycles) != 0: |
| 1170 | response = auto_cycles.pop(0) | 1214 | response = auto_cycles.pop(0) |
| 1171 | else: | 1215 | else: |
| 1172 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 1216 | response = input( |
| 1217 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " | ||
| 1218 | ) | ||
| 1173 | 1219 | ||
| 1174 | if response.lower().strip() == "o": | 1220 | if response.lower().strip() == "o": |
| 1175 | if args.learning_rate_emb is not None: | ||
| 1176 | learning_rate_emb = args.learning_rate_emb * 2 | ||
| 1177 | if args.learning_rate_unet is not None: | 1221 | if args.learning_rate_unet is not None: |
| 1178 | learning_rate_unet = args.learning_rate_unet * 2 | 1222 | learning_rate_unet = ( |
| 1223 | args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) | ||
| 1224 | ) | ||
| 1179 | if args.learning_rate_text is not None: | 1225 | if args.learning_rate_text is not None: |
| 1180 | learning_rate_text = args.learning_rate_text * 2 | 1226 | learning_rate_text = ( |
| 1227 | args.learning_rate_text * 2 * (args.cycle_decay**training_iter) | ||
| 1228 | ) | ||
| 1181 | else: | 1229 | else: |
| 1182 | learning_rate_emb = args.learning_rate_emb | 1230 | learning_rate_unet = args.learning_rate_unet * ( |
| 1183 | learning_rate_unet = args.learning_rate_unet | 1231 | args.cycle_decay**training_iter |
| 1184 | learning_rate_text = args.learning_rate_text | 1232 | ) |
| 1233 | learning_rate_text = args.learning_rate_text * ( | ||
| 1234 | args.cycle_decay**training_iter | ||
| 1235 | ) | ||
| 1185 | 1236 | ||
| 1186 | if response.lower().strip() == "o": | 1237 | if response.lower().strip() == "o": |
| 1187 | lr_scheduler = "one_cycle" | 1238 | lr_scheduler = "one_cycle" |
| @@ -1204,9 +1255,11 @@ def main(): | |||
| 1204 | print("") | 1255 | print("") |
| 1205 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") | 1256 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") |
| 1206 | print("") | 1257 | print("") |
| 1207 | 1258 | ||
| 1208 | for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): | 1259 | for group, lr in zip( |
| 1209 | group['lr'] = lr | 1260 | lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text] |
| 1261 | ): | ||
| 1262 | group["lr"] = lr | ||
| 1210 | 1263 | ||
| 1211 | lora_lr_scheduler = create_lr_scheduler( | 1264 | lora_lr_scheduler = create_lr_scheduler( |
| 1212 | lr_scheduler, | 1265 | lr_scheduler, |
| @@ -1218,7 +1271,9 @@ def main(): | |||
| 1218 | warmup_epochs=lr_warmup_epochs, | 1271 | warmup_epochs=lr_warmup_epochs, |
| 1219 | ) | 1272 | ) |
| 1220 | 1273 | ||
| 1221 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" | 1274 | lora_checkpoint_output_dir = ( |
| 1275 | output_dir / lora_project / f"model_{training_iter}" | ||
| 1276 | ) | ||
| 1222 | 1277 | ||
| 1223 | trainer( | 1278 | trainer( |
| 1224 | strategy=lora_strategy, | 1279 | strategy=lora_strategy, |
| @@ -1246,12 +1301,6 @@ def main(): | |||
| 1246 | ) | 1301 | ) |
| 1247 | 1302 | ||
| 1248 | training_iter += 1 | 1303 | training_iter += 1 |
| 1249 | if learning_rate_emb is not None: | ||
| 1250 | learning_rate_emb *= args.cycle_decay | ||
| 1251 | if learning_rate_unet is not None: | ||
| 1252 | learning_rate_unet *= args.cycle_decay | ||
| 1253 | if learning_rate_text is not None: | ||
| 1254 | learning_rate_text *= args.cycle_decay | ||
| 1255 | 1304 | ||
| 1256 | accelerator.end_training() | 1305 | accelerator.end_training() |
| 1257 | 1306 | ||
diff --git a/train_ti.py b/train_ti.py index f60e3e5..c6f0b3a 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -32,7 +32,7 @@ from util.files import load_config, load_embeddings_from_dir | |||
| 32 | 32 | ||
| 33 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
| 34 | 34 | ||
| 35 | warnings.filterwarnings('ignore') | 35 | warnings.filterwarnings("ignore") |
| 36 | 36 | ||
| 37 | 37 | ||
| 38 | torch.backends.cuda.matmul.allow_tf32 = True | 38 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -46,9 +46,7 @@ hidet.torch.dynamo_config.search_space(0) | |||
| 46 | 46 | ||
| 47 | 47 | ||
| 48 | def parse_args(): | 48 | def parse_args(): |
| 49 | parser = argparse.ArgumentParser( | 49 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 50 | description="Simple example of a training script." | ||
| 51 | ) | ||
| 52 | parser.add_argument( | 50 | parser.add_argument( |
| 53 | "--pretrained_model_name_or_path", | 51 | "--pretrained_model_name_or_path", |
| 54 | type=str, | 52 | type=str, |
| @@ -65,12 +63,12 @@ def parse_args(): | |||
| 65 | "--train_data_file", | 63 | "--train_data_file", |
| 66 | type=str, | 64 | type=str, |
| 67 | default=None, | 65 | default=None, |
| 68 | help="A CSV file containing the training data." | 66 | help="A CSV file containing the training data.", |
| 69 | ) | 67 | ) |
| 70 | parser.add_argument( | 68 | parser.add_argument( |
| 71 | "--train_data_template", | 69 | "--train_data_template", |
| 72 | type=str, | 70 | type=str, |
| 73 | nargs='*', | 71 | nargs="*", |
| 74 | default="template", | 72 | default="template", |
| 75 | ) | 73 | ) |
| 76 | parser.add_argument( | 74 | parser.add_argument( |
| @@ -80,59 +78,47 @@ def parse_args(): | |||
| 80 | help="The name of the current project.", | 78 | help="The name of the current project.", |
| 81 | ) | 79 | ) |
| 82 | parser.add_argument( | 80 | parser.add_argument( |
| 83 | "--auto_cycles", | 81 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
| 84 | type=str, | ||
| 85 | default="o", | ||
| 86 | help="Cycles to run automatically." | ||
| 87 | ) | 82 | ) |
| 88 | parser.add_argument( | 83 | parser.add_argument( |
| 89 | "--cycle_decay", | 84 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." |
| 90 | type=float, | ||
| 91 | default=1.0, | ||
| 92 | help="Learning rate decay per cycle." | ||
| 93 | ) | 85 | ) |
| 94 | parser.add_argument( | 86 | parser.add_argument( |
| 95 | "--placeholder_tokens", | 87 | "--placeholder_tokens", |
| 96 | type=str, | 88 | type=str, |
| 97 | nargs='*', | 89 | nargs="*", |
| 98 | help="A token to use as a placeholder for the concept.", | 90 | help="A token to use as a placeholder for the concept.", |
| 99 | ) | 91 | ) |
| 100 | parser.add_argument( | 92 | parser.add_argument( |
| 101 | "--initializer_tokens", | 93 | "--initializer_tokens", |
| 102 | type=str, | 94 | type=str, |
| 103 | nargs='*', | 95 | nargs="*", |
| 104 | help="A token to use as initializer word." | 96 | help="A token to use as initializer word.", |
| 105 | ) | 97 | ) |
| 106 | parser.add_argument( | 98 | parser.add_argument( |
| 107 | "--filter_tokens", | 99 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." |
| 108 | type=str, | ||
| 109 | nargs='*', | ||
| 110 | help="Tokens to filter the dataset by." | ||
| 111 | ) | 100 | ) |
| 112 | parser.add_argument( | 101 | parser.add_argument( |
| 113 | "--initializer_noise", | 102 | "--initializer_noise", |
| 114 | type=float, | 103 | type=float, |
| 115 | default=0, | 104 | default=0, |
| 116 | help="Noise to apply to the initializer word" | 105 | help="Noise to apply to the initializer word", |
| 117 | ) | 106 | ) |
| 118 | parser.add_argument( | 107 | parser.add_argument( |
| 119 | "--alias_tokens", | 108 | "--alias_tokens", |
| 120 | type=str, | 109 | type=str, |
| 121 | nargs='*', | 110 | nargs="*", |
| 122 | default=[], | 111 | default=[], |
| 123 | help="Tokens to create an alias for." | 112 | help="Tokens to create an alias for.", |
| 124 | ) | 113 | ) |
| 125 | parser.add_argument( | 114 | parser.add_argument( |
| 126 | "--inverted_initializer_tokens", | 115 | "--inverted_initializer_tokens", |
| 127 | type=str, | 116 | type=str, |
| 128 | nargs='*', | 117 | nargs="*", |
| 129 | help="A token to use as initializer word." | 118 | help="A token to use as initializer word.", |
| 130 | ) | 119 | ) |
| 131 | parser.add_argument( | 120 | parser.add_argument( |
| 132 | "--num_vectors", | 121 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
| 133 | type=int, | ||
| 134 | nargs='*', | ||
| 135 | help="Number of vectors per embedding." | ||
| 136 | ) | 122 | ) |
| 137 | parser.add_argument( | 123 | parser.add_argument( |
| 138 | "--sequential", | 124 | "--sequential", |
| @@ -147,7 +133,7 @@ def parse_args(): | |||
| 147 | "--num_class_images", | 133 | "--num_class_images", |
| 148 | type=int, | 134 | type=int, |
| 149 | default=0, | 135 | default=0, |
| 150 | help="How many class images to generate." | 136 | help="How many class images to generate.", |
| 151 | ) | 137 | ) |
| 152 | parser.add_argument( | 138 | parser.add_argument( |
| 153 | "--class_image_dir", | 139 | "--class_image_dir", |
| @@ -158,7 +144,7 @@ def parse_args(): | |||
| 158 | parser.add_argument( | 144 | parser.add_argument( |
| 159 | "--exclude_collections", | 145 | "--exclude_collections", |
| 160 | type=str, | 146 | type=str, |
| 161 | nargs='*', | 147 | nargs="*", |
| 162 | help="Exclude all items with a listed collection.", | 148 | help="Exclude all items with a listed collection.", |
| 163 | ) | 149 | ) |
| 164 | parser.add_argument( | 150 | parser.add_argument( |
| @@ -181,14 +167,11 @@ def parse_args(): | |||
| 181 | parser.add_argument( | 167 | parser.add_argument( |
| 182 | "--collection", | 168 | "--collection", |
| 183 | type=str, | 169 | type=str, |
| 184 | nargs='*', | 170 | nargs="*", |
| 185 | help="A collection to filter the dataset.", | 171 | help="A collection to filter the dataset.", |
| 186 | ) | 172 | ) |
| 187 | parser.add_argument( | 173 | parser.add_argument( |
| 188 | "--seed", | 174 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 189 | type=int, | ||
| 190 | default=None, | ||
| 191 | help="A seed for reproducible training." | ||
| 192 | ) | 175 | ) |
| 193 | parser.add_argument( | 176 | parser.add_argument( |
| 194 | "--resolution", | 177 | "--resolution", |
| @@ -244,7 +227,7 @@ def parse_args(): | |||
| 244 | type=str, | 227 | type=str, |
| 245 | default="auto", | 228 | default="auto", |
| 246 | choices=["all", "trailing", "leading", "between", "auto", "off"], | 229 | choices=["all", "trailing", "leading", "between", "auto", "off"], |
| 247 | help='Vector shuffling algorithm.', | 230 | help="Vector shuffling algorithm.", |
| 248 | ) | 231 | ) |
| 249 | parser.add_argument( | 232 | parser.add_argument( |
| 250 | "--offset_noise_strength", | 233 | "--offset_noise_strength", |
| @@ -256,18 +239,10 @@ def parse_args(): | |||
| 256 | "--input_pertubation", | 239 | "--input_pertubation", |
| 257 | type=float, | 240 | type=float, |
| 258 | default=0, | 241 | default=0, |
| 259 | help="The scale of input pretubation. Recommended 0.1." | 242 | help="The scale of input pretubation. Recommended 0.1.", |
| 260 | ) | ||
| 261 | parser.add_argument( | ||
| 262 | "--num_train_epochs", | ||
| 263 | type=int, | ||
| 264 | default=None | ||
| 265 | ) | ||
| 266 | parser.add_argument( | ||
| 267 | "--num_train_steps", | ||
| 268 | type=int, | ||
| 269 | default=2000 | ||
| 270 | ) | 243 | ) |
| 244 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
| 245 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
| 271 | parser.add_argument( | 246 | parser.add_argument( |
| 272 | "--gradient_accumulation_steps", | 247 | "--gradient_accumulation_steps", |
| 273 | type=int, | 248 | type=int, |
| @@ -299,27 +274,31 @@ def parse_args(): | |||
| 299 | "--lr_scheduler", | 274 | "--lr_scheduler", |
| 300 | type=str, | 275 | type=str, |
| 301 | default="one_cycle", | 276 | default="one_cycle", |
| 302 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 277 | choices=[ |
| 303 | "constant", "constant_with_warmup", "one_cycle"], | 278 | "linear", |
| 304 | help='The scheduler type to use.', | 279 | "cosine", |
| 280 | "cosine_with_restarts", | ||
| 281 | "polynomial", | ||
| 282 | "constant", | ||
| 283 | "constant_with_warmup", | ||
| 284 | "one_cycle", | ||
| 285 | ], | ||
| 286 | help="The scheduler type to use.", | ||
| 305 | ) | 287 | ) |
| 306 | parser.add_argument( | 288 | parser.add_argument( |
| 307 | "--lr_warmup_epochs", | 289 | "--lr_warmup_epochs", |
| 308 | type=int, | 290 | type=int, |
| 309 | default=10, | 291 | default=10, |
| 310 | help="Number of steps for the warmup in the lr scheduler." | 292 | help="Number of steps for the warmup in the lr scheduler.", |
| 311 | ) | 293 | ) |
| 312 | parser.add_argument( | 294 | parser.add_argument( |
| 313 | "--lr_mid_point", | 295 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
| 314 | type=float, | ||
| 315 | default=0.3, | ||
| 316 | help="OneCycle schedule mid point." | ||
| 317 | ) | 296 | ) |
| 318 | parser.add_argument( | 297 | parser.add_argument( |
| 319 | "--lr_cycles", | 298 | "--lr_cycles", |
| 320 | type=int, | 299 | type=int, |
| 321 | default=None, | 300 | default=None, |
| 322 | help="Number of restart cycles in the lr scheduler." | 301 | help="Number of restart cycles in the lr scheduler.", |
| 323 | ) | 302 | ) |
| 324 | parser.add_argument( | 303 | parser.add_argument( |
| 325 | "--lr_warmup_func", | 304 | "--lr_warmup_func", |
| @@ -331,7 +310,7 @@ def parse_args(): | |||
| 331 | "--lr_warmup_exp", | 310 | "--lr_warmup_exp", |
| 332 | type=int, | 311 | type=int, |
| 333 | default=1, | 312 | default=1, |
| 334 | help='If lr_warmup_func is "cos", exponent to modify the function' | 313 | help='If lr_warmup_func is "cos", exponent to modify the function', |
| 335 | ) | 314 | ) |
| 336 | parser.add_argument( | 315 | parser.add_argument( |
| 337 | "--lr_annealing_func", | 316 | "--lr_annealing_func", |
| @@ -343,89 +322,67 @@ def parse_args(): | |||
| 343 | "--lr_annealing_exp", | 322 | "--lr_annealing_exp", |
| 344 | type=int, | 323 | type=int, |
| 345 | default=1, | 324 | default=1, |
| 346 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 325 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
| 347 | ) | 326 | ) |
| 348 | parser.add_argument( | 327 | parser.add_argument( |
| 349 | "--lr_min_lr", | 328 | "--lr_min_lr", |
| 350 | type=float, | 329 | type=float, |
| 351 | default=0.04, | 330 | default=0.04, |
| 352 | help="Minimum learning rate in the lr scheduler." | 331 | help="Minimum learning rate in the lr scheduler.", |
| 353 | ) | 332 | ) |
| 354 | parser.add_argument( | 333 | parser.add_argument( |
| 355 | "--use_ema", | 334 | "--use_ema", action="store_true", help="Whether to use EMA model." |
| 356 | action="store_true", | ||
| 357 | help="Whether to use EMA model." | ||
| 358 | ) | ||
| 359 | parser.add_argument( | ||
| 360 | "--ema_inv_gamma", | ||
| 361 | type=float, | ||
| 362 | default=1.0 | ||
| 363 | ) | ||
| 364 | parser.add_argument( | ||
| 365 | "--ema_power", | ||
| 366 | type=float, | ||
| 367 | default=4/5 | ||
| 368 | ) | ||
| 369 | parser.add_argument( | ||
| 370 | "--ema_max_decay", | ||
| 371 | type=float, | ||
| 372 | default=0.9999 | ||
| 373 | ) | ||
| 374 | parser.add_argument( | ||
| 375 | "--min_snr_gamma", | ||
| 376 | type=int, | ||
| 377 | default=5, | ||
| 378 | help="MinSNR gamma." | ||
| 379 | ) | 335 | ) |
| 336 | parser.add_argument("--ema_inv_gamma", type=float, default=1.0) | ||
| 337 | parser.add_argument("--ema_power", type=float, default=4 / 5) | ||
| 338 | parser.add_argument("--ema_max_decay", type=float, default=0.9999) | ||
| 339 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
| 380 | parser.add_argument( | 340 | parser.add_argument( |
| 381 | "--schedule_sampler", | 341 | "--schedule_sampler", |
| 382 | type=str, | 342 | type=str, |
| 383 | default="uniform", | 343 | default="uniform", |
| 384 | choices=["uniform", "loss-second-moment"], | 344 | choices=["uniform", "loss-second-moment"], |
| 385 | help="Noise schedule sampler." | 345 | help="Noise schedule sampler.", |
| 386 | ) | 346 | ) |
| 387 | parser.add_argument( | 347 | parser.add_argument( |
| 388 | "--optimizer", | 348 | "--optimizer", |
| 389 | type=str, | 349 | type=str, |
| 390 | default="adan", | 350 | default="adan", |
| 391 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 351 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
| 392 | help='Optimizer to use' | 352 | help="Optimizer to use", |
| 393 | ) | 353 | ) |
| 394 | parser.add_argument( | 354 | parser.add_argument( |
| 395 | "--dadaptation_d0", | 355 | "--dadaptation_d0", |
| 396 | type=float, | 356 | type=float, |
| 397 | default=1e-6, | 357 | default=1e-6, |
| 398 | help="The d0 parameter for Dadaptation optimizers." | 358 | help="The d0 parameter for Dadaptation optimizers.", |
| 399 | ) | 359 | ) |
| 400 | parser.add_argument( | 360 | parser.add_argument( |
| 401 | "--adam_beta1", | 361 | "--adam_beta1", |
| 402 | type=float, | 362 | type=float, |
| 403 | default=None, | 363 | default=None, |
| 404 | help="The beta1 parameter for the Adam optimizer." | 364 | help="The beta1 parameter for the Adam optimizer.", |
| 405 | ) | 365 | ) |
| 406 | parser.add_argument( | 366 | parser.add_argument( |
| 407 | "--adam_beta2", | 367 | "--adam_beta2", |
| 408 | type=float, | 368 | type=float, |
| 409 | default=None, | 369 | default=None, |
| 410 | help="The beta2 parameter for the Adam optimizer." | 370 | help="The beta2 parameter for the Adam optimizer.", |
| 411 | ) | 371 | ) |
| 412 | parser.add_argument( | 372 | parser.add_argument( |
| 413 | "--adam_weight_decay", | 373 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
| 414 | type=float, | ||
| 415 | default=2e-2, | ||
| 416 | help="Weight decay to use." | ||
| 417 | ) | 374 | ) |
| 418 | parser.add_argument( | 375 | parser.add_argument( |
| 419 | "--adam_epsilon", | 376 | "--adam_epsilon", |
| 420 | type=float, | 377 | type=float, |
| 421 | default=1e-08, | 378 | default=1e-08, |
| 422 | help="Epsilon value for the Adam optimizer" | 379 | help="Epsilon value for the Adam optimizer", |
| 423 | ) | 380 | ) |
| 424 | parser.add_argument( | 381 | parser.add_argument( |
| 425 | "--adam_amsgrad", | 382 | "--adam_amsgrad", |
| 426 | type=bool, | 383 | type=bool, |
| 427 | default=False, | 384 | default=False, |
| 428 | help="Amsgrad value for the Adam optimizer" | 385 | help="Amsgrad value for the Adam optimizer", |
| 429 | ) | 386 | ) |
| 430 | parser.add_argument( | 387 | parser.add_argument( |
| 431 | "--mixed_precision", | 388 | "--mixed_precision", |
| @@ -456,7 +413,7 @@ def parse_args(): | |||
| 456 | ) | 413 | ) |
| 457 | parser.add_argument( | 414 | parser.add_argument( |
| 458 | "--no_milestone_checkpoints", | 415 | "--no_milestone_checkpoints", |
| 459 | action='store_true', | 416 | action="store_true", |
| 460 | help="If checkpoints are saved on maximum accuracy", | 417 | help="If checkpoints are saved on maximum accuracy", |
| 461 | ) | 418 | ) |
| 462 | parser.add_argument( | 419 | parser.add_argument( |
| @@ -493,25 +450,25 @@ def parse_args(): | |||
| 493 | "--valid_set_size", | 450 | "--valid_set_size", |
| 494 | type=int, | 451 | type=int, |
| 495 | default=None, | 452 | default=None, |
| 496 | help="Number of images in the validation dataset." | 453 | help="Number of images in the validation dataset.", |
| 497 | ) | 454 | ) |
| 498 | parser.add_argument( | 455 | parser.add_argument( |
| 499 | "--train_set_pad", | 456 | "--train_set_pad", |
| 500 | type=int, | 457 | type=int, |
| 501 | default=None, | 458 | default=None, |
| 502 | help="The number to fill train dataset items up to." | 459 | help="The number to fill train dataset items up to.", |
| 503 | ) | 460 | ) |
| 504 | parser.add_argument( | 461 | parser.add_argument( |
| 505 | "--valid_set_pad", | 462 | "--valid_set_pad", |
| 506 | type=int, | 463 | type=int, |
| 507 | default=None, | 464 | default=None, |
| 508 | help="The number to fill validation dataset items up to." | 465 | help="The number to fill validation dataset items up to.", |
| 509 | ) | 466 | ) |
| 510 | parser.add_argument( | 467 | parser.add_argument( |
| 511 | "--train_batch_size", | 468 | "--train_batch_size", |
| 512 | type=int, | 469 | type=int, |
| 513 | default=1, | 470 | default=1, |
| 514 | help="Batch size (per device) for the training dataloader." | 471 | help="Batch size (per device) for the training dataloader.", |
| 515 | ) | 472 | ) |
| 516 | parser.add_argument( | 473 | parser.add_argument( |
| 517 | "--sample_steps", | 474 | "--sample_steps", |
| @@ -523,14 +480,9 @@ def parse_args(): | |||
| 523 | "--prior_loss_weight", | 480 | "--prior_loss_weight", |
| 524 | type=float, | 481 | type=float, |
| 525 | default=1.0, | 482 | default=1.0, |
| 526 | help="The weight of prior preservation loss." | 483 | help="The weight of prior preservation loss.", |
| 527 | ) | ||
| 528 | parser.add_argument( | ||
| 529 | "--emb_alpha", | ||
| 530 | type=float, | ||
| 531 | default=1.0, | ||
| 532 | help="Embedding alpha" | ||
| 533 | ) | 484 | ) |
| 485 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
| 534 | parser.add_argument( | 486 | parser.add_argument( |
| 535 | "--emb_dropout", | 487 | "--emb_dropout", |
| 536 | type=float, | 488 | type=float, |
| @@ -538,21 +490,13 @@ def parse_args(): | |||
| 538 | help="Embedding dropout probability.", | 490 | help="Embedding dropout probability.", |
| 539 | ) | 491 | ) |
| 540 | parser.add_argument( | 492 | parser.add_argument( |
| 541 | "--use_emb_decay", | 493 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." |
| 542 | action="store_true", | ||
| 543 | help="Whether to use embedding decay." | ||
| 544 | ) | 494 | ) |
| 545 | parser.add_argument( | 495 | parser.add_argument( |
| 546 | "--emb_decay_target", | 496 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." |
| 547 | default=0.4, | ||
| 548 | type=float, | ||
| 549 | help="Embedding decay target." | ||
| 550 | ) | 497 | ) |
| 551 | parser.add_argument( | 498 | parser.add_argument( |
| 552 | "--emb_decay", | 499 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." |
| 553 | default=1e+2, | ||
| 554 | type=float, | ||
| 555 | help="Embedding decay factor." | ||
| 556 | ) | 500 | ) |
| 557 | parser.add_argument( | 501 | parser.add_argument( |
| 558 | "--noise_timesteps", | 502 | "--noise_timesteps", |
| @@ -563,7 +507,7 @@ def parse_args(): | |||
| 563 | "--resume_from", | 507 | "--resume_from", |
| 564 | type=str, | 508 | type=str, |
| 565 | default=None, | 509 | default=None, |
| 566 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" | 510 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)", |
| 567 | ) | 511 | ) |
| 568 | parser.add_argument( | 512 | parser.add_argument( |
| 569 | "--global_step", | 513 | "--global_step", |
| @@ -574,7 +518,7 @@ def parse_args(): | |||
| 574 | "--config", | 518 | "--config", |
| 575 | type=str, | 519 | type=str, |
| 576 | default=None, | 520 | default=None, |
| 577 | help="Path to a JSON configuration file containing arguments for invoking this script." | 521 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
| 578 | ) | 522 | ) |
| 579 | 523 | ||
| 580 | args = parser.parse_args() | 524 | args = parser.parse_args() |
| @@ -595,29 +539,44 @@ def parse_args(): | |||
| 595 | args.placeholder_tokens = [args.placeholder_tokens] | 539 | args.placeholder_tokens = [args.placeholder_tokens] |
| 596 | 540 | ||
| 597 | if isinstance(args.initializer_tokens, str): | 541 | if isinstance(args.initializer_tokens, str): |
| 598 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 542 | args.initializer_tokens = [args.initializer_tokens] * len( |
| 543 | args.placeholder_tokens | ||
| 544 | ) | ||
| 599 | 545 | ||
| 600 | if len(args.placeholder_tokens) == 0: | 546 | if len(args.placeholder_tokens) == 0: |
| 601 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 547 | args.placeholder_tokens = [ |
| 548 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
| 549 | ] | ||
| 602 | 550 | ||
| 603 | if len(args.initializer_tokens) == 0: | 551 | if len(args.initializer_tokens) == 0: |
| 604 | args.initializer_tokens = args.placeholder_tokens.copy() | 552 | args.initializer_tokens = args.placeholder_tokens.copy() |
| 605 | 553 | ||
| 606 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 554 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 607 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 555 | raise ValueError( |
| 556 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
| 557 | ) | ||
| 608 | 558 | ||
| 609 | if isinstance(args.inverted_initializer_tokens, str): | 559 | if isinstance(args.inverted_initializer_tokens, str): |
| 610 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | 560 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( |
| 561 | args.placeholder_tokens | ||
| 562 | ) | ||
| 611 | 563 | ||
| 612 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | 564 | if ( |
| 565 | isinstance(args.inverted_initializer_tokens, list) | ||
| 566 | and len(args.inverted_initializer_tokens) != 0 | ||
| 567 | ): | ||
| 613 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | 568 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] |
| 614 | args.initializer_tokens += args.inverted_initializer_tokens | 569 | args.initializer_tokens += args.inverted_initializer_tokens |
| 615 | 570 | ||
| 616 | if isinstance(args.num_vectors, int): | 571 | if isinstance(args.num_vectors, int): |
| 617 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 572 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 618 | 573 | ||
| 619 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 574 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( |
| 620 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 575 | args.num_vectors |
| 576 | ): | ||
| 577 | raise ValueError( | ||
| 578 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
| 579 | ) | ||
| 621 | 580 | ||
| 622 | if args.alias_tokens is None: | 581 | if args.alias_tokens is None: |
| 623 | args.alias_tokens = [] | 582 | args.alias_tokens = [] |
| @@ -639,16 +598,22 @@ def parse_args(): | |||
| 639 | ] | 598 | ] |
| 640 | 599 | ||
| 641 | if isinstance(args.train_data_template, str): | 600 | if isinstance(args.train_data_template, str): |
| 642 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 601 | args.train_data_template = [args.train_data_template] * len( |
| 602 | args.placeholder_tokens | ||
| 603 | ) | ||
| 643 | 604 | ||
| 644 | if len(args.placeholder_tokens) != len(args.train_data_template): | 605 | if len(args.placeholder_tokens) != len(args.train_data_template): |
| 645 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 606 | raise ValueError( |
| 607 | "--placeholder_tokens and --train_data_template must have the same number of items" | ||
| 608 | ) | ||
| 646 | 609 | ||
| 647 | if args.num_vectors is None: | 610 | if args.num_vectors is None: |
| 648 | args.num_vectors = [None] * len(args.placeholder_tokens) | 611 | args.num_vectors = [None] * len(args.placeholder_tokens) |
| 649 | else: | 612 | else: |
| 650 | if isinstance(args.train_data_template, list): | 613 | if isinstance(args.train_data_template, list): |
| 651 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | 614 | raise ValueError( |
| 615 | "--train_data_template can't be a list in simultaneous mode" | ||
| 616 | ) | ||
| 652 | 617 | ||
| 653 | if isinstance(args.collection, str): | 618 | if isinstance(args.collection, str): |
| 654 | args.collection = [args.collection] | 619 | args.collection = [args.collection] |
| @@ -660,13 +625,13 @@ def parse_args(): | |||
| 660 | raise ValueError("You must specify --output_dir") | 625 | raise ValueError("You must specify --output_dir") |
| 661 | 626 | ||
| 662 | if args.adam_beta1 is None: | 627 | if args.adam_beta1 is None: |
| 663 | if args.optimizer == 'lion': | 628 | if args.optimizer == "lion": |
| 664 | args.adam_beta1 = 0.95 | 629 | args.adam_beta1 = 0.95 |
| 665 | else: | 630 | else: |
| 666 | args.adam_beta1 = 0.9 | 631 | args.adam_beta1 = 0.9 |
| 667 | 632 | ||
| 668 | if args.adam_beta2 is None: | 633 | if args.adam_beta2 is None: |
| 669 | if args.optimizer == 'lion': | 634 | if args.optimizer == "lion": |
| 670 | args.adam_beta2 = 0.98 | 635 | args.adam_beta2 = 0.98 |
| 671 | else: | 636 | else: |
| 672 | args.adam_beta2 = 0.999 | 637 | args.adam_beta2 = 0.999 |
| @@ -679,13 +644,13 @@ def main(): | |||
| 679 | 644 | ||
| 680 | global_step_offset = args.global_step | 645 | global_step_offset = args.global_step |
| 681 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 646 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 682 | output_dir = Path(args.output_dir)/slugify(args.project)/now | 647 | output_dir = Path(args.output_dir) / slugify(args.project) / now |
| 683 | output_dir.mkdir(parents=True, exist_ok=True) | 648 | output_dir.mkdir(parents=True, exist_ok=True) |
| 684 | 649 | ||
| 685 | accelerator = Accelerator( | 650 | accelerator = Accelerator( |
| 686 | log_with=LoggerType.TENSORBOARD, | 651 | log_with=LoggerType.TENSORBOARD, |
| 687 | project_dir=f"{output_dir}", | 652 | project_dir=f"{output_dir}", |
| 688 | mixed_precision=args.mixed_precision | 653 | mixed_precision=args.mixed_precision, |
| 689 | ) | 654 | ) |
| 690 | 655 | ||
| 691 | weight_dtype = torch.float32 | 656 | weight_dtype = torch.float32 |
| @@ -703,9 +668,15 @@ def main(): | |||
| 703 | 668 | ||
| 704 | save_args(output_dir, args) | 669 | save_args(output_dir, args) |
| 705 | 670 | ||
| 706 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) | 671 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
| 707 | embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | 672 | args.pretrained_model_name_or_path |
| 708 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 673 | ) |
| 674 | embeddings = patch_managed_embeddings( | ||
| 675 | text_encoder, args.emb_alpha, args.emb_dropout | ||
| 676 | ) | ||
| 677 | schedule_sampler = create_named_schedule_sampler( | ||
| 678 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
| 679 | ) | ||
| 709 | 680 | ||
| 710 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 681 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 711 | tokenizer.set_dropout(args.vector_dropout) | 682 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -717,16 +688,16 @@ def main(): | |||
| 717 | unet.enable_xformers_memory_efficient_attention() | 688 | unet.enable_xformers_memory_efficient_attention() |
| 718 | elif args.compile_unet: | 689 | elif args.compile_unet: |
| 719 | unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | 690 | unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False |
| 720 | 691 | ||
| 721 | proc = AttnProcessor() | 692 | proc = AttnProcessor() |
| 722 | 693 | ||
| 723 | def fn_recursive_set_proc(module: torch.nn.Module): | 694 | def fn_recursive_set_proc(module: torch.nn.Module): |
| 724 | if hasattr(module, "processor"): | 695 | if hasattr(module, "processor"): |
| 725 | module.processor = proc | 696 | module.processor = proc |
| 726 | 697 | ||
| 727 | for child in module.children(): | 698 | for child in module.children(): |
| 728 | fn_recursive_set_proc(child) | 699 | fn_recursive_set_proc(child) |
| 729 | 700 | ||
| 730 | fn_recursive_set_proc(unet) | 701 | fn_recursive_set_proc(unet) |
| 731 | 702 | ||
| 732 | if args.gradient_checkpointing: | 703 | if args.gradient_checkpointing: |
| @@ -751,18 +722,24 @@ def main(): | |||
| 751 | tokenizer=tokenizer, | 722 | tokenizer=tokenizer, |
| 752 | embeddings=embeddings, | 723 | embeddings=embeddings, |
| 753 | placeholder_tokens=alias_placeholder_tokens, | 724 | placeholder_tokens=alias_placeholder_tokens, |
| 754 | initializer_tokens=alias_initializer_tokens | 725 | initializer_tokens=alias_initializer_tokens, |
| 755 | ) | 726 | ) |
| 756 | embeddings.persist() | 727 | embeddings.persist() |
| 757 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 728 | print( |
| 729 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
| 730 | ) | ||
| 758 | 731 | ||
| 759 | if args.embeddings_dir is not None: | 732 | if args.embeddings_dir is not None: |
| 760 | embeddings_dir = Path(args.embeddings_dir) | 733 | embeddings_dir = Path(args.embeddings_dir) |
| 761 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 734 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 762 | raise ValueError("--embeddings_dir must point to an existing directory") | 735 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 763 | 736 | ||
| 764 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 737 | added_tokens, added_ids = load_embeddings_from_dir( |
| 765 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 738 | tokenizer, embeddings, embeddings_dir |
| 739 | ) | ||
| 740 | print( | ||
| 741 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
| 742 | ) | ||
| 766 | 743 | ||
| 767 | if args.train_dir_embeddings: | 744 | if args.train_dir_embeddings: |
| 768 | args.placeholder_tokens = added_tokens | 745 | args.placeholder_tokens = added_tokens |
| @@ -772,19 +749,23 @@ def main(): | |||
| 772 | 749 | ||
| 773 | if args.scale_lr: | 750 | if args.scale_lr: |
| 774 | args.learning_rate = ( | 751 | args.learning_rate = ( |
| 775 | args.learning_rate * args.gradient_accumulation_steps * | 752 | args.learning_rate |
| 776 | args.train_batch_size * accelerator.num_processes | 753 | * args.gradient_accumulation_steps |
| 754 | * args.train_batch_size | ||
| 755 | * accelerator.num_processes | ||
| 777 | ) | 756 | ) |
| 778 | 757 | ||
| 779 | if args.find_lr: | 758 | if args.find_lr: |
| 780 | args.learning_rate = 1e-5 | 759 | args.learning_rate = 1e-5 |
| 781 | args.lr_scheduler = "exponential_growth" | 760 | args.lr_scheduler = "exponential_growth" |
| 782 | 761 | ||
| 783 | if args.optimizer == 'adam8bit': | 762 | if args.optimizer == "adam8bit": |
| 784 | try: | 763 | try: |
| 785 | import bitsandbytes as bnb | 764 | import bitsandbytes as bnb |
| 786 | except ImportError: | 765 | except ImportError: |
| 787 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 766 | raise ImportError( |
| 767 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
| 768 | ) | ||
| 788 | 769 | ||
| 789 | create_optimizer = partial( | 770 | create_optimizer = partial( |
| 790 | bnb.optim.AdamW8bit, | 771 | bnb.optim.AdamW8bit, |
| @@ -793,7 +774,7 @@ def main(): | |||
| 793 | eps=args.adam_epsilon, | 774 | eps=args.adam_epsilon, |
| 794 | amsgrad=args.adam_amsgrad, | 775 | amsgrad=args.adam_amsgrad, |
| 795 | ) | 776 | ) |
| 796 | elif args.optimizer == 'adam': | 777 | elif args.optimizer == "adam": |
| 797 | create_optimizer = partial( | 778 | create_optimizer = partial( |
| 798 | torch.optim.AdamW, | 779 | torch.optim.AdamW, |
| 799 | betas=(args.adam_beta1, args.adam_beta2), | 780 | betas=(args.adam_beta1, args.adam_beta2), |
| @@ -801,11 +782,13 @@ def main(): | |||
| 801 | eps=args.adam_epsilon, | 782 | eps=args.adam_epsilon, |
| 802 | amsgrad=args.adam_amsgrad, | 783 | amsgrad=args.adam_amsgrad, |
| 803 | ) | 784 | ) |
| 804 | elif args.optimizer == 'adan': | 785 | elif args.optimizer == "adan": |
| 805 | try: | 786 | try: |
| 806 | import timm.optim | 787 | import timm.optim |
| 807 | except ImportError: | 788 | except ImportError: |
| 808 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 789 | raise ImportError( |
| 790 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
| 791 | ) | ||
| 809 | 792 | ||
| 810 | create_optimizer = partial( | 793 | create_optimizer = partial( |
| 811 | timm.optim.Adan, | 794 | timm.optim.Adan, |
| @@ -813,11 +796,13 @@ def main(): | |||
| 813 | eps=args.adam_epsilon, | 796 | eps=args.adam_epsilon, |
| 814 | no_prox=True, | 797 | no_prox=True, |
| 815 | ) | 798 | ) |
| 816 | elif args.optimizer == 'lion': | 799 | elif args.optimizer == "lion": |
| 817 | try: | 800 | try: |
| 818 | import lion_pytorch | 801 | import lion_pytorch |
| 819 | except ImportError: | 802 | except ImportError: |
| 820 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 803 | raise ImportError( |
| 804 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
| 805 | ) | ||
| 821 | 806 | ||
| 822 | create_optimizer = partial( | 807 | create_optimizer = partial( |
| 823 | lion_pytorch.Lion, | 808 | lion_pytorch.Lion, |
| @@ -825,7 +810,7 @@ def main(): | |||
| 825 | weight_decay=args.adam_weight_decay, | 810 | weight_decay=args.adam_weight_decay, |
| 826 | use_triton=True, | 811 | use_triton=True, |
| 827 | ) | 812 | ) |
| 828 | elif args.optimizer == 'adafactor': | 813 | elif args.optimizer == "adafactor": |
| 829 | create_optimizer = partial( | 814 | create_optimizer = partial( |
| 830 | transformers.optimization.Adafactor, | 815 | transformers.optimization.Adafactor, |
| 831 | weight_decay=args.adam_weight_decay, | 816 | weight_decay=args.adam_weight_decay, |
| @@ -837,11 +822,13 @@ def main(): | |||
| 837 | args.lr_scheduler = "adafactor" | 822 | args.lr_scheduler = "adafactor" |
| 838 | args.lr_min_lr = args.learning_rate | 823 | args.lr_min_lr = args.learning_rate |
| 839 | args.learning_rate = None | 824 | args.learning_rate = None |
| 840 | elif args.optimizer == 'dadam': | 825 | elif args.optimizer == "dadam": |
| 841 | try: | 826 | try: |
| 842 | import dadaptation | 827 | import dadaptation |
| 843 | except ImportError: | 828 | except ImportError: |
| 844 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 829 | raise ImportError( |
| 830 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
| 831 | ) | ||
| 845 | 832 | ||
| 846 | create_optimizer = partial( | 833 | create_optimizer = partial( |
| 847 | dadaptation.DAdaptAdam, | 834 | dadaptation.DAdaptAdam, |
| @@ -851,11 +838,13 @@ def main(): | |||
| 851 | decouple=True, | 838 | decouple=True, |
| 852 | d0=args.dadaptation_d0, | 839 | d0=args.dadaptation_d0, |
| 853 | ) | 840 | ) |
| 854 | elif args.optimizer == 'dadan': | 841 | elif args.optimizer == "dadan": |
| 855 | try: | 842 | try: |
| 856 | import dadaptation | 843 | import dadaptation |
| 857 | except ImportError: | 844 | except ImportError: |
| 858 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 845 | raise ImportError( |
| 846 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 847 | ) | ||
| 859 | 848 | ||
| 860 | create_optimizer = partial( | 849 | create_optimizer = partial( |
| 861 | dadaptation.DAdaptAdan, | 850 | dadaptation.DAdaptAdan, |
| @@ -864,7 +853,7 @@ def main(): | |||
| 864 | d0=args.dadaptation_d0, | 853 | d0=args.dadaptation_d0, |
| 865 | ) | 854 | ) |
| 866 | else: | 855 | else: |
| 867 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 856 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 868 | 857 | ||
| 869 | trainer = partial( | 858 | trainer = partial( |
| 870 | train, | 859 | train, |
| @@ -904,10 +893,21 @@ def main(): | |||
| 904 | sample_image_size=args.sample_image_size, | 893 | sample_image_size=args.sample_image_size, |
| 905 | ) | 894 | ) |
| 906 | 895 | ||
| 896 | optimizer = create_optimizer( | ||
| 897 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 898 | lr=learning_rate, | ||
| 899 | ) | ||
| 900 | |||
| 907 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 901 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
| 908 | data_npgenerator = np.random.default_rng(args.seed) | 902 | data_npgenerator = np.random.default_rng(args.seed) |
| 909 | 903 | ||
| 910 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 904 | def run( |
| 905 | i: int, | ||
| 906 | placeholder_tokens: list[str], | ||
| 907 | initializer_tokens: list[str], | ||
| 908 | num_vectors: Union[int, list[int]], | ||
| 909 | data_template: str, | ||
| 910 | ): | ||
| 911 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 911 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 912 | tokenizer=tokenizer, | 912 | tokenizer=tokenizer, |
| 913 | embeddings=embeddings, | 913 | embeddings=embeddings, |
| @@ -917,14 +917,23 @@ def main(): | |||
| 917 | initializer_noise=args.initializer_noise, | 917 | initializer_noise=args.initializer_noise, |
| 918 | ) | 918 | ) |
| 919 | 919 | ||
| 920 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) | 920 | stats = list( |
| 921 | zip( | ||
| 922 | placeholder_tokens, | ||
| 923 | placeholder_token_ids, | ||
| 924 | initializer_tokens, | ||
| 925 | initializer_token_ids, | ||
| 926 | ) | ||
| 927 | ) | ||
| 921 | 928 | ||
| 922 | print("") | 929 | print("") |
| 923 | print(f"============ TI batch {i + 1} ============") | 930 | print(f"============ TI batch {i + 1} ============") |
| 924 | print("") | 931 | print("") |
| 925 | print(stats) | 932 | print(stats) |
| 926 | 933 | ||
| 927 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 934 | filter_tokens = [ |
| 935 | token for token in args.filter_tokens if token in placeholder_tokens | ||
| 936 | ] | ||
| 928 | 937 | ||
| 929 | datamodule = VlpnDataModule( | 938 | datamodule = VlpnDataModule( |
| 930 | data_file=args.train_data_file, | 939 | data_file=args.train_data_file, |
| @@ -945,7 +954,9 @@ def main(): | |||
| 945 | valid_set_size=args.valid_set_size, | 954 | valid_set_size=args.valid_set_size, |
| 946 | train_set_pad=args.train_set_pad, | 955 | train_set_pad=args.train_set_pad, |
| 947 | valid_set_pad=args.valid_set_pad, | 956 | valid_set_pad=args.valid_set_pad, |
| 948 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 957 | filter=partial( |
| 958 | keyword_filter, filter_tokens, args.collection, args.exclude_collections | ||
| 959 | ), | ||
| 949 | dtype=weight_dtype, | 960 | dtype=weight_dtype, |
| 950 | generator=data_generator, | 961 | generator=data_generator, |
| 951 | npgenerator=data_npgenerator, | 962 | npgenerator=data_npgenerator, |
| @@ -955,11 +966,16 @@ def main(): | |||
| 955 | num_train_epochs = args.num_train_epochs | 966 | num_train_epochs = args.num_train_epochs |
| 956 | sample_frequency = args.sample_frequency | 967 | sample_frequency = args.sample_frequency |
| 957 | if num_train_epochs is None: | 968 | if num_train_epochs is None: |
| 958 | num_train_epochs = math.ceil( | 969 | num_train_epochs = ( |
| 959 | args.num_train_steps / len(datamodule.train_dataset) | 970 | math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
| 960 | ) * args.gradient_accumulation_steps | 971 | * args.gradient_accumulation_steps |
| 961 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 972 | ) |
| 962 | num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) | 973 | sample_frequency = math.ceil( |
| 974 | num_train_epochs * (sample_frequency / args.num_train_steps) | ||
| 975 | ) | ||
| 976 | num_training_steps_per_epoch = math.ceil( | ||
| 977 | len(datamodule.train_dataset) / args.gradient_accumulation_steps | ||
| 978 | ) | ||
| 963 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 979 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
| 964 | if args.sample_num is not None: | 980 | if args.sample_num is not None: |
| 965 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 981 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| @@ -988,7 +1004,8 @@ def main(): | |||
| 988 | response = auto_cycles.pop(0) | 1004 | response = auto_cycles.pop(0) |
| 989 | else: | 1005 | else: |
| 990 | response = input( | 1006 | response = input( |
| 991 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 1007 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " |
| 1008 | ) | ||
| 992 | 1009 | ||
| 993 | if response.lower().strip() == "o": | 1010 | if response.lower().strip() == "o": |
| 994 | if args.learning_rate is not None: | 1011 | if args.learning_rate is not None: |
| @@ -1018,10 +1035,8 @@ def main(): | |||
| 1018 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") | 1035 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") |
| 1019 | print("") | 1036 | print("") |
| 1020 | 1037 | ||
| 1021 | optimizer = create_optimizer( | 1038 | for group, lr in zip(optimizer.param_groups, [learning_rate]): |
| 1022 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 1039 | group["lr"] = lr |
| 1023 | lr=learning_rate, | ||
| 1024 | ) | ||
| 1025 | 1040 | ||
| 1026 | lr_scheduler = get_scheduler( | 1041 | lr_scheduler = get_scheduler( |
| 1027 | lr_scheduler, | 1042 | lr_scheduler, |
| @@ -1040,7 +1055,9 @@ def main(): | |||
| 1040 | mid_point=args.lr_mid_point, | 1055 | mid_point=args.lr_mid_point, |
| 1041 | ) | 1056 | ) |
| 1042 | 1057 | ||
| 1043 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" | 1058 | checkpoint_output_dir = ( |
| 1059 | output_dir / project / f"checkpoints_{training_iter}" | ||
| 1060 | ) | ||
| 1044 | 1061 | ||
| 1045 | trainer( | 1062 | trainer( |
| 1046 | train_dataloader=datamodule.train_dataloader, | 1063 | train_dataloader=datamodule.train_dataloader, |
| @@ -1070,14 +1087,20 @@ def main(): | |||
| 1070 | accelerator.end_training() | 1087 | accelerator.end_training() |
| 1071 | 1088 | ||
| 1072 | if not args.sequential: | 1089 | if not args.sequential: |
| 1073 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 1090 | run( |
| 1091 | 0, | ||
| 1092 | args.placeholder_tokens, | ||
| 1093 | args.initializer_tokens, | ||
| 1094 | args.num_vectors, | ||
| 1095 | args.train_data_template, | ||
| 1096 | ) | ||
| 1074 | else: | 1097 | else: |
| 1075 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 1098 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
| 1076 | range(len(args.placeholder_tokens)), | 1099 | range(len(args.placeholder_tokens)), |
| 1077 | args.placeholder_tokens, | 1100 | args.placeholder_tokens, |
| 1078 | args.initializer_tokens, | 1101 | args.initializer_tokens, |
| 1079 | args.num_vectors, | 1102 | args.num_vectors, |
| 1080 | args.train_data_template | 1103 | args.train_data_template, |
| 1081 | ): | 1104 | ): |
| 1082 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) | 1105 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) |
| 1083 | embeddings.persist() | 1106 | embeddings.persist() |
diff --git a/training/functional.py b/training/functional.py index fd3f9f4..f68faf9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -14,7 +14,13 @@ import numpy as np | |||
| 14 | 14 | ||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 16 | from transformers import CLIPTextModel | 16 | from transformers import CLIPTextModel |
| 17 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin | 17 | from diffusers import ( |
| 18 | AutoencoderKL, | ||
| 19 | UNet2DConditionModel, | ||
| 20 | DDPMScheduler, | ||
| 21 | UniPCMultistepScheduler, | ||
| 22 | SchedulerMixin, | ||
| 23 | ) | ||
| 18 | 24 | ||
| 19 | from tqdm.auto import tqdm | 25 | from tqdm.auto import tqdm |
| 20 | 26 | ||
| @@ -33,11 +39,12 @@ from util.noise import perlin_noise | |||
| 33 | def const(result=None): | 39 | def const(result=None): |
| 34 | def fn(*args, **kwargs): | 40 | def fn(*args, **kwargs): |
| 35 | return result | 41 | return result |
| 42 | |||
| 36 | return fn | 43 | return fn |
| 37 | 44 | ||
| 38 | 45 | ||
| 39 | @dataclass | 46 | @dataclass |
| 40 | class TrainingCallbacks(): | 47 | class TrainingCallbacks: |
| 41 | on_log: Callable[[], dict[str, Any]] = const({}) | 48 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 42 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 49 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 43 | on_before_optimize: Callable[[int], Any] = const() | 50 | on_before_optimize: Callable[[int], Any] = const() |
| @@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol): | |||
| 58 | train_dataloader: DataLoader, | 65 | train_dataloader: DataLoader, |
| 59 | val_dataloader: Optional[DataLoader], | 66 | val_dataloader: Optional[DataLoader], |
| 60 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 67 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 61 | **kwargs | 68 | **kwargs, |
| 62 | ) -> Tuple: ... | 69 | ) -> Tuple: |
| 70 | ... | ||
| 63 | 71 | ||
| 64 | 72 | ||
| 65 | @dataclass | 73 | @dataclass |
| 66 | class TrainingStrategy(): | 74 | class TrainingStrategy: |
| 67 | callbacks: Callable[..., TrainingCallbacks] | 75 | callbacks: Callable[..., TrainingCallbacks] |
| 68 | prepare: TrainingStrategyPrepareCallable | 76 | prepare: TrainingStrategyPrepareCallable |
| 69 | 77 | ||
| 70 | 78 | ||
| 71 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): | 79 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): |
| 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 80 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
| 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) | 81 | pretrained_model_name_or_path, subfolder="tokenizer" |
| 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) | 82 | ) |
| 75 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) | 83 | text_encoder = CLIPTextModel.from_pretrained( |
| 76 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 84 | pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype |
| 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 85 | ) |
| 86 | vae = AutoencoderKL.from_pretrained( | ||
| 87 | pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype | ||
| 88 | ) | ||
| 89 | unet = UNet2DConditionModel.from_pretrained( | ||
| 90 | pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype | ||
| 91 | ) | ||
| 92 | noise_scheduler = DDPMScheduler.from_pretrained( | ||
| 93 | pretrained_model_name_or_path, subfolder="scheduler" | ||
| 94 | ) | ||
| 95 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | ||
| 96 | pretrained_model_name_or_path, subfolder="scheduler" | ||
| 97 | ) | ||
| 78 | 98 | ||
| 79 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 99 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
| 80 | 100 | ||
| @@ -113,7 +133,9 @@ def save_samples( | |||
| 113 | 133 | ||
| 114 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) | 134 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
| 115 | 135 | ||
| 116 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] | 136 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [ |
| 137 | ("train", train_dataloader, None) | ||
| 138 | ] | ||
| 117 | 139 | ||
| 118 | if val_dataloader is not None: | 140 | if val_dataloader is not None: |
| 119 | datasets.append(("stable", val_dataloader, generator)) | 141 | datasets.append(("stable", val_dataloader, generator)) |
| @@ -124,17 +146,11 @@ def save_samples( | |||
| 124 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" | 146 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" |
| 125 | file_path.parent.mkdir(parents=True, exist_ok=True) | 147 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 126 | 148 | ||
| 127 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 149 | batches = list( |
| 128 | prompt_ids = [ | 150 | itertools.islice(itertools.cycle(data), batch_size * num_batches) |
| 129 | prompt | 151 | ) |
| 130 | for batch in batches | 152 | prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]] |
| 131 | for prompt in batch["prompt_ids"] | 153 | nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]] |
| 132 | ] | ||
| 133 | nprompt_ids = [ | ||
| 134 | prompt | ||
| 135 | for batch in batches | ||
| 136 | for prompt in batch["nprompt_ids"] | ||
| 137 | ] | ||
| 138 | 154 | ||
| 139 | with torch.inference_mode(): | 155 | with torch.inference_mode(): |
| 140 | for i in range(num_batches): | 156 | for i in range(num_batches): |
| @@ -165,7 +181,9 @@ def save_samples( | |||
| 165 | pass | 181 | pass |
| 166 | 182 | ||
| 167 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 183 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
| 168 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | 184 | image_grid = pipeline.numpy_to_pil( |
| 185 | image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() | ||
| 186 | )[0] | ||
| 169 | image_grid.save(file_path, quality=85) | 187 | image_grid.save(file_path, quality=85) |
| 170 | 188 | ||
| 171 | del generator, pipeline | 189 | del generator, pipeline |
| @@ -184,15 +202,17 @@ def generate_class_images( | |||
| 184 | train_dataset: VlpnDataset, | 202 | train_dataset: VlpnDataset, |
| 185 | sample_batch_size: int, | 203 | sample_batch_size: int, |
| 186 | sample_image_size: int, | 204 | sample_image_size: int, |
| 187 | sample_steps: int | 205 | sample_steps: int, |
| 188 | ): | 206 | ): |
| 189 | missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] | 207 | missing_data = [ |
| 208 | item for item in train_dataset.items if not item.class_image_path.exists() | ||
| 209 | ] | ||
| 190 | 210 | ||
| 191 | if len(missing_data) == 0: | 211 | if len(missing_data) == 0: |
| 192 | return | 212 | return |
| 193 | 213 | ||
| 194 | batched_data = [ | 214 | batched_data = [ |
| 195 | missing_data[i:i+sample_batch_size] | 215 | missing_data[i : i + sample_batch_size] |
| 196 | for i in range(0, len(missing_data), sample_batch_size) | 216 | for i in range(0, len(missing_data), sample_batch_size) |
| 197 | ] | 217 | ] |
| 198 | 218 | ||
| @@ -216,7 +236,7 @@ def generate_class_images( | |||
| 216 | negative_prompt=nprompt, | 236 | negative_prompt=nprompt, |
| 217 | height=sample_image_size, | 237 | height=sample_image_size, |
| 218 | width=sample_image_size, | 238 | width=sample_image_size, |
| 219 | num_inference_steps=sample_steps | 239 | num_inference_steps=sample_steps, |
| 220 | ).images | 240 | ).images |
| 221 | 241 | ||
| 222 | for i, image in enumerate(images): | 242 | for i, image in enumerate(images): |
| @@ -245,8 +265,12 @@ def add_placeholder_tokens( | |||
| 245 | 265 | ||
| 246 | embeddings.resize(len(tokenizer)) | 266 | embeddings.resize(len(tokenizer)) |
| 247 | 267 | ||
| 248 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 268 | for placeholder_token_id, initializer_token_id in zip( |
| 249 | embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) | 269 | placeholder_token_ids, initializer_token_ids |
| 270 | ): | ||
| 271 | embeddings.add_embed( | ||
| 272 | placeholder_token_id, initializer_token_id, initializer_noise | ||
| 273 | ) | ||
| 250 | 274 | ||
| 251 | return placeholder_token_ids, initializer_token_ids | 275 | return placeholder_token_ids, initializer_token_ids |
| 252 | 276 | ||
| @@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler): | |||
| 261 | 285 | ||
| 262 | # Expand the tensors. | 286 | # Expand the tensors. |
| 263 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 | 287 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 |
| 264 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 288 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
| 289 | timesteps | ||
| 290 | ].float() | ||
| 265 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): | 291 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
| 266 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 292 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
| 267 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) | 293 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
| 268 | 294 | ||
| 269 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 295 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
| 296 | device=timesteps.device | ||
| 297 | )[timesteps].float() | ||
| 270 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): | 298 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
| 271 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 299 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
| 272 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | 300 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
| @@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler): | |||
| 277 | 305 | ||
| 278 | 306 | ||
| 279 | def get_original( | 307 | def get_original( |
| 280 | noise_scheduler, | 308 | noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor |
| 281 | model_output, | ||
| 282 | sample: torch.FloatTensor, | ||
| 283 | timesteps: torch.IntTensor | ||
| 284 | ): | 309 | ): |
| 285 | alphas_cumprod = noise_scheduler.alphas_cumprod | 310 | alphas_cumprod = noise_scheduler.alphas_cumprod |
| 286 | sqrt_alphas_cumprod = alphas_cumprod**0.5 | 311 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| 287 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | 312 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| 288 | 313 | ||
| 289 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 314 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
| 315 | timesteps | ||
| 316 | ].float() | ||
| 290 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | 317 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): |
| 291 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 318 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
| 292 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | 319 | alpha = sqrt_alphas_cumprod.expand(sample.shape) |
| 293 | 320 | ||
| 294 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 321 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
| 322 | device=timesteps.device | ||
| 323 | )[timesteps].float() | ||
| 295 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | 324 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): |
| 296 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 325 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
| 297 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | 326 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) |
| @@ -329,7 +358,9 @@ def loss_step( | |||
| 329 | eval: bool = False, | 358 | eval: bool = False, |
| 330 | ): | 359 | ): |
| 331 | images = batch["pixel_values"] | 360 | images = batch["pixel_values"] |
| 332 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | 361 | generator = ( |
| 362 | torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | ||
| 363 | ) | ||
| 333 | bsz = images.shape[0] | 364 | bsz = images.shape[0] |
| 334 | 365 | ||
| 335 | # Convert images to latent space | 366 | # Convert images to latent space |
| @@ -342,7 +373,7 @@ def loss_step( | |||
| 342 | dtype=latents.dtype, | 373 | dtype=latents.dtype, |
| 343 | layout=latents.layout, | 374 | layout=latents.layout, |
| 344 | device=latents.device, | 375 | device=latents.device, |
| 345 | generator=generator | 376 | generator=generator, |
| 346 | ) | 377 | ) |
| 347 | applied_noise = noise | 378 | applied_noise = noise |
| 348 | 379 | ||
| @@ -353,7 +384,7 @@ def loss_step( | |||
| 353 | octaves=4, | 384 | octaves=4, |
| 354 | dtype=latents.dtype, | 385 | dtype=latents.dtype, |
| 355 | device=latents.device, | 386 | device=latents.device, |
| 356 | generator=generator | 387 | generator=generator, |
| 357 | ) | 388 | ) |
| 358 | 389 | ||
| 359 | if input_pertubation != 0: | 390 | if input_pertubation != 0: |
| @@ -362,7 +393,7 @@ def loss_step( | |||
| 362 | dtype=latents.dtype, | 393 | dtype=latents.dtype, |
| 363 | layout=latents.layout, | 394 | layout=latents.layout, |
| 364 | device=latents.device, | 395 | device=latents.device, |
| 365 | generator=generator | 396 | generator=generator, |
| 366 | ) | 397 | ) |
| 367 | 398 | ||
| 368 | # Sample a random timestep for each image | 399 | # Sample a random timestep for each image |
| @@ -375,25 +406,27 @@ def loss_step( | |||
| 375 | 406 | ||
| 376 | # Get the text embedding for conditioning | 407 | # Get the text embedding for conditioning |
| 377 | encoder_hidden_states = get_extended_embeddings( | 408 | encoder_hidden_states = get_extended_embeddings( |
| 378 | text_encoder, | 409 | text_encoder, batch["input_ids"], batch["attention_mask"] |
| 379 | batch["input_ids"], | ||
| 380 | batch["attention_mask"] | ||
| 381 | ) | 410 | ) |
| 382 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 411 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
| 383 | 412 | ||
| 384 | # Predict the noise residual | 413 | # Predict the noise residual |
| 385 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] | 414 | model_pred = unet( |
| 415 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False | ||
| 416 | )[0] | ||
| 386 | 417 | ||
| 387 | if guidance_scale != 0: | 418 | if guidance_scale != 0: |
| 388 | uncond_encoder_hidden_states = get_extended_embeddings( | 419 | uncond_encoder_hidden_states = get_extended_embeddings( |
| 389 | text_encoder, | 420 | text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] |
| 390 | batch["negative_input_ids"], | ||
| 391 | batch["negative_attention_mask"] | ||
| 392 | ) | 421 | ) |
| 393 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | 422 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) |
| 394 | 423 | ||
| 395 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] | 424 | model_pred_uncond = unet( |
| 396 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 425 | noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False |
| 426 | )[0] | ||
| 427 | model_pred = model_pred_uncond + guidance_scale * ( | ||
| 428 | model_pred - model_pred_uncond | ||
| 429 | ) | ||
| 397 | 430 | ||
| 398 | # Get the target for loss depending on the prediction type | 431 | # Get the target for loss depending on the prediction type |
| 399 | if noise_scheduler.config.prediction_type == "epsilon": | 432 | if noise_scheduler.config.prediction_type == "epsilon": |
| @@ -401,7 +434,9 @@ def loss_step( | |||
| 401 | elif noise_scheduler.config.prediction_type == "v_prediction": | 434 | elif noise_scheduler.config.prediction_type == "v_prediction": |
| 402 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 435 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
| 403 | else: | 436 | else: |
| 404 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 437 | raise ValueError( |
| 438 | f"Unknown prediction type {noise_scheduler.config.prediction_type}" | ||
| 439 | ) | ||
| 405 | 440 | ||
| 406 | acc = (model_pred == target).float().mean() | 441 | acc = (model_pred == target).float().mean() |
| 407 | 442 | ||
| @@ -414,7 +449,9 @@ def loss_step( | |||
| 414 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 449 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 415 | 450 | ||
| 416 | # Compute prior loss | 451 | # Compute prior loss |
| 417 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") | 452 | prior_loss = F.mse_loss( |
| 453 | model_pred_prior.float(), target_prior.float(), reduction="none" | ||
| 454 | ) | ||
| 418 | 455 | ||
| 419 | # Add the prior loss to the instance loss. | 456 | # Add the prior loss to the instance loss. |
| 420 | loss = loss + prior_loss_weight * prior_loss | 457 | loss = loss + prior_loss_weight * prior_loss |
| @@ -433,7 +470,10 @@ def loss_step( | |||
| 433 | if min_snr_gamma != 0: | 470 | if min_snr_gamma != 0: |
| 434 | snr = compute_snr(timesteps, noise_scheduler) | 471 | snr = compute_snr(timesteps, noise_scheduler) |
| 435 | mse_loss_weights = ( | 472 | mse_loss_weights = ( |
| 436 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | 473 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( |
| 474 | dim=1 | ||
| 475 | )[0] | ||
| 476 | / snr | ||
| 437 | ) | 477 | ) |
| 438 | loss = loss * mse_loss_weights | 478 | loss = loss * mse_loss_weights |
| 439 | 479 | ||
| @@ -447,8 +487,14 @@ def loss_step( | |||
| 447 | 487 | ||
| 448 | 488 | ||
| 449 | class LossCallable(Protocol): | 489 | class LossCallable(Protocol): |
| 450 | def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], | 490 | def __call__( |
| 451 | eval: bool = False) -> Tuple[Any, Any, int]: ... | 491 | self, |
| 492 | step: int, | ||
| 493 | batch: dict[Any, Any], | ||
| 494 | cache: dict[str, Any], | ||
| 495 | eval: bool = False, | ||
| 496 | ) -> Tuple[Any, Any, int]: | ||
| 497 | ... | ||
| 452 | 498 | ||
| 453 | 499 | ||
| 454 | def train_loop( | 500 | def train_loop( |
| @@ -472,9 +518,14 @@ def train_loop( | |||
| 472 | avg_acc_val: AverageMeter = AverageMeter(), | 518 | avg_acc_val: AverageMeter = AverageMeter(), |
| 473 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 519 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 474 | ): | 520 | ): |
| 475 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 521 | num_training_steps_per_epoch = math.ceil( |
| 476 | num_val_steps_per_epoch = math.ceil( | 522 | len(train_dataloader) / gradient_accumulation_steps |
| 477 | len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 | 523 | ) |
| 524 | num_val_steps_per_epoch = ( | ||
| 525 | math.ceil(len(val_dataloader) / gradient_accumulation_steps) | ||
| 526 | if val_dataloader is not None | ||
| 527 | else 0 | ||
| 528 | ) | ||
| 478 | 529 | ||
| 479 | num_training_steps = num_training_steps_per_epoch * num_epochs | 530 | num_training_steps = num_training_steps_per_epoch * num_epochs |
| 480 | num_val_steps = num_val_steps_per_epoch * num_epochs | 531 | num_val_steps = num_val_steps_per_epoch * num_epochs |
| @@ -488,14 +539,14 @@ def train_loop( | |||
| 488 | local_progress_bar = tqdm( | 539 | local_progress_bar = tqdm( |
| 489 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 540 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| 490 | disable=not accelerator.is_local_main_process, | 541 | disable=not accelerator.is_local_main_process, |
| 491 | dynamic_ncols=True | 542 | dynamic_ncols=True, |
| 492 | ) | 543 | ) |
| 493 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | 544 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
| 494 | 545 | ||
| 495 | global_progress_bar = tqdm( | 546 | global_progress_bar = tqdm( |
| 496 | range(num_training_steps + num_val_steps), | 547 | range(num_training_steps + num_val_steps), |
| 497 | disable=not accelerator.is_local_main_process, | 548 | disable=not accelerator.is_local_main_process, |
| 498 | dynamic_ncols=True | 549 | dynamic_ncols=True, |
| 499 | ) | 550 | ) |
| 500 | global_progress_bar.set_description("Total progress") | 551 | global_progress_bar.set_description("Total progress") |
| 501 | 552 | ||
| @@ -513,7 +564,9 @@ def train_loop( | |||
| 513 | try: | 564 | try: |
| 514 | import dadaptation | 565 | import dadaptation |
| 515 | 566 | ||
| 516 | isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) | 567 | isDadaptation = isinstance( |
| 568 | optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan) | ||
| 569 | ) | ||
| 517 | except ImportError: | 570 | except ImportError: |
| 518 | pass | 571 | pass |
| 519 | 572 | ||
| @@ -565,7 +618,10 @@ def train_loop( | |||
| 565 | label = group_labels[i] if i < len(group_labels) else f"{i}" | 618 | label = group_labels[i] if i < len(group_labels) else f"{i}" |
| 566 | logs[f"lr/{label}"] = lr | 619 | logs[f"lr/{label}"] = lr |
| 567 | if isDadaptation: | 620 | if isDadaptation: |
| 568 | lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] | 621 | lr = ( |
| 622 | optimizer.param_groups[i]["d"] | ||
| 623 | * optimizer.param_groups[i]["lr"] | ||
| 624 | ) | ||
| 569 | logs[f"d*lr/{label}"] = lr | 625 | logs[f"d*lr/{label}"] = lr |
| 570 | lrs[label] = lr | 626 | lrs[label] = lr |
| 571 | 627 | ||
| @@ -573,8 +629,10 @@ def train_loop( | |||
| 573 | 629 | ||
| 574 | local_progress_bar.set_postfix(**logs) | 630 | local_progress_bar.set_postfix(**logs) |
| 575 | 631 | ||
| 576 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 632 | if ((step + 1) % gradient_accumulation_steps == 0) or ( |
| 577 | before_optimize_result = on_before_optimize(epoch) | 633 | (step + 1) == len(train_dataloader) |
| 634 | ): | ||
| 635 | before_optimize_result = on_before_optimize(cycle) | ||
| 578 | 636 | ||
| 579 | optimizer.step() | 637 | optimizer.step() |
| 580 | lr_scheduler.step() | 638 | lr_scheduler.step() |
| @@ -614,7 +672,9 @@ def train_loop( | |||
| 614 | } | 672 | } |
| 615 | local_progress_bar.set_postfix(**logs) | 673 | local_progress_bar.set_postfix(**logs) |
| 616 | 674 | ||
| 617 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): | 675 | if ((step + 1) % gradient_accumulation_steps == 0) or ( |
| 676 | (step + 1) == len(val_dataloader) | ||
| 677 | ): | ||
| 618 | local_progress_bar.update(1) | 678 | local_progress_bar.update(1) |
| 619 | global_progress_bar.update(1) | 679 | global_progress_bar.update(1) |
| 620 | 680 | ||
| @@ -634,7 +694,8 @@ def train_loop( | |||
| 634 | global_progress_bar.clear() | 694 | global_progress_bar.clear() |
| 635 | 695 | ||
| 636 | accelerator.print( | 696 | accelerator.print( |
| 637 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") | 697 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}" |
| 698 | ) | ||
| 638 | on_checkpoint(global_step, "milestone") | 699 | on_checkpoint(global_step, "milestone") |
| 639 | best_acc_val = avg_acc_val.max | 700 | best_acc_val = avg_acc_val.max |
| 640 | else: | 701 | else: |
| @@ -644,7 +705,8 @@ def train_loop( | |||
| 644 | global_progress_bar.clear() | 705 | global_progress_bar.clear() |
| 645 | 706 | ||
| 646 | accelerator.print( | 707 | accelerator.print( |
| 647 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") | 708 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}" |
| 709 | ) | ||
| 648 | on_checkpoint(global_step, "milestone") | 710 | on_checkpoint(global_step, "milestone") |
| 649 | best_acc = avg_acc.max | 711 | best_acc = avg_acc.max |
| 650 | 712 | ||
| @@ -700,17 +762,32 @@ def train( | |||
| 700 | avg_acc_val: AverageMeter = AverageMeter(), | 762 | avg_acc_val: AverageMeter = AverageMeter(), |
| 701 | **kwargs, | 763 | **kwargs, |
| 702 | ): | 764 | ): |
| 703 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 765 | ( |
| 704 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) | 766 | text_encoder, |
| 767 | unet, | ||
| 768 | optimizer, | ||
| 769 | train_dataloader, | ||
| 770 | val_dataloader, | ||
| 771 | lr_scheduler, | ||
| 772 | ) = strategy.prepare( | ||
| 773 | accelerator, | ||
| 774 | text_encoder, | ||
| 775 | unet, | ||
| 776 | optimizer, | ||
| 777 | train_dataloader, | ||
| 778 | val_dataloader, | ||
| 779 | lr_scheduler, | ||
| 780 | **kwargs, | ||
| 781 | ) | ||
| 705 | 782 | ||
| 706 | vae.to(accelerator.device, dtype=dtype) | 783 | vae.to(accelerator.device, dtype=dtype) |
| 707 | vae.requires_grad_(False) | 784 | vae.requires_grad_(False) |
| 708 | vae.eval() | 785 | vae.eval() |
| 709 | 786 | ||
| 710 | vae = torch.compile(vae, backend='hidet') | 787 | vae = torch.compile(vae, backend="hidet") |
| 711 | 788 | ||
| 712 | if compile_unet: | 789 | if compile_unet: |
| 713 | unet = torch.compile(unet, backend='hidet') | 790 | unet = torch.compile(unet, backend="hidet") |
| 714 | # unet = torch.compile(unet, mode="reduce-overhead") | 791 | # unet = torch.compile(unet, mode="reduce-overhead") |
| 715 | 792 | ||
| 716 | callbacks = strategy.callbacks( | 793 | callbacks = strategy.callbacks( |
diff --git a/training/lr.py b/training/lr.py index f5b362f..a75078f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -23,12 +23,12 @@ def plot_metrics( | |||
| 23 | fig, ax_loss = plt.subplots() | 23 | fig, ax_loss = plt.subplots() |
| 24 | ax_acc = ax_loss.twinx() | 24 | ax_acc = ax_loss.twinx() |
| 25 | 25 | ||
| 26 | ax_loss.plot(lrs, losses, color='red') | 26 | ax_loss.plot(lrs, losses, color="red") |
| 27 | ax_loss.set_xscale("log") | 27 | ax_loss.set_xscale("log") |
| 28 | ax_loss.set_xlabel(f"Learning rate") | 28 | ax_loss.set_xlabel(f"Learning rate") |
| 29 | ax_loss.set_ylabel("Loss") | 29 | ax_loss.set_ylabel("Loss") |
| 30 | 30 | ||
| 31 | ax_acc.plot(lrs, accs, color='blue') | 31 | ax_acc.plot(lrs, accs, color="blue") |
| 32 | ax_acc.set_xscale("log") | 32 | ax_acc.set_xscale("log") |
| 33 | ax_acc.set_ylabel("Accuracy") | 33 | ax_acc.set_ylabel("Accuracy") |
| 34 | 34 | ||
diff --git a/training/optimization.py b/training/optimization.py index d22a900..55531bf 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -5,7 +5,10 @@ from functools import partial | |||
| 5 | import torch | 5 | import torch |
| 6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
| 7 | 7 | ||
| 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import ( |
| 9 | get_scheduler as get_scheduler_, | ||
| 10 | get_cosine_with_hard_restarts_schedule_with_warmup, | ||
| 11 | ) | ||
| 9 | from transformers.optimization import get_adafactor_schedule | 12 | from transformers.optimization import get_adafactor_schedule |
| 10 | 13 | ||
| 11 | 14 | ||
| @@ -52,7 +55,7 @@ def get_one_cycle_schedule( | |||
| 52 | annealing_exp: int = 1, | 55 | annealing_exp: int = 1, |
| 53 | min_lr: float = 0.04, | 56 | min_lr: float = 0.04, |
| 54 | mid_point: float = 0.3, | 57 | mid_point: float = 0.3, |
| 55 | last_epoch: int = -1 | 58 | last_epoch: int = -1, |
| 56 | ): | 59 | ): |
| 57 | if warmup == "linear": | 60 | if warmup == "linear": |
| 58 | warmup_func = warmup_linear | 61 | warmup_func = warmup_linear |
| @@ -83,12 +86,16 @@ def get_one_cycle_schedule( | |||
| 83 | 86 | ||
| 84 | def lr_lambda(current_step: int): | 87 | def lr_lambda(current_step: int): |
| 85 | phase = [p for p in phases if current_step >= p.step_min][-1] | 88 | phase = [p for p in phases if current_step >= p.step_min][-1] |
| 86 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | 89 | return phase.min + phase.func( |
| 90 | (current_step - phase.step_min) / (phase.step_max - phase.step_min) | ||
| 91 | ) * (phase.max - phase.min) | ||
| 87 | 92 | ||
| 88 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 93 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| 89 | 94 | ||
| 90 | 95 | ||
| 91 | def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): | 96 | def get_exponential_growing_schedule( |
| 97 | optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 | ||
| 98 | ): | ||
| 92 | def lr_lambda(base_lr: float, current_step: int): | 99 | def lr_lambda(base_lr: float, current_step: int): |
| 93 | return (end_lr / base_lr) ** (current_step / num_training_steps) | 100 | return (end_lr / base_lr) ** (current_step / num_training_steps) |
| 94 | 101 | ||
| @@ -132,7 +139,14 @@ def get_scheduler( | |||
| 132 | ) | 139 | ) |
| 133 | elif id == "exponential_growth": | 140 | elif id == "exponential_growth": |
| 134 | if cycles is None: | 141 | if cycles is None: |
| 135 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 142 | cycles = math.ceil( |
| 143 | math.sqrt( | ||
| 144 | ( | ||
| 145 | (num_training_steps - num_warmup_steps) | ||
| 146 | / num_training_steps_per_epoch | ||
| 147 | ) | ||
| 148 | ) | ||
| 149 | ) | ||
| 136 | 150 | ||
| 137 | lr_scheduler = get_exponential_growing_schedule( | 151 | lr_scheduler = get_exponential_growing_schedule( |
| 138 | optimizer=optimizer, | 152 | optimizer=optimizer, |
| @@ -141,7 +155,14 @@ def get_scheduler( | |||
| 141 | ) | 155 | ) |
| 142 | elif id == "cosine_with_restarts": | 156 | elif id == "cosine_with_restarts": |
| 143 | if cycles is None: | 157 | if cycles is None: |
| 144 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 158 | cycles = math.ceil( |
| 159 | math.sqrt( | ||
| 160 | ( | ||
| 161 | (num_training_steps - num_warmup_steps) | ||
| 162 | / num_training_steps_per_epoch | ||
| 163 | ) | ||
| 164 | ) | ||
| 165 | ) | ||
| 145 | 166 | ||
| 146 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 167 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 147 | optimizer=optimizer, | 168 | optimizer=optimizer, |
| @@ -150,10 +171,7 @@ def get_scheduler( | |||
| 150 | num_cycles=cycles, | 171 | num_cycles=cycles, |
| 151 | ) | 172 | ) |
| 152 | elif id == "adafactor": | 173 | elif id == "adafactor": |
| 153 | lr_scheduler = get_adafactor_schedule( | 174 | lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) |
| 154 | optimizer, | ||
| 155 | initial_lr=min_lr | ||
| 156 | ) | ||
| 157 | else: | 175 | else: |
| 158 | lr_scheduler = get_scheduler_( | 176 | lr_scheduler = get_scheduler_( |
| 159 | id, | 177 | id, |
diff --git a/training/sampler.py b/training/sampler.py index bdb3e90..0487d66 100644 --- a/training/sampler.py +++ b/training/sampler.py | |||
| @@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler): | |||
| 134 | def weights(self): | 134 | def weights(self): |
| 135 | if not self._warmed_up(): | 135 | if not self._warmed_up(): |
| 136 | return np.ones([self.num_timesteps], dtype=np.float64) | 136 | return np.ones([self.num_timesteps], dtype=np.float64) |
| 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) | 137 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) |
| 138 | weights /= np.sum(weights) | 138 | weights /= np.sum(weights) |
| 139 | weights *= 1 - self.uniform_prob | 139 | weights *= 1 - self.uniform_prob |
| 140 | weights += self.uniform_prob / len(weights) | 140 | weights += self.uniform_prob / len(weights) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( | |||
| 29 | sample_output_dir: Path, | 29 | sample_output_dir: Path, |
| 30 | checkpoint_output_dir: Path, | 30 | checkpoint_output_dir: Path, |
| 31 | seed: int, | 31 | seed: int, |
| 32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_cycles: int, |
| 33 | max_grad_norm: float = 1.0, | 33 | max_grad_norm: float = 1.0, |
| 34 | use_ema: bool = False, | 34 | use_ema: bool = False, |
| 35 | ema_inv_gamma: float = 1.0, | 35 | ema_inv_gamma: float = 1.0, |
| @@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( | |||
| 85 | return nullcontext() | 85 | return nullcontext() |
| 86 | 86 | ||
| 87 | @contextmanager | 87 | @contextmanager |
| 88 | def on_train(epoch: int): | 88 | def on_train(cycle: int): |
| 89 | unet.train() | 89 | unet.train() |
| 90 | tokenizer.train() | 90 | tokenizer.train() |
| 91 | 91 | ||
| 92 | if epoch < train_text_encoder_epochs: | 92 | if cycle < train_text_encoder_cycles: |
| 93 | text_encoder.train() | 93 | text_encoder.train() |
| 94 | elif epoch == train_text_encoder_epochs: | 94 | tokenizer.train() |
| 95 | text_encoder.requires_grad_(False) | ||
| 96 | text_encoder.eval() | ||
| 97 | 95 | ||
| 98 | yield | 96 | yield |
| 99 | 97 | ||
| @@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( | |||
| 106 | with ema_context(): | 104 | with ema_context(): |
| 107 | yield | 105 | yield |
| 108 | 106 | ||
| 109 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
| 110 | params_to_clip = [unet.parameters()] | 108 | params_to_clip = [unet.parameters()] |
| 111 | if epoch < train_text_encoder_epochs: | 109 | if cycle < train_text_encoder_cycles: |
| 112 | params_to_clip.append(text_encoder.parameters()) | 110 | params_to_clip.append(text_encoder.parameters()) |
| 113 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 111 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
| 114 | 112 | ||
| @@ -189,8 +187,16 @@ def dreambooth_prepare( | |||
| 189 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 190 | **kwargs | 188 | **kwargs |
| 191 | ): | 189 | ): |
| 192 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 190 | ( |
| 193 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 191 | text_encoder, |
| 192 | unet, | ||
| 193 | optimizer, | ||
| 194 | train_dataloader, | ||
| 195 | val_dataloader, | ||
| 196 | lr_scheduler, | ||
| 197 | ) = accelerator.prepare( | ||
| 198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 199 | ) | ||
| 194 | 200 | ||
| 195 | text_encoder.text_model.embeddings.requires_grad_(False) | 201 | text_encoder.text_model.embeddings.requires_grad_(False) |
| 196 | 202 | ||
| @@ -198,6 +204,5 @@ def dreambooth_prepare( | |||
| 198 | 204 | ||
| 199 | 205 | ||
| 200 | dreambooth_strategy = TrainingStrategy( | 206 | dreambooth_strategy = TrainingStrategy( |
| 201 | callbacks=dreambooth_strategy_callbacks, | 207 | callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare |
| 202 | prepare=dreambooth_prepare | ||
| 203 | ) | 208 | ) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -81,7 +81,7 @@ def lora_strategy_callbacks( | |||
| 81 | tokenizer.eval() | 81 | tokenizer.eval() |
| 82 | yield | 82 | yield |
| 83 | 83 | ||
| 84 | def on_before_optimize(epoch: int): | 84 | def on_before_optimize(cycle: int): |
| 85 | if not pti_mode: | 85 | if not pti_mode: |
| 86 | accelerator.clip_grad_norm_( | 86 | accelerator.clip_grad_norm_( |
| 87 | itertools.chain( | 87 | itertools.chain( |
| @@ -89,7 +89,7 @@ def lora_strategy_callbacks( | |||
| 89 | text_encoder.text_model.encoder.parameters(), | 89 | text_encoder.text_model.encoder.parameters(), |
| 90 | text_encoder.text_model.final_layer_norm.parameters(), | 90 | text_encoder.text_model.final_layer_norm.parameters(), |
| 91 | ), | 91 | ), |
| 92 | max_grad_norm | 92 | max_grad_norm, |
| 93 | ) | 93 | ) |
| 94 | 94 | ||
| 95 | if len(placeholder_tokens) != 0 and use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
| @@ -108,7 +108,9 @@ def lora_strategy_callbacks( | |||
| 108 | 108 | ||
| 109 | if lambda_ != 0: | 109 | if lambda_ != 0: |
| 110 | norm = w[:, :].norm(dim=-1, keepdim=True) | 110 | norm = w[:, :].norm(dim=-1, keepdim=True) |
| 111 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 111 | w[:].add_( |
| 112 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
| 113 | ) | ||
| 112 | 114 | ||
| 113 | @torch.no_grad() | 115 | @torch.no_grad() |
| 114 | def on_checkpoint(step, postfix): | 116 | def on_checkpoint(step, postfix): |
| @@ -128,25 +130,32 @@ def lora_strategy_callbacks( | |||
| 128 | 130 | ||
| 129 | if not pti_mode: | 131 | if not pti_mode: |
| 130 | lora_config = {} | 132 | lora_config = {} |
| 131 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 133 | state_dict = get_peft_model_state_dict( |
| 134 | unet_, state_dict=accelerator.get_state_dict(unet_) | ||
| 135 | ) | ||
| 132 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 136 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
| 133 | 137 | ||
| 134 | text_encoder_state_dict = get_peft_model_state_dict( | 138 | text_encoder_state_dict = get_peft_model_state_dict( |
| 135 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) | 139 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
| 136 | ) | 140 | ) |
| 137 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 141 | text_encoder_state_dict = { |
| 142 | f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() | ||
| 143 | } | ||
| 138 | state_dict.update(text_encoder_state_dict) | 144 | state_dict.update(text_encoder_state_dict) |
| 139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 145 | lora_config[ |
| 146 | "text_encoder_peft_config" | ||
| 147 | ] = text_encoder_.get_peft_config_as_dict(inference=True) | ||
| 140 | 148 | ||
| 141 | if len(placeholder_tokens) != 0: | 149 | if len(placeholder_tokens) != 0: |
| 142 | ti_state_dict = { | 150 | ti_state_dict = { |
| 143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | 151 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) |
| 144 | for (token, ids) | 152 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) |
| 145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
| 146 | } | 153 | } |
| 147 | state_dict.update(ti_state_dict) | 154 | state_dict.update(ti_state_dict) |
| 148 | 155 | ||
| 149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 156 | save_file( |
| 157 | state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" | ||
| 158 | ) | ||
| 150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 159 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
| 151 | json.dump(lora_config, f) | 160 | json.dump(lora_config, f) |
| 152 | 161 | ||
| @@ -185,10 +194,18 @@ def lora_prepare( | |||
| 185 | train_dataloader: DataLoader, | 194 | train_dataloader: DataLoader, |
| 186 | val_dataloader: Optional[DataLoader], | 195 | val_dataloader: Optional[DataLoader], |
| 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 196 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 188 | **kwargs | 197 | **kwargs, |
| 189 | ): | 198 | ): |
| 190 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 199 | ( |
| 191 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 200 | text_encoder, |
| 201 | unet, | ||
| 202 | optimizer, | ||
| 203 | train_dataloader, | ||
| 204 | val_dataloader, | ||
| 205 | lr_scheduler, | ||
| 206 | ) = accelerator.prepare( | ||
| 207 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 208 | ) | ||
| 192 | 209 | ||
| 193 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) | 210 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
| 194 | 211 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | yield | 104 | yield |
| 105 | 105 | ||
| 106 | @torch.no_grad() | 106 | @torch.no_grad() |
| 107 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
| 108 | if use_emb_decay: | 108 | if use_emb_decay: |
| 109 | params = [ | 109 | params = [ |
| 110 | p | 110 | p |
| @@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( | |||
| 116 | @torch.no_grad() | 116 | @torch.no_grad() |
| 117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
| 118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
| 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step( |
| 120 | text_encoder.text_model.embeddings.token_embedding.parameters() | ||
| 121 | ) | ||
| 120 | 122 | ||
| 121 | if use_emb_decay and w is not None: | 123 | if use_emb_decay and w is not None: |
| 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 124 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
| @@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( | |||
| 124 | 126 | ||
| 125 | if lambda_ != 0: | 127 | if lambda_ != 0: |
| 126 | norm = w[:, :].norm(dim=-1, keepdim=True) | 128 | norm = w[:, :].norm(dim=-1, keepdim=True) |
| 127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 129 | w[:].add_( |
| 130 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
| 131 | ) | ||
| 128 | 132 | ||
| 129 | def on_log(): | 133 | def on_log(): |
| 130 | if ema_embeddings is not None: | 134 | if ema_embeddings is not None: |
| @@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( | |||
| 136 | print(f"Saving checkpoint for step {step}...") | 140 | print(f"Saving checkpoint for step {step}...") |
| 137 | 141 | ||
| 138 | with ema_context(): | 142 | with ema_context(): |
| 139 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
| 140 | text_encoder.text_model.embeddings.save_embed( | 144 | text_encoder.text_model.embeddings.save_embed( |
| 141 | ids, | 145 | ids, |
| 142 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 146 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", |
| 143 | ) | 147 | ) |
| 144 | 148 | ||
| 145 | @torch.no_grad() | 149 | @torch.no_grad() |
| @@ -183,7 +187,7 @@ def textual_inversion_prepare( | |||
| 183 | val_dataloader: Optional[DataLoader], | 187 | val_dataloader: Optional[DataLoader], |
| 184 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 188 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 185 | gradient_checkpointing: bool = False, | 189 | gradient_checkpointing: bool = False, |
| 186 | **kwargs | 190 | **kwargs, |
| 187 | ): | 191 | ): |
| 188 | weight_dtype = torch.float32 | 192 | weight_dtype = torch.float32 |
| 189 | if accelerator.state.mixed_precision == "fp16": | 193 | if accelerator.state.mixed_precision == "fp16": |
| @@ -191,8 +195,15 @@ def textual_inversion_prepare( | |||
| 191 | elif accelerator.state.mixed_precision == "bf16": | 195 | elif accelerator.state.mixed_precision == "bf16": |
| 192 | weight_dtype = torch.bfloat16 | 196 | weight_dtype = torch.bfloat16 |
| 193 | 197 | ||
| 194 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 198 | ( |
| 195 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 199 | text_encoder, |
| 200 | optimizer, | ||
| 201 | train_dataloader, | ||
| 202 | val_dataloader, | ||
| 203 | lr_scheduler, | ||
| 204 | ) = accelerator.prepare( | ||
| 205 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 206 | ) | ||
| 196 | 207 | ||
| 197 | unet.to(accelerator.device, dtype=weight_dtype) | 208 | unet.to(accelerator.device, dtype=weight_dtype) |
| 198 | unet.requires_grad_(False) | 209 | unet.requires_grad_(False) |
