summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
blob: 88d18242fd839f90bc06a27b4ac4c4addfb6ef63 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from contextlib import nullcontext
from typing import Optional
from functools import partial
from contextlib import contextmanager, nullcontext
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from accelerate import Accelerator
from transformers import CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
from diffusers.loaders import AttnProcsLayers

from slugify import slugify

from models.clip.tokenizer import MultiCLIPTokenizer
from training.util import EMAModel
from training.functional import TrainingStrategy, TrainingCallbacks, save_samples


def lora_strategy_callbacks(
    accelerator: Accelerator,
    unet: UNet2DConditionModel,
    text_encoder: CLIPTextModel,
    tokenizer: MultiCLIPTokenizer,
    vae: AutoencoderKL,
    sample_scheduler: DPMSolverMultistepScheduler,
    train_dataloader: DataLoader,
    val_dataloader: Optional[DataLoader],
    sample_output_dir: Path,
    checkpoint_output_dir: Path,
    seed: int,
    lora_layers: AttnProcsLayers,
    max_grad_norm: float = 1.0,
    sample_batch_size: int = 1,
    sample_num_batches: int = 1,
    sample_num_steps: int = 20,
    sample_guidance_scale: float = 7.5,
    sample_image_size: Optional[int] = None,
):
    sample_output_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_output_dir.mkdir(parents=True, exist_ok=True)

    weight_dtype = torch.float32
    if accelerator.state.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.state.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    save_samples_ = partial(
        save_samples,
        accelerator=accelerator,
        unet=unet,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        vae=vae,
        sample_scheduler=sample_scheduler,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        output_dir=sample_output_dir,
        seed=seed,
        batch_size=sample_batch_size,
        num_batches=sample_num_batches,
        num_steps=sample_num_steps,
        guidance_scale=sample_guidance_scale,
        image_size=sample_image_size,
    )

    def on_prepare():
        lora_layers.requires_grad_(True)

    def on_accum_model():
        return unet

    @contextmanager
    def on_train(epoch: int):
        tokenizer.train()
        yield

    @contextmanager
    def on_eval():
        tokenizer.eval()
        yield

    def on_before_optimize(lr: float, epoch: int):
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm)

    @torch.no_grad()
    def on_checkpoint(step, postfix):
        print(f"Saving checkpoint for step {step}...")
        orig_unet_dtype = unet.dtype
        unet.to(dtype=torch.float32)
        unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}"))
        unet.to(dtype=orig_unet_dtype)

    @torch.no_grad()
    def on_sample(step):
        orig_unet_dtype = unet.dtype
        unet.to(dtype=weight_dtype)
        save_samples_(step=step)
        unet.to(dtype=orig_unet_dtype)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return TrainingCallbacks(
        on_prepare=on_prepare,
        on_accum_model=on_accum_model,
        on_train=on_train,
        on_eval=on_eval,
        on_before_optimize=on_before_optimize,
        on_checkpoint=on_checkpoint,
        on_sample=on_sample,
    )


def lora_prepare(
    accelerator: Accelerator,
    text_encoder: CLIPTextModel,
    unet: UNet2DConditionModel,
    optimizer: torch.optim.Optimizer,
    train_dataloader: DataLoader,
    val_dataloader: Optional[DataLoader],
    lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
    lora_layers: AttnProcsLayers,
    **kwargs
):
    weight_dtype = torch.float32
    if accelerator.state.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.state.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
        lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler)
    unet.to(accelerator.device, dtype=weight_dtype)
    text_encoder.to(accelerator.device, dtype=weight_dtype)
    return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers}


lora_strategy = TrainingStrategy(
    callbacks=lora_strategy_callbacks,
    prepare=lora_prepare,
)