summaryrefslogtreecommitdiffstats
path: root/training/common.py
blob: 73ce814ae464ff417f9d5153f7239596239782c7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
import math
from pathlib import Path
from contextlib import _GeneratorContextManager, nullcontext
from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple
import datetime
import logging

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from accelerate import Accelerator
from accelerate.utils import LoggerType, set_seed
from transformers import CLIPTextModel
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup

from tqdm.auto import tqdm
from slugify import slugify

from data.csv import VlpnDataModule, VlpnDataItem
from util import load_embeddings_from_dir
from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
from models.clip.embeddings import patch_managed_embeddings
from models.clip.util import get_extended_embeddings
from models.clip.tokenizer import MultiCLIPTokenizer
from training.optimization import get_one_cycle_schedule
from training.util import AverageMeter, CheckpointerBase


class TrainingSetup(NamedTuple):
    accelerator: Accelerator
    tokenizer: MultiCLIPTokenizer
    text_encoder: CLIPTextModel
    vae: AutoencoderKL
    unet: UNet2DConditionModel
    noise_scheduler: DDPMScheduler
    checkpoint_scheduler: DPMSolverMultistepScheduler
    optimizer_class: Callable
    learning_rate: float
    weight_dtype: torch.dtype
    output_dir: Path
    seed: int
    train_dataloader: DataLoader
    val_dataloader: DataLoader
    placeholder_token: list[str]
    placeholder_token_ids: list[list[int]]


def noop(*args, **kwards):
    pass


def noop_ctx(*args, **kwards):
    return nullcontext()


def noop_on_log():
    return {}


def get_scheduler(
    id: str,
    optimizer: torch.optim.Optimizer,
    num_training_steps_per_epoch: int,
    gradient_accumulation_steps: int,
    min_lr: float = 0.04,
    warmup_func: str = "cos",
    annealing_func: str = "cos",
    warmup_exp: int = 1,
    annealing_exp: int = 1,
    cycles: int = 1,
    train_epochs: int = 100,
    warmup_epochs: int = 10,
):
    num_training_steps_per_epoch = math.ceil(
        num_training_steps_per_epoch / gradient_accumulation_steps
    ) * gradient_accumulation_steps
    num_training_steps = train_epochs * num_training_steps_per_epoch
    num_warmup_steps = warmup_epochs * num_training_steps_per_epoch

    if id == "one_cycle":
        lr_scheduler = get_one_cycle_schedule(
            optimizer=optimizer,
            num_training_steps=num_training_steps,
            warmup=warmup_func,
            annealing=annealing_func,
            warmup_exp=warmup_exp,
            annealing_exp=annealing_exp,
            min_lr=min_lr,
        )
    elif id == "cosine_with_restarts":
        if cycles is None:
            cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))

        lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
            num_cycles=cycles,
        )
    else:
        lr_scheduler = get_scheduler_(
            id,
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps,
        )

    return lr_scheduler


def generate_class_images(
    accelerator,
    text_encoder,
    vae,
    unet,
    tokenizer,
    scheduler,
    data_train,
    sample_batch_size,
    sample_image_size,
    sample_steps
):
    missing_data = [item for item in data_train if not item.class_image_path.exists()]

    if len(missing_data) != 0:
        batched_data = [
            missing_data[i:i+sample_batch_size]
            for i in range(0, len(missing_data), sample_batch_size)
        ]

        pipeline = VlpnStableDiffusion(
            text_encoder=text_encoder,
            vae=vae,
            unet=unet,
            tokenizer=tokenizer,
            scheduler=scheduler,
        ).to(accelerator.device)
        pipeline.set_progress_bar_config(dynamic_ncols=True)

        with torch.inference_mode():
            for batch in batched_data:
                image_name = [item.class_image_path for item in batch]
                prompt = [item.cprompt for item in batch]
                nprompt = [item.nprompt for item in batch]

                images = pipeline(
                    prompt=prompt,
                    negative_prompt=nprompt,
                    height=sample_image_size,
                    width=sample_image_size,
                    num_inference_steps=sample_steps
                ).images

                for i, image in enumerate(images):
                    image.save(image_name[i])

        del pipeline

        if torch.cuda.is_available():
            torch.cuda.empty_cache()


