From 83808fe00ac891ad2f625388d144c318b2cb5bfe Mon Sep 17 00:00:00 2001
From: Volpeon <git@volpeon.ink>
Date: Sat, 14 Jan 2023 21:53:07 +0100
Subject: WIP: Modularization ("free(): invalid pointer" my ass)

---
 trainer/base.py       | 544 ++++++++++++++++++++++++++++++++++++++++++++++++++
 trainer/dreambooth.py |   0
 trainer/ti.py         | 164 +++++++++++++++
 3 files changed, 708 insertions(+)
 create mode 100644 trainer/base.py
 create mode 100644 trainer/dreambooth.py
 create mode 100644 trainer/ti.py

(limited to 'trainer')

diff --git a/trainer/base.py b/trainer/base.py
new file mode 100644
index 0000000..e700dd6
--- /dev/null
+++ b/trainer/base.py
@@ -0,0 +1,544 @@
+from pathlib import Path
+import math
+from contextlib import contextmanager
+from typing import Type, Optional
+import itertools
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+from accelerate import Accelerator
+from transformers import CLIPTextModel
+from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
+
+from tqdm.auto import tqdm
+from PIL import Image
+
+from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
+from models.clip.tokenizer import MultiCLIPTokenizer
+from models.clip.util import get_extended_embeddings
+from training.util import AverageMeter
+
+
+def make_grid(images, rows, cols):
+    w, h = images[0].size
+    grid = Image.new('RGB', size=(cols*w, rows*h))
+    for i, image in enumerate(images):
+        grid.paste(image, box=(i % cols*w, i//cols*h))
+    return grid
+
+
+class Checkpointer():
+    def __init__(
+        self,
+        accelerator: Accelerator,
+        vae: AutoencoderKL,
+        unet: UNet2DConditionModel,
+        text_encoder: CLIPTextModel,
+        tokenizer: MultiCLIPTokenizer,
+        sample_scheduler,
+        dtype,
+        train_dataloader: DataLoader,
+        val_dataloader: DataLoader,
+        output_dir: Path,
+        sample_steps: int = 20,
+        sample_guidance_scale: float = 7.5,
+        sample_image_size: int = 768,
+        sample_batches: int = 1,
+        sample_batch_size: int = 1,
+        seed: Optional[int] = None,
+        *args,
+        **kwargs,
+    ):
+        self.accelerator = accelerator
+        self.vae = vae
+        self.unet = unet
+        self.text_encoder = text_encoder
+        self.tokenizer = tokenizer
+        self.sample_scheduler = sample_scheduler
+        self.dtype = dtype
+        self.train_dataloader = train_dataloader
+        self.val_dataloader = val_dataloader
+        self.output_dir = output_dir
+        self.sample_steps = sample_steps
+        self.sample_guidance_scale = sample_guidance_scale
+        self.sample_image_size = sample_image_size
+        self.sample_batches = sample_batches
+        self.sample_batch_size = sample_batch_size
+        self.seed = seed if seed is not None else torch.random.seed()
+
+    @torch.no_grad()
+    def checkpoint(self, step: int, postfix: str):
+        pass
+
+    @torch.inference_mode()
+    def save_samples(self, step: int):
+        print(f"Saving samples for step {step}...")
+
+        samples_path = self.output_dir.joinpath("samples")
+
+        grid_cols = min(self.sample_batch_size, 4)
+        grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
+
+        unet = self.accelerator.unwrap_model(self.unet)
+        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
+
+        orig_unet_dtype = unet.dtype
+        orig_text_encoder_dtype = text_encoder.dtype
+
+        unet.to(dtype=self.dtype)
+        text_encoder.to(dtype=self.dtype)
+
+        pipeline = VlpnStableDiffusion(
+            text_encoder=text_encoder,
+            vae=self.vae,
+            unet=self.unet,
+            tokenizer=self.tokenizer,
+            scheduler=self.sample_scheduler,
+        ).to(self.accelerator.device)
+        pipeline.set_progress_bar_config(dynamic_ncols=True)
+
+        generator = torch.Generator(device=self.accelerator.device).manual_seed(self.seed)
+
+        for pool, data, gen in [
+            ("stable", self.val_dataloader, generator),
+            ("val", self.val_dataloader, None),
+            ("train", self.train_dataloader, None)
+        ]:
+            all_samples = []
+            file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
+            file_path.parent.mkdir(parents=True, exist_ok=True)
+
+            batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
+            prompt_ids = [
+                prompt
+                for batch in batches
+                for prompt in batch["prompt_ids"]
+            ]
+            nprompt_ids = [
+                prompt
+                for batch in batches
+                for prompt in batch["nprompt_ids"]
+            ]
+
+            for i in range(self.sample_batches):
+                start = i * self.sample_batch_size
+                end = (i + 1) * self.sample_batch_size
+                prompt = prompt_ids[start:end]
+                nprompt = nprompt_ids[start:end]
+
+                samples = pipeline(
+                    prompt=prompt,
+                    negative_prompt=nprompt,
+                    height=self.sample_image_size,
+                    width=self.sample_image_size,
+                    generator=gen,
+                    guidance_scale=self.sample_guidance_scale,
+                    num_inference_steps=self.sample_steps,
+                    output_type='pil'
+                ).images
+
+                all_samples += samples
+
+            image_grid = make_grid(all_samples, grid_rows, grid_cols)
+            image_grid.save(file_path, quality=85)
+
+        unet.to(dtype=orig_unet_dtype)
+        text_encoder.to(dtype=orig_text_encoder_dtype)
+
+        del unet
+        del text_encoder
+        del generator
+        del pipeline
+
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+
+
+class TrainingStrategy():
+    def __init__(
+        self,
+        tokenizer: MultiCLIPTokenizer,
+        *args,
+        **kwargs,
+    ):
+        self.tokenizer = tokenizer
+        self.checkpointer = Checkpointer(tokenizer=tokenizer, *args, **kwargs)
+
+    @property
+    def main_model(self) -> nn.Module:
+        ...
+
+    @contextmanager
+    def on_train(self, epoch: int):
+        try:
+            self.tokenizer.train()
+            yield
+        finally:
+            pass
+
+    @contextmanager
+    def on_eval(self):
+        try:
+            self.tokenizer.eval()
+            yield
+        finally:
+            pass
+
+    def on_before_optimize(self, epoch: int):
+        ...
+
+    def on_after_optimize(self, lr: float):
+        ...
+
+    def on_log():
+        return {}
+
+
+def loss_step(
+    vae: AutoencoderKL,
+    unet: UNet2DConditionModel,
+    text_encoder: CLIPTextModel,
+    seed: int,
+    noise_scheduler,
+    prior_loss_weight: float,
+    step: int,
+    batch: dict,
+    eval: bool = False
+):
+    # Convert images to latent space
+    latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
+    latents = latents * 0.18215
+
+    generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
+
+    # Sample noise that we'll add to the latents
+    noise = torch.randn(
+        latents.shape,
+        dtype=latents.dtype,
+        layout=latents.layout,
+        device=latents.device,
+        generator=generator
+    )
+    bsz = latents.shape[0]
+    # Sample a random timestep for each image
+    timesteps = torch.randint(
+        0,
+        noise_scheduler.config.num_train_timesteps,
+        (bsz,),
+        generator=generator,
+        device=latents.device,
+    )
+    timesteps = timesteps.long()
+
+    # Add noise to the latents according to the noise magnitude at each timestep
+    # (this is the forward diffusion process)
+    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+    noisy_latents = noisy_latents.to(dtype=unet.dtype)
+
+    # Get the text embedding for conditioning
+    encoder_hidden_states = get_extended_embeddings(
+        text_encoder,
+        batch["input_ids"],
+        batch["attention_mask"]
+    )
+    encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
+
+    # Predict the noise residual
+    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+    # Get the target for loss depending on the prediction type
+    if noise_scheduler.config.prediction_type == "epsilon":
+        target = noise
+    elif noise_scheduler.config.prediction_type == "v_prediction":
+        target = noise_scheduler.get_velocity(latents, noise, timesteps)
+    else:
+        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+    if batch["with_prior"].all():
+        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+        target, target_prior = torch.chunk(target, 2, dim=0)
+
+        # Compute instance loss
+        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+        # Compute prior loss
+        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+        # Add the prior loss to the instance loss.
+        loss = loss + prior_loss_weight * prior_loss
+    else:
+        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+    acc = (model_pred == target).float().mean()
+
+    return loss, acc, bsz
+
+
+def train_loop(
+    strategy: TrainingStrategy,
+    accelerator: Accelerator,
+    vae: AutoencoderKL,
+    unet: UNet2DConditionModel,
+    text_encoder: CLIPTextModel,
+    train_dataloader: DataLoader,
+    val_dataloader: DataLoader,
+    seed: int,
+    optimizer: torch.optim.Optimizer,
+    lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
+    noise_scheduler,
+    prior_loss_weight: float = 1.0,
+    sample_frequency: int = 10,
+    checkpoint_frequency: int = 50,
+    global_step_offset: int = 0,
+    num_epochs: int = 100,
+):
+    num_training_steps_per_epoch = math.ceil(
+        len(train_dataloader) / accelerator.gradient_accumulation_steps
+    )
+    num_val_steps_per_epoch = len(val_dataloader)
+
+    num_training_steps = num_training_steps_per_epoch * num_epochs
+    num_val_steps = num_val_steps_per_epoch * num_epochs
+
+    global_step = 0
+
+    avg_loss = AverageMeter()
+    avg_acc = AverageMeter()
+
+    avg_loss_val = AverageMeter()
+    avg_acc_val = AverageMeter()
+
+    max_acc_val = 0.0
+
+    local_progress_bar = tqdm(
+        range(num_training_steps_per_epoch + num_val_steps_per_epoch),
+        disable=not accelerator.is_local_main_process,
+        dynamic_ncols=True
+    )
+    local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
+
+    global_progress_bar = tqdm(
+        range(num_training_steps + num_val_steps),
+        disable=not accelerator.is_local_main_process,
+        dynamic_ncols=True
+    )
+    global_progress_bar.set_description("Total progress")
+
+    loss_step_ = partial(
+        loss_step,
+        vae,
+        unet,
+        text_encoder,
+        seed,
+        noise_scheduler,
+        prior_loss_weight
+    )
+
+    try:
+        for epoch in range(num_epochs):
+            if accelerator.is_main_process:
+                if epoch % sample_frequency == 0 and epoch != 0:
+                    strategy.checkpointer.save_samples(global_step + global_step_offset)
+
+                if epoch % checkpoint_frequency == 0 and epoch != 0:
+                    strategy.checkpointer.checkpoint(global_step + global_step_offset, "training")
+
+            local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
+            local_progress_bar.reset()
+
+            strategy.main_model.train()
+
+            with strategy.on_train(epoch):
+                for step, batch in enumerate(train_dataloader):
+                    with accelerator.accumulate(strategy.main_model):
+                        loss, acc, bsz = loss_step_(step, batch)
+
+                        accelerator.backward(loss)
+
+                        strategy.on_before_optimize(epoch)
+
+                        optimizer.step()
+                        lr_scheduler.step()
+                        optimizer.zero_grad(set_to_none=True)
+
+                        avg_loss.update(loss.detach_(), bsz)
+                        avg_acc.update(acc.detach_(), bsz)
+
+                    # Checks if the accelerator has performed an optimization step behind the scenes
+                    if accelerator.sync_gradients:
+                        strategy.on_after_optimize(lr_scheduler.get_last_lr()[0])
+
+                        local_progress_bar.update(1)
+                        global_progress_bar.update(1)
+
+                        global_step += 1
+
+                    logs = {
+                        "train/loss": avg_loss.avg.item(),
+                        "train/acc": avg_acc.avg.item(),
+                        "train/cur_loss": loss.item(),
+                        "train/cur_acc": acc.item(),
+                        "lr": lr_scheduler.get_last_lr()[0],
+                    }
+                    logs.update(strategy.on_log())
+
+                    accelerator.log(logs, step=global_step)
+
+                    local_progress_bar.set_postfix(**logs)
+
+                    if global_step >= num_training_steps:
+                        break
+
+            accelerator.wait_for_everyone()
+
+            strategy.main_model.eval()
+
+            cur_loss_val = AverageMeter()
+            cur_acc_val = AverageMeter()
+
+            with torch.inference_mode(), strategy.on_eval():
+                for step, batch in enumerate(val_dataloader):
+                    loss, acc, bsz = loss_step_(step, batch, True)
+
+                    loss = loss.detach_()
+                    acc = acc.detach_()
+
+                    cur_loss_val.update(loss, bsz)
+                    cur_acc_val.update(acc, bsz)
+
+                    avg_loss_val.update(loss, bsz)
+                    avg_acc_val.update(acc, bsz)
+
+                    local_progress_bar.update(1)
+                    global_progress_bar.update(1)
+
+                    logs = {
+                        "val/loss": avg_loss_val.avg.item(),
+                        "val/acc": avg_acc_val.avg.item(),
+                        "val/cur_loss": loss.item(),
+                        "val/cur_acc": acc.item(),
+                    }
+                    local_progress_bar.set_postfix(**logs)
+
+            logs["val/cur_loss"] = cur_loss_val.avg.item()
+            logs["val/cur_acc"] = cur_acc_val.avg.item()
+
+            accelerator.log(logs, step=global_step)
+
+            local_progress_bar.clear()
+            global_progress_bar.clear()
+
+            if accelerator.is_main_process:
+                if avg_acc_val.avg.item() > max_acc_val:
+                    accelerator.print(
+                        f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
+                    strategy.checkpointer.checkpoint(global_step + global_step_offset, "milestone")
+                    max_acc_val = avg_acc_val.avg.item()
+
+        # Create the pipeline using using the trained modules and save it.
+        if accelerator.is_main_process:
+            print("Finished!")
+            strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
+            strategy.checkpointer.save_samples(global_step + global_step_offset)
+            accelerator.end_training()
+
+    except KeyboardInterrupt:
+        if accelerator.is_main_process:
+            print("Interrupted")
+            strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
+            accelerator.end_training()
+
+
+class Trainer():
+    def __init__(
+        self,
+        accelerator: Accelerator,
+        unet: UNet2DConditionModel,
+        text_encoder: CLIPTextModel,
+        tokenizer: MultiCLIPTokenizer,
+        vae: AutoencoderKL,
+        noise_scheduler: DDPMScheduler,
+        sample_scheduler: DPMSolverMultistepScheduler,
+        train_dataloader: DataLoader,
+        val_dataloader: DataLoader,
+        dtype: torch.dtype,
+    ):
+        self.accelerator = accelerator
+        self.unet = unet
+        self.text_encoder = text_encoder
+        self.tokenizer = tokenizer
+        self.vae = vae
+        self.noise_scheduler = noise_scheduler
+        self.sample_scheduler = sample_scheduler
+        self.train_dataloader = train_dataloader
+        self.val_dataloader = val_dataloader
+        self.dtype = dtype
+
+    def __call__(
+        self,
+        strategy_class: Type[TrainingStrategy],
+        optimizer,
+        lr_scheduler,
+        num_train_epochs: int = 100,
+        sample_frequency: int = 20,
+        checkpoint_frequency: int = 50,
+        global_step_offset: int = 0,
+        prior_loss_weight: float = 0,
+        seed: Optional[int] = None,
+        **kwargs,
+    ):
+        unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = self.accelerator.prepare(
+            self.unet, self.text_encoder, optimizer, self.train_dataloader, self.val_dataloader, lr_scheduler
+        )
+
+        self.vae.to(self.accelerator.device, dtype=self.dtype)
+
+        for model in (unet, text_encoder, self.vae):
+            model.requires_grad_(False)
+            model.eval()
+
+        if seed is None:
+            seed = torch.random.seed()
+
+        strategy = strategy_class(
+            accelerator=self.accelerator,
+            vae=self.vae,
+            unet=unet,
+            text_encoder=text_encoder,
+            tokenizer=self.tokenizer,
+            sample_scheduler=self.sample_scheduler,
+            train_dataloader=train_dataloader,
+            val_dataloader=val_dataloader,
+            dtype=self.dtype,
+            seed=seed,
+            **kwargs
+        )
+
+        if self.accelerator.is_main_process:
+            self.accelerator.init_trackers("textual_inversion")
+
+        train_loop(
+            strategy=strategy,
+            accelerator=self.accelerator,
+            vae=self.vae,
+            unet=unet,
+            text_encoder=text_encoder,
+            train_dataloader=train_dataloader,
+            val_dataloader=val_dataloader,
+            seed=seed,
+            optimizer=optimizer,
+            lr_scheduler=lr_scheduler,
+            noise_scheduler=self.noise_scheduler,
+            prior_loss_weight=prior_loss_weight,
+            sample_frequency=sample_frequency,
+            checkpoint_frequency=checkpoint_frequency,
+            global_step_offset=global_step_offset,
+            num_epochs=num_train_epochs,
+        )
+
+        self.accelerator.free_memory()
diff --git a/trainer/dreambooth.py b/trainer/dreambooth.py
new file mode 100644
index 0000000..e69de29
diff --git a/trainer/ti.py b/trainer/ti.py
new file mode 100644
index 0000000..15cf747
--- /dev/null
+++ b/trainer/ti.py
@@ -0,0 +1,164 @@
+from contextlib import contextmanager, nullcontext
+
+import torch
+
+from slugify import slugify
+
+from diffusers import UNet2DConditionModel
+from transformers import CLIPTextModel
+
+from trainer.base import TrainingStrategy, Checkpointer
+from training.util import EMAModel
+
+
+class TextualInversionCheckpointer(Checkpointer):
+    def __init__(
+        self,
+        ema_embeddings: EMAModel,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+
+        self.ema_embeddings = ema_embeddings
+
+    @torch.no_grad()
+    def checkpoint(self, step, postfix):
+        print(f"Saving checkpoint for step {step}...")
+
+        checkpoints_path = self.output_dir.joinpath("checkpoints")
+        checkpoints_path.mkdir(parents=True, exist_ok=True)
+
+        text_encoder = self.accelerator.unwrap_model(self.text_encoder)
+
+        ema_context = self.ema_embeddings.apply_temporary(
+            text_encoder.text_model.embeddings.temp_token_embedding.parameters()
+        ) if self.ema_embeddings is not None else nullcontext()
+
+        with ema_context:
+            for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids):
+                text_encoder.text_model.embeddings.save_embed(
+                    ids,
+                    checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
+                )
+
+    @torch.inference_mode()
+    def save_samples(self, step):
+        ema_context = self.ema_embeddings.apply_temporary(
+            self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
+        ) if self.ema_embeddings is not None else nullcontext()
+
+        with ema_context:
+            super().save_samples(step)
+
+
+class TextualInversionTrainingStrategy(TrainingStrategy):
+    def __init__(
+        self,
+        unet: UNet2DConditionModel,
+        text_encoder: CLIPTextModel,
+        placeholder_tokens: list[str],
+        placeholder_token_ids: list[list[int]],
+        learning_rate: float,
+        gradient_checkpointing: bool = False,
+        use_emb_decay: bool = False,
+        emb_decay_target: float = 0.4,
+        emb_decay_factor: float = 1,
+        emb_decay_start: float = 1e-4,
+        use_ema: bool = False,
+        ema_inv_gamma: float = 1.0,
+        ema_power: int = 1,
+        ema_max_decay: float = 0.9999,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(
+            unet=unet,
+            text_encoder=text_encoder,
+            *args,
+            **kwargs
+        )
+
+        self.text_encoder = text_encoder
+        self.unet = unet
+
+        self.placeholder_tokens = placeholder_tokens
+        self.placeholder_token_ids = placeholder_token_ids
+
+        self.gradient_checkpointing = gradient_checkpointing
+
+        self.learning_rate = learning_rate
+        self.use_emb_decay = use_emb_decay
+        self.emb_decay_target = emb_decay_target
+        self.emb_decay_factor = emb_decay_factor
+        self.emb_decay_start = emb_decay_start
+
+        self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
+
+        self.ema_embeddings = None
+
+        if use_ema:
+            self.ema_embeddings = EMAModel(
+                self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
+                inv_gamma=ema_inv_gamma,
+                power=ema_power,
+                max_value=ema_max_decay,
+            )
+
+        self.checkpointer = TextualInversionCheckpointer(
+            unet=unet,
+            text_encoder=text_encoder,
+            ema_embeddings=self.ema_embeddings,
+            *args,
+            **kwargs
+        )
+
+    @property
+    def main_model(self):
+        return self.text_encoder
+
+    @contextmanager
+    def on_train(self, epoch: int):
+        try:
+            if self.gradient_checkpointing:
+                self.unet.train()
+
+            with super().on_eval():
+                yield
+        finally:
+            pass
+
+    @contextmanager
+    def on_eval(self):
+        try:
+            if self.gradient_checkpointing:
+                self.unet.eval()
+
+            ema_context = self.ema_embeddings.apply_temporary(
+                self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()
+            ) if self.ema_embeddings is not None else nullcontext()
+
+            with ema_context, super().on_eval():
+                yield
+        finally:
+            pass
+
+    @torch.no_grad()
+    def on_after_optimize(self, lr: float):
+        if self.use_emb_decay:
+            self.text_encoder.text_model.embeddings.normalize(
+                self.emb_decay_target,
+                min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start))))
+            )
+
+        if self.ema_embeddings is not None:
+            self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters())
+
+    def on_log(self):
+        log = super().on_log()
+        added = {}
+
+        if self.ema_embeddings is not None:
+            added = {"ema_decay": self.ema_embeddings.decay}
+
+        return log.update(added)
-- 
cgit v1.2.3-70-g09d2