From 6d46bf79bd7710cea799fbfe27c12d06d12cd53f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 27 Apr 2023 07:47:59 +0200 Subject: Update --- infer.py | 22 +++++++++++++++--- train_lora.py | 13 +++++++++-- train_ti.py | 13 +++++++++-- training/functional.py | 58 ++++++++++++++++++++++++++--------------------- training/strategy/lora.py | 8 +++++++ training/util.py | 22 +++++++++++------- util/ti.py | 24 ++++++++++++++++++++ 7 files changed, 119 insertions(+), 41 deletions(-) create mode 100644 util/ti.py diff --git a/infer.py b/infer.py index 4648c0a..7346de9 100644 --- a/infer.py +++ b/infer.py @@ -35,6 +35,7 @@ from models.clip.embeddings import patch_managed_embeddings from models.clip.tokenizer import MultiCLIPTokenizer from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from util.files import load_config, load_embeddings_from_dir +from util.ti import load_embeddings torch.backends.cuda.matmul.allow_tf32 = True @@ -229,7 +230,7 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def load_embeddings(pipeline, embeddings_dir): +def load_embeddings_dir(pipeline, embeddings_dir): added_tokens, added_ids = load_embeddings_from_dir( pipeline.tokenizer, pipeline.text_encoder.text_model.embeddings, @@ -258,6 +259,9 @@ def load_lora(pipeline, path): text_encoder_lora_ds = { k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k } + ti_lora_ds = { + k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k + } unet_config = LoraConfig(**lora_config["peft_config"]) pipeline.unet = LoraModel(unet_config, pipeline.unet) @@ -268,6 +272,18 @@ def load_lora(pipeline, path): pipeline.text_encoder = LoraModel(text_encoder_config, pipeline.text_encoder) set_peft_model_state_dict(pipeline.text_encoder, text_encoder_lora_ds) + tokens = [k for k, _ in ti_lora_ds] + token_embeddings = [v for _, v in ti_lora_ds] + + added_tokens, added_ids = load_embeddings( + tokenizer=pipeline.tokenizer, + embeddings=pipeline.text_encoder.text_model.embeddings, + tokens=tokens, + token_embeddings=token_embeddings, + ) + pipeline.text_encoder.text_model.embeddings.persist() + print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") + return @@ -435,7 +451,7 @@ class CmdParse(cmd.Cmd): return True if elements[0] == 'reload_embeddings': - load_embeddings(self.pipeline, self.ti_embeddings_dir) + load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) return try: @@ -475,7 +491,7 @@ def main(): pipeline = create_pipeline(args.model, dtype) - load_embeddings(pipeline, args.ti_embeddings_dir) + load_embeddings_dir(pipeline, args.ti_embeddings_dir) load_lora(pipeline, args.lora_embedding) # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) diff --git a/train_lora.py b/train_lora.py index c197206..9cf17c7 100644 --- a/train_lora.py +++ b/train_lora.py @@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models from training.strategy.lora import lora_strategy from training.optimization import get_scheduler -from training.util import save_args +from training.util import AverageMeter, save_args # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] @@ -1035,6 +1035,11 @@ def main(): lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles + avg_loss = AverageMeter() + avg_acc = AverageMeter() + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() + while True: if len(auto_cycles) != 0: response = auto_cycles.pop(0) @@ -1122,7 +1127,7 @@ def main(): warmup_epochs=lr_warmup_epochs, ) - lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter + 1}" + lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" trainer( strategy=lora_strategy, @@ -1142,6 +1147,10 @@ def main(): sample_frequency=lora_sample_frequency, offset_noise_strength=args.offset_noise_strength, no_val=args.valid_set_size == 0, + avg_loss=avg_loss, + avg_acc=avg_acc, + avg_loss_val=avg_loss_val, + avg_acc_val=avg_acc_val, ) training_iter += 1 diff --git a/train_ti.py b/train_ti.py index d1e5467..fce4a5e 100644 --- a/train_ti.py +++ b/train_ti.py @@ -23,7 +23,7 @@ from data.csv import VlpnDataModule, keyword_filter from training.functional import train, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler -from training.util import save_args +from training.util import AverageMeter, save_args logger = get_logger(__name__) @@ -920,6 +920,11 @@ def main(): lr_warmup_epochs = args.lr_warmup_epochs lr_cycles = args.lr_cycles + avg_loss = AverageMeter() + avg_acc = AverageMeter() + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() + while True: if len(auto_cycles) != 0: response = auto_cycles.pop(0) @@ -977,7 +982,7 @@ def main(): mid_point=args.lr_mid_point, ) - checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter + 1}" + checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" trainer( train_dataloader=datamodule.train_dataloader, @@ -994,6 +999,10 @@ def main(): sample_frequency=sample_frequency, placeholder_tokens=placeholder_tokens, placeholder_token_ids=placeholder_token_ids, + avg_loss=avg_loss, + avg_acc=avg_acc, + avg_loss_val=avg_loss_val, + avg_acc_val=avg_acc_val, ) training_iter += 1 diff --git a/training/functional.py b/training/functional.py index 695a24f..3036ed9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -461,6 +461,10 @@ def train_loop( num_epochs: int = 100, gradient_accumulation_steps: int = 1, group_labels: list[str] = [], + avg_loss: AverageMeter = AverageMeter(), + avg_acc: AverageMeter = AverageMeter(), + avg_loss_val: AverageMeter = AverageMeter(), + avg_acc_val: AverageMeter = AverageMeter(), callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) @@ -472,14 +476,8 @@ def train_loop( global_step = 0 cache = {} - avg_loss = AverageMeter() - avg_acc = AverageMeter() - - avg_loss_val = AverageMeter() - avg_acc_val = AverageMeter() - - best_acc = 0.0 - best_acc_val = 0.0 + best_acc = avg_acc.avg + best_acc_val = avg_acc_val.avg local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), @@ -544,12 +542,12 @@ def train_loop( accelerator.backward(loss) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) + avg_loss.update(loss.item(), bsz) + avg_acc.update(acc.item(), bsz) logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), + "train/loss": avg_loss.avg, + "train/acc": avg_acc.avg, "train/cur_loss": loss.item(), "train/cur_acc": acc.item(), } @@ -603,47 +601,47 @@ def train_loop( loss = loss.detach_() acc = acc.detach_() - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + cur_loss_val.update(loss.item(), bsz) + cur_acc_val.update(acc.item(), bsz) - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + avg_loss_val.update(loss.item(), bsz) + avg_acc_val.update(acc.item(), bsz) local_progress_bar.update(1) global_progress_bar.update(1) logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), + "val/loss": avg_loss_val.avg, + "val/acc": avg_acc_val.avg, "val/cur_loss": loss.item(), "val/cur_acc": acc.item(), } local_progress_bar.set_postfix(**logs) - logs["val/cur_loss"] = cur_loss_val.avg.item() - logs["val/cur_acc"] = cur_acc_val.avg.item() + logs["val/cur_loss"] = cur_loss_val.avg + logs["val/cur_acc"] = cur_acc_val.avg accelerator.log(logs, step=global_step) if accelerator.is_main_process: - if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: + if avg_acc_val.avg > best_acc_val and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc_val = avg_acc_val.avg.item() + best_acc_val = avg_acc_val.avg else: if accelerator.is_main_process: - if avg_acc.avg.item() > best_acc and milestone_checkpoints: + if avg_acc.avg > best_acc and milestone_checkpoints: local_progress_bar.clear() global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") + f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") on_checkpoint(global_step, "milestone") - best_acc = avg_acc.avg.item() + best_acc = avg_acc.avg # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: @@ -688,6 +686,10 @@ def train( offset_noise_strength: float = 0.15, disc: Optional[ConvNeXtDiscriminator] = None, min_snr_gamma: int = 5, + avg_loss: AverageMeter = AverageMeter(), + avg_acc: AverageMeter = AverageMeter(), + avg_loss_val: AverageMeter = AverageMeter(), + avg_acc_val: AverageMeter = AverageMeter(), **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( @@ -737,6 +739,10 @@ def train( num_epochs=num_train_epochs, gradient_accumulation_steps=gradient_accumulation_steps, group_labels=group_labels, + avg_loss=avg_loss, + avg_acc=avg_acc, + avg_loss_val=avg_loss_val, + avg_acc_val=avg_acc_val, callbacks=callbacks, ) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 1f0a117..3f4dbbc 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -138,6 +138,14 @@ def lora_strategy_callbacks( state_dict.update(text_encoder_state_dict) lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) + if len(placeholder_tokens) != 0: + ti_state_dict = { + f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) + for (token, ids) + in zip(placeholder_tokens, placeholder_token_ids) + } + state_dict.update(ti_state_dict) + save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) diff --git a/training/util.py b/training/util.py index 8bd8a83..61f1533 100644 --- a/training/util.py +++ b/training/util.py @@ -16,19 +16,25 @@ def save_args(basepath: Path, args, extra={}): class AverageMeter: - avg: Any - - def __init__(self, name=None): - self.name = name + def __init__(self, inv_gamma=1.0, power=2 / 3): + self.inv_gamma = inv_gamma + self.power = power self.reset() def reset(self): - self.sum = self.count = self.avg = 0 + self.step = 0 + self.avg = 0 + + def get_decay(self): + if self.step <= 0: + return 1 + + return (self.step / self.inv_gamma) ** -self.power def update(self, val, n=1): - self.sum += val * n - self.count += n - self.avg = self.sum / self.count + for _ in range(n): + self.step += n + self.avg += (val - self.avg) * self.get_decay() class EMAModel(EMAModel_): diff --git a/util/ti.py b/util/ti.py new file mode 100644 index 0000000..4cc732e --- /dev/null +++ b/util/ti.py @@ -0,0 +1,24 @@ +from pathlib import Path + +import torch + +from models.clip.embeddings import ManagedCLIPTextEmbeddings +from models.clip.tokenizer import MultiCLIPTokenizer + + +def load_embeddings( + tokenizer: MultiCLIPTokenizer, + embeddings: ManagedCLIPTextEmbeddings, + tokens: list[str], + token_embeddings: torch.FloatTensor, +): + num_vectors = [embedding.shape[0] for embedding in token_embeddings] + + token_ids = tokenizer.add_multi_tokens(tokens, num_vectors) + + embeddings.resize(len(tokenizer)) + + for (new_id, embeds) in zip(token_ids, token_embeddings): + embeddings.add_embed(new_id, embeds) + + return tokens, token_ids -- cgit v1.2.3-70-g09d2