summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py10
-rw-r--r--dreambooth.py41
-rw-r--r--textual_inversion.py10
3 files changed, 50 insertions, 11 deletions
diff --git a/data/csv.py b/data/csv.py
index f9b5e39..6bd7f9b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -38,7 +38,8 @@ class CSVDataModule(pl.LightningDataModule):
38 center_crop: bool = False, 38 center_crop: bool = False,
39 valid_set_size: Optional[int] = None, 39 valid_set_size: Optional[int] = None,
40 generator: Optional[torch.Generator] = None, 40 generator: Optional[torch.Generator] = None,
41 collate_fn=None 41 collate_fn=None,
42 num_workers: int = 0
42 ): 43 ):
43 super().__init__() 44 super().__init__()
44 45
@@ -62,6 +63,7 @@ class CSVDataModule(pl.LightningDataModule):
62 self.valid_set_size = valid_set_size 63 self.valid_set_size = valid_set_size
63 self.generator = generator 64 self.generator = generator
64 self.collate_fn = collate_fn 65 self.collate_fn = collate_fn
66 self.num_workers = num_workers
65 self.batch_size = batch_size 67 self.batch_size = batch_size
66 68
67 def prepare_subdata(self, template, data, num_class_images=1): 69 def prepare_subdata(self, template, data, num_class_images=1):
@@ -113,9 +115,11 @@ class CSVDataModule(pl.LightningDataModule):
113 size=self.size, interpolation=self.interpolation, 115 size=self.size, interpolation=self.interpolation,
114 center_crop=self.center_crop) 116 center_crop=self.center_crop)
115 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 117 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
116 shuffle=True, pin_memory=True, collate_fn=self.collate_fn) 118 shuffle=True, pin_memory=True, collate_fn=self.collate_fn,
119 num_workers=self.num_workers)
117 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, 120 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
118 pin_memory=True, collate_fn=self.collate_fn) 121 pin_memory=True, collate_fn=self.collate_fn,
122 num_workers=self.num_workers)
119 123
120 def train_dataloader(self): 124 def train_dataloader(self):
121 return self.train_dataloader_ 125 return self.train_dataloader_
diff --git a/dreambooth.py b/dreambooth.py
index db097e5..e71b7f0 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -72,15 +72,23 @@ def parse_args():
72 "--placeholder_token", 72 "--placeholder_token",
73 type=str, 73 type=str,
74 nargs='*', 74 nargs='*',
75 default=[],
75 help="A token to use as a placeholder for the concept.", 76 help="A token to use as a placeholder for the concept.",
76 ) 77 )
77 parser.add_argument( 78 parser.add_argument(
78 "--initializer_token", 79 "--initializer_token",
79 type=str, 80 type=str,
80 nargs='*', 81 nargs='*',
82 default=[],
81 help="A token to use as initializer word." 83 help="A token to use as initializer word."
82 ) 84 )
83 parser.add_argument( 85 parser.add_argument(
86 "--train_text_encoder",
87 action="store_true",
88 default=True,
89 help="Whether to train the whole text encoder."
90 )
91 parser.add_argument(
84 "--num_class_images", 92 "--num_class_images",
85 type=int, 93 type=int,
86 default=400, 94 default=400,
@@ -119,6 +127,15 @@ def parse_args():
119 help="Whether to center crop images before resizing to resolution" 127 help="Whether to center crop images before resizing to resolution"
120 ) 128 )
121 parser.add_argument( 129 parser.add_argument(
130 "--dataloader_num_workers",
131 type=int,
132 default=0,
133 help=(
134 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
135 " process."
136 ),
137 )
138 parser.add_argument(
122 "--num_train_epochs", 139 "--num_train_epochs",
123 type=int, 140 type=int,
124 default=100 141 default=100
@@ -323,7 +340,7 @@ def parse_args():
323 args.placeholder_token = [args.placeholder_token] 340 args.placeholder_token = [args.placeholder_token]
324 341
325 if len(args.placeholder_token) == 0: 342 if len(args.placeholder_token) == 0:
326 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)] 343 args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))]
327 344
328 if len(args.placeholder_token) != len(args.initializer_token): 345 if len(args.placeholder_token) != len(args.initializer_token):
329 raise ValueError("Number of items in --placeholder_token and --initializer_token must match") 346 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
@@ -391,6 +408,9 @@ class Checkpointer:
391 408
392 @torch.no_grad() 409 @torch.no_grad()
393 def save_embedding(self, step, postfix): 410 def save_embedding(self, step, postfix):
411 if len(self.placeholder_token) == 0:
412 return
413
394 print("Saving checkpoint for step %d..." % step) 414 print("Saving checkpoint for step %d..." % step)
395 415
396 checkpoints_path = self.output_dir.joinpath("checkpoints") 416 checkpoints_path = self.output_dir.joinpath("checkpoints")
@@ -406,9 +426,6 @@ class Checkpointer:
406 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 426 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
407 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) 427 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
408 428
409 del unwrapped
410 del learned_embeds
411
412 @torch.no_grad() 429 @torch.no_grad()
413 def save_model(self): 430 def save_model(self):
414 print("Saving model...") 431 print("Saving model...")
@@ -575,7 +592,9 @@ def main():
575 # Freeze text_encoder and vae 592 # Freeze text_encoder and vae
576 freeze_params(vae.parameters()) 593 freeze_params(vae.parameters())
577 594
578 if len(args.initializer_token) != 0: 595 if len(args.placeholder_token) != 0:
596 print(f"Adding text embeddings: {args.placeholder_token}")
597
579 # Convert the initializer_token, placeholder_token to ids 598 # Convert the initializer_token, placeholder_token to ids
580 initializer_token_ids = torch.stack([ 599 initializer_token_ids = torch.stack([
581 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) 600 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
@@ -597,14 +616,19 @@ def main():
597 616
598 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 617 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
599 token_embeds[token_id] = embeddings 618 token_embeds[token_id] = embeddings
619 else:
620 placeholder_token_id = []
621
622 if args.train_text_encoder:
623 print(f"Training entire text encoder.")
624 else:
625 print(f"Training added text embeddings")
600 626
601 freeze_params(itertools.chain( 627 freeze_params(itertools.chain(
602 text_encoder.text_model.encoder.parameters(), 628 text_encoder.text_model.encoder.parameters(),
603 text_encoder.text_model.final_layer_norm.parameters(), 629 text_encoder.text_model.final_layer_norm.parameters(),
604 text_encoder.text_model.embeddings.position_embedding.parameters(), 630 text_encoder.text_model.embeddings.position_embedding.parameters(),
605 )) 631 ))
606 else:
607 placeholder_token_id = []
608 632
609 prompt_processor = PromptProcessor(tokenizer, text_encoder) 633 prompt_processor = PromptProcessor(tokenizer, text_encoder)
610 634
@@ -700,6 +724,7 @@ def main():
700 repeats=args.repeats, 724 repeats=args.repeats,
701 center_crop=args.center_crop, 725 center_crop=args.center_crop,
702 valid_set_size=args.sample_batch_size*args.sample_batches, 726 valid_set_size=args.sample_batch_size*args.sample_batches,
727 num_workers=args.dataloader_num_workers,
703 collate_fn=collate_fn 728 collate_fn=collate_fn
704 ) 729 )
705 730
@@ -906,7 +931,7 @@ def main():
906 931
907 accelerator.backward(loss) 932 accelerator.backward(loss)
908 933
909 if args.initializer_token is not None: 934 if not args.train_text_encoder:
910 # Keep the token embeddings fixed except the newly added 935 # Keep the token embeddings fixed except the newly added
911 # embeddings for the concept, as we only want to optimize the concept embeddings 936 # embeddings for the concept, as we only want to optimize the concept embeddings
912 if accelerator.num_processes > 1: 937 if accelerator.num_processes > 1:
diff --git a/textual_inversion.py b/textual_inversion.py
index dd7c3bd..115f3aa 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -117,6 +117,15 @@ def parse_args():
117 help="Whether to center crop images before resizing to resolution" 117 help="Whether to center crop images before resizing to resolution"
118 ) 118 )
119 parser.add_argument( 119 parser.add_argument(
120 "--dataloader_num_workers",
121 type=int,
122 default=0,
123 help=(
124 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
125 " process."
126 ),
127 )
128 parser.add_argument(
120 "--num_train_epochs", 129 "--num_train_epochs",
121 type=int, 130 type=int,
122 default=100 131 default=100
@@ -626,6 +635,7 @@ def main():
626 repeats=args.repeats, 635 repeats=args.repeats,
627 center_crop=args.center_crop, 636 center_crop=args.center_crop,
628 valid_set_size=args.sample_batch_size*args.sample_batches, 637 valid_set_size=args.sample_batch_size*args.sample_batches,
638 num_workers=args.dataloader_num_workers,
629 collate_fn=collate_fn 639 collate_fn=collate_fn
630 ) 640 )
631 641