diff options
-rw-r--r-- | environment.yaml | 4 | ||||
-rw-r--r-- | infer.py | 21 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 14 | ||||
-rw-r--r-- | train_dreambooth.py | 46 | ||||
-rw-r--r-- | train_lora.py | 566 | ||||
-rw-r--r-- | train_ti.py | 10 | ||||
-rw-r--r-- | training/functional.py | 31 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 35 | ||||
-rw-r--r-- | training/strategy/lora.py | 147 | ||||
-rw-r--r-- | training/strategy/ti.py | 38 |
10 files changed, 819 insertions, 93 deletions
diff --git a/environment.yaml b/environment.yaml index c992759..f5632bf 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -18,11 +18,11 @@ dependencies: | |||
18 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 18 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
19 | - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion | 19 | - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion |
20 | - accelerate==0.15.0 | 20 | - accelerate==0.15.0 |
21 | - bitsandbytes==0.36.0.post2 | 21 | - bitsandbytes==0.37.0 |
22 | - python-slugify>=6.1.2 | 22 | - python-slugify>=6.1.2 |
23 | - safetensors==0.2.7 | 23 | - safetensors==0.2.7 |
24 | - setuptools==65.6.3 | 24 | - setuptools==65.6.3 |
25 | - test-tube>=0.7.5 | 25 | - test-tube>=0.7.5 |
26 | - transformers==4.25.1 | 26 | - transformers==4.25.1 |
27 | - triton==2.0.0.dev20221202 | 27 | - triton==2.0.0.dev20221202 |
28 | - xformers==0.0.16.dev430 | 28 | - xformers==0.0.17.dev443 |
@@ -39,7 +39,8 @@ torch.backends.cudnn.benchmark = True | |||
39 | default_args = { | 39 | default_args = { |
40 | "model": "stabilityai/stable-diffusion-2-1", | 40 | "model": "stabilityai/stable-diffusion-2-1", |
41 | "precision": "fp32", | 41 | "precision": "fp32", |
42 | "ti_embeddings_dir": "embeddings", | 42 | "ti_embeddings_dir": "embeddings_ti", |
43 | "lora_embeddings_dir": "embeddings_lora", | ||
43 | "output_dir": "output/inference", | 44 | "output_dir": "output/inference", |
44 | "config": None, | 45 | "config": None, |
45 | } | 46 | } |
@@ -60,6 +61,7 @@ default_cmds = { | |||
60 | "batch_num": 1, | 61 | "batch_num": 1, |
61 | "steps": 30, | 62 | "steps": 30, |
62 | "guidance_scale": 7.0, | 63 | "guidance_scale": 7.0, |
64 | "lora_scale": 0.5, | ||
63 | "seed": None, | 65 | "seed": None, |
64 | "config": None, | 66 | "config": None, |
65 | } | 67 | } |
@@ -92,6 +94,10 @@ def create_args_parser(): | |||
92 | type=str, | 94 | type=str, |
93 | ) | 95 | ) |
94 | parser.add_argument( | 96 | parser.add_argument( |
97 | "--lora_embeddings_dir", | ||
98 | type=str, | ||
99 | ) | ||
100 | parser.add_argument( | ||
95 | "--output_dir", | 101 | "--output_dir", |
96 | type=str, | 102 | type=str, |
97 | ) | 103 | ) |
@@ -169,6 +175,10 @@ def create_cmd_parser(): | |||
169 | type=float, | 175 | type=float, |
170 | ) | 176 | ) |
171 | parser.add_argument( | 177 | parser.add_argument( |
178 | "--lora_scale", | ||
179 | type=float, | ||
180 | ) | ||
181 | parser.add_argument( | ||
172 | "--seed", | 182 | "--seed", |
173 | type=int, | 183 | type=int, |
174 | ) | 184 | ) |
@@ -315,6 +325,7 @@ def generate(output_dir: Path, pipeline, args): | |||
315 | generator=generator, | 325 | generator=generator, |
316 | image=init_image, | 326 | image=init_image, |
317 | strength=args.image_noise, | 327 | strength=args.image_noise, |
328 | cross_attention_kwargs={"scale": args.lora_scale}, | ||
318 | ).images | 329 | ).images |
319 | 330 | ||
320 | for j, image in enumerate(images): | 331 | for j, image in enumerate(images): |
@@ -334,11 +345,12 @@ class CmdParse(cmd.Cmd): | |||
334 | prompt = 'dream> ' | 345 | prompt = 'dream> ' |
335 | commands = [] | 346 | commands = [] |
336 | 347 | ||
337 | def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): | 348 | def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): |
338 | super().__init__() | 349 | super().__init__() |
339 | 350 | ||
340 | self.output_dir = output_dir | 351 | self.output_dir = output_dir |
341 | self.ti_embeddings_dir = ti_embeddings_dir | 352 | self.ti_embeddings_dir = ti_embeddings_dir |
353 | self.lora_embeddings_dir = lora_embeddings_dir | ||
342 | self.pipeline = pipeline | 354 | self.pipeline = pipeline |
343 | self.parser = parser | 355 | self.parser = parser |
344 | 356 | ||
@@ -394,9 +406,12 @@ def main(): | |||
394 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 406 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
395 | 407 | ||
396 | pipeline = create_pipeline(args.model, dtype) | 408 | pipeline = create_pipeline(args.model, dtype) |
409 | |||
397 | load_embeddings(pipeline, args.ti_embeddings_dir) | 410 | load_embeddings(pipeline, args.ti_embeddings_dir) |
411 | pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | ||
412 | |||
398 | cmd_parser = create_cmd_parser() | 413 | cmd_parser = create_cmd_parser() |
399 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) | 414 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) |
400 | cmd_prompt.cmdloop() | 415 | cmd_prompt.cmdloop() |
401 | 416 | ||
402 | 417 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3027421..dab7878 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -1,6 +1,6 @@ | |||
1 | import inspect | 1 | import inspect |
2 | import warnings | 2 | import warnings |
3 | from typing import List, Optional, Union, Callable | 3 | from typing import List, Dict, Any, Optional, Union, Callable |
4 | 4 | ||
5 | import numpy as np | 5 | import numpy as np |
6 | import torch | 6 | import torch |
@@ -337,6 +337,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
337 | return_dict: bool = True, | 337 | return_dict: bool = True, |
338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
339 | callback_steps: int = 1, | 339 | callback_steps: int = 1, |
340 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
340 | ): | 341 | ): |
341 | r""" | 342 | r""" |
342 | Function invoked when calling the pipeline for generation. | 343 | Function invoked when calling the pipeline for generation. |
@@ -379,6 +380,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
379 | return_dict (`bool`, *optional*, defaults to `True`): | 380 | return_dict (`bool`, *optional*, defaults to `True`): |
380 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | 381 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
381 | plain tuple. | 382 | plain tuple. |
383 | cross_attention_kwargs (`dict`, *optional*): | ||
384 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under | ||
385 | `self.processor` in | ||
386 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||
382 | 387 | ||
383 | Returns: | 388 | Returns: |
384 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | 389 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
@@ -450,7 +455,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
450 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 455 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
451 | 456 | ||
452 | # predict the noise residual | 457 | # predict the noise residual |
453 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 458 | noise_pred = self.unet( |
459 | latent_model_input, | ||
460 | t, | ||
461 | encoder_hidden_states=text_embeddings, | ||
462 | cross_attention_kwargs=cross_attention_kwargs, | ||
463 | ).sample | ||
454 | 464 | ||
455 | # perform guidance | 465 | # perform guidance |
456 | if do_classifier_free_guidance: | 466 | if do_classifier_free_guidance: |
diff --git a/train_dreambooth.py b/train_dreambooth.py index a70c80e..5a4c47b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -74,26 +74,6 @@ def parse_args(): | |||
74 | help="The name of the current project.", | 74 | help="The name of the current project.", |
75 | ) | 75 | ) |
76 | parser.add_argument( | 76 | parser.add_argument( |
77 | "--placeholder_tokens", | ||
78 | type=str, | ||
79 | nargs='*', | ||
80 | default=[], | ||
81 | help="A token to use as a placeholder for the concept.", | ||
82 | ) | ||
83 | parser.add_argument( | ||
84 | "--initializer_tokens", | ||
85 | type=str, | ||
86 | nargs='*', | ||
87 | default=[], | ||
88 | help="A token to use as initializer word." | ||
89 | ) | ||
90 | parser.add_argument( | ||
91 | "--num_vectors", | ||
92 | type=int, | ||
93 | nargs='*', | ||
94 | help="Number of vectors per embedding." | ||
95 | ) | ||
96 | parser.add_argument( | ||
97 | "--exclude_collections", | 77 | "--exclude_collections", |
98 | type=str, | 78 | type=str, |
99 | nargs='*', | 79 | nargs='*', |
@@ -436,30 +416,6 @@ def parse_args(): | |||
436 | if args.project is None: | 416 | if args.project is None: |
437 | raise ValueError("You must specify --project") | 417 | raise ValueError("You must specify --project") |
438 | 418 | ||
439 | if isinstance(args.placeholder_tokens, str): | ||
440 | args.placeholder_tokens = [args.placeholder_tokens] | ||
441 | |||
442 | if isinstance(args.initializer_tokens, str): | ||
443 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | ||
444 | |||
445 | if len(args.initializer_tokens) == 0: | ||
446 | raise ValueError("You must specify --initializer_tokens") | ||
447 | |||
448 | if len(args.placeholder_tokens) == 0: | ||
449 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
450 | |||
451 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
452 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | ||
453 | |||
454 | if args.num_vectors is None: | ||
455 | args.num_vectors = 1 | ||
456 | |||
457 | if isinstance(args.num_vectors, int): | ||
458 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) | ||
459 | |||
460 | if len(args.placeholder_tokens) != len(args.num_vectors): | ||
461 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | ||
462 | |||
463 | if isinstance(args.collection, str): | 419 | if isinstance(args.collection, str): |
464 | args.collection = [args.collection] | 420 | args.collection = [args.collection] |
465 | 421 | ||
@@ -503,7 +459,7 @@ def main(): | |||
503 | 459 | ||
504 | vae.enable_slicing() | 460 | vae.enable_slicing() |
505 | vae.set_use_memory_efficient_attention_xformers(True) | 461 | vae.set_use_memory_efficient_attention_xformers(True) |
506 | unet.set_use_memory_efficient_attention_xformers(True) | 462 | unet.enable_xformers_memory_efficient_attention() |
507 | 463 | ||
508 | if args.gradient_checkpointing: | 464 | if args.gradient_checkpointing: |
509 | unet.enable_gradient_checkpointing() | 465 | unet.enable_gradient_checkpointing() |
diff --git a/train_lora.py b/train_lora.py new file mode 100644 index 0000000..2cb85cc --- /dev/null +++ b/train_lora.py | |||
@@ -0,0 +1,566 @@ | |||
1 | import argparse | ||
2 | import datetime | ||
3 | import logging | ||
4 | import itertools | ||
5 | from pathlib import Path | ||
6 | from functools import partial | ||
7 | |||
8 | import torch | ||
9 | import torch.utils.checkpoint | ||
10 | |||
11 | from accelerate import Accelerator | ||
12 | from accelerate.logging import get_logger | ||
13 | from accelerate.utils import LoggerType, set_seed | ||
14 | from slugify import slugify | ||
15 | from diffusers.loaders import AttnProcsLayers | ||
16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor | ||
17 | |||
18 | from util import load_config, load_embeddings_from_dir | ||
19 | from data.csv import VlpnDataModule, keyword_filter | ||
20 | from training.functional import train, get_models | ||
21 | from training.lr import plot_metrics | ||
22 | from training.strategy.lora import lora_strategy | ||
23 | from training.optimization import get_scheduler | ||
24 | from training.util import save_args | ||
25 | |||
26 | logger = get_logger(__name__) | ||
27 | |||
28 | |||
29 | torch.backends.cuda.matmul.allow_tf32 = True | ||
30 | torch.backends.cudnn.benchmark = True | ||
31 | |||
32 | |||
33 | def parse_args(): | ||
34 | parser = argparse.ArgumentParser( | ||
35 | description="Simple example of a training script." | ||
36 | ) | ||
37 | parser.add_argument( | ||
38 | "--pretrained_model_name_or_path", | ||
39 | type=str, | ||
40 | default=None, | ||
41 | help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
42 | ) | ||
43 | parser.add_argument( | ||
44 | "--tokenizer_name", | ||
45 | type=str, | ||
46 | default=None, | ||
47 | help="Pretrained tokenizer name or path if not the same as model_name", | ||
48 | ) | ||
49 | parser.add_argument( | ||
50 | "--train_data_file", | ||
51 | type=str, | ||
52 | default=None, | ||
53 | help="A folder containing the training data." | ||
54 | ) | ||
55 | parser.add_argument( | ||
56 | "--train_data_template", | ||
57 | type=str, | ||
58 | default="template", | ||
59 | ) | ||
60 | parser.add_argument( | ||
61 | "--train_set_pad", | ||
62 | type=int, | ||
63 | default=None, | ||
64 | help="The number to fill train dataset items up to." | ||
65 | ) | ||
66 | parser.add_argument( | ||
67 | "--valid_set_pad", | ||
68 | type=int, | ||
69 | default=None, | ||
70 | help="The number to fill validation dataset items up to." | ||
71 | ) | ||
72 | parser.add_argument( | ||
73 | "--project", | ||
74 | type=str, | ||
75 | default=None, | ||
76 | help="The name of the current project.", | ||
77 | ) | ||
78 | parser.add_argument( | ||
79 | "--exclude_collections", | ||
80 | type=str, | ||
81 | nargs='*', | ||
82 | help="Exclude all items with a listed collection.", | ||
83 | ) | ||
84 | parser.add_argument( | ||
85 | "--num_buckets", | ||
86 | type=int, | ||
87 | default=4, | ||
88 | help="Number of aspect ratio buckets in either direction.", | ||
89 | ) | ||
90 | parser.add_argument( | ||
91 | "--progressive_buckets", | ||
92 | action="store_true", | ||
93 | help="Include images in smaller buckets as well.", | ||
94 | ) | ||
95 | parser.add_argument( | ||
96 | "--bucket_step_size", | ||
97 | type=int, | ||
98 | default=64, | ||
99 | help="Step size between buckets.", | ||
100 | ) | ||
101 | parser.add_argument( | ||
102 | "--bucket_max_pixels", | ||
103 | type=int, | ||
104 | default=None, | ||
105 | help="Maximum pixels per bucket.", | ||
106 | ) | ||
107 | parser.add_argument( | ||
108 | "--tag_dropout", | ||
109 | type=float, | ||
110 | default=0.1, | ||
111 | help="Tag dropout probability.", | ||
112 | ) | ||
113 | parser.add_argument( | ||
114 | "--no_tag_shuffle", | ||
115 | action="store_true", | ||
116 | help="Shuffle tags.", | ||
117 | ) | ||
118 | parser.add_argument( | ||
119 | "--num_class_images", | ||
120 | type=int, | ||
121 | default=0, | ||
122 | help="How many class images to generate." | ||
123 | ) | ||
124 | parser.add_argument( | ||
125 | "--class_image_dir", | ||
126 | type=str, | ||
127 | default="cls", | ||
128 | help="The directory where class images will be saved.", | ||
129 | ) | ||
130 | parser.add_argument( | ||
131 | "--output_dir", | ||
132 | type=str, | ||
133 | default="output/lora", | ||
134 | help="The output directory where the model predictions and checkpoints will be written.", | ||
135 | ) | ||
136 | parser.add_argument( | ||
137 | "--embeddings_dir", | ||
138 | type=str, | ||
139 | default=None, | ||
140 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
141 | ) | ||
142 | parser.add_argument( | ||
143 | "--collection", | ||
144 | type=str, | ||
145 | nargs='*', | ||
146 | help="A collection to filter the dataset.", | ||
147 | ) | ||
148 | parser.add_argument( | ||
149 | "--seed", | ||
150 | type=int, | ||
151 | default=None, | ||
152 | help="A seed for reproducible training." | ||
153 | ) | ||
154 | parser.add_argument( | ||
155 | "--resolution", | ||
156 | type=int, | ||
157 | default=768, | ||
158 | help=( | ||
159 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
160 | " resolution" | ||
161 | ), | ||
162 | ) | ||
163 | parser.add_argument( | ||
164 | "--num_train_epochs", | ||
165 | type=int, | ||
166 | default=100 | ||
167 | ) | ||
168 | parser.add_argument( | ||
169 | "--max_train_steps", | ||
170 | type=int, | ||
171 | default=None, | ||
172 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
173 | ) | ||
174 | parser.add_argument( | ||
175 | "--gradient_accumulation_steps", | ||
176 | type=int, | ||
177 | default=1, | ||
178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
179 | ) | ||
180 | parser.add_argument( | ||
181 | "--find_lr", | ||
182 | action="store_true", | ||
183 | help="Automatically find a learning rate (no training).", | ||
184 | ) | ||
185 | parser.add_argument( | ||
186 | "--learning_rate", | ||
187 | type=float, | ||
188 | default=2e-6, | ||
189 | help="Initial learning rate (after the potential warmup period) to use.", | ||
190 | ) | ||
191 | parser.add_argument( | ||
192 | "--scale_lr", | ||
193 | action="store_true", | ||
194 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | ||
195 | ) | ||
196 | parser.add_argument( | ||
197 | "--lr_scheduler", | ||
198 | type=str, | ||
199 | default="one_cycle", | ||
200 | help=( | ||
201 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | ||
202 | ' "constant", "constant_with_warmup", "one_cycle"]' | ||
203 | ), | ||
204 | ) | ||
205 | parser.add_argument( | ||
206 | "--lr_warmup_epochs", | ||
207 | type=int, | ||
208 | default=10, | ||
209 | help="Number of steps for the warmup in the lr scheduler." | ||
210 | ) | ||
211 | parser.add_argument( | ||
212 | "--lr_cycles", | ||
213 | type=int, | ||
214 | default=None, | ||
215 | help="Number of restart cycles in the lr scheduler (if supported)." | ||
216 | ) | ||
217 | parser.add_argument( | ||
218 | "--lr_warmup_func", | ||
219 | type=str, | ||
220 | default="cos", | ||
221 | help='Choose between ["linear", "cos"]' | ||
222 | ) | ||
223 | parser.add_argument( | ||
224 | "--lr_warmup_exp", | ||
225 | type=int, | ||
226 | default=1, | ||
227 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
228 | ) | ||
229 | parser.add_argument( | ||
230 | "--lr_annealing_func", | ||
231 | type=str, | ||
232 | default="cos", | ||
233 | help='Choose between ["linear", "half_cos", "cos"]' | ||
234 | ) | ||
235 | parser.add_argument( | ||
236 | "--lr_annealing_exp", | ||
237 | type=int, | ||
238 | default=3, | ||
239 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
240 | ) | ||
241 | parser.add_argument( | ||
242 | "--lr_min_lr", | ||
243 | type=float, | ||
244 | default=0.04, | ||
245 | help="Minimum learning rate in the lr scheduler." | ||
246 | ) | ||
247 | parser.add_argument( | ||
248 | "--use_8bit_adam", | ||
249 | action="store_true", | ||
250 | help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
251 | ) | ||
252 | parser.add_argument( | ||
253 | "--adam_beta1", | ||
254 | type=float, | ||
255 | default=0.9, | ||
256 | help="The beta1 parameter for the Adam optimizer." | ||
257 | ) | ||
258 | parser.add_argument( | ||
259 | "--adam_beta2", | ||
260 | type=float, | ||
261 | default=0.999, | ||
262 | help="The beta2 parameter for the Adam optimizer." | ||
263 | ) | ||
264 | parser.add_argument( | ||
265 | "--adam_weight_decay", | ||
266 | type=float, | ||
267 | default=1e-2, | ||
268 | help="Weight decay to use." | ||
269 | ) | ||
270 | parser.add_argument( | ||
271 | "--adam_epsilon", | ||
272 | type=float, | ||
273 | default=1e-08, | ||
274 | help="Epsilon value for the Adam optimizer" | ||
275 | ) | ||
276 | parser.add_argument( | ||
277 | "--adam_amsgrad", | ||
278 | type=bool, | ||
279 | default=False, | ||
280 | help="Amsgrad value for the Adam optimizer" | ||
281 | ) | ||
282 | parser.add_argument( | ||
283 | "--mixed_precision", | ||
284 | type=str, | ||
285 | default="no", | ||
286 | choices=["no", "fp16", "bf16"], | ||
287 | help=( | ||
288 | "Whether to use mixed precision. Choose" | ||
289 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." | ||
290 | "and an Nvidia Ampere GPU." | ||
291 | ), | ||
292 | ) | ||
293 | parser.add_argument( | ||
294 | "--sample_frequency", | ||
295 | type=int, | ||
296 | default=1, | ||
297 | help="How often to save a checkpoint and sample image", | ||
298 | ) | ||
299 | parser.add_argument( | ||
300 | "--sample_image_size", | ||
301 | type=int, | ||
302 | default=768, | ||
303 | help="Size of sample images", | ||
304 | ) | ||
305 | parser.add_argument( | ||
306 | "--sample_batches", | ||
307 | type=int, | ||
308 | default=1, | ||
309 | help="Number of sample batches to generate per checkpoint", | ||
310 | ) | ||
311 | parser.add_argument( | ||
312 | "--sample_batch_size", | ||
313 | type=int, | ||
314 | default=1, | ||
315 | help="Number of samples to generate per batch", | ||
316 | ) | ||
317 | parser.add_argument( | ||
318 | "--valid_set_size", | ||
319 | type=int, | ||
320 | default=None, | ||
321 | help="Number of images in the validation dataset." | ||
322 | ) | ||
323 | parser.add_argument( | ||
324 | "--valid_set_repeat", | ||
325 | type=int, | ||
326 | default=1, | ||
327 | help="Times the images in the validation dataset are repeated." | ||
328 | ) | ||
329 | parser.add_argument( | ||
330 | "--train_batch_size", | ||
331 | type=int, | ||
332 | default=1, | ||
333 | help="Batch size (per device) for the training dataloader." | ||
334 | ) | ||
335 | parser.add_argument( | ||
336 | "--sample_steps", | ||
337 | type=int, | ||
338 | default=20, | ||
339 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | ||
340 | ) | ||
341 | parser.add_argument( | ||
342 | "--prior_loss_weight", | ||
343 | type=float, | ||
344 | default=1.0, | ||
345 | help="The weight of prior preservation loss." | ||
346 | ) | ||
347 | parser.add_argument( | ||
348 | "--max_grad_norm", | ||
349 | default=1.0, | ||
350 | type=float, | ||
351 | help="Max gradient norm." | ||
352 | ) | ||
353 | parser.add_argument( | ||
354 | "--noise_timesteps", | ||
355 | type=int, | ||
356 | default=1000, | ||
357 | ) | ||
358 | parser.add_argument( | ||
359 | "--config", | ||
360 | type=str, | ||
361 | default=None, | ||
362 | help="Path to a JSON configuration file containing arguments for invoking this script." | ||
363 | ) | ||
364 | |||
365 | args = parser.parse_args() | ||
366 | if args.config is not None: | ||
367 | args = load_config(args.config) | ||
368 | args = parser.parse_args(namespace=argparse.Namespace(**args)) | ||
369 | |||
370 | if args.train_data_file is None: | ||
371 | raise ValueError("You must specify --train_data_file") | ||
372 | |||
373 | if args.pretrained_model_name_or_path is None: | ||
374 | raise ValueError("You must specify --pretrained_model_name_or_path") | ||
375 | |||
376 | if args.project is None: | ||
377 | raise ValueError("You must specify --project") | ||
378 | |||
379 | if isinstance(args.collection, str): | ||
380 | args.collection = [args.collection] | ||
381 | |||
382 | if isinstance(args.exclude_collections, str): | ||
383 | args.exclude_collections = [args.exclude_collections] | ||
384 | |||
385 | if args.output_dir is None: | ||
386 | raise ValueError("You must specify --output_dir") | ||
387 | |||
388 | return args | ||
389 | |||
390 | |||
391 | def main(): | ||
392 | args = parse_args() | ||
393 | |||
394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
395 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | ||
396 | output_dir.mkdir(parents=True, exist_ok=True) | ||
397 | |||
398 | accelerator = Accelerator( | ||
399 | log_with=LoggerType.TENSORBOARD, | ||
400 | logging_dir=f"{output_dir}", | ||
401 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
402 | mixed_precision=args.mixed_precision | ||
403 | ) | ||
404 | |||
405 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | ||
406 | |||
407 | if args.seed is None: | ||
408 | args.seed = torch.random.seed() >> 32 | ||
409 | |||
410 | set_seed(args.seed) | ||
411 | |||
412 | save_args(output_dir, args) | ||
413 | |||
414 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | ||
415 | args.pretrained_model_name_or_path) | ||
416 | |||
417 | vae.enable_slicing() | ||
418 | vae.set_use_memory_efficient_attention_xformers(True) | ||
419 | unet.enable_xformers_memory_efficient_attention() | ||
420 | |||
421 | lora_attn_procs = {} | ||
422 | for name in unet.attn_processors.keys(): | ||
423 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
424 | if name.startswith("mid_block"): | ||
425 | hidden_size = unet.config.block_out_channels[-1] | ||
426 | elif name.startswith("up_blocks"): | ||
427 | block_id = int(name[len("up_blocks.")]) | ||
428 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | ||
429 | elif name.startswith("down_blocks"): | ||
430 | block_id = int(name[len("down_blocks.")]) | ||
431 | hidden_size = unet.config.block_out_channels[block_id] | ||
432 | |||
433 | lora_attn_procs[name] = LoRACrossAttnProcessor( | ||
434 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | ||
435 | ) | ||
436 | |||
437 | unet.set_attn_processor(lora_attn_procs) | ||
438 | lora_layers = AttnProcsLayers(unet.attn_processors) | ||
439 | |||
440 | if args.embeddings_dir is not None: | ||
441 | embeddings_dir = Path(args.embeddings_dir) | ||
442 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
443 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
444 | |||
445 | embeddings.persist() | ||
446 | |||
447 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
448 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
449 | |||
450 | if args.scale_lr: | ||
451 | args.learning_rate = ( | ||
452 | args.learning_rate * args.gradient_accumulation_steps * | ||
453 | args.train_batch_size * accelerator.num_processes | ||
454 | ) | ||
455 | |||
456 | if args.find_lr: | ||
457 | args.learning_rate = 1e-6 | ||
458 | args.lr_scheduler = "exponential_growth" | ||
459 | |||
460 | if args.use_8bit_adam: | ||
461 | try: | ||
462 | import bitsandbytes as bnb | ||
463 | except ImportError: | ||
464 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | ||
465 | |||
466 | optimizer_class = bnb.optim.AdamW8bit | ||
467 | else: | ||
468 | optimizer_class = torch.optim.AdamW | ||
469 | |||
470 | weight_dtype = torch.float32 | ||
471 | if args.mixed_precision == "fp16": | ||
472 | weight_dtype = torch.float16 | ||
473 | elif args.mixed_precision == "bf16": | ||
474 | weight_dtype = torch.bfloat16 | ||
475 | |||
476 | trainer = partial( | ||
477 | train, | ||
478 | accelerator=accelerator, | ||
479 | unet=unet, | ||
480 | text_encoder=text_encoder, | ||
481 | vae=vae, | ||
482 | lora_layers=lora_layers, | ||
483 | noise_scheduler=noise_scheduler, | ||
484 | dtype=weight_dtype, | ||
485 | with_prior_preservation=args.num_class_images != 0, | ||
486 | prior_loss_weight=args.prior_loss_weight, | ||
487 | ) | ||
488 | |||
489 | checkpoint_output_dir = output_dir.joinpath("model") | ||
490 | sample_output_dir = output_dir.joinpath(f"samples") | ||
491 | |||
492 | datamodule = VlpnDataModule( | ||
493 | data_file=args.train_data_file, | ||
494 | batch_size=args.train_batch_size, | ||
495 | tokenizer=tokenizer, | ||
496 | class_subdir=args.class_image_dir, | ||
497 | num_class_images=args.num_class_images, | ||
498 | size=args.resolution, | ||
499 | num_buckets=args.num_buckets, | ||
500 | progressive_buckets=args.progressive_buckets, | ||
501 | bucket_step_size=args.bucket_step_size, | ||
502 | bucket_max_pixels=args.bucket_max_pixels, | ||
503 | dropout=args.tag_dropout, | ||
504 | shuffle=not args.no_tag_shuffle, | ||
505 | template_key=args.train_data_template, | ||
506 | valid_set_size=args.valid_set_size, | ||
507 | train_set_pad=args.train_set_pad, | ||
508 | valid_set_pad=args.valid_set_pad, | ||
509 | seed=args.seed, | ||
510 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | ||
511 | dtype=weight_dtype | ||
512 | ) | ||
513 | datamodule.setup() | ||
514 | |||
515 | optimizer = optimizer_class( | ||
516 | lora_layers.parameters(), | ||
517 | lr=args.learning_rate, | ||
518 | betas=(args.adam_beta1, args.adam_beta2), | ||
519 | weight_decay=args.adam_weight_decay, | ||
520 | eps=args.adam_epsilon, | ||
521 | amsgrad=args.adam_amsgrad, | ||
522 | ) | ||
523 | |||
524 | lr_scheduler = get_scheduler( | ||
525 | args.lr_scheduler, | ||
526 | optimizer=optimizer, | ||
527 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
528 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
529 | min_lr=args.lr_min_lr, | ||
530 | warmup_func=args.lr_warmup_func, | ||
531 | annealing_func=args.lr_annealing_func, | ||
532 | warmup_exp=args.lr_warmup_exp, | ||
533 | annealing_exp=args.lr_annealing_exp, | ||
534 | cycles=args.lr_cycles, | ||
535 | end_lr=1e2, | ||
536 | train_epochs=args.num_train_epochs, | ||
537 | warmup_epochs=args.lr_warmup_epochs, | ||
538 | ) | ||
539 | |||
540 | metrics = trainer( | ||
541 | strategy=lora_strategy, | ||
542 | project="lora", | ||
543 | train_dataloader=datamodule.train_dataloader, | ||
544 | val_dataloader=datamodule.val_dataloader, | ||
545 | seed=args.seed, | ||
546 | optimizer=optimizer, | ||
547 | lr_scheduler=lr_scheduler, | ||
548 | num_train_epochs=args.num_train_epochs, | ||
549 | sample_frequency=args.sample_frequency, | ||
550 | # -- | ||
551 | tokenizer=tokenizer, | ||
552 | sample_scheduler=sample_scheduler, | ||
553 | sample_output_dir=sample_output_dir, | ||
554 | checkpoint_output_dir=checkpoint_output_dir, | ||
555 | max_grad_norm=args.max_grad_norm, | ||
556 | sample_batch_size=args.sample_batch_size, | ||
557 | sample_num_batches=args.sample_batches, | ||
558 | sample_num_steps=args.sample_steps, | ||
559 | sample_image_size=args.sample_image_size, | ||
560 | ) | ||
561 | |||
562 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | ||
563 | |||
564 | |||
565 | if __name__ == "__main__": | ||
566 | main() | ||
diff --git a/train_ti.py b/train_ti.py index c118aab..56f9e97 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -166,7 +166,7 @@ def parse_args(): | |||
166 | parser.add_argument( | 166 | parser.add_argument( |
167 | "--tag_dropout", | 167 | "--tag_dropout", |
168 | type=float, | 168 | type=float, |
169 | default=0, | 169 | default=0.1, |
170 | help="Tag dropout probability.", | 170 | help="Tag dropout probability.", |
171 | ) | 171 | ) |
172 | parser.add_argument( | 172 | parser.add_argument( |
@@ -414,7 +414,7 @@ def parse_args(): | |||
414 | ) | 414 | ) |
415 | parser.add_argument( | 415 | parser.add_argument( |
416 | "--emb_decay", | 416 | "--emb_decay", |
417 | default=1e0, | 417 | default=1e-2, |
418 | type=float, | 418 | type=float, |
419 | help="Embedding decay factor." | 419 | help="Embedding decay factor." |
420 | ) | 420 | ) |
@@ -530,7 +530,7 @@ def main(): | |||
530 | 530 | ||
531 | vae.enable_slicing() | 531 | vae.enable_slicing() |
532 | vae.set_use_memory_efficient_attention_xformers(True) | 532 | vae.set_use_memory_efficient_attention_xformers(True) |
533 | unet.set_use_memory_efficient_attention_xformers(True) | 533 | unet.enable_xformers_memory_efficient_attention() |
534 | 534 | ||
535 | if args.gradient_checkpointing: | 535 | if args.gradient_checkpointing: |
536 | unet.enable_gradient_checkpointing() | 536 | unet.enable_gradient_checkpointing() |
@@ -612,8 +612,10 @@ def main(): | |||
612 | 612 | ||
613 | if len(placeholder_tokens) == 1: | 613 | if len(placeholder_tokens) == 1: |
614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") |
615 | metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") | ||
615 | else: | 616 | else: |
616 | sample_output_dir = output_dir.joinpath("samples") | 617 | sample_output_dir = output_dir.joinpath("samples") |
618 | metrics_output_file = output_dir.joinpath(f"lr.png") | ||
617 | 619 | ||
618 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 620 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
619 | tokenizer=tokenizer, | 621 | tokenizer=tokenizer, |
@@ -687,7 +689,7 @@ def main(): | |||
687 | placeholder_token_ids=placeholder_token_ids, | 689 | placeholder_token_ids=placeholder_token_ids, |
688 | ) | 690 | ) |
689 | 691 | ||
690 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 692 | plot_metrics(metrics, metrics_output_file) |
691 | 693 | ||
692 | if args.simultaneous: | 694 | if args.simultaneous: |
693 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 695 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
diff --git a/training/functional.py b/training/functional.py index c373ac9..8f47734 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -34,7 +34,7 @@ def const(result=None): | |||
34 | @dataclass | 34 | @dataclass |
35 | class TrainingCallbacks(): | 35 | class TrainingCallbacks(): |
36 | on_prepare: Callable[[], None] = const() | 36 | on_prepare: Callable[[], None] = const() |
37 | on_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) |
38 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
40 | on_before_optimize: Callable[[float, int], None] = const() | 40 | on_before_optimize: Callable[[float, int], None] = const() |
@@ -51,7 +51,11 @@ class TrainingStrategyPrepareCallable(Protocol): | |||
51 | accelerator: Accelerator, | 51 | accelerator: Accelerator, |
52 | text_encoder: CLIPTextModel, | 52 | text_encoder: CLIPTextModel, |
53 | unet: UNet2DConditionModel, | 53 | unet: UNet2DConditionModel, |
54 | *args | 54 | optimizer: torch.optim.Optimizer, |
55 | train_dataloader: DataLoader, | ||
56 | val_dataloader: Optional[DataLoader], | ||
57 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
58 | **kwargs | ||
55 | ) -> Tuple: ... | 59 | ) -> Tuple: ... |
56 | 60 | ||
57 | 61 | ||
@@ -92,7 +96,6 @@ def save_samples( | |||
92 | sample_scheduler: DPMSolverMultistepScheduler, | 96 | sample_scheduler: DPMSolverMultistepScheduler, |
93 | train_dataloader: DataLoader, | 97 | train_dataloader: DataLoader, |
94 | val_dataloader: Optional[DataLoader], | 98 | val_dataloader: Optional[DataLoader], |
95 | dtype: torch.dtype, | ||
96 | output_dir: Path, | 99 | output_dir: Path, |
97 | seed: int, | 100 | seed: int, |
98 | step: int, | 101 | step: int, |
@@ -107,15 +110,6 @@ def save_samples( | |||
107 | grid_cols = min(batch_size, 4) | 110 | grid_cols = min(batch_size, 4) |
108 | grid_rows = (num_batches * batch_size) // grid_cols | 111 | grid_rows = (num_batches * batch_size) // grid_cols |
109 | 112 | ||
110 | unet = accelerator.unwrap_model(unet) | ||
111 | text_encoder = accelerator.unwrap_model(text_encoder) | ||
112 | |||
113 | orig_unet_dtype = unet.dtype | ||
114 | orig_text_encoder_dtype = text_encoder.dtype | ||
115 | |||
116 | unet.to(dtype=dtype) | ||
117 | text_encoder.to(dtype=dtype) | ||
118 | |||
119 | pipeline = VlpnStableDiffusion( | 113 | pipeline = VlpnStableDiffusion( |
120 | text_encoder=text_encoder, | 114 | text_encoder=text_encoder, |
121 | vae=vae, | 115 | vae=vae, |
@@ -172,11 +166,6 @@ def save_samples( | |||
172 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 166 | image_grid = make_grid(all_samples, grid_rows, grid_cols) |
173 | image_grid.save(file_path, quality=85) | 167 | image_grid.save(file_path, quality=85) |
174 | 168 | ||
175 | unet.to(dtype=orig_unet_dtype) | ||
176 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
177 | |||
178 | del unet | ||
179 | del text_encoder | ||
180 | del generator | 169 | del generator |
181 | del pipeline | 170 | del pipeline |
182 | 171 | ||
@@ -393,7 +382,7 @@ def train_loop( | |||
393 | ) | 382 | ) |
394 | global_progress_bar.set_description("Total progress") | 383 | global_progress_bar.set_description("Total progress") |
395 | 384 | ||
396 | model = callbacks.on_model() | 385 | model = callbacks.on_accum_model() |
397 | on_log = callbacks.on_log | 386 | on_log = callbacks.on_log |
398 | on_train = callbacks.on_train | 387 | on_train = callbacks.on_train |
399 | on_before_optimize = callbacks.on_before_optimize | 388 | on_before_optimize = callbacks.on_before_optimize |
@@ -559,8 +548,10 @@ def train( | |||
559 | prior_loss_weight: float = 1.0, | 548 | prior_loss_weight: float = 1.0, |
560 | **kwargs, | 549 | **kwargs, |
561 | ): | 550 | ): |
562 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 551 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
563 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 552 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) |
553 | |||
554 | kwargs.update(extra) | ||
564 | 555 | ||
565 | vae.to(accelerator.device, dtype=dtype) | 556 | vae.to(accelerator.device, dtype=dtype) |
566 | 557 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e88bf90..b4c77f3 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -61,14 +61,11 @@ def dreambooth_strategy_callbacks( | |||
61 | save_samples_ = partial( | 61 | save_samples_ = partial( |
62 | save_samples, | 62 | save_samples, |
63 | accelerator=accelerator, | 63 | accelerator=accelerator, |
64 | unet=unet, | ||
65 | text_encoder=text_encoder, | ||
66 | tokenizer=tokenizer, | 64 | tokenizer=tokenizer, |
67 | vae=vae, | 65 | vae=vae, |
68 | sample_scheduler=sample_scheduler, | 66 | sample_scheduler=sample_scheduler, |
69 | train_dataloader=train_dataloader, | 67 | train_dataloader=train_dataloader, |
70 | val_dataloader=val_dataloader, | 68 | val_dataloader=val_dataloader, |
71 | dtype=weight_dtype, | ||
72 | output_dir=sample_output_dir, | 69 | output_dir=sample_output_dir, |
73 | seed=seed, | 70 | seed=seed, |
74 | batch_size=sample_batch_size, | 71 | batch_size=sample_batch_size, |
@@ -94,7 +91,7 @@ def dreambooth_strategy_callbacks( | |||
94 | else: | 91 | else: |
95 | return nullcontext() | 92 | return nullcontext() |
96 | 93 | ||
97 | def on_model(): | 94 | def on_accum_model(): |
98 | return unet | 95 | return unet |
99 | 96 | ||
100 | def on_prepare(): | 97 | def on_prepare(): |
@@ -172,11 +169,29 @@ def dreambooth_strategy_callbacks( | |||
172 | @torch.no_grad() | 169 | @torch.no_grad() |
173 | def on_sample(step): | 170 | def on_sample(step): |
174 | with ema_context(): | 171 | with ema_context(): |
175 | save_samples_(step=step) | 172 | unet_ = accelerator.unwrap_model(unet) |
173 | text_encoder_ = accelerator.unwrap_model(text_encoder) | ||
174 | |||
175 | orig_unet_dtype = unet_.dtype | ||
176 | orig_text_encoder_dtype = text_encoder_.dtype | ||
177 | |||
178 | unet_.to(dtype=weight_dtype) | ||
179 | text_encoder_.to(dtype=weight_dtype) | ||
180 | |||
181 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | ||
182 | |||
183 | unet_.to(dtype=orig_unet_dtype) | ||
184 | text_encoder_.to(dtype=orig_text_encoder_dtype) | ||
185 | |||
186 | del unet_ | ||
187 | del text_encoder_ | ||
188 | |||
189 | if torch.cuda.is_available(): | ||
190 | torch.cuda.empty_cache() | ||
176 | 191 | ||
177 | return TrainingCallbacks( | 192 | return TrainingCallbacks( |
178 | on_prepare=on_prepare, | 193 | on_prepare=on_prepare, |
179 | on_model=on_model, | 194 | on_accum_model=on_accum_model, |
180 | on_train=on_train, | 195 | on_train=on_train, |
181 | on_eval=on_eval, | 196 | on_eval=on_eval, |
182 | on_before_optimize=on_before_optimize, | 197 | on_before_optimize=on_before_optimize, |
@@ -191,9 +206,13 @@ def dreambooth_prepare( | |||
191 | accelerator: Accelerator, | 206 | accelerator: Accelerator, |
192 | text_encoder: CLIPTextModel, | 207 | text_encoder: CLIPTextModel, |
193 | unet: UNet2DConditionModel, | 208 | unet: UNet2DConditionModel, |
194 | *args | 209 | optimizer: torch.optim.Optimizer, |
210 | train_dataloader: DataLoader, | ||
211 | val_dataloader: Optional[DataLoader], | ||
212 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
213 | **kwargs | ||
195 | ): | 214 | ): |
196 | return accelerator.prepare(text_encoder, unet, *args) | 215 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) |
197 | 216 | ||
198 | 217 | ||
199 | dreambooth_strategy = TrainingStrategy( | 218 | dreambooth_strategy = TrainingStrategy( |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py new file mode 100644 index 0000000..88d1824 --- /dev/null +++ b/training/strategy/lora.py | |||
@@ -0,0 +1,147 @@ | |||
1 | from contextlib import nullcontext | ||
2 | from typing import Optional | ||
3 | from functools import partial | ||
4 | from contextlib import contextmanager, nullcontext | ||
5 | from pathlib import Path | ||
6 | |||
7 | import torch | ||
8 | import torch.nn as nn | ||
9 | from torch.utils.data import DataLoader | ||
10 | |||
11 | from accelerate import Accelerator | ||
12 | from transformers import CLIPTextModel | ||
13 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
14 | from diffusers.loaders import AttnProcsLayers | ||
15 | |||
16 | from slugify import slugify | ||
17 | |||
18 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
19 | from training.util import EMAModel | ||
20 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | ||
21 | |||
22 | |||
23 | def lora_strategy_callbacks( | ||
24 | accelerator: Accelerator, | ||
25 | unet: UNet2DConditionModel, | ||
26 | text_encoder: CLIPTextModel, | ||
27 | tokenizer: MultiCLIPTokenizer, | ||
28 | vae: AutoencoderKL, | ||
29 | sample_scheduler: DPMSolverMultistepScheduler, | ||
30 | train_dataloader: DataLoader, | ||
31 | val_dataloader: Optional[DataLoader], | ||
32 | sample_output_dir: Path, | ||
33 | checkpoint_output_dir: Path, | ||
34 | seed: int, | ||
35 | lora_layers: AttnProcsLayers, | ||
36 | max_grad_norm: float = 1.0, | ||
37 | sample_batch_size: int = 1, | ||
38 | sample_num_batches: int = 1, | ||
39 | sample_num_steps: int = 20, | ||
40 | sample_guidance_scale: float = 7.5, | ||
41 | sample_image_size: Optional[int] = None, | ||
42 | ): | ||
43 | sample_output_dir.mkdir(parents=True, exist_ok=True) | ||
44 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | ||
45 | |||
46 | weight_dtype = torch.float32 | ||
47 | if accelerator.state.mixed_precision == "fp16": | ||
48 | weight_dtype = torch.float16 | ||
49 | elif accelerator.state.mixed_precision == "bf16": | ||
50 | weight_dtype = torch.bfloat16 | ||
51 | |||
52 | save_samples_ = partial( | ||
53 | save_samples, | ||
54 | accelerator=accelerator, | ||
55 | unet=unet, | ||
56 | text_encoder=text_encoder, | ||
57 | tokenizer=tokenizer, | ||
58 | vae=vae, | ||
59 | sample_scheduler=sample_scheduler, | ||
60 | train_dataloader=train_dataloader, | ||
61 | val_dataloader=val_dataloader, | ||
62 | output_dir=sample_output_dir, | ||
63 | seed=seed, | ||
64 | batch_size=sample_batch_size, | ||
65 | num_batches=sample_num_batches, | ||
66 | num_steps=sample_num_steps, | ||
67 | guidance_scale=sample_guidance_scale, | ||
68 | image_size=sample_image_size, | ||
69 | ) | ||
70 | |||
71 | def on_prepare(): | ||
72 | lora_layers.requires_grad_(True) | ||
73 | |||
74 | def on_accum_model(): | ||
75 | return unet | ||
76 | |||
77 | @contextmanager | ||
78 | def on_train(epoch: int): | ||
79 | tokenizer.train() | ||
80 | yield | ||
81 | |||
82 | @contextmanager | ||
83 | def on_eval(): | ||
84 | tokenizer.eval() | ||
85 | yield | ||
86 | |||
87 | def on_before_optimize(lr: float, epoch: int): | ||
88 | if accelerator.sync_gradients: | ||
89 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) | ||
90 | |||
91 | @torch.no_grad() | ||
92 | def on_checkpoint(step, postfix): | ||
93 | print(f"Saving checkpoint for step {step}...") | ||
94 | orig_unet_dtype = unet.dtype | ||
95 | unet.to(dtype=torch.float32) | ||
96 | unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) | ||
97 | unet.to(dtype=orig_unet_dtype) | ||
98 | |||
99 | @torch.no_grad() | ||
100 | def on_sample(step): | ||
101 | orig_unet_dtype = unet.dtype | ||
102 | unet.to(dtype=weight_dtype) | ||
103 | save_samples_(step=step) | ||
104 | unet.to(dtype=orig_unet_dtype) | ||
105 | |||
106 | if torch.cuda.is_available(): | ||
107 | torch.cuda.empty_cache() | ||
108 | |||
109 | return TrainingCallbacks( | ||
110 | on_prepare=on_prepare, | ||
111 | on_accum_model=on_accum_model, | ||
112 | on_train=on_train, | ||
113 | on_eval=on_eval, | ||
114 | on_before_optimize=on_before_optimize, | ||
115 | on_checkpoint=on_checkpoint, | ||
116 | on_sample=on_sample, | ||
117 | ) | ||
118 | |||
119 | |||
120 | def lora_prepare( | ||
121 | accelerator: Accelerator, | ||
122 | text_encoder: CLIPTextModel, | ||
123 | unet: UNet2DConditionModel, | ||
124 | optimizer: torch.optim.Optimizer, | ||
125 | train_dataloader: DataLoader, | ||
126 | val_dataloader: Optional[DataLoader], | ||
127 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
128 | lora_layers: AttnProcsLayers, | ||
129 | **kwargs | ||
130 | ): | ||
131 | weight_dtype = torch.float32 | ||
132 | if accelerator.state.mixed_precision == "fp16": | ||
133 | weight_dtype = torch.float16 | ||
134 | elif accelerator.state.mixed_precision == "bf16": | ||
135 | weight_dtype = torch.bfloat16 | ||
136 | |||
137 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
138 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
139 | unet.to(accelerator.device, dtype=weight_dtype) | ||
140 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
141 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} | ||
142 | |||
143 | |||
144 | lora_strategy = TrainingStrategy( | ||
145 | callbacks=lora_strategy_callbacks, | ||
146 | prepare=lora_prepare, | ||
147 | ) | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 14bdafd..d306f18 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -59,14 +59,11 @@ def textual_inversion_strategy_callbacks( | |||
59 | save_samples_ = partial( | 59 | save_samples_ = partial( |
60 | save_samples, | 60 | save_samples, |
61 | accelerator=accelerator, | 61 | accelerator=accelerator, |
62 | unet=unet, | ||
63 | text_encoder=text_encoder, | ||
64 | tokenizer=tokenizer, | 62 | tokenizer=tokenizer, |
65 | vae=vae, | 63 | vae=vae, |
66 | sample_scheduler=sample_scheduler, | 64 | sample_scheduler=sample_scheduler, |
67 | train_dataloader=train_dataloader, | 65 | train_dataloader=train_dataloader, |
68 | val_dataloader=val_dataloader, | 66 | val_dataloader=val_dataloader, |
69 | dtype=weight_dtype, | ||
70 | output_dir=sample_output_dir, | 67 | output_dir=sample_output_dir, |
71 | seed=seed, | 68 | seed=seed, |
72 | batch_size=sample_batch_size, | 69 | batch_size=sample_batch_size, |
@@ -94,7 +91,7 @@ def textual_inversion_strategy_callbacks( | |||
94 | else: | 91 | else: |
95 | return nullcontext() | 92 | return nullcontext() |
96 | 93 | ||
97 | def on_model(): | 94 | def on_accum_model(): |
98 | return text_encoder.text_model.embeddings.temp_token_embedding | 95 | return text_encoder.text_model.embeddings.temp_token_embedding |
99 | 96 | ||
100 | def on_prepare(): | 97 | def on_prepare(): |
@@ -149,11 +146,29 @@ def textual_inversion_strategy_callbacks( | |||
149 | @torch.no_grad() | 146 | @torch.no_grad() |
150 | def on_sample(step): | 147 | def on_sample(step): |
151 | with ema_context(): | 148 | with ema_context(): |
152 | save_samples_(step=step) | 149 | unet_ = accelerator.unwrap_model(unet) |
150 | text_encoder_ = accelerator.unwrap_model(text_encoder) | ||
151 | |||
152 | orig_unet_dtype = unet_.dtype | ||
153 | orig_text_encoder_dtype = text_encoder_.dtype | ||
154 | |||
155 | unet_.to(dtype=weight_dtype) | ||
156 | text_encoder_.to(dtype=weight_dtype) | ||
157 | |||
158 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | ||
159 | |||
160 | unet_.to(dtype=orig_unet_dtype) | ||
161 | text_encoder_.to(dtype=orig_text_encoder_dtype) | ||
162 | |||
163 | del unet_ | ||
164 | del text_encoder_ | ||
165 | |||
166 | if torch.cuda.is_available(): | ||
167 | torch.cuda.empty_cache() | ||
153 | 168 | ||
154 | return TrainingCallbacks( | 169 | return TrainingCallbacks( |
155 | on_prepare=on_prepare, | 170 | on_prepare=on_prepare, |
156 | on_model=on_model, | 171 | on_accum_model=on_accum_model, |
157 | on_train=on_train, | 172 | on_train=on_train, |
158 | on_eval=on_eval, | 173 | on_eval=on_eval, |
159 | on_before_optimize=on_before_optimize, | 174 | on_before_optimize=on_before_optimize, |
@@ -168,7 +183,11 @@ def textual_inversion_prepare( | |||
168 | accelerator: Accelerator, | 183 | accelerator: Accelerator, |
169 | text_encoder: CLIPTextModel, | 184 | text_encoder: CLIPTextModel, |
170 | unet: UNet2DConditionModel, | 185 | unet: UNet2DConditionModel, |
171 | *args | 186 | optimizer: torch.optim.Optimizer, |
187 | train_dataloader: DataLoader, | ||
188 | val_dataloader: Optional[DataLoader], | ||
189 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
190 | **kwargs | ||
172 | ): | 191 | ): |
173 | weight_dtype = torch.float32 | 192 | weight_dtype = torch.float32 |
174 | if accelerator.state.mixed_precision == "fp16": | 193 | if accelerator.state.mixed_precision == "fp16": |
@@ -176,9 +195,10 @@ def textual_inversion_prepare( | |||
176 | elif accelerator.state.mixed_precision == "bf16": | 195 | elif accelerator.state.mixed_precision == "bf16": |
177 | weight_dtype = torch.bfloat16 | 196 | weight_dtype = torch.bfloat16 |
178 | 197 | ||
179 | prepped = accelerator.prepare(text_encoder, *args) | 198 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
199 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
180 | unet.to(accelerator.device, dtype=weight_dtype) | 200 | unet.to(accelerator.device, dtype=weight_dtype) |
181 | return (prepped[0], unet) + prepped[1:] | 201 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} |
182 | 202 | ||
183 | 203 | ||
184 | textual_inversion_strategy = TrainingStrategy( | 204 | textual_inversion_strategy = TrainingStrategy( |