summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py566
1 files changed, 566 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py
new file mode 100644
index 0000000..2cb85cc
--- /dev/null
+++ b/train_lora.py
@@ -0,0 +1,566 @@
1import argparse
2import datetime
3import logging
4import itertools
5from pathlib import Path
6from functools import partial
7
8import torch
9import torch.utils.checkpoint
10
11from accelerate import Accelerator
12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify
15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRACrossAttnProcessor
17
18from util import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter
20from training.functional import train, get_models
21from training.lr import plot_metrics
22from training.strategy.lora import lora_strategy
23from training.optimization import get_scheduler
24from training.util import save_args
25
26logger = get_logger(__name__)
27
28
29torch.backends.cuda.matmul.allow_tf32 = True
30torch.backends.cudnn.benchmark = True
31
32
33def parse_args():
34 parser = argparse.ArgumentParser(
35 description="Simple example of a training script."
36 )
37 parser.add_argument(
38 "--pretrained_model_name_or_path",
39 type=str,
40 default=None,
41 help="Path to pretrained model or model identifier from huggingface.co/models.",
42 )
43 parser.add_argument(
44 "--tokenizer_name",
45 type=str,
46 default=None,
47 help="Pretrained tokenizer name or path if not the same as model_name",
48 )
49 parser.add_argument(
50 "--train_data_file",
51 type=str,
52 default=None,
53 help="A folder containing the training data."
54 )
55 parser.add_argument(
56 "--train_data_template",
57 type=str,
58 default="template",
59 )
60 parser.add_argument(
61 "--train_set_pad",
62 type=int,
63 default=None,
64 help="The number to fill train dataset items up to."
65 )
66 parser.add_argument(
67 "--valid_set_pad",
68 type=int,
69 default=None,
70 help="The number to fill validation dataset items up to."
71 )
72 parser.add_argument(
73 "--project",
74 type=str,
75 default=None,
76 help="The name of the current project.",
77 )
78 parser.add_argument(
79 "--exclude_collections",
80 type=str,
81 nargs='*',
82 help="Exclude all items with a listed collection.",
83 )
84 parser.add_argument(
85 "--num_buckets",
86 type=int,
87 default=4,
88 help="Number of aspect ratio buckets in either direction.",
89 )
90 parser.add_argument(
91 "--progressive_buckets",
92 action="store_true",
93 help="Include images in smaller buckets as well.",
94 )
95 parser.add_argument(
96 "--bucket_step_size",
97 type=int,
98 default=64,
99 help="Step size between buckets.",
100 )
101 parser.add_argument(
102 "--bucket_max_pixels",
103 type=int,
104 default=None,
105 help="Maximum pixels per bucket.",
106 )
107 parser.add_argument(
108 "--tag_dropout",
109 type=float,
110 default=0.1,
111 help="Tag dropout probability.",
112 )
113 parser.add_argument(
114 "--no_tag_shuffle",
115 action="store_true",
116 help="Shuffle tags.",
117 )
118 parser.add_argument(
119 "--num_class_images",
120 type=int,
121 default=0,
122 help="How many class images to generate."
123 )
124 parser.add_argument(
125 "--class_image_dir",
126 type=str,
127 default="cls",
128 help="The directory where class images will be saved.",
129 )
130 parser.add_argument(
131 "--output_dir",
132 type=str,
133 default="output/lora",
134 help="The output directory where the model predictions and checkpoints will be written.",
135 )
136 parser.add_argument(
137 "--embeddings_dir",
138 type=str,
139 default=None,
140 help="The embeddings directory where Textual Inversion embeddings are stored.",
141 )
142 parser.add_argument(
143 "--collection",
144 type=str,
145 nargs='*',
146 help="A collection to filter the dataset.",
147 )
148 parser.add_argument(
149 "--seed",
150 type=int,
151 default=None,
152 help="A seed for reproducible training."
153 )
154 parser.add_argument(
155 "--resolution",
156 type=int,
157 default=768,
158 help=(
159 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
160 " resolution"
161 ),
162 )
163 parser.add_argument(
164 "--num_train_epochs",
165 type=int,
166 default=100
167 )
168 parser.add_argument(
169 "--max_train_steps",
170 type=int,
171 default=None,
172 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
173 )
174 parser.add_argument(
175 "--gradient_accumulation_steps",
176 type=int,
177 default=1,
178 help="Number of updates steps to accumulate before performing a backward/update pass.",
179 )
180 parser.add_argument(
181 "--find_lr",
182 action="store_true",
183 help="Automatically find a learning rate (no training).",
184 )
185 parser.add_argument(
186 "--learning_rate",
187 type=float,
188 default=2e-6,
189 help="Initial learning rate (after the potential warmup period) to use.",
190 )
191 parser.add_argument(
192 "--scale_lr",
193 action="store_true",
194 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
195 )
196 parser.add_argument(
197 "--lr_scheduler",
198 type=str,
199 default="one_cycle",
200 help=(
201 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
202 ' "constant", "constant_with_warmup", "one_cycle"]'
203 ),
204 )
205 parser.add_argument(
206 "--lr_warmup_epochs",
207 type=int,
208 default=10,
209 help="Number of steps for the warmup in the lr scheduler."
210 )
211 parser.add_argument(
212 "--lr_cycles",
213 type=int,
214 default=None,
215 help="Number of restart cycles in the lr scheduler (if supported)."
216 )
217 parser.add_argument(
218 "--lr_warmup_func",
219 type=str,
220 default="cos",
221 help='Choose between ["linear", "cos"]'
222 )
223 parser.add_argument(
224 "--lr_warmup_exp",
225 type=int,
226 default=1,
227 help='If lr_warmup_func is "cos", exponent to modify the function'
228 )
229 parser.add_argument(
230 "--lr_annealing_func",
231 type=str,
232 default="cos",
233 help='Choose between ["linear", "half_cos", "cos"]'
234 )
235 parser.add_argument(
236 "--lr_annealing_exp",
237 type=int,
238 default=3,
239 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
240 )
241 parser.add_argument(
242 "--lr_min_lr",
243 type=float,
244 default=0.04,
245 help="Minimum learning rate in the lr scheduler."
246 )
247 parser.add_argument(
248 "--use_8bit_adam",
249 action="store_true",
250 help="Whether or not to use 8-bit Adam from bitsandbytes."
251 )
252 parser.add_argument(
253 "--adam_beta1",
254 type=float,
255 default=0.9,
256 help="The beta1 parameter for the Adam optimizer."
257 )
258 parser.add_argument(
259 "--adam_beta2",
260 type=float,
261 default=0.999,
262 help="The beta2 parameter for the Adam optimizer."
263 )
264 parser.add_argument(
265 "--adam_weight_decay",
266 type=float,
267 default=1e-2,
268 help="Weight decay to use."
269 )
270 parser.add_argument(
271 "--adam_epsilon",
272 type=float,
273 default=1e-08,
274 help="Epsilon value for the Adam optimizer"
275 )
276 parser.add_argument(
277 "--adam_amsgrad",
278 type=bool,
279 default=False,
280 help="Amsgrad value for the Adam optimizer"
281 )
282 parser.add_argument(
283 "--mixed_precision",
284 type=str,
285 default="no",
286 choices=["no", "fp16", "bf16"],
287 help=(
288 "Whether to use mixed precision. Choose"
289 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
290 "and an Nvidia Ampere GPU."
291 ),
292 )
293 parser.add_argument(
294 "--sample_frequency",
295 type=int,
296 default=1,
297 help="How often to save a checkpoint and sample image",
298 )
299 parser.add_argument(
300 "--sample_image_size",
301 type=int,
302 default=768,
303 help="Size of sample images",
304 )
305 parser.add_argument(
306 "--sample_batches",
307 type=int,
308 default=1,
309 help="Number of sample batches to generate per checkpoint",
310 )
311 parser.add_argument(
312 "--sample_batch_size",
313 type=int,
314 default=1,
315 help="Number of samples to generate per batch",
316 )
317 parser.add_argument(
318 "--valid_set_size",
319 type=int,
320 default=None,
321 help="Number of images in the validation dataset."
322 )
323 parser.add_argument(
324 "--valid_set_repeat",
325 type=int,
326 default=1,
327 help="Times the images in the validation dataset are repeated."
328 )
329 parser.add_argument(
330 "--train_batch_size",
331 type=int,
332 default=1,
333 help="Batch size (per device) for the training dataloader."
334 )
335 parser.add_argument(
336 "--sample_steps",
337 type=int,
338 default=20,
339 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
340 )
341 parser.add_argument(
342 "--prior_loss_weight",
343 type=float,
344 default=1.0,
345 help="The weight of prior preservation loss."
346 )
347 parser.add_argument(
348 "--max_grad_norm",
349 default=1.0,
350 type=float,
351 help="Max gradient norm."
352 )
353 parser.add_argument(
354 "--noise_timesteps",
355 type=int,
356 default=1000,
357 )
358 parser.add_argument(
359 "--config",
360 type=str,
361 default=None,
362 help="Path to a JSON configuration file containing arguments for invoking this script."
363 )
364
365 args = parser.parse_args()
366 if args.config is not None:
367 args = load_config(args.config)
368 args = parser.parse_args(namespace=argparse.Namespace(**args))
369
370 if args.train_data_file is None:
371 raise ValueError("You must specify --train_data_file")
372
373 if args.pretrained_model_name_or_path is None:
374 raise ValueError("You must specify --pretrained_model_name_or_path")
375
376 if args.project is None:
377 raise ValueError("You must specify --project")
378
379 if isinstance(args.collection, str):
380 args.collection = [args.collection]
381
382 if isinstance(args.exclude_collections, str):
383 args.exclude_collections = [args.exclude_collections]
384
385 if args.output_dir is None:
386 raise ValueError("You must specify --output_dir")
387
388 return args
389
390
391def main():
392 args = parse_args()
393
394 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
395 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
396 output_dir.mkdir(parents=True, exist_ok=True)
397
398 accelerator = Accelerator(
399 log_with=LoggerType.TENSORBOARD,
400 logging_dir=f"{output_dir}",
401 gradient_accumulation_steps=args.gradient_accumulation_steps,
402 mixed_precision=args.mixed_precision
403 )
404
405 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
406
407 if args.seed is None:
408 args.seed = torch.random.seed() >> 32
409
410 set_seed(args.seed)
411
412 save_args(output_dir, args)
413
414 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
415 args.pretrained_model_name_or_path)
416
417 vae.enable_slicing()
418 vae.set_use_memory_efficient_attention_xformers(True)
419 unet.enable_xformers_memory_efficient_attention()
420
421 lora_attn_procs = {}
422 for name in unet.attn_processors.keys():
423 cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
424 if name.startswith("mid_block"):
425 hidden_size = unet.config.block_out_channels[-1]
426 elif name.startswith("up_blocks"):
427 block_id = int(name[len("up_blocks.")])
428 hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
429 elif name.startswith("down_blocks"):
430 block_id = int(name[len("down_blocks.")])
431 hidden_size = unet.config.block_out_channels[block_id]
432
433 lora_attn_procs[name] = LoRACrossAttnProcessor(
434 hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
435 )
436
437 unet.set_attn_processor(lora_attn_procs)
438 lora_layers = AttnProcsLayers(unet.attn_processors)
439
440 if args.embeddings_dir is not None:
441 embeddings_dir = Path(args.embeddings_dir)
442 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
443 raise ValueError("--embeddings_dir must point to an existing directory")
444
445 embeddings.persist()
446
447 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
448 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
449
450 if args.scale_lr:
451 args.learning_rate = (
452 args.learning_rate * args.gradient_accumulation_steps *
453 args.train_batch_size * accelerator.num_processes
454 )
455
456 if args.find_lr:
457 args.learning_rate = 1e-6
458 args.lr_scheduler = "exponential_growth"
459
460 if args.use_8bit_adam:
461 try:
462 import bitsandbytes as bnb
463 except ImportError:
464 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
465
466 optimizer_class = bnb.optim.AdamW8bit
467 else:
468 optimizer_class = torch.optim.AdamW
469
470 weight_dtype = torch.float32
471 if args.mixed_precision == "fp16":
472 weight_dtype = torch.float16
473 elif args.mixed_precision == "bf16":
474 weight_dtype = torch.bfloat16
475
476 trainer = partial(
477 train,
478 accelerator=accelerator,
479 unet=unet,
480 text_encoder=text_encoder,
481 vae=vae,
482 lora_layers=lora_layers,
483 noise_scheduler=noise_scheduler,
484 dtype=weight_dtype,
485 with_prior_preservation=args.num_class_images != 0,
486 prior_loss_weight=args.prior_loss_weight,
487 )
488
489 checkpoint_output_dir = output_dir.joinpath("model")
490 sample_output_dir = output_dir.joinpath(f"samples")
491
492 datamodule = VlpnDataModule(
493 data_file=args.train_data_file,
494 batch_size=args.train_batch_size,
495 tokenizer=tokenizer,
496 class_subdir=args.class_image_dir,
497 num_class_images=args.num_class_images,
498 size=args.resolution,
499 num_buckets=args.num_buckets,
500 progressive_buckets=args.progressive_buckets,
501 bucket_step_size=args.bucket_step_size,
502 bucket_max_pixels=args.bucket_max_pixels,
503 dropout=args.tag_dropout,
504 shuffle=not args.no_tag_shuffle,
505 template_key=args.train_data_template,
506 valid_set_size=args.valid_set_size,
507 train_set_pad=args.train_set_pad,
508 valid_set_pad=args.valid_set_pad,
509 seed=args.seed,
510 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
511 dtype=weight_dtype
512 )
513 datamodule.setup()
514
515 optimizer = optimizer_class(
516 lora_layers.parameters(),
517 lr=args.learning_rate,
518 betas=(args.adam_beta1, args.adam_beta2),
519 weight_decay=args.adam_weight_decay,
520 eps=args.adam_epsilon,
521 amsgrad=args.adam_amsgrad,
522 )
523
524 lr_scheduler = get_scheduler(
525 args.lr_scheduler,
526 optimizer=optimizer,
527 num_training_steps_per_epoch=len(datamodule.train_dataloader),
528 gradient_accumulation_steps=args.gradient_accumulation_steps,
529 min_lr=args.lr_min_lr,
530 warmup_func=args.lr_warmup_func,
531 annealing_func=args.lr_annealing_func,
532 warmup_exp=args.lr_warmup_exp,
533 annealing_exp=args.lr_annealing_exp,
534 cycles=args.lr_cycles,
535 end_lr=1e2,
536 train_epochs=args.num_train_epochs,
537 warmup_epochs=args.lr_warmup_epochs,
538 )
539
540 metrics = trainer(
541 strategy=lora_strategy,
542 project="lora",
543 train_dataloader=datamodule.train_dataloader,
544 val_dataloader=datamodule.val_dataloader,
545 seed=args.seed,
546 optimizer=optimizer,
547 lr_scheduler=lr_scheduler,
548 num_train_epochs=args.num_train_epochs,
549 sample_frequency=args.sample_frequency,
550 # --
551 tokenizer=tokenizer,
552 sample_scheduler=sample_scheduler,
553 sample_output_dir=sample_output_dir,
554 checkpoint_output_dir=checkpoint_output_dir,
555 max_grad_norm=args.max_grad_norm,
556 sample_batch_size=args.sample_batch_size,
557 sample_num_batches=args.sample_batches,
558 sample_num_steps=args.sample_steps,
559 sample_image_size=args.sample_image_size,
560 )
561
562 plot_metrics(metrics, output_dir.joinpath("lr.png"))
563
564
565if __name__ == "__main__":
566 main()