def train_setup(
    output_dir: str,
    project: str,
    pretrained_model_name_or_path: str,
    learning_rate: float,
    data_file: str,
    gradient_accumulation_steps: int = 1,
    mixed_precision: Literal["no", "fp16", "bf16"] = "no",
    seed: Optional[int] = None,
    vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto",
    vector_dropout: float = 0.1,
    gradient_checkpointing: bool = True,
    embeddings_dir: Optional[str] = None,
    placeholder_token: list[str] = [],
    initializer_token: list[str] = [],
    num_vectors: int = 1,
    scale_lr: bool = False,
    use_8bit_adam: bool = False,
    train_batch_size: int = 1,
    class_image_dir: Optional[str] = None,
    num_class_images: int = 0,
    resolution: int = 768,
    num_buckets: int = 0,
    progressive_buckets: bool = False,
    bucket_step_size: int = 64,
    bucket_max_pixels: Optional[int] = None,
    tag_dropout: float = 0.1,
    tag_shuffle: bool = True,
    data_template: str = "template",
    valid_set_size: Optional[int] = None,
    valid_set_repeat: int = 1,
    data_filter: Optional[Callable[[VlpnDataItem], bool]] = None,
    sample_batch_size: int = 1,
    sample_image_size: int = 768,
    sample_steps: int = 20,
) -> TrainingSetup:
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    output_dir = Path(output_dir).joinpath(slugify(project), now)
    output_dir.mkdir(parents=True, exist_ok=True)

    accelerator = Accelerator(
        log_with=LoggerType.TENSORBOARD,
        logging_dir=f"{output_dir}",
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=mixed_precision
    )

    logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)

    seed = seed or (torch.random.seed() >> 32)
    set_seed(seed)

    # Load the tokenizer and add the placeholder token as a additional special token
    tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
    tokenizer.set_use_vector_shuffle(vector_shuffle)
    tokenizer.set_dropout(vector_dropout)

    # Load models and create wrapper for stable diffusion
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
    vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
    unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
    noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
    checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
        pretrained_model_name_or_path, subfolder='scheduler')

    vae.enable_slicing()
    vae.set_use_memory_efficient_attention_xformers(True)
    unet.set_use_memory_efficient_attention_xformers(True)

    if gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        text_encoder.gradient_checkpointing_enable()

    embeddings = patch_managed_embeddings(text_encoder)

    if embeddings_dir is not None:
        embeddings_dir = Path(embeddings_dir)
        if not embeddings_dir.exists() or not embeddings_dir.is_dir():
            raise ValueError("--embeddings_dir must point to an existing directory")

        added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
        print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")

    # Convert the initializer_token, placeholder_token to ids
    initializer_token_ids = [
        tokenizer.encode(token, add_special_tokens=False)
        for token in initializer_token
    ]

    placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors)
    embeddings.resize(len(tokenizer))

    for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids):
        embeddings.add_embed(new_id, init_ids)

    init_ratios = [
        f"{len(init_ids)} / {len(new_id)}"
        for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids)
    ]

    print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}")

    vae.requires_grad_(False)
    unet.requires_grad_(False)
    text_encoder.requires_grad_(False)

    if scale_lr:
        learning_rate = (
            learning_rate * gradient_accumulation_steps *
            train_batch_size * accelerator.num_processes
        )

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")

        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    weight_dtype = torch.float32
    if mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    datamodule = VlpnDataModule(
        data_file=data_file,
        batch_size=train_batch_size,
        tokenizer=tokenizer,
        class_subdir=class_image_dir,
        num_class_images=num_class_images,
        size=resolution,
        num_buckets=num_buckets,
        progressive_buckets=progressive_buckets,
        bucket_step_size=bucket_step_size,
        bucket_max_pixels=bucket_max_pixels,
        dropout=tag_dropout,
        shuffle=tag_shuffle,
        template_key=data_template,
        valid_set_size=valid_set_size,
        valid_set_repeat=valid_set_repeat,
        seed=seed,
        filter=data_filter,
        dtype=weight_dtype
    )
    datamodule.setup()

    train_dataloader = datamodule.train_dataloader
    val_dataloader = datamodule.val_dataloader

    train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader)

    if num_class_images != 0:
        generate_class_images(
            accelerator,
            text_encoder,
            vae,
            unet,
            tokenizer,
            checkpoint_scheduler,
            datamodule.data_train,
            sample_batch_size,
            sample_image_size,
            sample_steps
        )

    return TrainingSetup(
        accelerator=accelerator,
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        vae=vae,
        unet=unet,
        noise_scheduler=noise_scheduler,
        checkpoint_scheduler=checkpoint_scheduler,
        optimizer_class=optimizer_class,
        learning_rate=learning_rate,
        output_dir=output_dir,
        weight_dtype=weight_dtype,
        seed=seed,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        placeholder_token=placeholder_token,
        placeholder_token_ids=placeholder_token_ids
    )


