summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py119
1 files changed, 54 insertions, 65 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index a9c3326..11babd8 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -170,7 +170,7 @@ def parse_args():
170 parser.add_argument( 170 parser.add_argument(
171 "--lr_scheduler", 171 "--lr_scheduler",
172 type=str, 172 type=str,
173 default="one_cycle", 173 default="constant_with_warmup",
174 help=( 174 help=(
175 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 175 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
176 ' "constant", "constant_with_warmup", "one_cycle"]' 176 ' "constant", "constant_with_warmup", "one_cycle"]'
@@ -231,14 +231,14 @@ def parse_args():
231 parser.add_argument( 231 parser.add_argument(
232 "--checkpoint_frequency", 232 "--checkpoint_frequency",
233 type=int, 233 type=int,
234 default=500, 234 default=5,
235 help="How often to save a checkpoint and sample image", 235 help="How often to save a checkpoint and sample image (in epochs)",
236 ) 236 )
237 parser.add_argument( 237 parser.add_argument(
238 "--sample_frequency", 238 "--sample_frequency",
239 type=int, 239 type=int,
240 default=100, 240 default=1,
241 help="How often to save a checkpoint and sample image", 241 help="How often to save a checkpoint and sample image (in epochs)",
242 ) 242 )
243 parser.add_argument( 243 parser.add_argument(
244 "--sample_image_size", 244 "--sample_image_size",
@@ -294,10 +294,9 @@ def parse_args():
294 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" 294 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
295 ) 295 )
296 parser.add_argument( 296 parser.add_argument(
297 "--resume_checkpoint", 297 "--global_step",
298 type=str, 298 type=int,
299 default=None, 299 default=0,
300 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
301 ) 300 )
302 parser.add_argument( 301 parser.add_argument(
303 "--config", 302 "--config",
@@ -512,19 +511,10 @@ def main():
512 if len(args.placeholder_token) != 0: 511 if len(args.placeholder_token) != 0:
513 instance_identifier = instance_identifier.format(args.placeholder_token[0]) 512 instance_identifier = instance_identifier.format(args.placeholder_token[0])
514 513
515 global_step_offset = 0 514 global_step_offset = args.global_step
516 if args.resume_from is not None: 515 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
517 basepath = Path(args.resume_from) 516 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
518 print("Resuming state from %s" % args.resume_from) 517 basepath.mkdir(parents=True, exist_ok=True)
519 with open(basepath.joinpath("resume.json"), 'r') as f:
520 state = json.load(f)
521 global_step_offset = state["args"].get("global_step", 0)
522
523 print("We've trained %d steps so far" % global_step_offset)
524 else:
525 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
526 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
527 basepath.mkdir(parents=True, exist_ok=True)
528 518
529 accelerator = Accelerator( 519 accelerator = Accelerator(
530 log_with=LoggerType.TENSORBOARD, 520 log_with=LoggerType.TENSORBOARD,
@@ -557,6 +547,7 @@ def main():
557 set_use_memory_efficient_attention_xformers(vae, True) 547 set_use_memory_efficient_attention_xformers(vae, True)
558 548
559 if args.gradient_checkpointing: 549 if args.gradient_checkpointing:
550 unet.enable_gradient_checkpointing()
560 text_encoder.gradient_checkpointing_enable() 551 text_encoder.gradient_checkpointing_enable()
561 552
562 print(f"Adding text embeddings: {args.placeholder_token}") 553 print(f"Adding text embeddings: {args.placeholder_token}")
@@ -577,14 +568,25 @@ def main():
577 568
578 # Initialise the newly added placeholder token with the embeddings of the initializer token 569 # Initialise the newly added placeholder token with the embeddings of the initializer token
579 token_embeds = text_encoder.get_input_embeddings().weight.data 570 token_embeds = text_encoder.get_input_embeddings().weight.data
580 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
581 571
582 if args.resume_checkpoint is not None: 572 if args.resume_from:
583 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] 573 resumepath = Path(args.resume_from).joinpath("checkpoints")
584 else: 574
585 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 575 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
586 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 576 embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin")
587 token_embeds[token_id] = embeddings 577 embedding_data = torch.load(embedding_file, map_location="cpu")
578
579 emb = next(iter(embedding_data.values()))
580 if len(emb.shape) == 1:
581 emb = emb.unsqueeze(0)
582
583 token_embeds[token_id] = emb
584
585 original_token_embeds = token_embeds.clone().to(accelerator.device)
586
587 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
588 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
589 token_embeds[token_id] = embeddings
588 590
589 index_fixed_tokens = torch.arange(len(tokenizer)) 591 index_fixed_tokens = torch.arange(len(tokenizer))
590 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] 592 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
@@ -891,21 +893,16 @@ def main():
891 893
892 accelerator.backward(loss) 894 accelerator.backward(loss)
893 895
894 # Keep the token embeddings fixed except the newly added
895 # embeddings for the concept, as we only want to optimize the concept embeddings
896 if accelerator.num_processes > 1:
897 token_embeds = text_encoder.module.get_input_embeddings().weight
898 else:
899 token_embeds = text_encoder.get_input_embeddings().weight
900
901 # Get the index for tokens that we want to freeze
902 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
903
904 optimizer.step() 896 optimizer.step()
905 if not accelerator.optimizer_step_was_skipped: 897 if not accelerator.optimizer_step_was_skipped:
906 lr_scheduler.step() 898 lr_scheduler.step()
907 optimizer.zero_grad(set_to_none=True) 899 optimizer.zero_grad(set_to_none=True)
908 900
901 # Let's make sure we don't update any embedding weights besides the newly added token
902 with torch.no_grad():
903 text_encoder.get_input_embeddings(
904 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
905
909 loss = loss.detach().item() 906 loss = loss.detach().item()
910 train_loss += loss 907 train_loss += loss
911 908
@@ -916,19 +913,6 @@ def main():
916 913
917 global_step += 1 914 global_step += 1
918 915
919 if global_step % args.sample_frequency == 0:
920 sample_checkpoint = True
921
922 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
923 local_progress_bar.clear()
924 global_progress_bar.clear()
925
926 checkpointer.checkpoint(global_step + global_step_offset, "training")
927 save_args(basepath, args, {
928 "global_step": global_step + global_step_offset,
929 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
930 })
931
932 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 916 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
933 917
934 accelerator.log(logs, step=global_step) 918 accelerator.log(logs, step=global_step)
@@ -992,24 +976,30 @@ def main():
992 local_progress_bar.clear() 976 local_progress_bar.clear()
993 global_progress_bar.clear() 977 global_progress_bar.clear()
994 978
995 if min_val_loss > val_loss: 979 if accelerator.is_main_process:
996 accelerator.print( 980 if min_val_loss > val_loss:
997 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 981 accelerator.print(
998 checkpointer.checkpoint(global_step + global_step_offset, "milestone") 982 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
999 min_val_loss = val_loss 983 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
984 min_val_loss = val_loss
985
986 if epoch % args.checkpoint_frequency == 0:
987 checkpointer.checkpoint(global_step + global_step_offset, "training")
988 save_args(basepath, args, {
989 "global_step": global_step + global_step_offset
990 })
1000 991
1001 if sample_checkpoint and accelerator.is_main_process: 992 if epoch % args.sample_frequency == 0:
1002 checkpointer.save_samples( 993 checkpointer.save_samples(
1003 global_step + global_step_offset, 994 global_step + global_step_offset,
1004 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 995 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
1005 996
1006 # Create the pipeline using using the trained modules and save it. 997 # Create the pipeline using using the trained modules and save it.
1007 if accelerator.is_main_process: 998 if accelerator.is_main_process:
1008 print("Finished! Saving final checkpoint and resume state.") 999 print("Finished! Saving final checkpoint and resume state.")
1009 checkpointer.checkpoint(global_step + global_step_offset, "end") 1000 checkpointer.checkpoint(global_step + global_step_offset, "end")
1010 save_args(basepath, args, { 1001 save_args(basepath, args, {
1011 "global_step": global_step + global_step_offset, 1002 "global_step": global_step + global_step_offset
1012 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
1013 }) 1003 })
1014 accelerator.end_training() 1004 accelerator.end_training()
1015 1005
@@ -1018,8 +1008,7 @@ def main():
1018 print("Interrupted, saving checkpoint and resume state...") 1008 print("Interrupted, saving checkpoint and resume state...")
1019 checkpointer.checkpoint(global_step + global_step_offset, "end") 1009 checkpointer.checkpoint(global_step + global_step_offset, "end")
1020 save_args(basepath, args, { 1010 save_args(basepath, args, {
1021 "global_step": global_step + global_step_offset, 1011 "global_step": global_step + global_step_offset
1022 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
1023 }) 1012 })
1024 accelerator.end_training() 1013 accelerator.end_training()
1025 quit() 1014 quit()