summaryrefslogtreecommitdiffstats
path: root/train_lora.py
blob: 34e1008ea982df9efb172c4b0baee7fdd9172a0d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
import argparse
import itertools
import math
import datetime
import logging
import json
from pathlib import Path

import torch
import torch.nn.functional as F
import torch.utils.checkpoint

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import LoggerType, set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
from diffusers.training_utils import EMAModel
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from slugify import slugify

from common import load_text_embeddings
from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
from data.csv import CSVDataModule
from training.lora import LoraAttnProcessor
from training.optimization import get_one_cycle_schedule
from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args
from models.clip.prompt import PromptProcessor

logger = get_logger(__name__)


torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True


def parse_args():
    parser = argparse.ArgumentParser(
        description="Simple example of a training script."
    )
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--train_data_file",
        type=str,
        default=None,
        help="A folder containing the training data."
    )
    parser.add_argument(
        "--train_data_template",
        type=str,
        default="template",
    )
    parser.add_argument(
        "--instance_identifier",
        type=str,
        default=None,
        help="A token to use as a placeholder for the concept.",
    )
    parser.add_argument(
        "--class_identifier",
        type=str,
        default=None,
        help="A token to use as a placeholder for the concept.",
    )
    parser.add_argument(
        "--placeholder_token",
        type=str,
        nargs='*',
        default=[],
        help="A token to use as a placeholder for the concept.",
    )
    parser.add_argument(
        "--initializer_token",
        type=str,
        nargs='*',
        default=[],
        help="A token to use as initializer word."
    )
    parser.add_argument(
        "--tag_dropout",
        type=float,
        default=0.1,
        help="Tag dropout probability.",
    )
    parser.add_argument(
        "--num_class_images",
        type=int,
        default=400,
        help="How many class images to generate."
    )
    parser.add_argument(
        "--repeats",
        type=int,
        default=1,
        help="How many times to repeat the training data."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output/lora",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--embeddings_dir",
        type=str,
        default=None,
        help="The embeddings directory where Textual Inversion embeddings are stored.",
    )
    parser.add_argument(
        "--mode",
        type=str,
        default=None,
        help="A mode to filter the dataset.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="A seed for reproducible training."
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=768,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--center_crop",
        action="store_true",
        help="Whether to center crop images before resizing to resolution"
    )
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=0,
        help=(
            "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
            " process."
        ),
    )
    parser.add_argument(
        "--num_train_epochs",
        type=int,
        default=100
    )
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=2e-6,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=True,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="one_cycle",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup", "one_cycle"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_epochs",
        type=int,
        default=10,
        help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--lr_cycles",
        type=int,
        default=None,
        help="Number of restart cycles in the lr scheduler (if supported)."
    )
    parser.add_argument(
        "--use_8bit_adam",
        action="store_true",
        default=True,
        help="Whether or not to use 8-bit Adam from bitsandbytes."
    )
    parser.add_argument(
        "--adam_beta1",
        type=float,
        default=0.9,
        help="The beta1 parameter for the Adam optimizer."
    )
    parser.add_argument(
        "--adam_beta2",
        type=float,
        default=0.999,
        help="The beta2 parameter for the Adam optimizer."
    )
    parser.add_argument(
        "--adam_weight_decay",
        type=float,
        default=1e-2,
        help="Weight decay to use."
    )
    parser.add_argument(
        "--adam_epsilon",
        type=float,
        default=1e-08,
        help="Epsilon value for the Adam optimizer"
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="no",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose"
            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
            "and an Nvidia Ampere GPU."
        ),
    )
    parser.add_argument(
        "--sample_frequency",
        type=int,
        default=1,
        help="How often to save a checkpoint and sample image",
    )
    parser.add_argument(
        "--sample_image_size",
        type=int,
        default=768,
        help="Size of sample images",
    )
    parser.add_argument(
        "--sample_batches",
        type=int,
        default=1,
        help="Number of sample batches to generate per checkpoint",
    )
    parser.add_argument(
        "--sample_batch_size",
        type=int,
        default=1,
        help="Number of samples to generate per batch",
    )
    parser.add_argument(
        "--valid_set_size",
        type=int,
        default=None,
        help="Number of images in the validation dataset."
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=1,
        help="Batch size (per device) for the training dataloader."
    )
    parser.add_argument(
        "--sample_steps",
        type=int,
        default=15,
        help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
    )
    parser.add_argument(
        "--prior_loss_weight",
        type=float,
        default=1.0,
        help="The weight of prior preservation loss."
    )
    parser.add_argument(
        "--max_grad_norm",
        default=1.0,
        type=float,
        help="Max gradient norm."
    )
    parser.add_argument(
        "--noise_timesteps",
        type=int,
        default=1000,
    )
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="Path to a JSON configuration file containing arguments for invoking this script."
    )

    args = parser.parse_args()
    if args.config is not None:
        with open(args.config, 'rt') as f:
            args = parser.parse_args(
                namespace=argparse.Namespace(**json.load(f)["args"]))

    if args.train_data_file is None:
        raise ValueError("You must specify --train_data_file")

    if args.pretrained_model_name_or_path is None:
        raise ValueError("You must specify --pretrained_model_name_or_path")

    if args.instance_identifier is None:
        raise ValueError("You must specify --instance_identifier")

    if isinstance(args.initializer_token, str):
        args.initializer_token = [args.initializer_token]

    if isinstance(args.placeholder_token, str):
        args.placeholder_token = [args.placeholder_token]

    if len(args.placeholder_token) == 0:
        args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))]

    if len(args.placeholder_token) != len(args.initializer_token):
        raise ValueError("Number of items in --placeholder_token and --initializer_token must match")

    if args.output_dir is None:
        raise ValueError("You must specify --output_dir")

    return args


class Checkpointer(CheckpointerBase):
    def __init__(
        self,
        datamodule,
        accelerator,
        vae,
        unet,
        tokenizer,
        text_encoder,
        unet_lora,
        scheduler,
        instance_identifier,
        placeholder_token,
        placeholder_token_id,
        output_dir: Path,
        sample_image_size,
        sample_batches,
        sample_batch_size,
        seed
    ):
        super().__init__(
            datamodule=datamodule,
            output_dir=output_dir,
            instance_identifier=instance_identifier,
            placeholder_token=placeholder_token,
            placeholder_token_id=placeholder_token_id,
            sample_image_size=sample_image_size,
            seed=seed or torch.random.seed(),
            sample_batches=sample_batches,
            sample_batch_size=sample_batch_size
        )

        self.accelerator = accelerator
        self.vae = vae
        self.unet = unet
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.unet_lora = unet_lora
        self.scheduler = scheduler

    @torch.no_grad()
    def save_model(self):
        print("Saving model...")

        unet_lora = self.accelerator.unwrap_model(self.unet_lora)
        unet_lora.save_pretrained(self.output_dir.joinpath("model"))

        del unet_lora

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    @torch.no_grad()
    def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
        # Save a sample image
        pipeline = VlpnStableDiffusion(
            text_encoder=self.text_encoder,
            vae=self.vae,
            unet=self.unet,
            tokenizer=self.tokenizer,
            scheduler=self.scheduler,
        ).to(self.accelerator.device)
        pipeline.set_progress_bar_config(dynamic_ncols=True)

        super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)

        del pipeline
        del generator
        del stable_latents

        if torch.cuda.is_available():
            torch.cuda.empty_cache()


def main():
    args = parse_args()

    instance_identifier = args.instance_identifier

    if len(args.placeholder_token) != 0:
        instance_identifier = instance_identifier.format(args.placeholder_token[0])

    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
    basepath.mkdir(parents=True, exist_ok=True)

    accelerator = Accelerator(
        log_with=LoggerType.TENSORBOARD,
        logging_dir=f"{basepath}",
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision
    )

    logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)

    args.seed = args.seed or (torch.random.seed() >> 32)
    set_seed(args.seed)

    save_args(basepath, args)

    # Load the tokenizer and add the placeholder token as a additional special token
    if args.tokenizer_name:
        tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
    elif args.pretrained_model_name_or_path:
        tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')

    # Load models and create wrapper for stable diffusion
    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
    checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
        args.pretrained_model_name_or_path, subfolder='scheduler')

    unet_lora = LoraAttnProcessor(
        cross_attention_dim=unet.cross_attention_dim,
        inner_dim=unet.in_channels,
        r=4,
    )

    vae.enable_slicing()
    vae.set_use_memory_efficient_attention_xformers(True)
    unet.set_use_memory_efficient_attention_xformers(True)
    unet.set_attn_processor(unet_lora)

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        text_encoder.gradient_checkpointing_enable()

    # Freeze text_encoder and vae
    vae.requires_grad_(False)
    unet.requires_grad_(False)

    if args.embeddings_dir is not None:
        embeddings_dir = Path(args.embeddings_dir)
        if not embeddings_dir.exists() or not embeddings_dir.is_dir():
            raise ValueError("--embeddings_dir must point to an existing directory")
        added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
        print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")

    if len(args.placeholder_token) != 0:
        # Convert the initializer_token, placeholder_token to ids
        initializer_token_ids = torch.stack([
            torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
            for token in args.initializer_token
        ])

        num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
        print(f"Added {num_added_tokens} new tokens.")

        placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)

        # Resize the token embeddings as we are adding new special tokens to the tokenizer
        text_encoder.resize_token_embeddings(len(tokenizer))

        token_embeds = text_encoder.get_input_embeddings().weight.data
        original_token_embeds = token_embeds.clone().to(accelerator.device)
        initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)

        for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
            token_embeds[token_id] = embeddings
    else:
        placeholder_token_id = []

    print(f"Training added text embeddings")

    freeze_params(itertools.chain(
        text_encoder.text_model.encoder.parameters(),
        text_encoder.text_model.final_layer_norm.parameters(),
        text_encoder.text_model.embeddings.position_embedding.parameters(),
    ))

    index_fixed_tokens = torch.arange(len(tokenizer))
    index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]

    prompt_processor = PromptProcessor(tokenizer, text_encoder)

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate * args.gradient_accumulation_steps *
            args.train_batch_size * accelerator.num_processes
        )

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")

        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    # Initialize the optimizer
    optimizer = optimizer_class(
        [
            {
                'params': unet_lora.parameters(),
                'lr': args.learning_rate,
            },
        ],
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    weight_dtype = torch.float32
    if args.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif args.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    def collate_fn(examples):
        prompts = [example["prompts"] for example in examples]
        nprompts = [example["nprompts"] for example in examples]
        input_ids = [example["instance_prompt_ids"] for example in examples]
        pixel_values = [example["instance_images"] for example in examples]

        # concat class and instance examples for prior preservation
        if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
            input_ids += [example["class_prompt_ids"] for example in examples]
            pixel_values += [example["class_images"] for example in examples]

        pixel_values = torch.stack(pixel_values)
        pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)

        inputs = prompt_processor.unify_input_ids(input_ids)

        batch = {
            "prompts": prompts,
            "nprompts": nprompts,
            "input_ids": inputs.input_ids,
            "pixel_values": pixel_values,
            "attention_mask": inputs.attention_mask,
        }
        return batch

    datamodule = CSVDataModule(
        data_file=args.train_data_file,
        batch_size=args.train_batch_size,
        prompt_processor=prompt_processor,
        instance_identifier=instance_identifier,
        class_identifier=args.class_identifier,
        class_subdir="cls",
        num_class_images=args.num_class_images,
        size=args.resolution,
        repeats=args.repeats,
        mode=args.mode,
        dropout=args.tag_dropout,
        center_crop=args.center_crop,
        template_key=args.train_data_template,
        valid_set_size=args.valid_set_size,
        num_workers=args.dataloader_num_workers,
        collate_fn=collate_fn
    )

    datamodule.prepare_data()
    datamodule.setup()

    if args.num_class_images != 0:
        missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]

        if len(missing_data) != 0:
            batched_data = [
                missing_data[i:i+args.sample_batch_size]
                for i in range(0, len(missing_data), args.sample_batch_size)
            ]

            pipeline = VlpnStableDiffusion(
                text_encoder=text_encoder,
                vae=vae,
                unet=unet,
                tokenizer=tokenizer,
                scheduler=checkpoint_scheduler,
            ).to(accelerator.device)
            pipeline.set_progress_bar_config(dynamic_ncols=True)

            with torch.autocast("cuda"), torch.inference_mode():
                for batch in batched_data:
                    image_name = [item.class_image_path for item in batch]
                    prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
                    nprompt = [item.nprompt for item in batch]

                    images = pipeline(
                        prompt=prompt,
                        negative_prompt=nprompt,
                        num_inference_steps=args.sample_steps
                    ).images

                    for i, image in enumerate(images):
                        image.save(image_name[i])

            del pipeline

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    train_dataloader = datamodule.train_dataloader()
    val_dataloader = datamodule.val_dataloader()

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True
    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps

    if args.lr_scheduler == "one_cycle":
        lr_scheduler = get_one_cycle_schedule(
            optimizer=optimizer,
            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
        )
    elif args.lr_scheduler == "cosine_with_restarts":
        lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
            num_cycles=args.lr_cycles or math.ceil(math.sqrt(
                ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
        )
    else:
        lr_scheduler = get_scheduler(
            args.lr_scheduler,
            optimizer=optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
        )

    unet_lora, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
        unet_lora, optimizer, train_dataloader, val_dataloader, lr_scheduler
    )

    # Move text_encoder and vae to device
    vae.to(accelerator.device, dtype=weight_dtype)
    unet.to(accelerator.device, dtype=weight_dtype)
    text_encoder.to(accelerator.device, dtype=weight_dtype)

    # Keep text_encoder and vae in eval mode as we don't train these
    vae.eval()
    unet.eval()
    text_encoder.eval()

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

    num_val_steps_per_epoch = len(val_dataloader)
    num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
    val_steps = num_val_steps_per_epoch * num_epochs

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        config = vars(args).copy()
        config["initializer_token"] = " ".join(config["initializer_token"])
        config["placeholder_token"] = " ".join(config["placeholder_token"])
        accelerator.init_trackers("lora", config=config)

    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.

    global_step = 0

    avg_loss = AverageMeter()
    avg_acc = AverageMeter()

    avg_loss_val = AverageMeter()
    avg_acc_val = AverageMeter()

    max_acc_val = 0.0

    checkpointer = Checkpointer(
        datamodule=datamodule,
        accelerator=accelerator,
        vae=vae,
        unet=unet,
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        scheduler=checkpoint_scheduler,
        unet_lora=unet_lora,
        output_dir=basepath,
        instance_identifier=instance_identifier,
        placeholder_token=args.placeholder_token,
        placeholder_token_id=placeholder_token_id,
        sample_image_size=args.sample_image_size,
        sample_batch_size=args.sample_batch_size,
        sample_batches=args.sample_batches,
        seed=args.seed
    )

    if accelerator.is_main_process:
        checkpointer.save_samples(0, args.sample_steps)

    local_progress_bar = tqdm(
        range(num_update_steps_per_epoch + num_val_steps_per_epoch),
        disable=not accelerator.is_local_main_process,
        dynamic_ncols=True
    )
    local_progress_bar.set_description("Epoch X / Y")

    global_progress_bar = tqdm(
        range(args.max_train_steps + val_steps),
        disable=not accelerator.is_local_main_process,
        dynamic_ncols=True
    )
    global_progress_bar.set_description("Total progress")

    try:
        for epoch in range(num_epochs):
            local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
            local_progress_bar.reset()

            unet_lora.train()

            for step, batch in enumerate(train_dataloader):
                with accelerator.accumulate(unet_lora):
                    # Convert images to latent space
                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
                    latents = latents * 0.18215

                    # Sample noise that we'll add to the latents
                    noise = torch.randn_like(latents)
                    bsz = latents.shape[0]
                    # Sample a random timestep for each image
                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
                                              (bsz,), device=latents.device)
                    timesteps = timesteps.long()

                    # Add noise to the latents according to the noise magnitude at each timestep
                    # (this is the forward diffusion process)
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    # Get the text embedding for conditioning
                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])

                    # Predict the noise residual
                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

                    # Get the target for loss depending on the prediction type
                    if noise_scheduler.config.prediction_type == "epsilon":
                        target = noise
                    elif noise_scheduler.config.prediction_type == "v_prediction":
                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
                    else:
                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                    if args.num_class_images != 0:
                        # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
                        model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                        target, target_prior = torch.chunk(target, 2, dim=0)

                        # Compute instance loss
                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()

                        # Compute prior loss
                        prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

                        # Add the prior loss to the instance loss.
                        loss = loss + args.prior_loss_weight * prior_loss
                    else:
                        loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                    acc = (model_pred == latents).float().mean()

                    accelerator.backward(loss)

                    if accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(unet_lora.parameters(), args.max_grad_norm)

                    optimizer.step()
                    if not accelerator.optimizer_step_was_skipped:
                        lr_scheduler.step()
                    optimizer.zero_grad(set_to_none=True)

                    with torch.no_grad():
                        text_encoder.get_input_embeddings(
                        ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]

                    avg_loss.update(loss.detach_(), bsz)
                    avg_acc.update(acc.detach_(), bsz)

                # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    local_progress_bar.update(1)
                    global_progress_bar.update(1)

                    global_step += 1

                logs = {
                    "train/loss": avg_loss.avg.item(),
                    "train/acc": avg_acc.avg.item(),
                    "train/cur_loss": loss.item(),
                    "train/cur_acc": acc.item(),
                    "lr/unet": lr_scheduler.get_last_lr()[0],
                    "lr/text": lr_scheduler.get_last_lr()[1]
                }

                accelerator.log(logs, step=global_step)

                local_progress_bar.set_postfix(**logs)

                if global_step >= args.max_train_steps:
                    break

            accelerator.wait_for_everyone()

            unet_lora.eval()

            with torch.inference_mode():
                for step, batch in enumerate(val_dataloader):
                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
                    latents = latents * 0.18215

                    noise = torch.randn_like(latents)
                    bsz = latents.shape[0]
                    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
                                              (bsz,), device=latents.device)
                    timesteps = timesteps.long()

                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])

                    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

                    # Get the target for loss depending on the prediction type
                    if noise_scheduler.config.prediction_type == "epsilon":
                        target = noise
                    elif noise_scheduler.config.prediction_type == "v_prediction":
                        target = noise_scheduler.get_velocity(latents, noise, timesteps)
                    else:
                        raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                    acc = (model_pred == latents).float().mean()

                    avg_loss_val.update(loss.detach_(), bsz)
                    avg_acc_val.update(acc.detach_(), bsz)

                    if accelerator.sync_gradients:
                        local_progress_bar.update(1)
                        global_progress_bar.update(1)

                    logs = {
                        "val/loss": avg_loss_val.avg.item(),
                        "val/acc": avg_acc_val.avg.item(),
                        "val/cur_loss": loss.item(),
                        "val/cur_acc": acc.item(),
                    }
                    local_progress_bar.set_postfix(**logs)

            accelerator.log({
                "val/loss": avg_loss_val.avg.item(),
                "val/acc": avg_acc_val.avg.item(),
            }, step=global_step)

            local_progress_bar.clear()
            global_progress_bar.clear()

            if avg_acc_val.avg.item() > max_acc_val:
                accelerator.print(
                    f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
                max_acc_val = avg_acc_val.avg.item()

            if accelerator.is_main_process:
                if (epoch + 1) % args.sample_frequency == 0:
                    checkpointer.save_samples(global_step, args.sample_steps)

        # Create the pipeline using using the trained modules and save it.
        if accelerator.is_main_process:
            print("Finished! Saving final checkpoint and resume state.")
            checkpointer.save_model()

            accelerator.end_training()

    except KeyboardInterrupt:
        if accelerator.is_main_process:
            print("Interrupted, saving checkpoint and resume state...")
            checkpointer.save_model()
            accelerator.end_training()
        quit()


if __name__ == "__main__":
    main()