def loss_step(
    vae: AutoencoderKL,
    noise_scheduler: DDPMScheduler,
    unet: UNet2DConditionModel,
    text_encoder: CLIPTextModel,
    with_prior: bool,
    prior_loss_weight: float,
    seed: int,
    step: int,
    batch: dict[str, Any],
    eval: bool = False
):
    # Convert images to latent space
    latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
    latents = latents * 0.18215

    # Sample noise that we'll add to the latents
    noise = torch.randn_like(latents)
    bsz = latents.shape[0]
    # Sample a random timestep for each image
    timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
    timesteps = torch.randint(
        0,
        noise_scheduler.config.num_train_timesteps,
        (bsz,),
        generator=timesteps_gen,
        device=latents.device,
    )
    timesteps = timesteps.long()

    # Add noise to the latents according to the noise magnitude at each timestep
    # (this is the forward diffusion process)
    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    noisy_latents = noisy_latents.to(dtype=unet.dtype)

    # Get the text embedding for conditioning
    encoder_hidden_states = get_extended_embeddings(
        text_encoder,
        batch["input_ids"],
        batch["attention_mask"]
    )
    encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)

    # Predict the noise residual
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

    # Get the target for loss depending on the prediction type
    if noise_scheduler.config.prediction_type == "epsilon":
        target = noise
    elif noise_scheduler.config.prediction_type == "v_prediction":
        target = noise_scheduler.get_velocity(latents, noise, timesteps)
    else:
        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

    if with_prior:
        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
        target, target_prior = torch.chunk(target, 2, dim=0)

        # Compute instance loss
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

        # Compute prior loss
        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

        # Add the prior loss to the instance loss.
        loss = loss + prior_loss_weight * prior_loss
    else:
        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

    acc = (model_pred == target).float().mean()

    return loss, acc, bsz


