From ba9fd1a10746d85d2502c8a79ac49db63d346b04 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 9 Apr 2023 11:29:31 +0200 Subject: Update --- infer.py | 1 + models/clip/embeddings.py | 7 ++- models/sparse.py | 13 ++++- train_dreambooth.py | 1 + train_lora.py | 140 ++++------------------------------------------ train_ti.py | 66 ++++++++++++++-------- training/functional.py | 4 +- 7 files changed, 72 insertions(+), 160 deletions(-) diff --git a/infer.py b/infer.py index 8fdf63d..4648c0a 100644 --- a/infer.py +++ b/infer.py @@ -235,6 +235,7 @@ def load_embeddings(pipeline, embeddings_dir): pipeline.text_encoder.text_model.embeddings, Path(embeddings_dir) ) + pipeline.text_encoder.text_model.embeddings.persist() print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6fda33c..dc4708a 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -37,7 +37,7 @@ def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initi class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings): + def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, dropout_p: float = 0.0): super().__init__(config) self.token_embedding = embeddings.token_embedding @@ -46,6 +46,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_override_embedding = PseudoSparseEmbedding( self.token_embedding.embedding_dim, + dropout_p=dropout_p, device=self.token_embedding.weight.device, dtype=self.token_embedding.weight.dtype, ) @@ -134,7 +135,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): return embeddings -def patch_managed_embeddings(text_encoder: CLIPTextModel) -> ManagedCLIPTextEmbeddings: - text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings) +def patch_managed_embeddings(text_encoder: CLIPTextModel, dropout_p: float = 0.0) -> ManagedCLIPTextEmbeddings: + text_embeddings = ManagedCLIPTextEmbeddings(text_encoder.config, text_encoder.text_model.embeddings, dropout_p) text_encoder.text_model.embeddings = text_embeddings return text_embeddings diff --git a/models/sparse.py b/models/sparse.py index d706db5..bcb2897 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -5,22 +5,29 @@ import torch.nn as nn class PseudoSparseEmbedding(nn.Module): - def __init__(self, embedding_dim: int, device=None, dtype=torch.float32): + def __init__(self, embedding_dim: int, dropout_p: float = 0.0, device=None, dtype=torch.float32): super().__init__() self.embedding_dim = embedding_dim self.dtype = dtype self.params = nn.ParameterList() + + if dropout_p > 0.0: + self.dropout = nn.Dropout(p=dropout_p) + else: + self.dropout = lambda x: x + self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) def forward(self, input_ids: torch.LongTensor): - ids = self.mapping[input_ids.to(self.mapping.device)] + input_ids = input_ids.to(self.mapping.device) + ids = self.mapping[input_ids] mask = ~(ids == -1) if torch.all(~mask): embs = None else: - embs = torch.stack([self.params[id] for id in ids[mask]]) + embs = self.dropout(torch.stack([self.params[id] for id in ids[mask]])) return embs, mask diff --git a/train_dreambooth.py b/train_dreambooth.py index f4d4cbb..2aca1e7 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -513,6 +513,7 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + embeddings.persist() print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") if args.scale_lr: diff --git a/train_lora.py b/train_lora.py index 8dbe45b..6e21634 100644 --- a/train_lora.py +++ b/train_lora.py @@ -158,12 +158,6 @@ def parse_args(): default=0, help="Tag dropout probability.", ) - parser.add_argument( - "--pti_tag_dropout", - type=float, - default=0, - help="Tag dropout probability.", - ) parser.add_argument( "--no_tag_shuffle", action="store_true", @@ -235,28 +229,12 @@ def parse_args(): type=int, default=2000 ) - parser.add_argument( - "--num_pti_epochs", - type=int, - default=None - ) - parser.add_argument( - "--num_pti_steps", - type=int, - default=500 - ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) - parser.add_argument( - "--pti_gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) parser.add_argument( "--lora_r", type=int, @@ -322,12 +300,6 @@ def parse_args(): default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) - parser.add_argument( - "--learning_rate_pti", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) parser.add_argument( "--learning_rate_emb", type=float, @@ -466,12 +438,6 @@ def parse_args(): default=1, help="How often to save a checkpoint and sample image", ) - parser.add_argument( - "--pti_sample_frequency", - type=int, - default=1, - help="How often to save a checkpoint and sample image", - ) parser.add_argument( "--sample_image_size", type=int, @@ -508,12 +474,6 @@ def parse_args(): default=1, help="Batch size (per device) for the training dataloader." ) - parser.add_argument( - "--pti_batch_size", - type=int, - default=1, - help="Batch size (per device) for the training dataloader." - ) parser.add_argument( "--sample_steps", type=int, @@ -526,6 +486,12 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--emb_dropout", + type=float, + default=0, + help="Embedding dropout probability.", + ) parser.add_argument( "--use_emb_decay", action="store_true", @@ -674,7 +640,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) + args.pretrained_model_name_or_path, args.emb_dropout) unet_config = LoraConfig( r=args.lora_r, @@ -720,6 +686,7 @@ def main(): raise ValueError("--embeddings_dir must point to an existing directory") added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + embeddings.persist() print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( @@ -744,19 +711,14 @@ def main(): args.learning_rate_text * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - args.learning_rate_pti = ( - args.learning_rate_pti * args.pti_gradient_accumulation_steps * - args.pti_batch_size * accelerator.num_processes - ) args.learning_rate_emb = ( - args.learning_rate_emb * args.pti_gradient_accumulation_steps * - args.pti_batch_size * accelerator.num_processes + args.learning_rate_emb * args.gradient_accumulation_steps * + args.train_batch_size * accelerator.num_processes ) if args.find_lr: args.learning_rate_unet = 1e-6 args.learning_rate_text = 1e-6 - args.learning_rate_pti = 1e-6 args.learning_rate_emb = 1e-6 args.lr_scheduler = "exponential_growth" @@ -817,7 +779,6 @@ def main(): args.lr_min_lr = args.learning_rate_unet args.learning_rate_unet = None args.learning_rate_text = None - args.learning_rate_pti = None args.learning_rate_emb = None elif args.optimizer == 'dadam': try: @@ -836,7 +797,6 @@ def main(): args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 - args.learning_rate_pti = 1.0 args.learning_rate_emb = 1.0 elif args.optimizer == 'dadan': try: @@ -853,7 +813,6 @@ def main(): args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 - args.learning_rate_pti = 1.0 args.learning_rate_emb = 1.0 else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") @@ -920,80 +879,6 @@ def main(): mid_point=args.lr_mid_point, ) - # PTI - # -------------------------------------------------------------------------------- - - if len(args.placeholder_tokens) != 0: - pti_datamodule = create_datamodule( - batch_size=args.pti_batch_size, - dropout=args.pti_tag_dropout, - filter=partial(keyword_filter, args.filter_tokens, args.collection, args.exclude_collections), - ) - pti_datamodule.setup() - - num_pti_epochs = args.num_pti_epochs - pti_sample_frequency = args.pti_sample_frequency - if num_pti_epochs is None: - num_pti_epochs = math.ceil( - args.num_pti_steps / len(pti_datamodule.train_dataset) - ) * args.pti_gradient_accumulation_steps - pti_sample_frequency = math.ceil(num_pti_epochs * (pti_sample_frequency / args.num_pti_steps)) - - if num_pti_epochs > 0: - pti_optimizer = create_optimizer( - [ - { - "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(), - "lr": args.learning_rate_pti, - "weight_decay": 0, - }, - ] - ) - - pti_lr_scheduler = create_lr_scheduler( - gradient_accumulation_steps=args.pti_gradient_accumulation_steps, - optimizer=pti_optimizer, - num_training_steps_per_epoch=len(pti_datamodule.train_dataloader), - train_epochs=num_pti_epochs, - ) - - continue_training = True - training_iter = 1 - - while continue_training: - print("") - print(f"============ PTI cycle {training_iter} ============") - print("") - - pti_project = f"pti_{training_iter}" - pti_output_dir = output_dir / pti_project - pti_checkpoint_output_dir = pti_output_dir / "model" - pti_sample_output_dir = pti_output_dir / "samples" - - trainer( - strategy=lora_strategy, - pti_mode=True, - project=pti_project, - train_dataloader=pti_datamodule.train_dataloader, - val_dataloader=pti_datamodule.val_dataloader, - optimizer=pti_optimizer, - lr_scheduler=pti_lr_scheduler, - num_train_epochs=num_pti_epochs, - gradient_accumulation_steps=args.pti_gradient_accumulation_steps, - # -- - group_labels=["emb"], - sample_output_dir=pti_sample_output_dir, - checkpoint_output_dir=pti_checkpoint_output_dir, - sample_frequency=pti_sample_frequency, - ) - - response = input("Run another cycle? [y/n] ") - continue_training = response.lower().strip() != "n" - training_iter += 1 - - if not args.train_emb: - embeddings.persist() - # LORA # -------------------------------------------------------------------------------- @@ -1062,9 +947,8 @@ def main(): print("") lora_project = f"lora_{training_iter}" - lora_output_dir = output_dir / lora_project - lora_checkpoint_output_dir = lora_output_dir / "model" - lora_sample_output_dir = lora_output_dir / "samples" + lora_checkpoint_output_dir = output_dir / lora_project / "model" + lora_sample_output_dir = output_dir / lora_project / "samples" trainer( strategy=lora_strategy, diff --git a/train_ti.py b/train_ti.py index daf8bc5..2d51800 100644 --- a/train_ti.py +++ b/train_ti.py @@ -457,6 +457,12 @@ def parse_args(): default=1.0, help="The weight of prior preservation loss." ) + parser.add_argument( + "--emb_dropout", + type=float, + default=0, + help="Embedding dropout probability.", + ) parser.add_argument( "--use_emb_decay", action="store_true", @@ -624,7 +630,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) + args.pretrained_model_name_or_path, args.emb_dropout) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -755,8 +761,6 @@ def main(): else: raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") - checkpoint_output_dir = output_dir / "checkpoints" - trainer = partial( train, accelerator=accelerator, @@ -777,7 +781,6 @@ def main(): global_step_offset=global_step_offset, offset_noise_strength=args.offset_noise_strength, # -- - checkpoint_output_dir=checkpoint_output_dir, use_emb_decay=args.use_emb_decay, emb_decay_target=args.emb_decay_target, emb_decay=args.emb_decay, @@ -793,11 +796,6 @@ def main(): ) def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): - if len(placeholder_tokens) == 1: - sample_output_dir = output_dir / f"samples_{placeholder_tokens[0]}" - else: - sample_output_dir = output_dir / "samples" - placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, @@ -809,7 +807,11 @@ def main(): stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) - print(f"{i + 1}: {stats}") + print("") + print(f"============ TI batch {i + 1} ============") + print("") + print(stats) + print("") filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] @@ -868,20 +870,36 @@ def main(): mid_point=args.lr_mid_point, ) - trainer( - project="textual_inversion", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - num_train_epochs=num_train_epochs, - # -- - group_labels=["emb"], - sample_output_dir=sample_output_dir, - sample_frequency=sample_frequency, - placeholder_tokens=placeholder_tokens, - placeholder_token_ids=placeholder_token_ids, - ) + continue_training = True + training_iter = 1 + + while continue_training: + print(f"------------ TI cycle {training_iter} ------------") + print("") + + project = f"{placeholder_tokens[0]}_{training_iter}" if len(placeholder_tokens) == 1 else f"{training_iter}" + sample_output_dir = output_dir / project / "samples" + checkpoint_output_dir = output_dir / project / "checkpoints" + + trainer( + project=project, + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + num_train_epochs=num_train_epochs, + # -- + group_labels=["emb"], + checkpoint_output_dir=checkpoint_output_dir, + sample_output_dir=sample_output_dir, + sample_frequency=sample_frequency, + placeholder_tokens=placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + ) + + response = input("Run another cycle? [y/n] ") + continue_training = response.lower().strip() != "n" + training_iter += 1 if not args.sequential: run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) diff --git a/training/functional.py b/training/functional.py index 7d49782..e14aeea 100644 --- a/training/functional.py +++ b/training/functional.py @@ -72,7 +72,7 @@ def make_grid(images, rows, cols): return grid -def get_models(pretrained_model_name_or_path: str): +def get_models(pretrained_model_name_or_path: str, emb_dropout: float = 0.0): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') @@ -81,7 +81,7 @@ def get_models(pretrained_model_name_or_path: str): sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') - embeddings = patch_managed_embeddings(text_encoder) + embeddings = patch_managed_embeddings(text_encoder, emb_dropout) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings -- cgit v1.2.3-70-g09d2