diff options
Diffstat (limited to 'textual_inversion.py')
-rw-r--r-- | textual_inversion.py | 119 |
1 files changed, 54 insertions, 65 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index a9c3326..11babd8 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -170,7 +170,7 @@ def parse_args(): | |||
170 | parser.add_argument( | 170 | parser.add_argument( |
171 | "--lr_scheduler", | 171 | "--lr_scheduler", |
172 | type=str, | 172 | type=str, |
173 | default="one_cycle", | 173 | default="constant_with_warmup", |
174 | help=( | 174 | help=( |
175 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 175 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
176 | ' "constant", "constant_with_warmup", "one_cycle"]' | 176 | ' "constant", "constant_with_warmup", "one_cycle"]' |
@@ -231,14 +231,14 @@ def parse_args(): | |||
231 | parser.add_argument( | 231 | parser.add_argument( |
232 | "--checkpoint_frequency", | 232 | "--checkpoint_frequency", |
233 | type=int, | 233 | type=int, |
234 | default=500, | 234 | default=5, |
235 | help="How often to save a checkpoint and sample image", | 235 | help="How often to save a checkpoint and sample image (in epochs)", |
236 | ) | 236 | ) |
237 | parser.add_argument( | 237 | parser.add_argument( |
238 | "--sample_frequency", | 238 | "--sample_frequency", |
239 | type=int, | 239 | type=int, |
240 | default=100, | 240 | default=1, |
241 | help="How often to save a checkpoint and sample image", | 241 | help="How often to save a checkpoint and sample image (in epochs)", |
242 | ) | 242 | ) |
243 | parser.add_argument( | 243 | parser.add_argument( |
244 | "--sample_image_size", | 244 | "--sample_image_size", |
@@ -294,10 +294,9 @@ def parse_args(): | |||
294 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" | 294 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" |
295 | ) | 295 | ) |
296 | parser.add_argument( | 296 | parser.add_argument( |
297 | "--resume_checkpoint", | 297 | "--global_step", |
298 | type=str, | 298 | type=int, |
299 | default=None, | 299 | default=0, |
300 | help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)." | ||
301 | ) | 300 | ) |
302 | parser.add_argument( | 301 | parser.add_argument( |
303 | "--config", | 302 | "--config", |
@@ -512,19 +511,10 @@ def main(): | |||
512 | if len(args.placeholder_token) != 0: | 511 | if len(args.placeholder_token) != 0: |
513 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | 512 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) |
514 | 513 | ||
515 | global_step_offset = 0 | 514 | global_step_offset = args.global_step |
516 | if args.resume_from is not None: | 515 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
517 | basepath = Path(args.resume_from) | 516 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
518 | print("Resuming state from %s" % args.resume_from) | 517 | basepath.mkdir(parents=True, exist_ok=True) |
519 | with open(basepath.joinpath("resume.json"), 'r') as f: | ||
520 | state = json.load(f) | ||
521 | global_step_offset = state["args"].get("global_step", 0) | ||
522 | |||
523 | print("We've trained %d steps so far" % global_step_offset) | ||
524 | else: | ||
525 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
526 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) | ||
527 | basepath.mkdir(parents=True, exist_ok=True) | ||
528 | 518 | ||
529 | accelerator = Accelerator( | 519 | accelerator = Accelerator( |
530 | log_with=LoggerType.TENSORBOARD, | 520 | log_with=LoggerType.TENSORBOARD, |
@@ -557,6 +547,7 @@ def main(): | |||
557 | set_use_memory_efficient_attention_xformers(vae, True) | 547 | set_use_memory_efficient_attention_xformers(vae, True) |
558 | 548 | ||
559 | if args.gradient_checkpointing: | 549 | if args.gradient_checkpointing: |
550 | unet.enable_gradient_checkpointing() | ||
560 | text_encoder.gradient_checkpointing_enable() | 551 | text_encoder.gradient_checkpointing_enable() |
561 | 552 | ||
562 | print(f"Adding text embeddings: {args.placeholder_token}") | 553 | print(f"Adding text embeddings: {args.placeholder_token}") |
@@ -577,14 +568,25 @@ def main(): | |||
577 | 568 | ||
578 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 569 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
579 | token_embeds = text_encoder.get_input_embeddings().weight.data | 570 | token_embeds = text_encoder.get_input_embeddings().weight.data |
580 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | ||
581 | 571 | ||
582 | if args.resume_checkpoint is not None: | 572 | if args.resume_from: |
583 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] | 573 | resumepath = Path(args.resume_from).joinpath("checkpoints") |
584 | else: | 574 | |
585 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | 575 | for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): |
586 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | 576 | embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin") |
587 | token_embeds[token_id] = embeddings | 577 | embedding_data = torch.load(embedding_file, map_location="cpu") |
578 | |||
579 | emb = next(iter(embedding_data.values())) | ||
580 | if len(emb.shape) == 1: | ||
581 | emb = emb.unsqueeze(0) | ||
582 | |||
583 | token_embeds[token_id] = emb | ||
584 | |||
585 | original_token_embeds = token_embeds.clone().to(accelerator.device) | ||
586 | |||
587 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
588 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
589 | token_embeds[token_id] = embeddings | ||
588 | 590 | ||
589 | index_fixed_tokens = torch.arange(len(tokenizer)) | 591 | index_fixed_tokens = torch.arange(len(tokenizer)) |
590 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] | 592 | index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] |
@@ -891,21 +893,16 @@ def main(): | |||
891 | 893 | ||
892 | accelerator.backward(loss) | 894 | accelerator.backward(loss) |
893 | 895 | ||
894 | # Keep the token embeddings fixed except the newly added | ||
895 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
896 | if accelerator.num_processes > 1: | ||
897 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
898 | else: | ||
899 | token_embeds = text_encoder.get_input_embeddings().weight | ||
900 | |||
901 | # Get the index for tokens that we want to freeze | ||
902 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
903 | |||
904 | optimizer.step() | 896 | optimizer.step() |
905 | if not accelerator.optimizer_step_was_skipped: | 897 | if not accelerator.optimizer_step_was_skipped: |
906 | lr_scheduler.step() | 898 | lr_scheduler.step() |
907 | optimizer.zero_grad(set_to_none=True) | 899 | optimizer.zero_grad(set_to_none=True) |
908 | 900 | ||
901 | # Let's make sure we don't update any embedding weights besides the newly added token | ||
902 | with torch.no_grad(): | ||
903 | text_encoder.get_input_embeddings( | ||
904 | ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens] | ||
905 | |||
909 | loss = loss.detach().item() | 906 | loss = loss.detach().item() |
910 | train_loss += loss | 907 | train_loss += loss |
911 | 908 | ||
@@ -916,19 +913,6 @@ def main(): | |||
916 | 913 | ||
917 | global_step += 1 | 914 | global_step += 1 |
918 | 915 | ||
919 | if global_step % args.sample_frequency == 0: | ||
920 | sample_checkpoint = True | ||
921 | |||
922 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | ||
923 | local_progress_bar.clear() | ||
924 | global_progress_bar.clear() | ||
925 | |||
926 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
927 | save_args(basepath, args, { | ||
928 | "global_step": global_step + global_step_offset, | ||
929 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
930 | }) | ||
931 | |||
932 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 916 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
933 | 917 | ||
934 | accelerator.log(logs, step=global_step) | 918 | accelerator.log(logs, step=global_step) |
@@ -992,24 +976,30 @@ def main(): | |||
992 | local_progress_bar.clear() | 976 | local_progress_bar.clear() |
993 | global_progress_bar.clear() | 977 | global_progress_bar.clear() |
994 | 978 | ||
995 | if min_val_loss > val_loss: | 979 | if accelerator.is_main_process: |
996 | accelerator.print( | 980 | if min_val_loss > val_loss: |
997 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 981 | accelerator.print( |
998 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 982 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
999 | min_val_loss = val_loss | 983 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
984 | min_val_loss = val_loss | ||
985 | |||
986 | if epoch % args.checkpoint_frequency == 0: | ||
987 | checkpointer.checkpoint(global_step + global_step_offset, "training") | ||
988 | save_args(basepath, args, { | ||
989 | "global_step": global_step + global_step_offset | ||
990 | }) | ||
1000 | 991 | ||
1001 | if sample_checkpoint and accelerator.is_main_process: | 992 | if epoch % args.sample_frequency == 0: |
1002 | checkpointer.save_samples( | 993 | checkpointer.save_samples( |
1003 | global_step + global_step_offset, | 994 | global_step + global_step_offset, |
1004 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 995 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
1005 | 996 | ||
1006 | # Create the pipeline using using the trained modules and save it. | 997 | # Create the pipeline using using the trained modules and save it. |
1007 | if accelerator.is_main_process: | 998 | if accelerator.is_main_process: |
1008 | print("Finished! Saving final checkpoint and resume state.") | 999 | print("Finished! Saving final checkpoint and resume state.") |
1009 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1000 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
1010 | save_args(basepath, args, { | 1001 | save_args(basepath, args, { |
1011 | "global_step": global_step + global_step_offset, | 1002 | "global_step": global_step + global_step_offset |
1012 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
1013 | }) | 1003 | }) |
1014 | accelerator.end_training() | 1004 | accelerator.end_training() |
1015 | 1005 | ||
@@ -1018,8 +1008,7 @@ def main(): | |||
1018 | print("Interrupted, saving checkpoint and resume state...") | 1008 | print("Interrupted, saving checkpoint and resume state...") |
1019 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 1009 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
1020 | save_args(basepath, args, { | 1010 | save_args(basepath, args, { |
1021 | "global_step": global_step + global_step_offset, | 1011 | "global_step": global_step + global_step_offset |
1022 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | ||
1023 | }) | 1012 | }) |
1024 | accelerator.end_training() | 1013 | accelerator.end_training() |
1025 | quit() | 1014 | quit() |