summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-19 21:10:58 +0100
committerVolpeon <git@volpeon.ink>2022-12-19 21:10:58 +0100
commit9b808b6ca102cfec0c273626a0bcadf897b7c942 (patch)
tree446311b3c6dca74ac9f9f4e055e2eba5f9cae9e5 /train_ti.py
parentAvoid increased VRAM usage on validation (diff)
downloadtextual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.tar.gz
textual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.tar.bz2
textual-inversion-diff-9b808b6ca102cfec0c273626a0bcadf897b7c942.zip
Improved dataset prompt handling, fixed
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py1032
1 files changed, 1032 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py
new file mode 100644
index 0000000..dbfe58c
--- /dev/null
+++ b/train_ti.py
@@ -0,0 +1,1032 @@
1import argparse
2import itertools
3import math
4import os
5import datetime
6import logging
7import json
8from pathlib import Path
9
10import numpy as np
11import torch
12import torch.nn.functional as F
13import torch.utils.checkpoint
14
15from accelerate import Accelerator
16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from PIL import Image
21from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify
24
25from common import load_text_embeddings, load_text_embedding
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from pipelines.util import set_use_memory_efficient_attention_xformers
28from data.csv import CSVDataModule, CSVDataItem
29from training.optimization import get_one_cycle_schedule
30from models.clip.prompt import PromptProcessor
31
32logger = get_logger(__name__)
33
34
35torch.backends.cuda.matmul.allow_tf32 = True
36torch.backends.cudnn.benchmark = True
37
38
39def parse_args():
40 parser = argparse.ArgumentParser(
41 description="Simple example of a training script."
42 )
43 parser.add_argument(
44 "--pretrained_model_name_or_path",
45 type=str,
46 default=None,
47 help="Path to pretrained model or model identifier from huggingface.co/models.",
48 )
49 parser.add_argument(
50 "--tokenizer_name",
51 type=str,
52 default=None,
53 help="Pretrained tokenizer name or path if not the same as model_name",
54 )
55 parser.add_argument(
56 "--train_data_file",
57 type=str,
58 default=None,
59 help="A CSV file containing the training data."
60 )
61 parser.add_argument(
62 "--train_data_template",
63 type=str,
64 default="template",
65 )
66 parser.add_argument(
67 "--instance_identifier",
68 type=str,
69 default=None,
70 help="A token to use as a placeholder for the concept.",
71 )
72 parser.add_argument(
73 "--class_identifier",
74 type=str,
75 default=None,
76 help="A token to use as a placeholder for the concept.",
77 )
78 parser.add_argument(
79 "--placeholder_token",
80 type=str,
81 nargs='*',
82 help="A token to use as a placeholder for the concept.",
83 )
84 parser.add_argument(
85 "--initializer_token",
86 type=str,
87 nargs='*',
88 help="A token to use as initializer word."
89 )
90 parser.add_argument(
91 "--num_class_images",
92 type=int,
93 default=400,
94 help="How many class images to generate."
95 )
96 parser.add_argument(
97 "--repeats",
98 type=int,
99 default=1,
100 help="How many times to repeat the training data."
101 )
102 parser.add_argument(
103 "--output_dir",
104 type=str,
105 default="output/text-inversion",
106 help="The output directory where the model predictions and checkpoints will be written.",
107 )
108 parser.add_argument(
109 "--embeddings_dir",
110 type=str,
111 default=None,
112 help="The embeddings directory where Textual Inversion embeddings are stored.",
113 )
114 parser.add_argument(
115 "--mode",
116 type=str,
117 default=None,
118 help="A mode to filter the dataset.",
119 )
120 parser.add_argument(
121 "--seed",
122 type=int,
123 default=None,
124 help="A seed for reproducible training.")
125 parser.add_argument(
126 "--resolution",
127 type=int,
128 default=768,
129 help=(
130 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
131 " resolution"
132 ),
133 )
134 parser.add_argument(
135 "--center_crop",
136 action="store_true",
137 help="Whether to center crop images before resizing to resolution"
138 )
139 parser.add_argument(
140 "--tag_dropout",
141 type=float,
142 default=0,
143 help="Tag dropout probability.",
144 )
145 parser.add_argument(
146 "--dataloader_num_workers",
147 type=int,
148 default=0,
149 help=(
150 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
151 " process."
152 ),
153 )
154 parser.add_argument(
155 "--num_train_epochs",
156 type=int,
157 default=100
158 )
159 parser.add_argument(
160 "--max_train_steps",
161 type=int,
162 default=None,
163 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
164 )
165 parser.add_argument(
166 "--gradient_accumulation_steps",
167 type=int,
168 default=1,
169 help="Number of updates steps to accumulate before performing a backward/update pass.",
170 )
171 parser.add_argument(
172 "--gradient_checkpointing",
173 action="store_true",
174 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
175 )
176 parser.add_argument(
177 "--learning_rate",
178 type=float,
179 default=1e-4,
180 help="Initial learning rate (after the potential warmup period) to use.",
181 )
182 parser.add_argument(
183 "--scale_lr",
184 action="store_true",
185 default=True,
186 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
187 )
188 parser.add_argument(
189 "--lr_scheduler",
190 type=str,
191 default="one_cycle",
192 help=(
193 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
194 ' "constant", "constant_with_warmup", "one_cycle"]'
195 ),
196 )
197 parser.add_argument(
198 "--lr_warmup_epochs",
199 type=int,
200 default=10,
201 help="Number of steps for the warmup in the lr scheduler."
202 )
203 parser.add_argument(
204 "--lr_cycles",
205 type=int,
206 default=None,
207 help="Number of restart cycles in the lr scheduler."
208 )
209 parser.add_argument(
210 "--use_8bit_adam",
211 action="store_true",
212 help="Whether or not to use 8-bit Adam from bitsandbytes."
213 )
214 parser.add_argument(
215 "--adam_beta1",
216 type=float,
217 default=0.9,
218 help="The beta1 parameter for the Adam optimizer."
219 )
220 parser.add_argument(
221 "--adam_beta2",
222 type=float,
223 default=0.999,
224 help="The beta2 parameter for the Adam optimizer."
225 )
226 parser.add_argument(
227 "--adam_weight_decay",
228 type=float,
229 default=1e-2,
230 help="Weight decay to use."
231 )
232 parser.add_argument(
233 "--adam_epsilon",
234 type=float,
235 default=1e-08,
236 help="Epsilon value for the Adam optimizer"
237 )
238 parser.add_argument(
239 "--mixed_precision",
240 type=str,
241 default="no",
242 choices=["no", "fp16", "bf16"],
243 help=(
244 "Whether to use mixed precision. Choose"
245 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
246 "and an Nvidia Ampere GPU."
247 ),
248 )
249 parser.add_argument(
250 "--checkpoint_frequency",
251 type=int,
252 default=5,
253 help="How often to save a checkpoint and sample image (in epochs)",
254 )
255 parser.add_argument(
256 "--sample_frequency",
257 type=int,
258 default=1,
259 help="How often to save a checkpoint and sample image (in epochs)",
260 )
261 parser.add_argument(
262 "--sample_image_size",
263 type=int,
264 default=768,
265 help="Size of sample images",
266 )
267 parser.add_argument(
268 "--sample_batches",
269 type=int,
270 default=1,
271 help="Number of sample batches to generate per checkpoint",
272 )
273 parser.add_argument(
274 "--sample_batch_size",
275 type=int,
276 default=1,
277 help="Number of samples to generate per batch",
278 )
279 parser.add_argument(
280 "--valid_set_size",
281 type=int,
282 default=None,
283 help="Number of images in the validation dataset."
284 )
285 parser.add_argument(
286 "--train_batch_size",
287 type=int,
288 default=1,
289 help="Batch size (per device) for the training dataloader."
290 )
291 parser.add_argument(
292 "--sample_steps",
293 type=int,
294 default=15,
295 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
296 )
297 parser.add_argument(
298 "--prior_loss_weight",
299 type=float,
300 default=1.0,
301 help="The weight of prior preservation loss."
302 )
303 parser.add_argument(
304 "--noise_timesteps",
305 type=int,
306 default=1000,
307 )
308 parser.add_argument(
309 "--resume_from",
310 type=str,
311 default=None,
312 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
313 )
314 parser.add_argument(
315 "--global_step",
316 type=int,
317 default=0,
318 )
319 parser.add_argument(
320 "--config",
321 type=str,
322 default=None,
323 help="Path to a JSON configuration file containing arguments for invoking this script."
324 )
325
326 args = parser.parse_args()
327 if args.config is not None:
328 with open(args.config, 'rt') as f:
329 args = parser.parse_args(
330 namespace=argparse.Namespace(**json.load(f)["args"]))
331
332 if args.train_data_file is None:
333 raise ValueError("You must specify --train_data_file")
334
335 if args.pretrained_model_name_or_path is None:
336 raise ValueError("You must specify --pretrained_model_name_or_path")
337
338 if isinstance(args.initializer_token, str):
339 args.initializer_token = [args.initializer_token]
340
341 if len(args.initializer_token) == 0:
342 raise ValueError("You must specify --initializer_token")
343
344 if isinstance(args.placeholder_token, str):
345 args.placeholder_token = [args.placeholder_token]
346
347 if len(args.placeholder_token) == 0:
348 args.placeholder_token = [f"<*{i}>" for i in range(args.initializer_token)]
349
350 if len(args.placeholder_token) != len(args.initializer_token):
351 raise ValueError("You must specify --placeholder_token")
352
353 if args.output_dir is None:
354 raise ValueError("You must specify --output_dir")
355
356 return args
357
358
359def freeze_params(params):
360 for param in params:
361 param.requires_grad = False
362
363
364def save_args(basepath: Path, args, extra={}):
365 info = {"args": vars(args)}
366 info["args"].update(extra)
367 with open(basepath.joinpath("args.json"), "w") as f:
368 json.dump(info, f, indent=4)
369
370
371def make_grid(images, rows, cols):
372 w, h = images[0].size
373 grid = Image.new('RGB', size=(cols*w, rows*h))
374 for i, image in enumerate(images):
375 grid.paste(image, box=(i % cols*w, i//cols*h))
376 return grid
377
378
379class Checkpointer:
380 def __init__(
381 self,
382 datamodule,
383 accelerator,
384 vae,
385 unet,
386 tokenizer,
387 text_encoder,
388 scheduler,
389 instance_identifier,
390 placeholder_token,
391 placeholder_token_id,
392 output_dir: Path,
393 sample_image_size,
394 sample_batches,
395 sample_batch_size,
396 seed
397 ):
398 self.datamodule = datamodule
399 self.accelerator = accelerator
400 self.vae = vae
401 self.unet = unet
402 self.tokenizer = tokenizer
403 self.text_encoder = text_encoder
404 self.scheduler = scheduler
405 self.instance_identifier = instance_identifier
406 self.placeholder_token = placeholder_token
407 self.placeholder_token_id = placeholder_token_id
408 self.output_dir = output_dir
409 self.sample_image_size = sample_image_size
410 self.seed = seed or torch.random.seed()
411 self.sample_batches = sample_batches
412 self.sample_batch_size = sample_batch_size
413
414 @torch.no_grad()
415 def checkpoint(self, step, postfix):
416 print("Saving checkpoint for step %d..." % step)
417
418 checkpoints_path = self.output_dir.joinpath("checkpoints")
419 checkpoints_path.mkdir(parents=True, exist_ok=True)
420
421 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
422
423 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
424 # Save a checkpoint
425 learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id]
426 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
427
428 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)
429 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
430
431 del text_encoder
432 del learned_embeds
433
434 @torch.no_grad()
435 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
436 samples_path = Path(self.output_dir).joinpath("samples")
437
438 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
439
440 # Save a sample image
441 pipeline = VlpnStableDiffusion(
442 text_encoder=text_encoder,
443 vae=self.vae,
444 unet=self.unet,
445 tokenizer=self.tokenizer,
446 scheduler=self.scheduler,
447 ).to(self.accelerator.device)
448 pipeline.set_progress_bar_config(dynamic_ncols=True)
449
450 train_data = self.datamodule.train_dataloader()
451 val_data = self.datamodule.val_dataloader()
452
453 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
454 stable_latents = torch.randn(
455 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
456 device=pipeline.device,
457 generator=generator,
458 )
459
460 with torch.autocast("cuda"), torch.inference_mode():
461 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
462 all_samples = []
463 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
464 file_path.parent.mkdir(parents=True, exist_ok=True)
465
466 data_enum = enumerate(data)
467
468 batches = [
469 batch
470 for j, batch in data_enum
471 if j * data.batch_size < self.sample_batch_size * self.sample_batches
472 ]
473 prompts = [
474 prompt.format(identifier=self.instance_identifier)
475 for batch in batches
476 for prompt in batch["prompts"]
477 ]
478 nprompts = [
479 prompt
480 for batch in batches
481 for prompt in batch["nprompts"]
482 ]
483
484 for i in range(self.sample_batches):
485 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
486 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
487
488 samples = pipeline(
489 prompt=prompt,
490 negative_prompt=nprompt,
491 height=self.sample_image_size,
492 width=self.sample_image_size,
493 image=latents[:len(prompt)] if latents is not None else None,
494 generator=generator if latents is not None else None,
495 guidance_scale=guidance_scale,
496 eta=eta,
497 num_inference_steps=num_inference_steps,
498 output_type='pil'
499 ).images
500
501 all_samples += samples
502
503 del samples
504
505 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
506 image_grid.save(file_path, quality=85)
507
508 del all_samples
509 del image_grid
510
511 del text_encoder
512 del pipeline
513 del generator
514 del stable_latents
515
516 if torch.cuda.is_available():
517 torch.cuda.empty_cache()
518
519
520def main():
521 args = parse_args()
522
523 instance_identifier = args.instance_identifier
524
525 if len(args.placeholder_token) != 0:
526 instance_identifier = instance_identifier.format(args.placeholder_token[0])
527
528 global_step_offset = args.global_step
529 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
530 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
531 basepath.mkdir(parents=True, exist_ok=True)
532
533 accelerator = Accelerator(
534 log_with=LoggerType.TENSORBOARD,
535 logging_dir=f"{basepath}",
536 gradient_accumulation_steps=args.gradient_accumulation_steps,
537 mixed_precision=args.mixed_precision
538 )
539
540 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
541
542 args.seed = args.seed or (torch.random.seed() >> 32)
543 set_seed(args.seed)
544
545 # Load the tokenizer and add the placeholder token as a additional special token
546 if args.tokenizer_name:
547 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
548 elif args.pretrained_model_name_or_path:
549 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
550
551 # Load models and create wrapper for stable diffusion
552 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
553 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
554 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
555 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
556 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
557 args.pretrained_model_name_or_path, subfolder='scheduler')
558
559 vae.enable_slicing()
560 set_use_memory_efficient_attention_xformers(unet, True)
561 set_use_memory_efficient_attention_xformers(vae, True)
562
563 if args.gradient_checkpointing:
564 unet.enable_gradient_checkpointing()
565 text_encoder.gradient_checkpointing_enable()
566
567 if args.embeddings_dir is not None:
568 embeddings_dir = Path(args.embeddings_dir)
569 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
570 raise ValueError("--embeddings_dir must point to an existing directory")
571 added_tokens_from_dir = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
572 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}")
573
574 # Convert the initializer_token, placeholder_token to ids
575 initializer_token_ids = torch.stack([
576 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
577 for token in args.initializer_token
578 ])
579
580 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
581 print(f"Added {num_added_tokens} new tokens.")
582
583 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
584
585 # Resize the token embeddings as we are adding new special tokens to the tokenizer
586 text_encoder.resize_token_embeddings(len(tokenizer))
587
588 # Initialise the newly added placeholder token with the embeddings of the initializer token
589 token_embeds = text_encoder.get_input_embeddings().weight.data
590
591 if args.resume_from is not None:
592 resumepath = Path(args.resume_from).joinpath("checkpoints")
593
594 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
595 load_text_embedding(token_embeds, token_id, resumepath.joinpath(f"{token}_{args.global_step}_end.bin"))
596
597 original_token_embeds = token_embeds.clone().to(accelerator.device)
598
599 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
600 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
601 token_embeds[token_id] = embeddings
602
603 index_fixed_tokens = torch.arange(len(tokenizer))
604 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
605
606 # Freeze vae and unet
607 freeze_params(vae.parameters())
608 freeze_params(unet.parameters())
609 # Freeze all parameters except for the token embeddings in text encoder
610 freeze_params(itertools.chain(
611 text_encoder.text_model.encoder.parameters(),
612 text_encoder.text_model.final_layer_norm.parameters(),
613 text_encoder.text_model.embeddings.position_embedding.parameters(),
614 ))
615
616 prompt_processor = PromptProcessor(tokenizer, text_encoder)
617
618 if args.scale_lr:
619 args.learning_rate = (
620 args.learning_rate * args.gradient_accumulation_steps *
621 args.train_batch_size * accelerator.num_processes
622 )
623
624 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
625 if args.use_8bit_adam:
626 try:
627 import bitsandbytes as bnb
628 except ImportError:
629 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
630
631 optimizer_class = bnb.optim.AdamW8bit
632 else:
633 optimizer_class = torch.optim.AdamW
634
635 # Initialize the optimizer
636 optimizer = optimizer_class(
637 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
638 lr=args.learning_rate,
639 betas=(args.adam_beta1, args.adam_beta2),
640 weight_decay=args.adam_weight_decay,
641 eps=args.adam_epsilon,
642 )
643
644 weight_dtype = torch.float32
645 if args.mixed_precision == "fp16":
646 weight_dtype = torch.float16
647 elif args.mixed_precision == "bf16":
648 weight_dtype = torch.bfloat16
649
650 def keyword_filter(item: CSVDataItem):
651 return any(keyword in item.prompt for keyword in args.placeholder_token)
652
653 def collate_fn(examples):
654 prompts = [example["prompts"] for example in examples]
655 nprompts = [example["nprompts"] for example in examples]
656 input_ids = [example["instance_prompt_ids"] for example in examples]
657 pixel_values = [example["instance_images"] for example in examples]
658
659 # concat class and instance examples for prior preservation
660 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
661 input_ids += [example["class_prompt_ids"] for example in examples]
662 pixel_values += [example["class_images"] for example in examples]
663
664 pixel_values = torch.stack(pixel_values)
665 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
666
667 inputs = prompt_processor.unify_input_ids(input_ids)
668
669 batch = {
670 "prompts": prompts,
671 "nprompts": nprompts,
672 "input_ids": inputs.input_ids,
673 "pixel_values": pixel_values,
674 "attention_mask": inputs.attention_mask,
675 }
676 return batch
677
678 datamodule = CSVDataModule(
679 data_file=args.train_data_file,
680 batch_size=args.train_batch_size,
681 prompt_processor=prompt_processor,
682 instance_identifier=args.instance_identifier,
683 class_identifier=args.class_identifier,
684 class_subdir="cls",
685 num_class_images=args.num_class_images,
686 size=args.resolution,
687 repeats=args.repeats,
688 mode=args.mode,
689 dropout=args.tag_dropout,
690 center_crop=args.center_crop,
691 template_key=args.train_data_template,
692 valid_set_size=args.valid_set_size,
693 num_workers=args.dataloader_num_workers,
694 filter=keyword_filter,
695 collate_fn=collate_fn
696 )
697
698 datamodule.prepare_data()
699 datamodule.setup()
700
701 if args.num_class_images != 0:
702 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
703
704 if len(missing_data) != 0:
705 batched_data = [
706 missing_data[i:i+args.sample_batch_size]
707 for i in range(0, len(missing_data), args.sample_batch_size)
708 ]
709
710 pipeline = VlpnStableDiffusion(
711 text_encoder=text_encoder,
712 vae=vae,
713 unet=unet,
714 tokenizer=tokenizer,
715 scheduler=checkpoint_scheduler,
716 ).to(accelerator.device)
717 pipeline.set_progress_bar_config(dynamic_ncols=True)
718
719 with torch.autocast("cuda"), torch.inference_mode():
720 for batch in batched_data:
721 image_name = [item.class_image_path for item in batch]
722 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
723 nprompt = [item.nprompt for item in batch]
724
725 images = pipeline(
726 prompt=prompt,
727 negative_prompt=nprompt,
728 num_inference_steps=args.sample_steps
729 ).images
730
731 for i, image in enumerate(images):
732 image.save(image_name[i])
733
734 del pipeline
735
736 if torch.cuda.is_available():
737 torch.cuda.empty_cache()
738
739 train_dataloader = datamodule.train_dataloader()
740 val_dataloader = datamodule.val_dataloader()
741
742 # Scheduler and math around the number of training steps.
743 overrode_max_train_steps = False
744 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
745 if args.max_train_steps is None:
746 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
747 overrode_max_train_steps = True
748 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
749
750 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
751
752 if args.lr_scheduler == "one_cycle":
753 lr_scheduler = get_one_cycle_schedule(
754 optimizer=optimizer,
755 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
756 )
757 elif args.lr_scheduler == "cosine_with_restarts":
758 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
759 optimizer=optimizer,
760 num_warmup_steps=warmup_steps,
761 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
762 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
763 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
764 )
765 else:
766 lr_scheduler = get_scheduler(
767 args.lr_scheduler,
768 optimizer=optimizer,
769 num_warmup_steps=warmup_steps,
770 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
771 )
772
773 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
774 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
775 )
776
777 # Move vae and unet to device
778 vae.to(accelerator.device, dtype=weight_dtype)
779 unet.to(accelerator.device, dtype=weight_dtype)
780
781 # Keep vae and unet in eval mode as we don't train these
782 vae.eval()
783 unet.eval()
784
785 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
786 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
787 if overrode_max_train_steps:
788 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
789
790 num_val_steps_per_epoch = len(val_dataloader)
791 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
792 val_steps = num_val_steps_per_epoch * num_epochs
793
794 # We need to initialize the trackers we use, and also store our configuration.
795 # The trackers initializes automatically on the main process.
796 if accelerator.is_main_process:
797 config = vars(args).copy()
798 config["initializer_token"] = " ".join(config["initializer_token"])
799 config["placeholder_token"] = " ".join(config["placeholder_token"])
800 accelerator.init_trackers("textual_inversion", config=config)
801
802 # Train!
803 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
804
805 logger.info("***** Running training *****")
806 logger.info(f" Num Epochs = {num_epochs}")
807 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
808 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
809 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
810 logger.info(f" Total optimization steps = {args.max_train_steps}")
811 # Only show the progress bar once on each machine.
812
813 global_step = 0
814 min_val_loss = np.inf
815
816 checkpointer = Checkpointer(
817 datamodule=datamodule,
818 accelerator=accelerator,
819 vae=vae,
820 unet=unet,
821 tokenizer=tokenizer,
822 text_encoder=text_encoder,
823 scheduler=checkpoint_scheduler,
824 instance_identifier=args.instance_identifier,
825 placeholder_token=args.placeholder_token,
826 placeholder_token_id=placeholder_token_id,
827 output_dir=basepath,
828 sample_image_size=args.sample_image_size,
829 sample_batch_size=args.sample_batch_size,
830 sample_batches=args.sample_batches,
831 seed=args.seed
832 )
833
834 if accelerator.is_main_process:
835 checkpointer.save_samples(
836 0,
837 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
838
839 local_progress_bar = tqdm(
840 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
841 disable=not accelerator.is_local_main_process,
842 dynamic_ncols=True
843 )
844 local_progress_bar.set_description("Epoch X / Y")
845
846 global_progress_bar = tqdm(
847 range(args.max_train_steps + val_steps),
848 disable=not accelerator.is_local_main_process,
849 dynamic_ncols=True
850 )
851 global_progress_bar.set_description("Total progress")
852
853 try:
854 for epoch in range(num_epochs):
855 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
856 local_progress_bar.reset()
857
858 text_encoder.train()
859 train_loss = 0.0
860
861 for step, batch in enumerate(train_dataloader):
862 with accelerator.accumulate(text_encoder):
863 # Convert images to latent space
864 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
865 latents = latents * 0.18215
866
867 # Sample noise that we'll add to the latents
868 noise = torch.randn_like(latents)
869 bsz = latents.shape[0]
870 # Sample a random timestep for each image
871 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
872 (bsz,), device=latents.device)
873 timesteps = timesteps.long()
874
875 # Add noise to the latents according to the noise magnitude at each timestep
876 # (this is the forward diffusion process)
877 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
878
879 # Get the text embedding for conditioning
880 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
881 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
882
883 # Predict the noise residual
884 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
885
886 # Get the target for loss depending on the prediction type
887 if noise_scheduler.config.prediction_type == "epsilon":
888 target = noise
889 elif noise_scheduler.config.prediction_type == "v_prediction":
890 target = noise_scheduler.get_velocity(latents, noise, timesteps)
891 else:
892 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
893
894 if args.num_class_images != 0:
895 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
896 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
897 target, target_prior = torch.chunk(target, 2, dim=0)
898
899 # Compute instance loss
900 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
901
902 # Compute prior loss
903 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
904
905 # Add the prior loss to the instance loss.
906 loss = loss + args.prior_loss_weight * prior_loss
907 else:
908 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
909
910 accelerator.backward(loss)
911
912 optimizer.step()
913 if not accelerator.optimizer_step_was_skipped:
914 lr_scheduler.step()
915 optimizer.zero_grad(set_to_none=True)
916
917 # Let's make sure we don't update any embedding weights besides the newly added token
918 with torch.no_grad():
919 text_encoder.get_input_embeddings(
920 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
921
922 loss = loss.detach().item()
923 train_loss += loss
924
925 # Checks if the accelerator has performed an optimization step behind the scenes
926 if accelerator.sync_gradients:
927 local_progress_bar.update(1)
928 global_progress_bar.update(1)
929
930 global_step += 1
931
932 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
933
934 accelerator.log(logs, step=global_step)
935
936 local_progress_bar.set_postfix(**logs)
937
938 if global_step >= args.max_train_steps:
939 break
940
941 train_loss /= len(train_dataloader)
942
943 accelerator.wait_for_everyone()
944
945 text_encoder.eval()
946 val_loss = 0.0
947
948 with torch.inference_mode():
949 for step, batch in enumerate(val_dataloader):
950 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
951 latents = latents * 0.18215
952
953 noise = torch.randn_like(latents)
954 bsz = latents.shape[0]
955 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
956 (bsz,), device=latents.device)
957 timesteps = timesteps.long()
958
959 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
960
961 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
962 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
963
964 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
965
966 # Get the target for loss depending on the prediction type
967 if noise_scheduler.config.prediction_type == "epsilon":
968 target = noise
969 elif noise_scheduler.config.prediction_type == "v_prediction":
970 target = noise_scheduler.get_velocity(latents, noise, timesteps)
971 else:
972 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
973
974 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
975
976 loss = loss.detach().item()
977 val_loss += loss
978
979 if accelerator.sync_gradients:
980 local_progress_bar.update(1)
981 global_progress_bar.update(1)
982
983 logs = {"val/loss": loss}
984 local_progress_bar.set_postfix(**logs)
985
986 val_loss /= len(val_dataloader)
987
988 accelerator.log({"val/loss": val_loss}, step=global_step)
989
990 local_progress_bar.clear()
991 global_progress_bar.clear()
992
993 if accelerator.is_main_process:
994 if min_val_loss > val_loss:
995 accelerator.print(
996 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
997 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
998 min_val_loss = val_loss
999
1000 if (epoch + 1) % args.checkpoint_frequency == 0:
1001 checkpointer.checkpoint(global_step + global_step_offset, "training")
1002 save_args(basepath, args, {
1003 "global_step": global_step + global_step_offset
1004 })
1005
1006 if (epoch + 1) % args.sample_frequency == 0:
1007 checkpointer.save_samples(
1008 global_step + global_step_offset,
1009 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
1010
1011 # Create the pipeline using using the trained modules and save it.
1012 if accelerator.is_main_process:
1013 print("Finished! Saving final checkpoint and resume state.")
1014 checkpointer.checkpoint(global_step + global_step_offset, "end")
1015 save_args(basepath, args, {
1016 "global_step": global_step + global_step_offset
1017 })
1018 accelerator.end_training()
1019
1020 except KeyboardInterrupt:
1021 if accelerator.is_main_process:
1022 print("Interrupted, saving checkpoint and resume state...")
1023 checkpointer.checkpoint(global_step + global_step_offset, "end")
1024 save_args(basepath, args, {
1025 "global_step": global_step + global_step_offset
1026 })
1027 accelerator.end_training()
1028 quit()
1029
1030
1031if __name__ == "__main__":
1032 main()