def train_loop(
    accelerator: Accelerator,
    optimizer: torch.optim.Optimizer,
    lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
    model: torch.nn.Module,
    checkpointer: CheckpointerBase,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
    sample_frequency: int = 10,
    sample_steps: int = 20,
    checkpoint_frequency: int = 50,
    global_step_offset: int = 0,
    num_epochs: int = 100,
    on_log: Callable[[], dict[str, Any]] = noop_on_log,
    on_train: Callable[[int], _GeneratorContextManager] = noop_ctx,
    on_before_optimize: Callable[[int], None] = noop,
    on_after_optimize: Callable[[float], None] = noop,
    on_eval: Callable[[], _GeneratorContextManager] = noop_ctx
):
    num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
    num_val_steps_per_epoch = len(val_dataloader)

    num_training_steps = num_training_steps_per_epoch * num_epochs
    num_val_steps = num_val_steps_per_epoch * num_epochs

    global_step = 0

    avg_loss = AverageMeter()
    avg_acc = AverageMeter()

    avg_loss_val = AverageMeter()
    avg_acc_val = AverageMeter()

    max_acc_val = 0.0

    local_progress_bar = tqdm(
        range(num_training_steps_per_epoch + num_val_steps_per_epoch),
        disable=not accelerator.is_local_main_process,
        dynamic_ncols=True
    )
    local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")

    global_progress_bar = tqdm(
        range(num_training_steps + num_val_steps),
        disable=not accelerator.is_local_main_process,
        dynamic_ncols=True
    )
    global_progress_bar.set_description("Total progress")

    try:
        for epoch in range(num_epochs):
            if accelerator.is_main_process:
                if epoch % sample_frequency == 0:
                    checkpointer.save_samples(global_step + global_step_offset, sample_steps)

                if epoch % checkpoint_frequency == 0 and epoch != 0:
                    checkpointer.checkpoint(global_step + global_step_offset, "training")

            local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
            local_progress_bar.reset()

            model.train()

            with on_train(epoch):
                for step, batch in enumerate(train_dataloader):
                    with accelerator.accumulate(model):
                        loss, acc, bsz = loss_step(step, batch)

                        accelerator.backward(loss)

                        on_before_optimize(epoch)

                        optimizer.step()
                        lr_scheduler.step()
                        optimizer.zero_grad(set_to_none=True)

                        avg_loss.update(loss.detach_(), bsz)
                        avg_acc.update(acc.detach_(), bsz)

                    # Checks if the accelerator has performed an optimization step behind the scenes
                    if accelerator.sync_gradients:
                        on_after_optimize(lr_scheduler.get_last_lr()[0])

                        local_progress_bar.update(1)
                        global_progress_bar.update(1)

                        global_step += 1

                    logs = {
                        "train/loss": avg_loss.avg.item(),
                        "train/acc": avg_acc.avg.item(),
                        "train/cur_loss": loss.item(),
                        "train/cur_acc": acc.item(),
                        "lr": lr_scheduler.get_last_lr()[0],
                    }
                    logs.update(on_log())

                    accelerator.log(logs, step=global_step)

                    local_progress_bar.set_postfix(**logs)

                    if global_step >= num_training_steps:
                        break

            accelerator.wait_for_everyone()

            model.eval()

            cur_loss_val = AverageMeter()
            cur_acc_val = AverageMeter()

            with torch.inference_mode():
                with on_eval():
                    for step, batch in enumerate(val_dataloader):
                        loss, acc, bsz = loss_step(step, batch, True)

                        loss = loss.detach_()
                        acc = acc.detach_()

                        cur_loss_val.update(loss, bsz)
                        cur_acc_val.update(acc, bsz)

                        avg_loss_val.update(loss, bsz)
                        avg_acc_val.update(acc, bsz)

                        local_progress_bar.update(1)
                        global_progress_bar.update(1)

                        logs = {
                            "val/loss": avg_loss_val.avg.item(),
                            "val/acc": avg_acc_val.avg.item(),
                            "val/cur_loss": loss.item(),
                            "val/cur_acc": acc.item(),
                        }
                        local_progress_bar.set_postfix(**logs)

            logs["val/cur_loss"] = cur_loss_val.avg.item()
            logs["val/cur_acc"] = cur_acc_val.avg.item()

            accelerator.log(logs, step=global_step)

            local_progress_bar.clear()
            global_progress_bar.clear()

            if accelerator.is_main_process:
                if avg_acc_val.avg.item() > max_acc_val:
                    accelerator.print(
                        f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
                    checkpointer.checkpoint(global_step + global_step_offset, "milestone")
                    max_acc_val = avg_acc_val.avg.item()

        # Create the pipeline using using the trained modules and save it.
        if accelerator.is_main_process:
            print("Finished!")
            checkpointer.checkpoint(global_step + global_step_offset, "end")
            checkpointer.save_samples(global_step + global_step_offset, sample_steps)
            accelerator.end_training()

    except KeyboardInterrupt:
        if accelerator.is_main_process:
            print("Interrupted")
            checkpointer.checkpoint(global_step + global_step_offset, "end")
            accelerator.end_training()
        quit()