From ba9fd1a10746d85d2502c8a79ac49db63d346b04 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Apr 2023 11:29:31 +0200 Subject: Update --- train_lora.py | 140 +++++----------------------------------------------------- 1 file changed, 12 insertions(+), 128 deletions(-) (limited to 'train_lora.py') diff --git a/train_lora.py b/train_lora.py index 8dbe45b..6e21634 100644 --- a/train_lora.py +++ b/train_lora.py @@ -158,12 +158,6 @@ def parse_args(): default=0, help="Tag dropout probability.", ) - parser.add_argument( - "--pti_tag_dropout", - type=float, - default=0, - help="Tag dropout probability.", - ) parser.add_argument( "--no_tag_shuffle", action="store_true", @@ -235,28 +229,12 @@ def parse_args(): type=int, default=2000 ) - parser.add_argument( - "--num_pti_epochs", - type=int, - default=None - ) - parser.add_argument( - "--num_pti_steps", - type=int, - default=500 - ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) - parser.add_argument( - "--pti_gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) parser.add_argument( "--lora_r", type=int, @@ -322,12 +300,6 @@ def parse_args(): default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) - parser.add_argument( - "--learning_rate_pti", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) parser.add_argument( "--learning_rate_emb", type=float, @@ -466,12 +438,6 @@ def parse_args(): default=1, help="How often to save a checkpoint and sample image", ) - parser.add_argument( - "--pti_sample_frequency", - type=int, - default=1, - help="How often to save a checkpoint and sample image", - ) parser.add_argument( "--sample_image_size", type=int, @@ -508,12 +474,6 @@ def parse_args(): default=1, help="Batch size (per device) for the training dataloader." ) - parser.add_argument( - "--pti_batch_size", - type=int, - default=1, - help="Batch size (per device) for the training dataloader." - ) parser.add_argument( "--sample_steps", type=int, @@ -526,6 +486,12 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--emb_dropout", + type=float, + default=0, + help="Embedding dropout probability.", + ) parser.add_argument( "--use_emb_decay", action="store_true", @@ -674,7 +640,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) + args.pretrained_model_name_or_path, args.emb_dropout) unet_config = LoraConfig( r=args.lora_r, @@ -720,6 +686,7 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + embeddings.persist() print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( @@ -744,19 +711,14 @@ def main(): args.learning_rate_text * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - args.learning_rate_pti = ( - args.learning_rate_pti * args.pti_gradient_accumulation_steps * - args.pti_batch_size * accelerator.num_processes - ) args.learning_rate_emb = ( - args.learning_rate_emb * args.pti_gradient_accumulation_steps * - args.pti_batch_size * accelerator.num_processes + args.learning_rate_emb * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes ) if args.find_lr: args.learning_rate_unet = 1e-6 args.learning_rate_text = 1e-6 - args.learning_rate_pti = 1e-6 args.learning_rate_emb = 1e-6 args.lr_scheduler = "exponential_growth" @@ -817,7 +779,6 @@ def main(): args.lr_min_lr = args.learning_rate_unet args.learning_rate_unet = None args.learning_rate_text = None - args.learning_rate_pti = None args.learning_rate_emb = None elif args.optimizer == 'dadam': try: @@ -836,7 +797,6 @@ def main(): args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 - args.learning_rate_pti = 1.0 args.learning_rate_emb = 1.0 elif args.optimizer == 'dadan': try: @@ -853,7 +813,6 @@ def main(): args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 - args.learning_rate_pti = 1.0 args.learning_rate_emb = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") @@ -920,80 +879,6 @@ def main(): mid_point=args.lr_mid_point, ) - # PTI - # -------------------------------------------------------------------------------- - - if len(args.placeholder_tokens) != 0: - pti_datamodule = create_datamodule( - batch_size=args.pti_batch_size, - dropout=args.pti_tag_dropout, - filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), - ) - pti_datamodule.setup() - - num_pti_epochs = args.num_pti_epochs - pti_sample_frequency = args.pti_sample_frequency - if num_pti_epochs is None: - num_pti_epochs = math.ceil( - args.num_pti_steps / len(pti_datamodule.train_dataset) - ) * args.pti_gradient_accumulation_steps - pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) - - if num_pti_epochs > 0: - pti_optimizer = create_optimizer( - [ - { - "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), - "lr": args.learning_rate_pti, - "weight_decay": 0, - }, - ] - ) - - pti_lr_scheduler = create_lr_scheduler( - gradient_accumulation_steps=args.pti_gradient_accumulation_steps, - optimizer=pti_optimizer, - num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), - train_epochs=num_pti_epochs, - ) - - continue_training = True - training_iter = 1 - - while continue_training: - print("") - print(f"============ PTI cycle {training_iter} ============") - print("") - - pti_project = f"pti_{training_iter}" - pti_output_dir = output_dir / pti_project - pti_checkpoint_output_dir = pti_output_dir / "model" - pti_sample_output_dir = pti_output_dir / "samples" - - trainer( - strategy=lora_strategy, - pti_mode=True, - project=pti_project, - train_dataloader=pti_datamodule.train_dataloader, - val_dataloader=pti_datamodule.val_dataloader, - optimizer=pti_optimizer, - lr_scheduler=pti_lr_scheduler, - num_train_epochs=num_pti_epochs, - gradient_accumulation_steps=args.pti_gradient_accumulation_steps, - # -- - group_labels=["emb"], - sample_output_dir=pti_sample_output_dir, - checkpoint_output_dir=pti_checkpoint_output_dir, - sample_frequency=pti_sample_frequency, - ) - - response = input("Run another cycle? [y/n] ") - continue_training = response.lower().strip() != "n" - training_iter += 1 - - if not args.train_emb: - embeddings.persist() - # LORA # -------------------------------------------------------------------------------- @@ -1062,9 +947,8 @@ def main(): print("") lora_project = f"lora_{training_iter}" - lora_output_dir = output_dir / lora_project - lora_checkpoint_output_dir = lora_output_dir / "model" - lora_sample_output_dir = lora_output_dir / "samples" + lora_checkpoint_output_dir = output_dir / lora_project / "model" + lora_sample_output_dir = output_dir / lora_project / "samples" trainer( strategy=lora_strategy, -- cgit v1.2.3-54-g00ecf