diff options
author | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-23 21:47:12 +0100 |
commit | 1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch) | |
tree | 42d3302610046dbc5d39d254f7a2d5d5f601aa18 /train_ti.py | |
parent | Fix (diff) | |
download | textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2 textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip |
Improved class prompt handling
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 29 |
1 files changed, 9 insertions, 20 deletions
diff --git a/train_ti.py b/train_ti.py index 9d06c50..55daa35 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -63,16 +63,10 @@ def parse_args(): | |||
63 | default="template", | 63 | default="template", |
64 | ) | 64 | ) |
65 | parser.add_argument( | 65 | parser.add_argument( |
66 | "--instance_identifier", | 66 | "--project", |
67 | type=str, | 67 | type=str, |
68 | default=None, | 68 | default=None, |
69 | help="A token to use as a placeholder for the concept.", | 69 | help="The name of the current project.", |
70 | ) | ||
71 | parser.add_argument( | ||
72 | "--class_identifier", | ||
73 | type=str, | ||
74 | default=None, | ||
75 | help="A token to use as a placeholder for the concept.", | ||
76 | ) | 70 | ) |
77 | parser.add_argument( | 71 | parser.add_argument( |
78 | "--placeholder_token", | 72 | "--placeholder_token", |
@@ -334,6 +328,9 @@ def parse_args(): | |||
334 | if args.pretrained_model_name_or_path is None: | 328 | if args.pretrained_model_name_or_path is None: |
335 | raise ValueError("You must specify --pretrained_model_name_or_path") | 329 | raise ValueError("You must specify --pretrained_model_name_or_path") |
336 | 330 | ||
331 | if args.project is None: | ||
332 | raise ValueError("You must specify --project") | ||
333 | |||
337 | if isinstance(args.initializer_token, str): | 334 | if isinstance(args.initializer_token, str): |
338 | args.initializer_token = [args.initializer_token] | 335 | args.initializer_token = [args.initializer_token] |
339 | 336 | ||
@@ -366,7 +363,6 @@ class Checkpointer(CheckpointerBase): | |||
366 | text_encoder, | 363 | text_encoder, |
367 | scheduler, | 364 | scheduler, |
368 | text_embeddings, | 365 | text_embeddings, |
369 | instance_identifier, | ||
370 | placeholder_token, | 366 | placeholder_token, |
371 | placeholder_token_id, | 367 | placeholder_token_id, |
372 | output_dir: Path, | 368 | output_dir: Path, |
@@ -378,7 +374,6 @@ class Checkpointer(CheckpointerBase): | |||
378 | super().__init__( | 374 | super().__init__( |
379 | datamodule=datamodule, | 375 | datamodule=datamodule, |
380 | output_dir=output_dir, | 376 | output_dir=output_dir, |
381 | instance_identifier=instance_identifier, | ||
382 | placeholder_token=placeholder_token, | 377 | placeholder_token=placeholder_token, |
383 | placeholder_token_id=placeholder_token_id, | 378 | placeholder_token_id=placeholder_token_id, |
384 | sample_image_size=sample_image_size, | 379 | sample_image_size=sample_image_size, |
@@ -441,14 +436,9 @@ class Checkpointer(CheckpointerBase): | |||
441 | def main(): | 436 | def main(): |
442 | args = parse_args() | 437 | args = parse_args() |
443 | 438 | ||
444 | instance_identifier = args.instance_identifier | ||
445 | |||
446 | if len(args.placeholder_token) != 0: | ||
447 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
448 | |||
449 | global_step_offset = args.global_step | 439 | global_step_offset = args.global_step |
450 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 440 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
451 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | 441 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) |
452 | basepath.mkdir(parents=True, exist_ok=True) | 442 | basepath.mkdir(parents=True, exist_ok=True) |
453 | 443 | ||
454 | accelerator = Accelerator( | 444 | accelerator = Accelerator( |
@@ -567,6 +557,7 @@ def main(): | |||
567 | 557 | ||
568 | def collate_fn(examples): | 558 | def collate_fn(examples): |
569 | prompts = [example["prompts"] for example in examples] | 559 | prompts = [example["prompts"] for example in examples] |
560 | cprompts = [example["cprompts"] for example in examples] | ||
570 | nprompts = [example["nprompts"] for example in examples] | 561 | nprompts = [example["nprompts"] for example in examples] |
571 | input_ids = [example["instance_prompt_ids"] for example in examples] | 562 | input_ids = [example["instance_prompt_ids"] for example in examples] |
572 | pixel_values = [example["instance_images"] for example in examples] | 563 | pixel_values = [example["instance_images"] for example in examples] |
@@ -583,6 +574,7 @@ def main(): | |||
583 | 574 | ||
584 | batch = { | 575 | batch = { |
585 | "prompts": prompts, | 576 | "prompts": prompts, |
577 | "cprompts": cprompts, | ||
586 | "nprompts": nprompts, | 578 | "nprompts": nprompts, |
587 | "input_ids": inputs.input_ids, | 579 | "input_ids": inputs.input_ids, |
588 | "pixel_values": pixel_values, | 580 | "pixel_values": pixel_values, |
@@ -594,8 +586,6 @@ def main(): | |||
594 | data_file=args.train_data_file, | 586 | data_file=args.train_data_file, |
595 | batch_size=args.train_batch_size, | 587 | batch_size=args.train_batch_size, |
596 | prompt_processor=prompt_processor, | 588 | prompt_processor=prompt_processor, |
597 | instance_identifier=args.instance_identifier, | ||
598 | class_identifier=args.class_identifier, | ||
599 | class_subdir="cls", | 589 | class_subdir="cls", |
600 | num_class_images=args.num_class_images, | 590 | num_class_images=args.num_class_images, |
601 | size=args.resolution, | 591 | size=args.resolution, |
@@ -634,7 +624,7 @@ def main(): | |||
634 | with torch.autocast("cuda"), torch.inference_mode(): | 624 | with torch.autocast("cuda"), torch.inference_mode(): |
635 | for batch in batched_data: | 625 | for batch in batched_data: |
636 | image_name = [item.class_image_path for item in batch] | 626 | image_name = [item.class_image_path for item in batch] |
637 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] | 627 | prompt = [item.cprompt for item in batch] |
638 | nprompt = [item.nprompt for item in batch] | 628 | nprompt = [item.nprompt for item in batch] |
639 | 629 | ||
640 | images = pipeline( | 630 | images = pipeline( |
@@ -744,7 +734,6 @@ def main(): | |||
744 | text_encoder=text_encoder, | 734 | text_encoder=text_encoder, |
745 | scheduler=checkpoint_scheduler, | 735 | scheduler=checkpoint_scheduler, |
746 | text_embeddings=text_embeddings, | 736 | text_embeddings=text_embeddings, |
747 | instance_identifier=args.instance_identifier, | ||
748 | placeholder_token=args.placeholder_token, | 737 | placeholder_token=args.placeholder_token, |
749 | placeholder_token_id=placeholder_token_id, | 738 | placeholder_token_id=placeholder_token_id, |
750 | output_dir=basepath, | 739 | output_dir=basepath, |