diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
| commit | 1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch) | |
| tree | 42d3302610046dbc5d39d254f7a2d5d5f601aa18 /train_ti.py | |
| parent | Fix (diff) | |
| download | textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2 textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip | |
Improved class prompt handling
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 29 |
1 files changed, 9 insertions, 20 deletions
diff --git a/train_ti.py b/train_ti.py index 9d06c50..55daa35 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -63,16 +63,10 @@ def parse_args(): | |||
| 63 | default="template", | 63 | default="template", |
| 64 | ) | 64 | ) |
| 65 | parser.add_argument( | 65 | parser.add_argument( |
| 66 | "--instance_identifier", | 66 | "--project", |
| 67 | type=str, | 67 | type=str, |
| 68 | default=None, | 68 | default=None, |
| 69 | help="A token to use as a placeholder for the concept.", | 69 | help="The name of the current project.", |
| 70 | ) | ||
| 71 | parser.add_argument( | ||
| 72 | "--class_identifier", | ||
| 73 | type=str, | ||
| 74 | default=None, | ||
| 75 | help="A token to use as a placeholder for the concept.", | ||
| 76 | ) | 70 | ) |
| 77 | parser.add_argument( | 71 | parser.add_argument( |
| 78 | "--placeholder_token", | 72 | "--placeholder_token", |
| @@ -334,6 +328,9 @@ def parse_args(): | |||
| 334 | if args.pretrained_model_name_or_path is None: | 328 | if args.pretrained_model_name_or_path is None: |
| 335 | raise ValueError("You must specify --pretrained_model_name_or_path") | 329 | raise ValueError("You must specify --pretrained_model_name_or_path") |
| 336 | 330 | ||
| 331 | if args.project is None: | ||
| 332 | raise ValueError("You must specify --project") | ||
| 333 | |||
| 337 | if isinstance(args.initializer_token, str): | 334 | if isinstance(args.initializer_token, str): |
| 338 | args.initializer_token = [args.initializer_token] | 335 | args.initializer_token = [args.initializer_token] |
| 339 | 336 | ||
| @@ -366,7 +363,6 @@ class Checkpointer(CheckpointerBase): | |||
| 366 | text_encoder, | 363 | text_encoder, |
| 367 | scheduler, | 364 | scheduler, |
| 368 | text_embeddings, | 365 | text_embeddings, |
| 369 | instance_identifier, | ||
| 370 | placeholder_token, | 366 | placeholder_token, |
| 371 | placeholder_token_id, | 367 | placeholder_token_id, |
| 372 | output_dir: Path, | 368 | output_dir: Path, |
| @@ -378,7 +374,6 @@ class Checkpointer(CheckpointerBase): | |||
| 378 | super().__init__( | 374 | super().__init__( |
| 379 | datamodule=datamodule, | 375 | datamodule=datamodule, |
| 380 | output_dir=output_dir, | 376 | output_dir=output_dir, |
| 381 | instance_identifier=instance_identifier, | ||
| 382 | placeholder_token=placeholder_token, | 377 | placeholder_token=placeholder_token, |
| 383 | placeholder_token_id=placeholder_token_id, | 378 | placeholder_token_id=placeholder_token_id, |
| 384 | sample_image_size=sample_image_size, | 379 | sample_image_size=sample_image_size, |
| @@ -441,14 +436,9 @@ class Checkpointer(CheckpointerBase): | |||
| 441 | def main(): | 436 | def main(): |
| 442 | args = parse_args() | 437 | args = parse_args() |
| 443 | 438 | ||
| 444 | instance_identifier = args.instance_identifier | ||
| 445 | |||
| 446 | if len(args.placeholder_token) != 0: | ||
| 447 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
| 448 | |||
| 449 | global_step_offset = args.global_step | 439 | global_step_offset = args.global_step |
| 450 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 440 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 451 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 441 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 452 | basepath.mkdir(parents=True, exist_ok=True) | 442 | basepath.mkdir(parents=True, exist_ok=True) |
| 453 | 443 | ||
| 454 | accelerator = Accelerator( | 444 | accelerator = Accelerator( |
| @@ -567,6 +557,7 @@ def main(): | |||
| 567 | 557 | ||
| 568 | def collate_fn(examples): | 558 | def collate_fn(examples): |
| 569 | prompts = [example["prompts"] for example in examples] | 559 | prompts = [example["prompts"] for example in examples] |
| 560 | cprompts = [example["cprompts"] for example in examples] | ||
| 570 | nprompts = [example["nprompts"] for example in examples] | 561 | nprompts = [example["nprompts"] for example in examples] |
| 571 | input_ids = [example["instance_prompt_ids"] for example in examples] | 562 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 572 | pixel_values = [example["instance_images"] for example in examples] | 563 | pixel_values = [example["instance_images"] for example in examples] |
| @@ -583,6 +574,7 @@ def main(): | |||
| 583 | 574 | ||
| 584 | batch = { | 575 | batch = { |
| 585 | "prompts": prompts, | 576 | "prompts": prompts, |
| 577 | "cprompts": cprompts, | ||
| 586 | "nprompts": nprompts, | 578 | "nprompts": nprompts, |
| 587 | "input_ids": inputs.input_ids, | 579 | "input_ids": inputs.input_ids, |
| 588 | "pixel_values": pixel_values, | 580 | "pixel_values": pixel_values, |
| @@ -594,8 +586,6 @@ def main(): | |||
| 594 | data_file=args.train_data_file, | 586 | data_file=args.train_data_file, |
| 595 | batch_size=args.train_batch_size, | 587 | batch_size=args.train_batch_size, |
| 596 | prompt_processor=prompt_processor, | 588 | prompt_processor=prompt_processor, |
| 597 | instance_identifier=args.instance_identifier, | ||
| 598 | class_identifier=args.class_identifier, | ||
| 599 | class_subdir="cls", | 589 | class_subdir="cls", |
| 600 | num_class_images=args.num_class_images, | 590 | num_class_images=args.num_class_images, |
| 601 | size=args.resolution, | 591 | size=args.resolution, |
| @@ -634,7 +624,7 @@ def main(): | |||
| 634 | with torch.autocast("cuda"), torch.inference_mode(): | 624 | with torch.autocast("cuda"), torch.inference_mode(): |
| 635 | for batch in batched_data: | 625 | for batch in batched_data: |
| 636 | image_name = [item.class_image_path for item in batch] | 626 | image_name = [item.class_image_path for item in batch] |
| 637 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] | 627 | prompt = [item.cprompt for item in batch] |
| 638 | nprompt = [item.nprompt for item in batch] | 628 | nprompt = [item.nprompt for item in batch] |
| 639 | 629 | ||
| 640 | images = pipeline( | 630 | images = pipeline( |
| @@ -744,7 +734,6 @@ def main(): | |||
| 744 | text_encoder=text_encoder, | 734 | text_encoder=text_encoder, |
| 745 | scheduler=checkpoint_scheduler, | 735 | scheduler=checkpoint_scheduler, |
| 746 | text_embeddings=text_embeddings, | 736 | text_embeddings=text_embeddings, |
| 747 | instance_identifier=args.instance_identifier, | ||
| 748 | placeholder_token=args.placeholder_token, | 737 | placeholder_token=args.placeholder_token, |
| 749 | placeholder_token_id=placeholder_token_id, | 738 | placeholder_token_id=placeholder_token_id, |
| 750 | output_dir=basepath, | 739 | output_dir=basepath, |
