diff options
author | Volpeon <git@volpeon.ink> | 2023-02-07 20:44:43 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-02-07 20:44:43 +0100 |
commit | 7ccd4614a56cfd6ecacba85605f338593f1059f0 (patch) | |
tree | fa9882b256c752705bc42229bac4e00ed7088643 /train_lora.py | |
parent | Restored LR finder (diff) | |
download | textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.gz textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.bz2 textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.zip |
Add Lora
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 566 |
1 files changed, 566 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py new file mode 100644 index 0000000..2cb85cc --- /dev/null +++ b/train_lora.py | |||
@@ -0,0 +1,566 @@ | |||
1 | import argparse | ||
2 | import datetime | ||
3 | import logging | ||
4 | import itertools | ||
5 | from pathlib import Path | ||
6 | from functools import partial | ||
7 | |||
8 | import torch | ||
9 | import torch.utils.checkpoint | ||
10 | |||
11 | from accelerate import Accelerator | ||
12 | from accelerate.logging import get_logger | ||
13 | from accelerate.utils import LoggerType, set_seed | ||
14 | from slugify import slugify | ||
15 | from diffusers.loaders import AttnProcsLayers | ||
16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor | ||
17 | |||
18 | from util import load_config, load_embeddings_from_dir | ||
19 | from data.csv import VlpnDataModule, keyword_filter | ||
20 | from training.functional import train, get_models | ||
21 | from training.lr import plot_metrics | ||
22 | from training.strategy.lora import lora_strategy | ||
23 | from training.optimization import get_scheduler | ||
24 | from training.util import save_args | ||
25 | |||
26 | logger = get_logger(__name__) | ||
27 | |||
28 | |||
29 | torch.backends.cuda.matmul.allow_tf32 = True | ||
30 | torch.backends.cudnn.benchmark = True | ||
31 | |||
32 | |||
33 | def parse_args(): | ||
34 | parser = argparse.ArgumentParser( | ||
35 | description="Simple example of a training script." | ||
36 | ) | ||
37 | parser.add_argument( | ||
38 | "--pretrained_model_name_or_path", | ||
39 | type=str, | ||
40 | default=None, | ||
41 | help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
42 | ) | ||
43 | parser.add_argument( | ||
44 | "--tokenizer_name", | ||
45 | type=str, | ||
46 | default=None, | ||
47 | help="Pretrained tokenizer name or path if not the same as model_name", | ||
48 | ) | ||
49 | parser.add_argument( | ||
50 | "--train_data_file", | ||
51 | type=str, | ||
52 | default=None, | ||
53 | help="A folder containing the training data." | ||
54 | ) | ||
55 | parser.add_argument( | ||
56 | "--train_data_template", | ||
57 | type=str, | ||
58 | default="template", | ||
59 | ) | ||
60 | parser.add_argument( | ||
61 | "--train_set_pad", | ||
62 | type=int, | ||
63 | default=None, | ||
64 | help="The number to fill train dataset items up to." | ||
65 | ) | ||
66 | parser.add_argument( | ||
67 | "--valid_set_pad", | ||
68 | type=int, | ||
69 | default=None, | ||
70 | help="The number to fill validation dataset items up to." | ||
71 | ) | ||
72 | parser.add_argument( | ||
73 | "--project", | ||
74 | type=str, | ||
75 | default=None, | ||
76 | help="The name of the current project.", | ||
77 | ) | ||
78 | parser.add_argument( | ||
79 | "--exclude_collections", | ||
80 | type=str, | ||
81 | nargs='*', | ||
82 | help="Exclude all items with a listed collection.", | ||
83 | ) | ||
84 | parser.add_argument( | ||
85 | "--num_buckets", | ||
86 | type=int, | ||
87 | default=4, | ||
88 | help="Number of aspect ratio buckets in either direction.", | ||
89 | ) | ||
90 | parser.add_argument( | ||
91 | "--progressive_buckets", | ||
92 | action="store_true", | ||
93 | help="Include images in smaller buckets as well.", | ||
94 | ) | ||
95 | parser.add_argument( | ||
96 | "--bucket_step_size", | ||
97 | type=int, | ||
98 | default=64, | ||
99 | help="Step size between buckets.", | ||
100 | ) | ||
101 | parser.add_argument( | ||
102 | "--bucket_max_pixels", | ||
103 | type=int, | ||
104 | default=None, | ||
105 | help="Maximum pixels per bucket.", | ||
106 | ) | ||
107 | parser.add_argument( | ||
108 | "--tag_dropout", | ||
109 | type=float, | ||
110 | default=0.1, | ||
111 | help="Tag dropout probability.", | ||
112 | ) | ||
113 | parser.add_argument( | ||
114 | "--no_tag_shuffle", | ||
115 | action="store_true", | ||
116 | help="Shuffle tags.", | ||
117 | ) | ||
118 | parser.add_argument( | ||
119 | "--num_class_images", | ||
120 | type=int, | ||
121 | default=0, | ||
122 | help="How many class images to generate." | ||
123 | ) | ||
124 | parser.add_argument( | ||
125 | "--class_image_dir", | ||
126 | type=str, | ||
127 | default="cls", | ||
128 | help="The directory where class images will be saved.", | ||
129 | ) | ||
130 | parser.add_argument( | ||
131 | "--output_dir", | ||
132 | type=str, | ||
133 | default="output/lora", | ||
134 | help="The output directory where the model predictions and checkpoints will be written.", | ||
135 | ) | ||
136 | parser.add_argument( | ||
137 | "--embeddings_dir", | ||
138 | type=str, | ||
139 | default=None, | ||
140 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
141 | ) | ||
142 | parser.add_argument( | ||
143 | "--collection", | ||
144 | type=str, | ||
145 | nargs='*', | ||
146 | help="A collection to filter the dataset.", | ||
147 | ) | ||
148 | parser.add_argument( | ||
149 | "--seed", | ||
150 | type=int, | ||
151 | default=None, | ||
152 | help="A seed for reproducible training." | ||
153 | ) | ||
154 | parser.add_argument( | ||
155 | "--resolution", | ||
156 | type=int, | ||
157 | default=768, | ||
158 | help=( | ||
159 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
160 | " resolution" | ||
161 | ), | ||
162 | ) | ||
163 | parser.add_argument( | ||
164 | "--num_train_epochs", | ||
165 | type=int, | ||
166 | default=100 | ||
167 | ) | ||
168 | parser.add_argument( | ||
169 | "--max_train_steps", | ||
170 | type=int, | ||
171 | default=None, | ||
172 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
173 | ) | ||
174 | parser.add_argument( | ||
175 | "--gradient_accumulation_steps", | ||
176 | type=int, | ||
177 | default=1, | ||
178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
179 | ) | ||
180 | parser.add_argument( | ||
181 | "--find_lr", | ||
182 | action="store_true", | ||
183 | help="Automatically find a learning rate (no training).", | ||
184 | ) | ||
185 | parser.add_argument( | ||
186 | "--learning_rate", | ||
187 | type=float, | ||
188 | default=2e-6, | ||
189 | help="Initial learning rate (after the potential warmup period) to use.", | ||
190 | ) | ||
191 | parser.add_argument( | ||
192 | "--scale_lr", | ||
193 | action="store_true", | ||
194 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | ||
195 | ) | ||
196 | parser.add_argument( | ||
197 | "--lr_scheduler", | ||
198 | type=str, | ||
199 | default="one_cycle", | ||
200 | help=( | ||
201 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | ||
202 | ' "constant", "constant_with_warmup", "one_cycle"]' | ||
203 | ), | ||
204 | ) | ||
205 | parser.add_argument( | ||
206 | "--lr_warmup_epochs", | ||
207 | type=int, | ||
208 | default=10, | ||
209 | help="Number of steps for the warmup in the lr scheduler." | ||
210 | ) | ||
211 | parser.add_argument( | ||
212 | "--lr_cycles", | ||
213 | type=int, | ||
214 | default=None, | ||
215 | help="Number of restart cycles in the lr scheduler (if supported)." | ||
216 | ) | ||
217 | parser.add_argument( | ||
218 | "--lr_warmup_func", | ||
219 | type=str, | ||
220 | default="cos", | ||
221 | help='Choose between ["linear", "cos"]' | ||
222 | ) | ||
223 | parser.add_argument( | ||
224 | "--lr_warmup_exp", | ||
225 | type=int, | ||
226 | default=1, | ||
227 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
228 | ) | ||
229 | parser.add_argument( | ||
230 | "--lr_annealing_func", | ||
231 | type=str, | ||
232 | default="cos", | ||
233 | help='Choose between ["linear", "half_cos", "cos"]' | ||
234 | ) | ||
235 | parser.add_argument( | ||
236 | "--lr_annealing_exp", | ||
237 | type=int, | ||
238 | default=3, | ||
239 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
240 | ) | ||
241 | parser.add_argument( | ||
242 | "--lr_min_lr", | ||
243 | type=float, | ||
244 | default=0.04, | ||
245 | help="Minimum learning rate in the lr scheduler." | ||
246 | ) | ||
247 | parser.add_argument( | ||
248 | "--use_8bit_adam", | ||
249 | action="store_true", | ||
250 | help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
251 | ) | ||
252 | parser.add_argument( | ||
253 | "--adam_beta1", | ||
254 | type=float, | ||
255 | default=0.9, | ||
256 | help="The beta1 parameter for the Adam optimizer." | ||
257 | ) | ||
258 | parser.add_argument( | ||
259 | "--adam_beta2", | ||
260 | type=float, | ||
261 | default=0.999, | ||
262 | help="The beta2 parameter for the Adam optimizer." | ||
263 | ) | ||
264 | parser.add_argument( | ||
265 | "--adam_weight_decay", | ||
266 | type=float, | ||
267 | default=1e-2, | ||
268 | help="Weight decay to use." | ||
269 | ) | ||
270 | parser.add_argument( | ||
271 | "--adam_epsilon", | ||
272 | type=float, | ||
273 | default=1e-08, | ||
274 | help="Epsilon value for the Adam optimizer" | ||
275 | ) | ||
276 | parser.add_argument( | ||
277 | "--adam_amsgrad", | ||
278 | type=bool, | ||
279 | default=False, | ||
280 | help="Amsgrad value for the Adam optimizer" | ||
281 | ) | ||
282 | parser.add_argument( | ||
283 | "--mixed_precision", | ||
284 | type=str, | ||
285 | default="no", | ||
286 | choices=["no", "fp16", "bf16"], | ||
287 | help=( | ||
288 | "Whether to use mixed precision. Choose" | ||
289 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." | ||
290 | "and an Nvidia Ampere GPU." | ||
291 | ), | ||
292 | ) | ||
293 | parser.add_argument( | ||
294 | "--sample_frequency", | ||
295 | type=int, | ||
296 | default=1, | ||
297 | help="How often to save a checkpoint and sample image", | ||
298 | ) | ||
299 | parser.add_argument( | ||
300 | "--sample_image_size", | ||
301 | type=int, | ||
302 | default=768, | ||
303 | help="Size of sample images", | ||
304 | ) | ||
305 | parser.add_argument( | ||
306 | "--sample_batches", | ||
307 | type=int, | ||
308 | default=1, | ||
309 | help="Number of sample batches to generate per checkpoint", | ||
310 | ) | ||
311 | parser.add_argument( | ||
312 | "--sample_batch_size", | ||
313 | type=int, | ||
314 | default=1, | ||
315 | help="Number of samples to generate per batch", | ||
316 | ) | ||
317 | parser.add_argument( | ||
318 | "--valid_set_size", | ||
319 | type=int, | ||
320 | default=None, | ||
321 | help="Number of images in the validation dataset." | ||
322 | ) | ||
323 | parser.add_argument( | ||
324 | "--valid_set_repeat", | ||
325 | type=int, | ||
326 | default=1, | ||
327 | help="Times the images in the validation dataset are repeated." | ||
328 | ) | ||
329 | parser.add_argument( | ||
330 | "--train_batch_size", | ||
331 | type=int, | ||
332 | default=1, | ||
333 | help="Batch size (per device) for the training dataloader." | ||
334 | ) | ||
335 | parser.add_argument( | ||
336 | "--sample_steps", | ||
337 | type=int, | ||
338 | default=20, | ||
339 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | ||
340 | ) | ||
341 | parser.add_argument( | ||
342 | "--prior_loss_weight", | ||
343 | type=float, | ||
344 | default=1.0, | ||
345 | help="The weight of prior preservation loss." | ||
346 | ) | ||
347 | parser.add_argument( | ||
348 | "--max_grad_norm", | ||
349 | default=1.0, | ||
350 | type=float, | ||
351 | help="Max gradient norm." | ||
352 | ) | ||
353 | parser.add_argument( | ||
354 | "--noise_timesteps", | ||
355 | type=int, | ||
356 | default=1000, | ||
357 | ) | ||
358 | parser.add_argument( | ||
359 | "--config", | ||
360 | type=str, | ||
361 | default=None, | ||
362 | help="Path to a JSON configuration file containing arguments for invoking this script." | ||
363 | ) | ||
364 | |||
365 | args = parser.parse_args() | ||
366 | if args.config is not None: | ||
367 | args = load_config(args.config) | ||
368 | args = parser.parse_args(namespace=argparse.Namespace(**args)) | ||
369 | |||
370 | if args.train_data_file is None: | ||
371 | raise ValueError("You must specify --train_data_file") | ||
372 | |||
373 | if args.pretrained_model_name_or_path is None: | ||
374 | raise ValueError("You must specify --pretrained_model_name_or_path") | ||
375 | |||
376 | if args.project is None: | ||
377 | raise ValueError("You must specify --project") | ||
378 | |||
379 | if isinstance(args.collection, str): | ||
380 | args.collection = [args.collection] | ||
381 | |||
382 | if isinstance(args.exclude_collections, str): | ||
383 | args.exclude_collections = [args.exclude_collections] | ||
384 | |||
385 | if args.output_dir is None: | ||
386 | raise ValueError("You must specify --output_dir") | ||
387 | |||
388 | return args | ||
389 | |||
390 | |||
391 | def main(): | ||
392 | args = parse_args() | ||
393 | |||
394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
395 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | ||
396 | output_dir.mkdir(parents=True, exist_ok=True) | ||
397 | |||
398 | accelerator = Accelerator( | ||
399 | log_with=LoggerType.TENSORBOARD, | ||
400 | logging_dir=f"{output_dir}", | ||
401 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
402 | mixed_precision=args.mixed_precision | ||
403 | ) | ||
404 | |||
405 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | ||
406 | |||
407 | if args.seed is None: | ||
408 | args.seed = torch.random.seed() >> 32 | ||
409 | |||
410 | set_seed(args.seed) | ||
411 | |||
412 | save_args(output_dir, args) | ||
413 | |||
414 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | ||
415 | args.pretrained_model_name_or_path) | ||
416 | |||
417 | vae.enable_slicing() | ||
418 | vae.set_use_memory_efficient_attention_xformers(True) | ||
419 | unet.enable_xformers_memory_efficient_attention() | ||
420 | |||
421 | lora_attn_procs = {} | ||
422 | for name in unet.attn_processors.keys(): | ||
423 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
424 | if name.startswith("mid_block"): | ||
425 | hidden_size = unet.config.block_out_channels[-1] | ||
426 | elif name.startswith("up_blocks"): | ||
427 | block_id = int(name[len("up_blocks.")]) | ||
428 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | ||
429 | elif name.startswith("down_blocks"): | ||
430 | block_id = int(name[len("down_blocks.")]) | ||
431 | hidden_size = unet.config.block_out_channels[block_id] | ||
432 | |||
433 | lora_attn_procs[name] = LoRACrossAttnProcessor( | ||
434 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | ||
435 | ) | ||
436 | |||
437 | unet.set_attn_processor(lora_attn_procs) | ||
438 | lora_layers = AttnProcsLayers(unet.attn_processors) | ||
439 | |||
440 | if args.embeddings_dir is not None: | ||
441 | embeddings_dir = Path(args.embeddings_dir) | ||
442 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
443 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
444 | |||
445 | embeddings.persist() | ||
446 | |||
447 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
448 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
449 | |||
450 | if args.scale_lr: | ||
451 | args.learning_rate = ( | ||
452 | args.learning_rate * args.gradient_accumulation_steps * | ||
453 | args.train_batch_size * accelerator.num_processes | ||
454 | ) | ||
455 | |||
456 | if args.find_lr: | ||
457 | args.learning_rate = 1e-6 | ||
458 | args.lr_scheduler = "exponential_growth" | ||
459 | |||
460 | if args.use_8bit_adam: | ||
461 | try: | ||
462 | import bitsandbytes as bnb | ||
463 | except ImportError: | ||
464 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | ||
465 | |||
466 | optimizer_class = bnb.optim.AdamW8bit | ||
467 | else: | ||
468 | optimizer_class = torch.optim.AdamW | ||
469 | |||
470 | weight_dtype = torch.float32 | ||
471 | if args.mixed_precision == "fp16": | ||
472 | weight_dtype = torch.float16 | ||
473 | elif args.mixed_precision == "bf16": | ||
474 | weight_dtype = torch.bfloat16 | ||
475 | |||
476 | trainer = partial( | ||
477 | train, | ||
478 | accelerator=accelerator, | ||
479 | unet=unet, | ||
480 | text_encoder=text_encoder, | ||
481 | vae=vae, | ||
482 | lora_layers=lora_layers, | ||
483 | noise_scheduler=noise_scheduler, | ||
484 | dtype=weight_dtype, | ||
485 | with_prior_preservation=args.num_class_images != 0, | ||
486 | prior_loss_weight=args.prior_loss_weight, | ||
487 | ) | ||
488 | |||
489 | checkpoint_output_dir = output_dir.joinpath("model") | ||
490 | sample_output_dir = output_dir.joinpath(f"samples") | ||
491 | |||
492 | datamodule = VlpnDataModule( | ||
493 | data_file=args.train_data_file, | ||
494 | batch_size=args.train_batch_size, | ||
495 | tokenizer=tokenizer, | ||
496 | class_subdir=args.class_image_dir, | ||
497 | num_class_images=args.num_class_images, | ||
498 | size=args.resolution, | ||
499 | num_buckets=args.num_buckets, | ||
500 | progressive_buckets=args.progressive_buckets, | ||
501 | bucket_step_size=args.bucket_step_size, | ||
502 | bucket_max_pixels=args.bucket_max_pixels, | ||
503 | dropout=args.tag_dropout, | ||
504 | shuffle=not args.no_tag_shuffle, | ||
505 | template_key=args.train_data_template, | ||
506 | valid_set_size=args.valid_set_size, | ||
507 | train_set_pad=args.train_set_pad, | ||
508 | valid_set_pad=args.valid_set_pad, | ||
509 | seed=args.seed, | ||
510 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | ||
511 | dtype=weight_dtype | ||
512 | ) | ||
513 | datamodule.setup() | ||
514 | |||
515 | optimizer = optimizer_class( | ||
516 | lora_layers.parameters(), | ||
517 | lr=args.learning_rate, | ||
518 | betas=(args.adam_beta1, args.adam_beta2), | ||
519 | weight_decay=args.adam_weight_decay, | ||
520 | eps=args.adam_epsilon, | ||
521 | amsgrad=args.adam_amsgrad, | ||
522 | ) | ||
523 | |||
524 | lr_scheduler = get_scheduler( | ||
525 | args.lr_scheduler, | ||
526 | optimizer=optimizer, | ||
527 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
528 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
529 | min_lr=args.lr_min_lr, | ||
530 | warmup_func=args.lr_warmup_func, | ||
531 | annealing_func=args.lr_annealing_func, | ||
532 | warmup_exp=args.lr_warmup_exp, | ||
533 | annealing_exp=args.lr_annealing_exp, | ||
534 | cycles=args.lr_cycles, | ||
535 | end_lr=1e2, | ||
536 | train_epochs=args.num_train_epochs, | ||
537 | warmup_epochs=args.lr_warmup_epochs, | ||
538 | ) | ||
539 | |||
540 | metrics = trainer( | ||
541 | strategy=lora_strategy, | ||
542 | project="lora", | ||
543 | train_dataloader=datamodule.train_dataloader, | ||
544 | val_dataloader=datamodule.val_dataloader, | ||
545 | seed=args.seed, | ||
546 | optimizer=optimizer, | ||
547 | lr_scheduler=lr_scheduler, | ||
548 | num_train_epochs=args.num_train_epochs, | ||
549 | sample_frequency=args.sample_frequency, | ||
550 | # -- | ||
551 | tokenizer=tokenizer, | ||
552 | sample_scheduler=sample_scheduler, | ||
553 | sample_output_dir=sample_output_dir, | ||
554 | checkpoint_output_dir=checkpoint_output_dir, | ||
555 | max_grad_norm=args.max_grad_norm, | ||
556 | sample_batch_size=args.sample_batch_size, | ||
557 | sample_num_batches=args.sample_batches, | ||
558 | sample_num_steps=args.sample_steps, | ||
559 | sample_image_size=args.sample_image_size, | ||
560 | ) | ||
561 | |||
562 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | ||
563 | |||
564 | |||
565 | if __name__ == "__main__": | ||
566 | main() | ||