diff options
author | Volpeon <git@volpeon.ink> | 2022-10-24 23:46:18 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-24 23:46:18 +0200 |
commit | baba91864a45939cef4f77f6ca96ade7ae5ef274 (patch) | |
tree | c40fc949a94d5a2bee81b2b505b814e7c7f82cc1 /dreambooth.py | |
parent | Update (diff) | |
download | textual-inversion-diff-baba91864a45939cef4f77f6ca96ade7ae5ef274.tar.gz textual-inversion-diff-baba91864a45939cef4f77f6ca96ade7ae5ef274.tar.bz2 textual-inversion-diff-baba91864a45939cef4f77f6ca96ade7ae5ef274.zip |
Advanced datasets
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 68 |
1 files changed, 42 insertions, 26 deletions
diff --git a/dreambooth.py b/dreambooth.py index 5c26f12..2c24908 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -71,13 +71,13 @@ def parse_args(): | |||
71 | parser.add_argument( | 71 | parser.add_argument( |
72 | "--placeholder_token", | 72 | "--placeholder_token", |
73 | type=str, | 73 | type=str, |
74 | default="<*>", | 74 | nargs='*', |
75 | help="A token to use as a placeholder for the concept.", | 75 | help="A token to use as a placeholder for the concept.", |
76 | ) | 76 | ) |
77 | parser.add_argument( | 77 | parser.add_argument( |
78 | "--initializer_token", | 78 | "--initializer_token", |
79 | type=str, | 79 | type=str, |
80 | default=None, | 80 | nargs='*', |
81 | help="A token to use as initializer word." | 81 | help="A token to use as initializer word." |
82 | ) | 82 | ) |
83 | parser.add_argument( | 83 | parser.add_argument( |
@@ -316,6 +316,18 @@ def parse_args(): | |||
316 | if args.instance_identifier is None: | 316 | if args.instance_identifier is None: |
317 | raise ValueError("You must specify --instance_identifier") | 317 | raise ValueError("You must specify --instance_identifier") |
318 | 318 | ||
319 | if isinstance(args.initializer_token, str): | ||
320 | args.initializer_token = [args.initializer_token] | ||
321 | |||
322 | if isinstance(args.placeholder_token, str): | ||
323 | args.placeholder_token = [args.placeholder_token] | ||
324 | |||
325 | if len(args.placeholder_token) == 0: | ||
326 | args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] | ||
327 | |||
328 | if len(args.placeholder_token) != len(args.initializer_token): | ||
329 | raise ValueError("Number of items in --placeholder_token and --initializer_token must match") | ||
330 | |||
319 | if args.output_dir is None: | 331 | if args.output_dir is None: |
320 | raise ValueError("You must specify --output_dir") | 332 | raise ValueError("You must specify --output_dir") |
321 | 333 | ||
@@ -379,9 +391,6 @@ class Checkpointer: | |||
379 | 391 | ||
380 | @torch.no_grad() | 392 | @torch.no_grad() |
381 | def save_embedding(self, step, postfix): | 393 | def save_embedding(self, step, postfix): |
382 | if self.placeholder_token_id is None: | ||
383 | return | ||
384 | |||
385 | print("Saving checkpoint for step %d..." % step) | 394 | print("Saving checkpoint for step %d..." % step) |
386 | 395 | ||
387 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 396 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
@@ -389,12 +398,13 @@ class Checkpointer: | |||
389 | 398 | ||
390 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) | 399 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
391 | 400 | ||
392 | # Save a checkpoint | 401 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
393 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 402 | # Save a checkpoint |
394 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 403 | learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id] |
404 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | ||
395 | 405 | ||
396 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 406 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
397 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) | 407 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
398 | 408 | ||
399 | del unwrapped | 409 | del unwrapped |
400 | del learned_embeds | 410 | del learned_embeds |
@@ -467,7 +477,7 @@ class Checkpointer: | |||
467 | for i in range(self.sample_batches): | 477 | for i in range(self.sample_batches): |
468 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 478 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
469 | prompt = [ | 479 | prompt = [ |
470 | prompt.format(self.instance_identifier) | 480 | prompt.format(identifier=self.instance_identifier) |
471 | for batch in batches | 481 | for batch in batches |
472 | for prompt in batch["prompts"] | 482 | for prompt in batch["prompts"] |
473 | ][:self.sample_batch_size] | 483 | ][:self.sample_batch_size] |
@@ -516,8 +526,8 @@ def main(): | |||
516 | 526 | ||
517 | instance_identifier = args.instance_identifier | 527 | instance_identifier = args.instance_identifier |
518 | 528 | ||
519 | if args.placeholder_token is not None: | 529 | if len(args.placeholder_token) != 0: |
520 | instance_identifier = instance_identifier.format(args.placeholder_token) | 530 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) |
521 | 531 | ||
522 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 532 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
523 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 533 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
@@ -565,18 +575,16 @@ def main(): | |||
565 | # Freeze text_encoder and vae | 575 | # Freeze text_encoder and vae |
566 | freeze_params(vae.parameters()) | 576 | freeze_params(vae.parameters()) |
567 | 577 | ||
568 | if args.initializer_token is not None: | 578 | if len(args.initializer_token) != 0: |
569 | # Convert the initializer_token, placeholder_token to ids | 579 | # Convert the initializer_token, placeholder_token to ids |
570 | initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) | 580 | initializer_token_ids = torch.stack([ |
571 | print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") | 581 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) |
572 | initializer_token_ids = torch.tensor(initializer_token_ids[:1]) | 582 | for token in args.initializer_token |
583 | ]) | ||
573 | 584 | ||
574 | # Add the placeholder token in tokenizer | ||
575 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 585 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
576 | if num_added_tokens == 0: | 586 | print(f"Added {num_added_tokens} new tokens.") |
577 | print(f"Re-using existing token {args.placeholder_token}.") | 587 | |
578 | else: | ||
579 | print(f"Training new token {args.placeholder_token}.") | ||
580 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 588 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
581 | 589 | ||
582 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 590 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
@@ -586,7 +594,9 @@ def main(): | |||
586 | token_embeds = text_encoder.get_input_embeddings().weight.data | 594 | token_embeds = text_encoder.get_input_embeddings().weight.data |
587 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 595 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
588 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 596 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) |
589 | token_embeds[placeholder_token_id] = initializer_token_embeddings | 597 | |
598 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
599 | token_embeds[token_id] = embeddings | ||
590 | 600 | ||
591 | freeze_params(itertools.chain( | 601 | freeze_params(itertools.chain( |
592 | text_encoder.text_model.encoder.parameters(), | 602 | text_encoder.text_model.encoder.parameters(), |
@@ -594,7 +604,7 @@ def main(): | |||
594 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 604 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
595 | )) | 605 | )) |
596 | else: | 606 | else: |
597 | placeholder_token_id = None | 607 | placeholder_token_id = [] |
598 | 608 | ||
599 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | 609 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
600 | 610 | ||
@@ -721,7 +731,7 @@ def main(): | |||
721 | with torch.inference_mode(): | 731 | with torch.inference_mode(): |
722 | for batch in batched_data: | 732 | for batch in batched_data: |
723 | image_name = [item.class_image_path for item in batch] | 733 | image_name = [item.class_image_path for item in batch] |
724 | prompt = [item.prompt.format(args.class_identifier) for item in batch] | 734 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] |
725 | nprompt = [item.nprompt for item in batch] | 735 | nprompt = [item.nprompt for item in batch] |
726 | 736 | ||
727 | images = pipeline( | 737 | images = pipeline( |
@@ -787,7 +797,10 @@ def main(): | |||
787 | # We need to initialize the trackers we use, and also store our configuration. | 797 | # We need to initialize the trackers we use, and also store our configuration. |
788 | # The trackers initializes automatically on the main process. | 798 | # The trackers initializes automatically on the main process. |
789 | if accelerator.is_main_process: | 799 | if accelerator.is_main_process: |
790 | accelerator.init_trackers("dreambooth", config=vars(args)) | 800 | config = vars(args).copy() |
801 | config["initializer_token"] = " ".join(config["initializer_token"]) | ||
802 | config["placeholder_token"] = " ".join(config["placeholder_token"]) | ||
803 | accelerator.init_trackers("dreambooth", config=config) | ||
791 | 804 | ||
792 | # Train! | 805 | # Train! |
793 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 806 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
@@ -932,6 +945,9 @@ def main(): | |||
932 | global_step += 1 | 945 | global_step += 1 |
933 | 946 | ||
934 | if global_step % args.sample_frequency == 0: | 947 | if global_step % args.sample_frequency == 0: |
948 | local_progress_bar.clear() | ||
949 | global_progress_bar.clear() | ||
950 | |||
935 | checkpointer.save_embedding(global_step, "training") | 951 | checkpointer.save_embedding(global_step, "training") |
936 | sample_checkpoint = True | 952 | sample_checkpoint = True |
937 | 953 | ||