summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-24 23:46:18 +0200
committerVolpeon <git@volpeon.ink>2022-10-24 23:46:18 +0200
commitbaba91864a45939cef4f77f6ca96ade7ae5ef274 (patch)
treec40fc949a94d5a2bee81b2b505b814e7c7f82cc1
parentUpdate (diff)
downloadtextual-inversion-diff-baba91864a45939cef4f77f6ca96ade7ae5ef274.tar.gz
textual-inversion-diff-baba91864a45939cef4f77f6ca96ade7ae5ef274.tar.bz2
textual-inversion-diff-baba91864a45939cef4f77f6ca96ade7ae5ef274.zip
Advanced datasets
-rw-r--r--data/csv.py64
-rw-r--r--dreambooth.py68
-rw-r--r--textual_inversion.py58
3 files changed, 115 insertions, 75 deletions
diff --git a/data/csv.py b/data/csv.py
index 5144c0a..f9b5e39 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -1,16 +1,20 @@
1import math 1import math
2import pandas as pd
3import torch 2import torch
3import json
4from pathlib import Path 4from pathlib import Path
5import pytorch_lightning as pl 5import pytorch_lightning as pl
6from PIL import Image 6from PIL import Image
7from torch.utils.data import Dataset, DataLoader, random_split 7from torch.utils.data import Dataset, DataLoader, random_split
8from torchvision import transforms 8from torchvision import transforms
9from typing import NamedTuple, List, Optional 9from typing import Dict, NamedTuple, List, Optional, Union
10 10
11from models.clip.prompt import PromptProcessor 11from models.clip.prompt import PromptProcessor
12 12
13 13
14def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt
16
17
14class CSVDataItem(NamedTuple): 18class CSVDataItem(NamedTuple):
15 instance_image_path: Path 19 instance_image_path: Path
16 class_image_path: Path 20 class_image_path: Path
@@ -60,24 +64,32 @@ class CSVDataModule(pl.LightningDataModule):
60 self.collate_fn = collate_fn 64 self.collate_fn = collate_fn
61 self.batch_size = batch_size 65 self.batch_size = batch_size
62 66
63 def prepare_subdata(self, data, num_class_images=1): 67 def prepare_subdata(self, template, data, num_class_images=1):
68 image = template["image"] if "image" in template else "{}"
69 prompt = template["prompt"] if "prompt" in template else "{content}"
70 nprompt = template["nprompt"] if "nprompt" in template else "{content}"
71
64 image_multiplier = max(math.ceil(num_class_images / len(data)), 1) 72 image_multiplier = max(math.ceil(num_class_images / len(data)), 1)
65 73
66 return [ 74 return [
67 CSVDataItem( 75 CSVDataItem(
68 self.data_root.joinpath(item.image), 76 self.data_root.joinpath(image.format(item["image"])),
69 self.class_root.joinpath(f"{Path(item.image).stem}_{i}{Path(item.image).suffix}"), 77 self.class_root.joinpath(f"{Path(item['image']).stem}_{i}{Path(item['image']).suffix}"),
70 item.prompt, 78 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
71 item.nprompt 79 nprompt.format(**prepare_prompt(item["nprompt"] if "nprompt" in item else ""))
72 ) 80 )
73 for item in data 81 for item in data
74 for i in range(image_multiplier) 82 for i in range(image_multiplier)
75 ] 83 ]
76 84
77 def prepare_data(self): 85 def prepare_data(self):
78 metadata = pd.read_json(self.data_file) 86 with open(self.data_file, 'rt') as f:
79 metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] 87 metadata = json.load(f)
80 num_images = len(metadata) 88 template = metadata["template"] if "template" in metadata else {}
89 items = metadata["items"] if "items" in metadata else []
90
91 items = [item for item in items if not "skip" in item or item["skip"] != True]
92 num_images = len(items)
81 93
82 valid_set_size = int(num_images * 0.2) 94 valid_set_size = int(num_images * 0.2)
83 if self.valid_set_size: 95 if self.valid_set_size:
@@ -85,10 +97,10 @@ class CSVDataModule(pl.LightningDataModule):
85 valid_set_size = max(valid_set_size, 1) 97 valid_set_size = max(valid_set_size, 1)
86 train_set_size = num_images - valid_set_size 98 train_set_size = num_images - valid_set_size
87 99
88 data_train, data_val = random_split(metadata, [train_set_size, valid_set_size], self.generator) 100 data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator)
89 101
90 self.data_train = self.prepare_subdata(data_train, self.num_class_images) 102 self.data_train = self.prepare_subdata(template, data_train, self.num_class_images)
91 self.data_val = self.prepare_subdata(data_val) 103 self.data_val = self.prepare_subdata(template, data_val)
92 104
93 def setup(self, stage=None): 105 def setup(self, stage=None):
94 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, 106 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size,
@@ -133,8 +145,8 @@ class CSVDataset(Dataset):
133 self.instance_identifier = instance_identifier 145 self.instance_identifier = instance_identifier
134 self.class_identifier = class_identifier 146 self.class_identifier = class_identifier
135 self.num_class_images = num_class_images 147 self.num_class_images = num_class_images
136 self.cache = {}
137 self.image_cache = {} 148 self.image_cache = {}
149 self.input_id_cache = {}
138 150
139 self.num_instance_images = len(self.data) 151 self.num_instance_images = len(self.data)
140 self._length = self.num_instance_images * repeats 152 self._length = self.num_instance_images * repeats
@@ -168,12 +180,19 @@ class CSVDataset(Dataset):
168 180
169 return image 181 return image
170 182
183 def get_input_ids(self, prompt, identifier):
184 prompt = prompt.format(identifier)
185
186 if prompt in self.input_id_cache:
187 return self.input_id_cache[prompt]
188
189 input_ids = self.prompt_processor.get_input_ids(prompt)
190 self.input_id_cache[prompt] = input_ids
191
192 return input_ids
193
171 def get_example(self, i): 194 def get_example(self, i):
172 item = self.data[i % self.num_instance_images] 195 item = self.data[i % self.num_instance_images]
173 cache_key = f"{item.instance_image_path}_{item.class_image_path}"
174
175 if cache_key in self.cache:
176 return self.cache[cache_key]
177 196
178 example = {} 197 example = {}
179 198
@@ -181,17 +200,12 @@ class CSVDataset(Dataset):
181 example["nprompts"] = item.nprompt 200 example["nprompts"] = item.nprompt
182 201
183 example["instance_images"] = self.get_image(item.instance_image_path) 202 example["instance_images"] = self.get_image(item.instance_image_path)
184 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 203 example["instance_prompt_ids"] = self.get_input_ids(item.prompt, self.instance_identifier)
185 item.prompt.format(self.instance_identifier)
186 )
187 204
188 if self.num_class_images != 0: 205 if self.num_class_images != 0:
189 example["class_images"] = self.get_image(item.class_image_path) 206 example["class_images"] = self.get_image(item.class_image_path)
190 example["class_prompt_ids"] = self.prompt_processor.get_input_ids( 207 example["class_prompt_ids"] = self.get_input_ids(item.nprompt, self.class_identifier)
191 item.nprompt.format(self.class_identifier)
192 )
193 208
194 self.cache[cache_key] = example
195 return example 209 return example
196 210
197 def __getitem__(self, i): 211 def __getitem__(self, i):
diff --git a/dreambooth.py b/dreambooth.py
index 5c26f12..2c24908 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -71,13 +71,13 @@ def parse_args():
71 parser.add_argument( 71 parser.add_argument(
72 "--placeholder_token", 72 "--placeholder_token",
73 type=str, 73 type=str,
74 default="<*>", 74 nargs='*',
75 help="A token to use as a placeholder for the concept.", 75 help="A token to use as a placeholder for the concept.",
76 ) 76 )
77 parser.add_argument( 77 parser.add_argument(
78 "--initializer_token", 78 "--initializer_token",
79 type=str, 79 type=str,
80 default=None, 80 nargs='*',
81 help="A token to use as initializer word." 81 help="A token to use as initializer word."
82 ) 82 )
83 parser.add_argument( 83 parser.add_argument(
@@ -316,6 +316,18 @@ def parse_args():
316 if args.instance_identifier is None: 316 if args.instance_identifier is None:
317 raise ValueError("You must specify --instance_identifier") 317 raise ValueError("You must specify --instance_identifier")
318 318
319 if isinstance(args.initializer_token, str):
320 args.initializer_token = [args.initializer_token]
321
322 if isinstance(args.placeholder_token, str):
323 args.placeholder_token = [args.placeholder_token]
324
325 if len(args.placeholder_token) == 0:
326 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
327
328 if len(args.placeholder_token) != len(args.initializer_token):
329 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
330
319 if args.output_dir is None: 331 if args.output_dir is None:
320 raise ValueError("You must specify --output_dir") 332 raise ValueError("You must specify --output_dir")
321 333
@@ -379,9 +391,6 @@ class Checkpointer:
379 391
380 @torch.no_grad() 392 @torch.no_grad()
381 def save_embedding(self, step, postfix): 393 def save_embedding(self, step, postfix):
382 if self.placeholder_token_id is None:
383 return
384
385 print("Saving checkpoint for step %d..." % step) 394 print("Saving checkpoint for step %d..." % step)
386 395
387 checkpoints_path = self.output_dir.joinpath("checkpoints") 396 checkpoints_path = self.output_dir.joinpath("checkpoints")
@@ -389,12 +398,13 @@ class Checkpointer:
389 398
390 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 399 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
391 400
392 # Save a checkpoint 401 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
393 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 402 # Save a checkpoint
394 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 403 learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id]
404 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
395 405
396 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) 406 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
397 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) 407 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
398 408
399 del unwrapped 409 del unwrapped
400 del learned_embeds 410 del learned_embeds
@@ -467,7 +477,7 @@ class Checkpointer:
467 for i in range(self.sample_batches): 477 for i in range(self.sample_batches):
468 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 478 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
469 prompt = [ 479 prompt = [
470 prompt.format(self.instance_identifier) 480 prompt.format(identifier=self.instance_identifier)
471 for batch in batches 481 for batch in batches
472 for prompt in batch["prompts"] 482 for prompt in batch["prompts"]
473 ][:self.sample_batch_size] 483 ][:self.sample_batch_size]
@@ -516,8 +526,8 @@ def main():
516 526
517 instance_identifier = args.instance_identifier 527 instance_identifier = args.instance_identifier
518 528
519 if args.placeholder_token is not None: 529 if len(args.placeholder_token) != 0:
520 instance_identifier = instance_identifier.format(args.placeholder_token) 530 instance_identifier = instance_identifier.format(args.placeholder_token[0])
521 531
522 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 532 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
523 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) 533 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
@@ -565,18 +575,16 @@ def main():
565 # Freeze text_encoder and vae 575 # Freeze text_encoder and vae
566 freeze_params(vae.parameters()) 576 freeze_params(vae.parameters())
567 577
568 if args.initializer_token is not None: 578 if len(args.initializer_token) != 0:
569 # Convert the initializer_token, placeholder_token to ids 579 # Convert the initializer_token, placeholder_token to ids
570 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 580 initializer_token_ids = torch.stack([
571 print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") 581 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
572 initializer_token_ids = torch.tensor(initializer_token_ids[:1]) 582 for token in args.initializer_token
583 ])
573 584
574 # Add the placeholder token in tokenizer
575 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 585 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
576 if num_added_tokens == 0: 586 print(f"Added {num_added_tokens} new tokens.")
577 print(f"Re-using existing token {args.placeholder_token}.") 587
578 else:
579 print(f"Training new token {args.placeholder_token}.")
580 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 588 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
581 589
582 # Resize the token embeddings as we are adding new special tokens to the tokenizer 590 # Resize the token embeddings as we are adding new special tokens to the tokenizer
@@ -586,7 +594,9 @@ def main():
586 token_embeds = text_encoder.get_input_embeddings().weight.data 594 token_embeds = text_encoder.get_input_embeddings().weight.data
587 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 595 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
588 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 596 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
589 token_embeds[placeholder_token_id] = initializer_token_embeddings 597
598 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
599 token_embeds[token_id] = embeddings
590 600
591 freeze_params(itertools.chain( 601 freeze_params(itertools.chain(
592 text_encoder.text_model.encoder.parameters(), 602 text_encoder.text_model.encoder.parameters(),
@@ -594,7 +604,7 @@ def main():
594 text_encoder.text_model.embeddings.position_embedding.parameters(), 604 text_encoder.text_model.embeddings.position_embedding.parameters(),
595 )) 605 ))
596 else: 606 else:
597 placeholder_token_id = None 607 placeholder_token_id = []
598 608
599 prompt_processor = PromptProcessor(tokenizer, text_encoder) 609 prompt_processor = PromptProcessor(tokenizer, text_encoder)
600 610
@@ -721,7 +731,7 @@ def main():
721 with torch.inference_mode(): 731 with torch.inference_mode():
722 for batch in batched_data: 732 for batch in batched_data:
723 image_name = [item.class_image_path for item in batch] 733 image_name = [item.class_image_path for item in batch]
724 prompt = [item.prompt.format(args.class_identifier) for item in batch] 734 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
725 nprompt = [item.nprompt for item in batch] 735 nprompt = [item.nprompt for item in batch]
726 736
727 images = pipeline( 737 images = pipeline(
@@ -787,7 +797,10 @@ def main():
787 # We need to initialize the trackers we use, and also store our configuration. 797 # We need to initialize the trackers we use, and also store our configuration.
788 # The trackers initializes automatically on the main process. 798 # The trackers initializes automatically on the main process.
789 if accelerator.is_main_process: 799 if accelerator.is_main_process:
790 accelerator.init_trackers("dreambooth", config=vars(args)) 800 config = vars(args).copy()
801 config["initializer_token"] = " ".join(config["initializer_token"])
802 config["placeholder_token"] = " ".join(config["placeholder_token"])
803 accelerator.init_trackers("dreambooth", config=config)
791 804
792 # Train! 805 # Train!
793 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 806 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -932,6 +945,9 @@ def main():
932 global_step += 1 945 global_step += 1
933 946
934 if global_step % args.sample_frequency == 0: 947 if global_step % args.sample_frequency == 0:
948 local_progress_bar.clear()
949 global_progress_bar.clear()
950
935 checkpointer.save_embedding(global_step, "training") 951 checkpointer.save_embedding(global_step, "training")
936 sample_checkpoint = True 952 sample_checkpoint = True
937 953
diff --git a/textual_inversion.py b/textual_inversion.py
index c42762f..bcdfd3a 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -70,13 +70,13 @@ def parse_args():
70 parser.add_argument( 70 parser.add_argument(
71 "--placeholder_token", 71 "--placeholder_token",
72 type=str, 72 type=str,
73 default="<*>", 73 nargs='*',
74 help="A token to use as a placeholder for the concept.", 74 help="A token to use as a placeholder for the concept.",
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
77 "--initializer_token", 77 "--initializer_token",
78 type=str, 78 type=str,
79 default=None, 79 nargs='*',
80 help="A token to use as initializer word." 80 help="A token to use as initializer word."
81 ) 81 )
82 parser.add_argument( 82 parser.add_argument(
@@ -299,12 +299,21 @@ def parse_args():
299 if args.pretrained_model_name_or_path is None: 299 if args.pretrained_model_name_or_path is None:
300 raise ValueError("You must specify --pretrained_model_name_or_path") 300 raise ValueError("You must specify --pretrained_model_name_or_path")
301 301
302 if args.placeholder_token is None: 302 if isinstance(args.initializer_token, str):
303 raise ValueError("You must specify --placeholder_token") 303 args.initializer_token = [args.initializer_token]
304 304
305 if args.initializer_token is None: 305 if len(args.initializer_token) == 0:
306 raise ValueError("You must specify --initializer_token") 306 raise ValueError("You must specify --initializer_token")
307 307
308 if isinstance(args.placeholder_token, str):
309 args.placeholder_token = [args.placeholder_token]
310
311 if len(args.placeholder_token) == 0:
312 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
313
314 if len(args.placeholder_token) != len(args.initializer_token):
315 raise ValueError("You must specify --placeholder_token")
316
308 if args.output_dir is None: 317 if args.output_dir is None:
309 raise ValueError("You must specify --output_dir") 318 raise ValueError("You must specify --output_dir")
310 319
@@ -373,12 +382,13 @@ class Checkpointer:
373 382
374 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 383 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
375 384
376 # Save a checkpoint 385 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
377 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 386 # Save a checkpoint
378 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 387 learned_embeds = unwrapped.get_input_embeddings().weight[placeholder_token_id]
388 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
379 389
380 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) 390 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
381 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) 391 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
382 392
383 del unwrapped 393 del unwrapped
384 del learned_embeds 394 del learned_embeds
@@ -422,7 +432,7 @@ class Checkpointer:
422 432
423 for i in range(self.sample_batches): 433 for i in range(self.sample_batches):
424 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 434 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
425 prompt = [prompt.format(self.instance_identifier) 435 prompt = [prompt.format(identifier=self.instance_identifier)
426 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 436 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
427 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 437 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
428 438
@@ -498,16 +508,13 @@ def main():
498 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 508 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
499 509
500 # Convert the initializer_token, placeholder_token to ids 510 # Convert the initializer_token, placeholder_token to ids
501 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 511 initializer_token_ids = torch.stack([
502 print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.") 512 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
503 initializer_token_ids = torch.tensor(initializer_token_ids[:1]) 513 for token in args.initializer_token
514 ])
504 515
505 # Add the placeholder token in tokenizer
506 num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 516 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
507 if num_added_tokens == 0: 517 print(f"Added {num_added_tokens} new tokens.")
508 print(f"Re-using existing token {args.placeholder_token}.")
509 else:
510 print(f"Training new token {args.placeholder_token}.")
511 518
512 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 519 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
513 520
@@ -533,11 +540,11 @@ def main():
533 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 540 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
534 541
535 if args.resume_checkpoint is not None: 542 if args.resume_checkpoint is not None:
536 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[ 543 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token]
537 args.placeholder_token]
538 else: 544 else:
539 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 545 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
540 token_embeds[placeholder_token_id] = initializer_token_embeddings 546 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
547 token_embeds[token_id] = embeddings
541 548
542 # Freeze vae and unet 549 # Freeze vae and unet
543 freeze_params(vae.parameters()) 550 freeze_params(vae.parameters())
@@ -648,7 +655,7 @@ def main():
648 with torch.inference_mode(): 655 with torch.inference_mode():
649 for batch in batched_data: 656 for batch in batched_data:
650 image_name = [p.class_image_path for p in batch] 657 image_name = [p.class_image_path for p in batch]
651 prompt = [p.prompt.format(args.class_identifier) for p in batch] 658 prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch]
652 nprompt = [p.nprompt for p in batch] 659 nprompt = [p.nprompt for p in batch]
653 660
654 images = pipeline( 661 images = pipeline(
@@ -716,7 +723,10 @@ def main():
716 # We need to initialize the trackers we use, and also store our configuration. 723 # We need to initialize the trackers we use, and also store our configuration.
717 # The trackers initializes automatically on the main process. 724 # The trackers initializes automatically on the main process.
718 if accelerator.is_main_process: 725 if accelerator.is_main_process:
719 accelerator.init_trackers("textual_inversion", config=vars(args)) 726 config = vars(args).copy()
727 config["initializer_token"] = " ".join(config["initializer_token"])
728 config["placeholder_token"] = " ".join(config["placeholder_token"])
729 accelerator.init_trackers("textual_inversion", config=config)
720 730
721 # Train! 731 # Train!
722 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 732 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps