summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pipelines/util.py9
-rw-r--r--train_dreambooth.py7
-rw-r--r--train_lora.py1040
-rw-r--r--train_ti.py5
-rw-r--r--training/lora.py27
5 files changed, 1073 insertions, 15 deletions
diff --git a/pipelines/util.py b/pipelines/util.py
deleted file mode 100644
index 661dbee..0000000
--- a/pipelines/util.py
+++ /dev/null
@@ -1,9 +0,0 @@
1import torch
2
3
4def set_use_memory_efficient_attention_xformers(module: torch.nn.Module, valid: bool) -> None:
5 if hasattr(module, "set_use_memory_efficient_attention_xformers"):
6 module.set_use_memory_efficient_attention_xformers(valid)
7
8 for child in module.children():
9 set_use_memory_efficient_attention_xformers(child, valid)
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 3eecf9c..0f8fece 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -23,7 +23,6 @@ from slugify import slugify
23 23
24from common import load_text_embeddings 24from common import load_text_embeddings
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from pipelines.util import set_use_memory_efficient_attention_xformers
27from data.csv import CSVDataModule 26from data.csv import CSVDataModule
28from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
29from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
@@ -614,8 +613,8 @@ def main():
614 args.pretrained_model_name_or_path, subfolder='scheduler') 613 args.pretrained_model_name_or_path, subfolder='scheduler')
615 614
616 vae.enable_slicing() 615 vae.enable_slicing()
617 set_use_memory_efficient_attention_xformers(unet, True) 616 vae.set_use_memory_efficient_attention_xformers(True)
618 set_use_memory_efficient_attention_xformers(vae, True) 617 unet.set_use_memory_efficient_attention_xformers(True)
619 618
620 if args.gradient_checkpointing: 619 if args.gradient_checkpointing:
621 unet.enable_gradient_checkpointing() 620 unet.enable_gradient_checkpointing()
@@ -964,6 +963,7 @@ def main():
964 # Add noise to the latents according to the noise magnitude at each timestep 963 # Add noise to the latents according to the noise magnitude at each timestep
965 # (this is the forward diffusion process) 964 # (this is the forward diffusion process)
966 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 965 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
966 noisy_latents = noisy_latents.to(dtype=unet.dtype)
967 967
968 # Get the text embedding for conditioning 968 # Get the text embedding for conditioning
969 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) 969 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
@@ -1065,6 +1065,7 @@ def main():
1065 timesteps = timesteps.long() 1065 timesteps = timesteps.long()
1066 1066
1067 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 1067 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1068 noisy_latents = noisy_latents.to(dtype=unet.dtype)
1068 1069
1069 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"]) 1070 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
1070 1071
diff --git a/train_lora.py b/train_lora.py
new file mode 100644
index 0000000..ffc1d10
--- /dev/null
+++ b/train_lora.py
@@ -0,0 +1,1040 @@
1import argparse
2import itertools
3import math
4import datetime
5import logging
6import json
7from pathlib import Path
8
9import torch
10import torch.nn.functional as F
11import torch.utils.checkpoint
12
13from accelerate import Accelerator
14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18from diffusers.training_utils import EMAModel
19from PIL import Image
20from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer
22from slugify import slugify
23
24from common import load_text_embeddings
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import CSVDataModule
27from training.lora import LoraAttention
28from training.optimization import get_one_cycle_schedule
29from models.clip.prompt import PromptProcessor
30
31logger = get_logger(__name__)
32
33
34torch.backends.cuda.matmul.allow_tf32 = True
35torch.backends.cudnn.benchmark = True
36
37
38def parse_args():
39 parser = argparse.ArgumentParser(
40 description="Simple example of a training script."
41 )
42 parser.add_argument(
43 "--pretrained_model_name_or_path",
44 type=str,
45 default=None,
46 help="Path to pretrained model or model identifier from huggingface.co/models.",
47 )
48 parser.add_argument(
49 "--tokenizer_name",
50 type=str,
51 default=None,
52 help="Pretrained tokenizer name or path if not the same as model_name",
53 )
54 parser.add_argument(
55 "--train_data_file",
56 type=str,
57 default=None,
58 help="A folder containing the training data."
59 )
60 parser.add_argument(
61 "--train_data_template",
62 type=str,
63 default="template",
64 )
65 parser.add_argument(
66 "--instance_identifier",
67 type=str,
68 default=None,
69 help="A token to use as a placeholder for the concept.",
70 )
71 parser.add_argument(
72 "--class_identifier",
73 type=str,
74 default=None,
75 help="A token to use as a placeholder for the concept.",
76 )
77 parser.add_argument(
78 "--placeholder_token",
79 type=str,
80 nargs='*',
81 default=[],
82 help="A token to use as a placeholder for the concept.",
83 )
84 parser.add_argument(
85 "--initializer_token",
86 type=str,
87 nargs='*',
88 default=[],
89 help="A token to use as initializer word."
90 )
91 parser.add_argument(
92 "--tag_dropout",
93 type=float,
94 default=0.1,
95 help="Tag dropout probability.",
96 )
97 parser.add_argument(
98 "--num_class_images",
99 type=int,
100 default=400,
101 help="How many class images to generate."
102 )
103 parser.add_argument(
104 "--repeats",
105 type=int,
106 default=1,
107 help="How many times to repeat the training data."
108 )
109 parser.add_argument(
110 "--output_dir",
111 type=str,
112 default="output/dreambooth",
113 help="The output directory where the model predictions and checkpoints will be written.",
114 )
115 parser.add_argument(
116 "--embeddings_dir",
117 type=str,
118 default=None,
119 help="The embeddings directory where Textual Inversion embeddings are stored.",
120 )
121 parser.add_argument(
122 "--mode",
123 type=str,
124 default=None,
125 help="A mode to filter the dataset.",
126 )
127 parser.add_argument(
128 "--seed",
129 type=int,
130 default=None,
131 help="A seed for reproducible training."
132 )
133 parser.add_argument(
134 "--resolution",
135 type=int,
136 default=768,
137 help=(
138 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
139 " resolution"
140 ),
141 )
142 parser.add_argument(
143 "--center_crop",
144 action="store_true",
145 help="Whether to center crop images before resizing to resolution"
146 )
147 parser.add_argument(
148 "--dataloader_num_workers",
149 type=int,
150 default=0,
151 help=(
152 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
153 " process."
154 ),
155 )
156 parser.add_argument(
157 "--num_train_epochs",
158 type=int,
159 default=100
160 )
161 parser.add_argument(
162 "--max_train_steps",
163 type=int,
164 default=None,
165 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
166 )
167 parser.add_argument(
168 "--gradient_accumulation_steps",
169 type=int,
170 default=1,
171 help="Number of updates steps to accumulate before performing a backward/update pass.",
172 )
173 parser.add_argument(
174 "--gradient_checkpointing",
175 action="store_true",
176 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
177 )
178 parser.add_argument(
179 "--learning_rate_unet",
180 type=float,
181 default=2e-6,
182 help="Initial learning rate (after the potential warmup period) to use.",
183 )
184 parser.add_argument(
185 "--scale_lr",
186 action="store_true",
187 default=True,
188 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
189 )
190 parser.add_argument(
191 "--lr_scheduler",
192 type=str,
193 default="one_cycle",
194 help=(
195 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
196 ' "constant", "constant_with_warmup", "one_cycle"]'
197 ),
198 )
199 parser.add_argument(
200 "--lr_warmup_epochs",
201 type=int,
202 default=10,
203 help="Number of steps for the warmup in the lr scheduler."
204 )
205 parser.add_argument(
206 "--lr_cycles",
207 type=int,
208 default=None,
209 help="Number of restart cycles in the lr scheduler (if supported)."
210 )
211 parser.add_argument(
212 "--use_8bit_adam",
213 action="store_true",
214 default=True,
215 help="Whether or not to use 8-bit Adam from bitsandbytes."
216 )
217 parser.add_argument(
218 "--adam_beta1",
219 type=float,
220 default=0.9,
221 help="The beta1 parameter for the Adam optimizer."
222 )
223 parser.add_argument(
224 "--adam_beta2",
225 type=float,
226 default=0.999,
227 help="The beta2 parameter for the Adam optimizer."
228 )
229 parser.add_argument(
230 "--adam_weight_decay",
231 type=float,
232 default=1e-2,
233 help="Weight decay to use."
234 )
235 parser.add_argument(
236 "--adam_epsilon",
237 type=float,
238 default=1e-08,
239 help="Epsilon value for the Adam optimizer"
240 )
241 parser.add_argument(
242 "--mixed_precision",
243 type=str,
244 default="no",
245 choices=["no", "fp16", "bf16"],
246 help=(
247 "Whether to use mixed precision. Choose"
248 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
249 "and an Nvidia Ampere GPU."
250 ),
251 )
252 parser.add_argument(
253 "--sample_frequency",
254 type=int,
255 default=1,
256 help="How often to save a checkpoint and sample image",
257 )
258 parser.add_argument(
259 "--sample_image_size",
260 type=int,
261 default=768,
262 help="Size of sample images",
263 )
264 parser.add_argument(
265 "--sample_batches",
266 type=int,
267 default=1,
268 help="Number of sample batches to generate per checkpoint",
269 )
270 parser.add_argument(
271 "--sample_batch_size",
272 type=int,
273 default=1,
274 help="Number of samples to generate per batch",
275 )
276 parser.add_argument(
277 "--valid_set_size",
278 type=int,
279 default=None,
280 help="Number of images in the validation dataset."
281 )
282 parser.add_argument(
283 "--train_batch_size",
284 type=int,
285 default=1,
286 help="Batch size (per device) for the training dataloader."
287 )
288 parser.add_argument(
289 "--sample_steps",
290 type=int,
291 default=15,
292 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
293 )
294 parser.add_argument(
295 "--prior_loss_weight",
296 type=float,
297 default=1.0,
298 help="The weight of prior preservation loss."
299 )
300 parser.add_argument(
301 "--max_grad_norm",
302 default=1.0,
303 type=float,
304 help="Max gradient norm."
305 )
306 parser.add_argument(
307 "--noise_timesteps",
308 type=int,
309 default=1000,
310 )
311 parser.add_argument(
312 "--config",
313 type=str,
314 default=None,
315 help="Path to a JSON configuration file containing arguments for invoking this script."
316 )
317
318 args = parser.parse_args()
319 if args.config is not None:
320 with open(args.config, 'rt') as f:
321 args = parser.parse_args(
322 namespace=argparse.Namespace(**json.load(f)["args"]))
323
324 if args.train_data_file is None:
325 raise ValueError("You must specify --train_data_file")
326
327 if args.pretrained_model_name_or_path is None:
328 raise ValueError("You must specify --pretrained_model_name_or_path")
329
330 if args.instance_identifier is None:
331 raise ValueError("You must specify --instance_identifier")
332
333 if isinstance(args.initializer_token, str):
334 args.initializer_token = [args.initializer_token]
335
336 if isinstance(args.placeholder_token, str):
337 args.placeholder_token = [args.placeholder_token]
338
339 if len(args.placeholder_token) == 0:
340 args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))]
341
342 if len(args.placeholder_token) != len(args.initializer_token):
343 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
344
345 if args.output_dir is None:
346 raise ValueError("You must specify --output_dir")
347
348 return args
349
350
351def save_args(basepath: Path, args, extra={}):
352 info = {"args": vars(args)}
353 info["args"].update(extra)
354 with open(basepath.joinpath("args.json"), "w") as f:
355 json.dump(info, f, indent=4)
356
357
358def freeze_params(params):
359 for param in params:
360 param.requires_grad = False
361
362
363def make_grid(images, rows, cols):
364 w, h = images[0].size
365 grid = Image.new('RGB', size=(cols*w, rows*h))
366 for i, image in enumerate(images):
367 grid.paste(image, box=(i % cols*w, i//cols*h))
368 return grid
369
370
371class AverageMeter:
372 def __init__(self, name=None):
373 self.name = name
374 self.reset()
375
376 def reset(self):
377 self.sum = self.count = self.avg = 0
378
379 def update(self, val, n=1):
380 self.sum += val * n
381 self.count += n
382 self.avg = self.sum / self.count
383
384
385class Checkpointer:
386 def __init__(
387 self,
388 datamodule,
389 accelerator,
390 vae,
391 unet,
392 unet_lora,
393 tokenizer,
394 text_encoder,
395 scheduler,
396 output_dir: Path,
397 instance_identifier,
398 placeholder_token,
399 placeholder_token_id,
400 sample_image_size,
401 sample_batches,
402 sample_batch_size,
403 seed
404 ):
405 self.datamodule = datamodule
406 self.accelerator = accelerator
407 self.vae = vae
408 self.unet = unet
409 self.unet_lora = unet_lora
410 self.tokenizer = tokenizer
411 self.text_encoder = text_encoder
412 self.scheduler = scheduler
413 self.output_dir = output_dir
414 self.instance_identifier = instance_identifier
415 self.placeholder_token = placeholder_token
416 self.placeholder_token_id = placeholder_token_id
417 self.sample_image_size = sample_image_size
418 self.seed = seed or torch.random.seed()
419 self.sample_batches = sample_batches
420 self.sample_batch_size = sample_batch_size
421
422 @torch.no_grad()
423 def save_model(self):
424 print("Saving model...")
425
426 unet_lora = self.accelerator.unwrap_model(self.unet_lora)
427 unet_lora.save_pretrained(self.output_dir.joinpath("model"))
428
429 del unet_lora
430
431 if torch.cuda.is_available():
432 torch.cuda.empty_cache()
433
434 @torch.no_grad()
435 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
436 samples_path = Path(self.output_dir).joinpath("samples")
437
438 unet = self.accelerator.unwrap_model(self.unet)
439 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
440
441 pipeline = VlpnStableDiffusion(
442 text_encoder=text_encoder,
443 vae=self.vae,
444 unet=unet,
445 tokenizer=self.tokenizer,
446 scheduler=self.scheduler,
447 ).to(self.accelerator.device)
448 pipeline.set_progress_bar_config(dynamic_ncols=True)
449
450 train_data = self.datamodule.train_dataloader()
451 val_data = self.datamodule.val_dataloader()
452
453 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
454 stable_latents = torch.randn(
455 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
456 device=pipeline.device,
457 generator=generator,
458 )
459
460 with torch.autocast("cuda"), torch.inference_mode():
461 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
462 all_samples = []
463 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
464 file_path.parent.mkdir(parents=True, exist_ok=True)
465
466 data_enum = enumerate(data)
467
468 batches = [
469 batch
470 for j, batch in data_enum
471 if j * data.batch_size < self.sample_batch_size * self.sample_batches
472 ]
473 prompts = [
474 prompt.format(identifier=self.instance_identifier)
475 for batch in batches
476 for prompt in batch["prompts"]
477 ]
478 nprompts = [
479 prompt
480 for batch in batches
481 for prompt in batch["nprompts"]
482 ]
483
484 for i in range(self.sample_batches):
485 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
486 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
487
488 samples = pipeline(
489 prompt=prompt,
490 negative_prompt=nprompt,
491 height=self.sample_image_size,
492 width=self.sample_image_size,
493 image=latents[:len(prompt)] if latents is not None else None,
494 generator=generator if latents is not None else None,
495 guidance_scale=guidance_scale,
496 eta=eta,
497 num_inference_steps=num_inference_steps,
498 output_type='pil'
499 ).images
500
501 all_samples += samples
502
503 del samples
504
505 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
506 image_grid.save(file_path, quality=85)
507
508 del all_samples
509 del image_grid
510
511 del unet
512 del text_encoder
513 del pipeline
514 del generator
515 del stable_latents
516
517 if torch.cuda.is_available():
518 torch.cuda.empty_cache()
519
520
521def main():
522 args = parse_args()
523
524 instance_identifier = args.instance_identifier
525
526 if len(args.placeholder_token) != 0:
527 instance_identifier = instance_identifier.format(args.placeholder_token[0])
528
529 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
530 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
531 basepath.mkdir(parents=True, exist_ok=True)
532
533 accelerator = Accelerator(
534 log_with=LoggerType.TENSORBOARD,
535 logging_dir=f"{basepath}",
536 gradient_accumulation_steps=args.gradient_accumulation_steps,
537 mixed_precision=args.mixed_precision
538 )
539
540 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
541
542 args.seed = args.seed or (torch.random.seed() >> 32)
543 set_seed(args.seed)
544
545 save_args(basepath, args)
546
547 # Load the tokenizer and add the placeholder token as a additional special token
548 if args.tokenizer_name:
549 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
550 elif args.pretrained_model_name_or_path:
551 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
552
553 # Load models and create wrapper for stable diffusion
554 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
555 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
556 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
557 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
558 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
559 args.pretrained_model_name_or_path, subfolder='scheduler')
560
561 unet_lora = LoraAttention()
562
563 vae.enable_slicing()
564 vae.set_use_memory_efficient_attention_xformers(True)
565 unet.set_use_memory_efficient_attention_xformers(True)
566 unet.set_attn_processor(unet_lora)
567
568 if args.gradient_checkpointing:
569 unet.enable_gradient_checkpointing()
570 text_encoder.gradient_checkpointing_enable()
571
572 # Freeze text_encoder and vae
573 vae.requires_grad_(False)
574 unet.requires_grad_(False)
575
576 if args.embeddings_dir is not None:
577 embeddings_dir = Path(args.embeddings_dir)
578 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
579 raise ValueError("--embeddings_dir must point to an existing directory")
580 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
581 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
582
583 if len(args.placeholder_token) != 0:
584 # Convert the initializer_token, placeholder_token to ids
585 initializer_token_ids = torch.stack([
586 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
587 for token in args.initializer_token
588 ])
589
590 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
591 print(f"Added {num_added_tokens} new tokens.")
592
593 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
594
595 # Resize the token embeddings as we are adding new special tokens to the tokenizer
596 text_encoder.resize_token_embeddings(len(tokenizer))
597
598 token_embeds = text_encoder.get_input_embeddings().weight.data
599 original_token_embeds = token_embeds.clone().to(accelerator.device)
600 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
601
602 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
603 token_embeds[token_id] = embeddings
604 else:
605 placeholder_token_id = []
606
607 print(f"Training added text embeddings")
608
609 freeze_params(itertools.chain(
610 text_encoder.text_model.encoder.parameters(),
611 text_encoder.text_model.final_layer_norm.parameters(),
612 text_encoder.text_model.embeddings.position_embedding.parameters(),
613 ))
614
615 index_fixed_tokens = torch.arange(len(tokenizer))
616 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
617
618 prompt_processor = PromptProcessor(tokenizer, text_encoder)
619
620 if args.scale_lr:
621 args.learning_rate_unet = (
622 args.learning_rate_unet * args.gradient_accumulation_steps *
623 args.train_batch_size * accelerator.num_processes
624 )
625
626 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
627 if args.use_8bit_adam:
628 try:
629 import bitsandbytes as bnb
630 except ImportError:
631 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
632
633 optimizer_class = bnb.optim.AdamW8bit
634 else:
635 optimizer_class = torch.optim.AdamW
636
637 # Initialize the optimizer
638 optimizer = optimizer_class(
639 [
640 {
641 'params': unet_lora.parameters(),
642 'lr': args.learning_rate_unet,
643 },
644 ],
645 betas=(args.adam_beta1, args.adam_beta2),
646 weight_decay=args.adam_weight_decay,
647 eps=args.adam_epsilon,
648 )
649
650 weight_dtype = torch.float32
651 if args.mixed_precision == "fp16":
652 weight_dtype = torch.float16
653 elif args.mixed_precision == "bf16":
654 weight_dtype = torch.bfloat16
655
656 def collate_fn(examples):
657 prompts = [example["prompts"] for example in examples]
658 nprompts = [example["nprompts"] for example in examples]
659 input_ids = [example["instance_prompt_ids"] for example in examples]
660 pixel_values = [example["instance_images"] for example in examples]
661
662 # concat class and instance examples for prior preservation
663 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
664 input_ids += [example["class_prompt_ids"] for example in examples]
665 pixel_values += [example["class_images"] for example in examples]
666
667 pixel_values = torch.stack(pixel_values)
668 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
669
670 inputs = prompt_processor.unify_input_ids(input_ids)
671
672 batch = {
673 "prompts": prompts,
674 "nprompts": nprompts,
675 "input_ids": inputs.input_ids,
676 "pixel_values": pixel_values,
677 "attention_mask": inputs.attention_mask,
678 }
679 return batch
680
681 datamodule = CSVDataModule(
682 data_file=args.train_data_file,
683 batch_size=args.train_batch_size,
684 prompt_processor=prompt_processor,
685 instance_identifier=instance_identifier,
686 class_identifier=args.class_identifier,
687 class_subdir="cls",
688 num_class_images=args.num_class_images,
689 size=args.resolution,
690 repeats=args.repeats,
691 mode=args.mode,
692 dropout=args.tag_dropout,
693 center_crop=args.center_crop,
694 template_key=args.train_data_template,
695 valid_set_size=args.valid_set_size,
696 num_workers=args.dataloader_num_workers,
697 collate_fn=collate_fn
698 )
699
700 datamodule.prepare_data()
701 datamodule.setup()
702
703 if args.num_class_images != 0:
704 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
705
706 if len(missing_data) != 0:
707 batched_data = [
708 missing_data[i:i+args.sample_batch_size]
709 for i in range(0, len(missing_data), args.sample_batch_size)
710 ]
711
712 pipeline = VlpnStableDiffusion(
713 text_encoder=text_encoder,
714 vae=vae,
715 unet=unet,
716 tokenizer=tokenizer,
717 scheduler=checkpoint_scheduler,
718 ).to(accelerator.device)
719 pipeline.set_progress_bar_config(dynamic_ncols=True)
720
721 with torch.autocast("cuda"), torch.inference_mode():
722 for batch in batched_data:
723 image_name = [item.class_image_path for item in batch]
724 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
725 nprompt = [item.nprompt for item in batch]
726
727 images = pipeline(
728 prompt=prompt,
729 negative_prompt=nprompt,
730 num_inference_steps=args.sample_steps
731 ).images
732
733 for i, image in enumerate(images):
734 image.save(image_name[i])
735
736 del pipeline
737
738 if torch.cuda.is_available():
739 torch.cuda.empty_cache()
740
741 train_dataloader = datamodule.train_dataloader()
742 val_dataloader = datamodule.val_dataloader()
743
744 # Scheduler and math around the number of training steps.
745 overrode_max_train_steps = False
746 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
747 if args.max_train_steps is None:
748 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
749 overrode_max_train_steps = True
750 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
751
752 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
753
754 if args.lr_scheduler == "one_cycle":
755 lr_scheduler = get_one_cycle_schedule(
756 optimizer=optimizer,
757 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
758 )
759 elif args.lr_scheduler == "cosine_with_restarts":
760 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
761 optimizer=optimizer,
762 num_warmup_steps=warmup_steps,
763 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
764 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
765 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
766 )
767 else:
768 lr_scheduler = get_scheduler(
769 args.lr_scheduler,
770 optimizer=optimizer,
771 num_warmup_steps=warmup_steps,
772 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
773 )
774
775 unet_lora, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
776 unet_lora, optimizer, train_dataloader, val_dataloader, lr_scheduler
777 )
778
779 # Move text_encoder and vae to device
780 vae.to(accelerator.device, dtype=weight_dtype)
781 unet.to(accelerator.device, dtype=weight_dtype)
782 text_encoder.to(accelerator.device, dtype=weight_dtype)
783
784 # Keep text_encoder and vae in eval mode as we don't train these
785 vae.eval()
786 unet.eval()
787 text_encoder.eval()
788
789 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
790 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
791 if overrode_max_train_steps:
792 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
793
794 num_val_steps_per_epoch = len(val_dataloader)
795 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
796 val_steps = num_val_steps_per_epoch * num_epochs
797
798 # We need to initialize the trackers we use, and also store our configuration.
799 # The trackers initializes automatically on the main process.
800 if accelerator.is_main_process:
801 config = vars(args).copy()
802 config["initializer_token"] = " ".join(config["initializer_token"])
803 config["placeholder_token"] = " ".join(config["placeholder_token"])
804 accelerator.init_trackers("dreambooth", config=config)
805
806 # Train!
807 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
808
809 logger.info("***** Running training *****")
810 logger.info(f" Num Epochs = {num_epochs}")
811 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
812 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
813 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
814 logger.info(f" Total optimization steps = {args.max_train_steps}")
815 # Only show the progress bar once on each machine.
816
817 global_step = 0
818
819 avg_loss = AverageMeter()
820 avg_acc = AverageMeter()
821
822 avg_loss_val = AverageMeter()
823 avg_acc_val = AverageMeter()
824
825 max_acc_val = 0.0
826
827 checkpointer = Checkpointer(
828 datamodule=datamodule,
829 accelerator=accelerator,
830 vae=vae,
831 unet=unet,
832 tokenizer=tokenizer,
833 text_encoder=text_encoder,
834 scheduler=checkpoint_scheduler,
835 output_dir=basepath,
836 instance_identifier=instance_identifier,
837 placeholder_token=args.placeholder_token,
838 placeholder_token_id=placeholder_token_id,
839 sample_image_size=args.sample_image_size,
840 sample_batch_size=args.sample_batch_size,
841 sample_batches=args.sample_batches,
842 seed=args.seed
843 )
844
845 if accelerator.is_main_process:
846 checkpointer.save_samples(0, args.sample_steps)
847
848 local_progress_bar = tqdm(
849 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
850 disable=not accelerator.is_local_main_process,
851 dynamic_ncols=True
852 )
853 local_progress_bar.set_description("Epoch X / Y")
854
855 global_progress_bar = tqdm(
856 range(args.max_train_steps + val_steps),
857 disable=not accelerator.is_local_main_process,
858 dynamic_ncols=True
859 )
860 global_progress_bar.set_description("Total progress")
861
862 try:
863 for epoch in range(num_epochs):
864 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
865 local_progress_bar.reset()
866
867 unet_lora.train()
868
869 for step, batch in enumerate(train_dataloader):
870 with accelerator.accumulate(unet_lora):
871 # Convert images to latent space
872 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
873 latents = latents * 0.18215
874
875 # Sample noise that we'll add to the latents
876 noise = torch.randn_like(latents)
877 bsz = latents.shape[0]
878 # Sample a random timestep for each image
879 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
880 (bsz,), device=latents.device)
881 timesteps = timesteps.long()
882
883 # Add noise to the latents according to the noise magnitude at each timestep
884 # (this is the forward diffusion process)
885 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
886
887 # Get the text embedding for conditioning
888 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
889
890 # Predict the noise residual
891 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
892
893 # Get the target for loss depending on the prediction type
894 if noise_scheduler.config.prediction_type == "epsilon":
895 target = noise
896 elif noise_scheduler.config.prediction_type == "v_prediction":
897 target = noise_scheduler.get_velocity(latents, noise, timesteps)
898 else:
899 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
900
901 if args.num_class_images != 0:
902 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
903 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
904 target, target_prior = torch.chunk(target, 2, dim=0)
905
906 # Compute instance loss
907 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
908
909 # Compute prior loss
910 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
911
912 # Add the prior loss to the instance loss.
913 loss = loss + args.prior_loss_weight * prior_loss
914 else:
915 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
916
917 acc = (model_pred == latents).float().mean()
918
919 accelerator.backward(loss)
920
921 if accelerator.sync_gradients:
922 accelerator.clip_grad_norm_(unet_lora.parameters(), args.max_grad_norm)
923
924 optimizer.step()
925 if not accelerator.optimizer_step_was_skipped:
926 lr_scheduler.step()
927 optimizer.zero_grad(set_to_none=True)
928
929 with torch.no_grad():
930 text_encoder.get_input_embeddings(
931 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
932
933 avg_loss.update(loss.detach_(), bsz)
934 avg_acc.update(acc.detach_(), bsz)
935
936 # Checks if the accelerator has performed an optimization step behind the scenes
937 if accelerator.sync_gradients:
938 local_progress_bar.update(1)
939 global_progress_bar.update(1)
940
941 global_step += 1
942
943 logs = {
944 "train/loss": avg_loss.avg.item(),
945 "train/acc": avg_acc.avg.item(),
946 "train/cur_loss": loss.item(),
947 "train/cur_acc": acc.item(),
948 "lr/unet": lr_scheduler.get_last_lr()[0],
949 "lr/text": lr_scheduler.get_last_lr()[1]
950 }
951
952 accelerator.log(logs, step=global_step)
953
954 local_progress_bar.set_postfix(**logs)
955
956 if global_step >= args.max_train_steps:
957 break
958
959 accelerator.wait_for_everyone()
960
961 unet_lora.eval()
962
963 with torch.inference_mode():
964 for step, batch in enumerate(val_dataloader):
965 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
966 latents = latents * 0.18215
967
968 noise = torch.randn_like(latents)
969 bsz = latents.shape[0]
970 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
971 (bsz,), device=latents.device)
972 timesteps = timesteps.long()
973
974 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
975
976 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
977
978 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
979
980 # Get the target for loss depending on the prediction type
981 if noise_scheduler.config.prediction_type == "epsilon":
982 target = noise
983 elif noise_scheduler.config.prediction_type == "v_prediction":
984 target = noise_scheduler.get_velocity(latents, noise, timesteps)
985 else:
986 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
987
988 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
989
990 acc = (model_pred == latents).float().mean()
991
992 avg_loss_val.update(loss.detach_(), bsz)
993 avg_acc_val.update(acc.detach_(), bsz)
994
995 if accelerator.sync_gradients:
996 local_progress_bar.update(1)
997 global_progress_bar.update(1)
998
999 logs = {
1000 "val/loss": avg_loss_val.avg.item(),
1001 "val/acc": avg_acc_val.avg.item(),
1002 "val/cur_loss": loss.item(),
1003 "val/cur_acc": acc.item(),
1004 }
1005 local_progress_bar.set_postfix(**logs)
1006
1007 accelerator.log({
1008 "val/loss": avg_loss_val.avg.item(),
1009 "val/acc": avg_acc_val.avg.item(),
1010 }, step=global_step)
1011
1012 local_progress_bar.clear()
1013 global_progress_bar.clear()
1014
1015 if avg_acc_val.avg.item() > max_acc_val:
1016 accelerator.print(
1017 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1018 max_acc_val = avg_acc_val.avg.item()
1019
1020 if accelerator.is_main_process:
1021 if (epoch + 1) % args.sample_frequency == 0:
1022 checkpointer.save_samples(global_step, args.sample_steps)
1023
1024 # Create the pipeline using using the trained modules and save it.
1025 if accelerator.is_main_process:
1026 print("Finished! Saving final checkpoint and resume state.")
1027 checkpointer.save_model()
1028
1029 accelerator.end_training()
1030
1031 except KeyboardInterrupt:
1032 if accelerator.is_main_process:
1033 print("Interrupted, saving checkpoint and resume state...")
1034 checkpointer.save_model()
1035 accelerator.end_training()
1036 quit()
1037
1038
1039if __name__ == "__main__":
1040 main()
diff --git a/train_ti.py b/train_ti.py
index 5c0299e..9616db6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -24,7 +24,6 @@ from slugify import slugify
24 24
25from common import load_text_embeddings, load_text_embedding 25from common import load_text_embeddings, load_text_embedding
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from pipelines.util import set_use_memory_efficient_attention_xformers
28from data.csv import CSVDataModule, CSVDataItem 27from data.csv import CSVDataModule, CSVDataItem
29from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
30from models.clip.prompt import PromptProcessor 29from models.clip.prompt import PromptProcessor
@@ -557,8 +556,8 @@ def main():
557 args.pretrained_model_name_or_path, subfolder='scheduler') 556 args.pretrained_model_name_or_path, subfolder='scheduler')
558 557
559 vae.enable_slicing() 558 vae.enable_slicing()
560 set_use_memory_efficient_attention_xformers(unet, True) 559 vae.set_use_memory_efficient_attention_xformers(True)
561 set_use_memory_efficient_attention_xformers(vae, True) 560 unet.set_use_memory_efficient_attention_xformers(True)
562 561
563 if args.gradient_checkpointing: 562 if args.gradient_checkpointing:
564 unet.enable_gradient_checkpointing() 563 unet.enable_gradient_checkpointing()
diff --git a/training/lora.py b/training/lora.py
new file mode 100644
index 0000000..d8dc147
--- /dev/null
+++ b/training/lora.py
@@ -0,0 +1,27 @@
1import torch.nn as nn
2from diffusers import ModelMixin, ConfigMixin, XFormersCrossAttnProcessor, register_to_config
3
4
5class LoraAttention(ModelMixin, ConfigMixin):
6 @register_to_config
7 def __init__(self, in_features, out_features, r=4):
8 super().__init__()
9
10 if r > min(in_features, out_features):
11 raise ValueError(
12 f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
13 )
14
15 self.lora_down = nn.Linear(in_features, r, bias=False)
16 self.lora_up = nn.Linear(r, out_features, bias=False)
17 self.scale = 1.0
18
19 self.processor = XFormersCrossAttnProcessor()
20
21 nn.init.normal_(self.lora_down.weight, std=1 / r**2)
22 nn.init.zeros_(self.lora_up.weight)
23
24 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
25 hidden_states = self.processor(attn, hidden_states, encoder_hidden_states, attention_mask, number)
26 hidden_states = hidden_states + self.lora_up(self.lora_down(hidden_states)) * self.scale
27 return hidden_states