From d5696615a84a768307e82d13e50b4aef64f69dbd Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 10:03:05 +0100 Subject: Extended Dreambooth: Train TI tokens separately --- train_dreambooth.py | 147 +++++++++++++++++++++++++++------------------------- 1 file changed, 76 insertions(+), 71 deletions(-) (limited to 'train_dreambooth.py') diff --git a/train_dreambooth.py b/train_dreambooth.py index 944256c..05777d0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -510,19 +510,6 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( - tokenizer=tokenizer, - embeddings=embeddings, - placeholder_tokens=args.placeholder_tokens, - initializer_tokens=args.initializer_tokens, - num_vectors=args.num_vectors - ) - - if len(placeholder_token_ids) != 0: - initializer_token_id_lens = [len(id) for id in initializer_token_ids] - placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) - print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -554,79 +541,98 @@ def main(): noise_scheduler=noise_scheduler, dtype=weight_dtype, seed=args.seed, - callbacks_fn=textual_inversion_strategy ) # Initial TI print("Phase 1: Textual Inversion") + ti_lr = 1e-1 + ti_train_epochs = 10 + cur_dir = output_dir.joinpath("1-ti") cur_dir.mkdir(parents=True, exist_ok=True) - datamodule = VlpnDataModule( - data_file=args.train_data_file, - batch_size=args.train_batch_size, - tokenizer=tokenizer, - class_subdir=args.class_image_dir, - num_class_images=args.num_class_images, - size=args.resolution, - shuffle=not args.no_tag_shuffle, - template_key=args.train_data_template, - valid_set_size=args.valid_set_size, - valid_set_repeat=args.valid_set_repeat, - seed=args.seed, - filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), - dtype=weight_dtype - ) - datamodule.setup() + for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors): + print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})") - optimizer = optimizer_class( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - lr=2e-1, - weight_decay=0.0, - ) + cur_subdir = cur_dir.joinpath(placeholder_token) + cur_subdir.mkdir(parents=True, exist_ok=True) - lr_scheduler = get_scheduler( - "linear", - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, - train_epochs=30, - warmup_epochs=10, - ) + placeholder_token_ids, _ = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=[placeholder_token], + initializer_tokens=[initializer_token], + num_vectors=num_vectors + ) - trainer( - project="textual_inversion", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - num_train_epochs=30, - sample_frequency=5, - checkpoint_frequency=9999999, - with_prior_preservation=args.num_class_images != 0, - prior_loss_weight=args.prior_loss_weight, - # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, - output_dir=cur_dir, - placeholder_tokens=args.placeholder_tokens, - placeholder_token_ids=placeholder_token_ids, - learning_rate=2e-1, - gradient_checkpointing=args.gradient_checkpointing, - use_emb_decay=True, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, - ) + datamodule = VlpnDataModule( + data_file=args.train_data_file, + batch_size=args.train_batch_size, + tokenizer=tokenizer, + class_subdir=args.class_image_dir, + num_class_images=args.num_class_images, + size=args.resolution, + shuffle=not args.no_tag_shuffle, + template_key=args.train_data_template, + valid_set_size=args.valid_set_size, + valid_set_repeat=args.valid_set_repeat, + seed=args.seed, + filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections), + dtype=weight_dtype + ) + datamodule.setup() + + optimizer = optimizer_class( + text_encoder.text_model.embeddings.temp_token_embedding.parameters(), + lr=ti_lr, + weight_decay=0.0, + ) + + lr_scheduler = get_scheduler( + "one_cycle", + optimizer=optimizer, + num_training_steps_per_epoch=len(datamodule.train_dataloader), + gradient_accumulation_steps=args.gradient_accumulation_steps, + train_epochs=ti_train_epochs, + warmup_epochs=ti_train_epochs // 4, + ) + + trainer( + callbacks_fn=textual_inversion_strategy, + project="textual_inversion", + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=ti_train_epochs, + sample_frequency=1, + checkpoint_frequency=9999999, + with_prior_preservation=args.num_class_images != 0, + prior_loss_weight=args.prior_loss_weight, + # -- + tokenizer=tokenizer, + sample_scheduler=sample_scheduler, + output_dir=cur_subdir, + placeholder_tokens=[placeholder_token], + placeholder_token_ids=placeholder_token_ids, + learning_rate=ti_lr, + gradient_checkpointing=args.gradient_checkpointing, + use_emb_decay=True, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + ) + + embeddings.persist() # Dreambooth print("Phase 2: Dreambooth") - cur_dir = output_dir.joinpath("2db") + cur_dir = output_dir.joinpath("2-db") cur_dir.mkdir(parents=True, exist_ok=True) args.seed = (args.seed + 28635) >> 32 @@ -685,6 +691,7 @@ def main(): ) trainer( + callbacks_fn=dreambooth_strategy, project="dreambooth", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, @@ -692,14 +699,12 @@ def main(): lr_scheduler=lr_scheduler, num_train_epochs=args.num_train_epochs, sample_frequency=args.sample_frequency, - checkpoint_frequency=args.checkpoint_frequency, with_prior_preservation=args.num_class_images != 0, prior_loss_weight=args.prior_loss_weight, # -- tokenizer=tokenizer, sample_scheduler=sample_scheduler, output_dir=cur_dir, - gradient_checkpointing=args.gradient_checkpointing, train_text_encoder_epochs=args.train_text_encoder_epochs, max_grad_norm=args.max_grad_norm, use_ema=args.use_ema, -- cgit v1.2.3-54-g00ecf