summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 21:47:12 +0100
commit1bd386f98bb076fe62696808e02a9bd9b9b64b42 (patch)
tree42d3302610046dbc5d39d254f7a2d5d5f601aa18 /train_ti.py
parentFix (diff)
downloadtextual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.gz
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.tar.bz2
textual-inversion-diff-1bd386f98bb076fe62696808e02a9bd9b9b64b42.zip
Improved class prompt handling
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py29
1 files changed, 9 insertions, 20 deletions
diff --git a/train_ti.py b/train_ti.py
index 9d06c50..55daa35 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -63,16 +63,10 @@ def parse_args():
63 default="template", 63 default="template",
64 ) 64 )
65 parser.add_argument( 65 parser.add_argument(
66 "--instance_identifier", 66 "--project",
67 type=str, 67 type=str,
68 default=None, 68 default=None,
69 help="A token to use as a placeholder for the concept.", 69 help="The name of the current project.",
70 )
71 parser.add_argument(
72 "--class_identifier",
73 type=str,
74 default=None,
75 help="A token to use as a placeholder for the concept.",
76 ) 70 )
77 parser.add_argument( 71 parser.add_argument(
78 "--placeholder_token", 72 "--placeholder_token",
@@ -334,6 +328,9 @@ def parse_args():
334 if args.pretrained_model_name_or_path is None: 328 if args.pretrained_model_name_or_path is None:
335 raise ValueError("You must specify --pretrained_model_name_or_path") 329 raise ValueError("You must specify --pretrained_model_name_or_path")
336 330
331 if args.project is None:
332 raise ValueError("You must specify --project")
333
337 if isinstance(args.initializer_token, str): 334 if isinstance(args.initializer_token, str):
338 args.initializer_token = [args.initializer_token] 335 args.initializer_token = [args.initializer_token]
339 336
@@ -366,7 +363,6 @@ class Checkpointer(CheckpointerBase):
366 text_encoder, 363 text_encoder,
367 scheduler, 364 scheduler,
368 text_embeddings, 365 text_embeddings,
369 instance_identifier,
370 placeholder_token, 366 placeholder_token,
371 placeholder_token_id, 367 placeholder_token_id,
372 output_dir: Path, 368 output_dir: Path,
@@ -378,7 +374,6 @@ class Checkpointer(CheckpointerBase):
378 super().__init__( 374 super().__init__(
379 datamodule=datamodule, 375 datamodule=datamodule,
380 output_dir=output_dir, 376 output_dir=output_dir,
381 instance_identifier=instance_identifier,
382 placeholder_token=placeholder_token, 377 placeholder_token=placeholder_token,
383 placeholder_token_id=placeholder_token_id, 378 placeholder_token_id=placeholder_token_id,
384 sample_image_size=sample_image_size, 379 sample_image_size=sample_image_size,
@@ -441,14 +436,9 @@ class Checkpointer(CheckpointerBase):
441def main(): 436def main():
442 args = parse_args() 437 args = parse_args()
443 438
444 instance_identifier = args.instance_identifier
445
446 if len(args.placeholder_token) != 0:
447 instance_identifier = instance_identifier.format(args.placeholder_token[0])
448
449 global_step_offset = args.global_step 439 global_step_offset = args.global_step
450 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 440 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
451 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 441 basepath = Path(args.output_dir).joinpath(slugify(args.project), now)
452 basepath.mkdir(parents=True, exist_ok=True) 442 basepath.mkdir(parents=True, exist_ok=True)
453 443
454 accelerator = Accelerator( 444 accelerator = Accelerator(
@@ -567,6 +557,7 @@ def main():
567 557
568 def collate_fn(examples): 558 def collate_fn(examples):
569 prompts = [example["prompts"] for example in examples] 559 prompts = [example["prompts"] for example in examples]
560 cprompts = [example["cprompts"] for example in examples]
570 nprompts = [example["nprompts"] for example in examples] 561 nprompts = [example["nprompts"] for example in examples]
571 input_ids = [example["instance_prompt_ids"] for example in examples] 562 input_ids = [example["instance_prompt_ids"] for example in examples]
572 pixel_values = [example["instance_images"] for example in examples] 563 pixel_values = [example["instance_images"] for example in examples]
@@ -583,6 +574,7 @@ def main():
583 574
584 batch = { 575 batch = {
585 "prompts": prompts, 576 "prompts": prompts,
577 "cprompts": cprompts,
586 "nprompts": nprompts, 578 "nprompts": nprompts,
587 "input_ids": inputs.input_ids, 579 "input_ids": inputs.input_ids,
588 "pixel_values": pixel_values, 580 "pixel_values": pixel_values,
@@ -594,8 +586,6 @@ def main():
594 data_file=args.train_data_file, 586 data_file=args.train_data_file,
595 batch_size=args.train_batch_size, 587 batch_size=args.train_batch_size,
596 prompt_processor=prompt_processor, 588 prompt_processor=prompt_processor,
597 instance_identifier=args.instance_identifier,
598 class_identifier=args.class_identifier,
599 class_subdir="cls", 589 class_subdir="cls",
600 num_class_images=args.num_class_images, 590 num_class_images=args.num_class_images,
601 size=args.resolution, 591 size=args.resolution,
@@ -634,7 +624,7 @@ def main():
634 with torch.autocast("cuda"), torch.inference_mode(): 624 with torch.autocast("cuda"), torch.inference_mode():
635 for batch in batched_data: 625 for batch in batched_data:
636 image_name = [item.class_image_path for item in batch] 626 image_name = [item.class_image_path for item in batch]
637 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] 627 prompt = [item.cprompt for item in batch]
638 nprompt = [item.nprompt for item in batch] 628 nprompt = [item.nprompt for item in batch]
639 629
640 images = pipeline( 630 images = pipeline(
@@ -744,7 +734,6 @@ def main():
744 text_encoder=text_encoder, 734 text_encoder=text_encoder,
745 scheduler=checkpoint_scheduler, 735 scheduler=checkpoint_scheduler,
746 text_embeddings=text_embeddings, 736 text_embeddings=text_embeddings,
747 instance_identifier=args.instance_identifier,
748 placeholder_token=args.placeholder_token, 737 placeholder_token=args.placeholder_token,
749 placeholder_token_id=placeholder_token_id, 738 placeholder_token_id=placeholder_token_id,
750 output_dir=basepath, 739 output_dir=basepath,