summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-07 20:44:43 +0100
committerVolpeon <git@volpeon.ink>2023-02-07 20:44:43 +0100
commit7ccd4614a56cfd6ecacba85605f338593f1059f0 (patch)
treefa9882b256c752705bc42229bac4e00ed7088643
parentRestored LR finder (diff)
downloadtextual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.gz
textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.bz2
textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.zip
Add Lora
-rw-r--r--environment.yaml4
-rw-r--r--infer.py21
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py14
-rw-r--r--train_dreambooth.py46
-rw-r--r--train_lora.py566
-rw-r--r--train_ti.py10
-rw-r--r--training/functional.py31
-rw-r--r--training/strategy/dreambooth.py35
-rw-r--r--training/strategy/lora.py147
-rw-r--r--training/strategy/ti.py38
10 files changed, 819 insertions, 93 deletions
diff --git a/environment.yaml b/environment.yaml
index c992759..f5632bf 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -18,11 +18,11 @@ dependencies:
18 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 18 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
19 - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion 19 - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion
20 - accelerate==0.15.0 20 - accelerate==0.15.0
21 - bitsandbytes==0.36.0.post2 21 - bitsandbytes==0.37.0
22 - python-slugify>=6.1.2 22 - python-slugify>=6.1.2
23 - safetensors==0.2.7 23 - safetensors==0.2.7
24 - setuptools==65.6.3 24 - setuptools==65.6.3
25 - test-tube>=0.7.5 25 - test-tube>=0.7.5
26 - transformers==4.25.1 26 - transformers==4.25.1
27 - triton==2.0.0.dev20221202 27 - triton==2.0.0.dev20221202
28 - xformers==0.0.16.dev430 28 - xformers==0.0.17.dev443
diff --git a/infer.py b/infer.py
index 2b07b21..42b4e2d 100644
--- a/infer.py
+++ b/infer.py
@@ -39,7 +39,8 @@ torch.backends.cudnn.benchmark = True
39default_args = { 39default_args = {
40 "model": "stabilityai/stable-diffusion-2-1", 40 "model": "stabilityai/stable-diffusion-2-1",
41 "precision": "fp32", 41 "precision": "fp32",
42 "ti_embeddings_dir": "embeddings", 42 "ti_embeddings_dir": "embeddings_ti",
43 "lora_embeddings_dir": "embeddings_lora",
43 "output_dir": "output/inference", 44 "output_dir": "output/inference",
44 "config": None, 45 "config": None,
45} 46}
@@ -60,6 +61,7 @@ default_cmds = {
60 "batch_num": 1, 61 "batch_num": 1,
61 "steps": 30, 62 "steps": 30,
62 "guidance_scale": 7.0, 63 "guidance_scale": 7.0,
64 "lora_scale": 0.5,
63 "seed": None, 65 "seed": None,
64 "config": None, 66 "config": None,
65} 67}
@@ -92,6 +94,10 @@ def create_args_parser():
92 type=str, 94 type=str,
93 ) 95 )
94 parser.add_argument( 96 parser.add_argument(
97 "--lora_embeddings_dir",
98 type=str,
99 )
100 parser.add_argument(
95 "--output_dir", 101 "--output_dir",
96 type=str, 102 type=str,
97 ) 103 )
@@ -169,6 +175,10 @@ def create_cmd_parser():
169 type=float, 175 type=float,
170 ) 176 )
171 parser.add_argument( 177 parser.add_argument(
178 "--lora_scale",
179 type=float,
180 )
181 parser.add_argument(
172 "--seed", 182 "--seed",
173 type=int, 183 type=int,
174 ) 184 )
@@ -315,6 +325,7 @@ def generate(output_dir: Path, pipeline, args):
315 generator=generator, 325 generator=generator,
316 image=init_image, 326 image=init_image,
317 strength=args.image_noise, 327 strength=args.image_noise,
328 cross_attention_kwargs={"scale": args.lora_scale},
318 ).images 329 ).images
319 330
320 for j, image in enumerate(images): 331 for j, image in enumerate(images):
@@ -334,11 +345,12 @@ class CmdParse(cmd.Cmd):
334 prompt = 'dream> ' 345 prompt = 'dream> '
335 commands = [] 346 commands = []
336 347
337 def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): 348 def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser):
338 super().__init__() 349 super().__init__()
339 350
340 self.output_dir = output_dir 351 self.output_dir = output_dir
341 self.ti_embeddings_dir = ti_embeddings_dir 352 self.ti_embeddings_dir = ti_embeddings_dir
353 self.lora_embeddings_dir = lora_embeddings_dir
342 self.pipeline = pipeline 354 self.pipeline = pipeline
343 self.parser = parser 355 self.parser = parser
344 356
@@ -394,9 +406,12 @@ def main():
394 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 406 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
395 407
396 pipeline = create_pipeline(args.model, dtype) 408 pipeline = create_pipeline(args.model, dtype)
409
397 load_embeddings(pipeline, args.ti_embeddings_dir) 410 load_embeddings(pipeline, args.ti_embeddings_dir)
411 pipeline.unet.load_attn_procs(args.lora_embeddings_dir)
412
398 cmd_parser = create_cmd_parser() 413 cmd_parser = create_cmd_parser()
399 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) 414 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser)
400 cmd_prompt.cmdloop() 415 cmd_prompt.cmdloop()
401 416
402 417
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 3027421..dab7878 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -1,6 +1,6 @@
1import inspect 1import inspect
2import warnings 2import warnings
3from typing import List, Optional, Union, Callable 3from typing import List, Dict, Any, Optional, Union, Callable
4 4
5import numpy as np 5import numpy as np
6import torch 6import torch
@@ -337,6 +337,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
337 return_dict: bool = True, 337 return_dict: bool = True,
338 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 338 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
339 callback_steps: int = 1, 339 callback_steps: int = 1,
340 cross_attention_kwargs: Optional[Dict[str, Any]] = None,
340 ): 341 ):
341 r""" 342 r"""
342 Function invoked when calling the pipeline for generation. 343 Function invoked when calling the pipeline for generation.
@@ -379,6 +380,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
379 return_dict (`bool`, *optional*, defaults to `True`): 380 return_dict (`bool`, *optional*, defaults to `True`):
380 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 381 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
381 plain tuple. 382 plain tuple.
383 cross_attention_kwargs (`dict`, *optional*):
384 A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
385 `self.processor` in
386 [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
382 387
383 Returns: 388 Returns:
384 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 389 [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -450,7 +455,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
450 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 455 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
451 456
452 # predict the noise residual 457 # predict the noise residual
453 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 458 noise_pred = self.unet(
459 latent_model_input,
460 t,
461 encoder_hidden_states=text_embeddings,
462 cross_attention_kwargs=cross_attention_kwargs,
463 ).sample
454 464
455 # perform guidance 465 # perform guidance
456 if do_classifier_free_guidance: 466 if do_classifier_free_guidance:
diff --git a/train_dreambooth.py b/train_dreambooth.py
index a70c80e..5a4c47b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -74,26 +74,6 @@ def parse_args():
74 help="The name of the current project.", 74 help="The name of the current project.",
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
77 "--placeholder_tokens",
78 type=str,
79 nargs='*',
80 default=[],
81 help="A token to use as a placeholder for the concept.",
82 )
83 parser.add_argument(
84 "--initializer_tokens",
85 type=str,
86 nargs='*',
87 default=[],
88 help="A token to use as initializer word."
89 )
90 parser.add_argument(
91 "--num_vectors",
92 type=int,
93 nargs='*',
94 help="Number of vectors per embedding."
95 )
96 parser.add_argument(
97 "--exclude_collections", 77 "--exclude_collections",
98 type=str, 78 type=str,
99 nargs='*', 79 nargs='*',
@@ -436,30 +416,6 @@ def parse_args():
436 if args.project is None: 416 if args.project is None:
437 raise ValueError("You must specify --project") 417 raise ValueError("You must specify --project")
438 418
439 if isinstance(args.placeholder_tokens, str):
440 args.placeholder_tokens = [args.placeholder_tokens]
441
442 if isinstance(args.initializer_tokens, str):
443 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
444
445 if len(args.initializer_tokens) == 0:
446 raise ValueError("You must specify --initializer_tokens")
447
448 if len(args.placeholder_tokens) == 0:
449 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))]
450
451 if len(args.placeholder_tokens) != len(args.initializer_tokens):
452 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
453
454 if args.num_vectors is None:
455 args.num_vectors = 1
456
457 if isinstance(args.num_vectors, int):
458 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens)
459
460 if len(args.placeholder_tokens) != len(args.num_vectors):
461 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
462
463 if isinstance(args.collection, str): 419 if isinstance(args.collection, str):
464 args.collection = [args.collection] 420 args.collection = [args.collection]
465 421
@@ -503,7 +459,7 @@ def main():
503 459
504 vae.enable_slicing() 460 vae.enable_slicing()
505 vae.set_use_memory_efficient_attention_xformers(True) 461 vae.set_use_memory_efficient_attention_xformers(True)
506 unet.set_use_memory_efficient_attention_xformers(True) 462 unet.enable_xformers_memory_efficient_attention()
507 463
508 if args.gradient_checkpointing: 464 if args.gradient_checkpointing:
509 unet.enable_gradient_checkpointing() 465 unet.enable_gradient_checkpointing()
diff --git a/train_lora.py b/train_lora.py
new file mode 100644
index 0000000..2cb85cc
--- /dev/null
+++ b/train_lora.py
@@ -0,0 +1,566 @@
1import argparse
2import datetime
3import logging
4import itertools
5from pathlib import Path
6from functools import partial
7
8import torch
9import torch.utils.checkpoint
10
11from accelerate import Accelerator
12from accelerate.logging import get_logger
13from accelerate.utils import LoggerType, set_seed
14from slugify import slugify
15from diffusers.loaders import AttnProcsLayers
16from diffusers.models.cross_attention import LoRACrossAttnProcessor
17
18from util import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter
20from training.functional import train, get_models
21from training.lr import plot_metrics
22from training.strategy.lora import lora_strategy
23from training.optimization import get_scheduler
24from training.util import save_args
25
26logger = get_logger(__name__)
27
28
29torch.backends.cuda.matmul.allow_tf32 = True
30torch.backends.cudnn.benchmark = True
31
32
33def parse_args():
34 parser = argparse.ArgumentParser(
35 description="Simple example of a training script."
36 )
37 parser.add_argument(
38 "--pretrained_model_name_or_path",
39 type=str,
40 default=None,
41 help="Path to pretrained model or model identifier from huggingface.co/models.",
42 )
43 parser.add_argument(
44 "--tokenizer_name",
45 type=str,
46 default=None,
47 help="Pretrained tokenizer name or path if not the same as model_name",
48 )
49 parser.add_argument(
50 "--train_data_file",
51 type=str,
52 default=None,
53 help="A folder containing the training data."
54 )
55 parser.add_argument(
56 "--train_data_template",
57 type=str,
58 default="template",
59 )
60 parser.add_argument(
61 "--train_set_pad",
62 type=int,
63 default=None,
64 help="The number to fill train dataset items up to."
65 )
66 parser.add_argument(
67 "--valid_set_pad",
68 type=int,
69 default=None,
70 help="The number to fill validation dataset items up to."
71 )
72 parser.add_argument(
73 "--project",
74 type=str,
75 default=None,
76 help="The name of the current project.",
77 )
78 parser.add_argument(
79 "--exclude_collections",
80 type=str,
81 nargs='*',
82 help="Exclude all items with a listed collection.",
83 )
84 parser.add_argument(
85 "--num_buckets",
86 type=int,
87 default=4,
88 help="Number of aspect ratio buckets in either direction.",
89 )
90 parser.add_argument(
91 "--progressive_buckets",
92 action="store_true",
93 help="Include images in smaller buckets as well.",
94 )
95 parser.add_argument(
96 "--bucket_step_size",
97 type=int,
98 default=64,
99 help="Step size between buckets.",
100 )
101 parser.add_argument(
102 "--bucket_max_pixels",
103 type=int,
104 default=None,
105 help="Maximum pixels per bucket.",
106 )
107 parser.add_argument(
108 "--tag_dropout",
109 type=float,
110 default=0.1,
111 help="Tag dropout probability.",
112 )
113 parser.add_argument(
114 "--no_tag_shuffle",
115 action="store_true",
116 help="Shuffle tags.",
117 )
118 parser.add_argument(
119 "--num_class_images",
120 type=int,
121 default=0,
122 help="How many class images to generate."
123 )
124 parser.add_argument(
125 "--class_image_dir",
126 type=str,
127 default="cls",
128 help="The directory where class images will be saved.",
129 )
130 parser.add_argument(
131 "--output_dir",
132 type=str,
133 default="output/lora",
134 help="The output directory where the model predictions and checkpoints will be written.",
135 )
136 parser.add_argument(
137 "--embeddings_dir",
138 type=str,
139 default=None,
140 help="The embeddings directory where Textual Inversion embeddings are stored.",
141 )
142 parser.add_argument(
143 "--collection",
144 type=str,
145 nargs='*',
146 help="A collection to filter the dataset.",
147 )
148 parser.add_argument(
149 "--seed",
150 type=int,
151 default=None,
152 help="A seed for reproducible training."
153 )
154 parser.add_argument(
155 "--resolution",
156 type=int,
157 default=768,
158 help=(
159 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
160 " resolution"
161 ),
162 )
163 parser.add_argument(
164 "--num_train_epochs",
165 type=int,
166 default=100
167 )
168 parser.add_argument(
169 "--max_train_steps",
170 type=int,
171 default=None,
172 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
173 )
174 parser.add_argument(
175 "--gradient_accumulation_steps",
176 type=int,
177 default=1,
178 help="Number of updates steps to accumulate before performing a backward/update pass.",
179 )
180 parser.add_argument(
181 "--find_lr",
182 action="store_true",
183 help="Automatically find a learning rate (no training).",
184 )
185 parser.add_argument(
186 "--learning_rate",
187 type=float,
188 default=2e-6,
189 help="Initial learning rate (after the potential warmup period) to use.",
190 )
191 parser.add_argument(
192 "--scale_lr",
193 action="store_true",
194 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
195 )
196 parser.add_argument(
197 "--lr_scheduler",
198 type=str,
199 default="one_cycle",
200 help=(
201 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
202 ' "constant", "constant_with_warmup", "one_cycle"]'
203 ),
204 )
205 parser.add_argument(
206 "--lr_warmup_epochs",
207 type=int,
208 default=10,
209 help="Number of steps for the warmup in the lr scheduler."
210 )
211 parser.add_argument(
212 "--lr_cycles",
213 type=int,
214 default=None,
215 help="Number of restart cycles in the lr scheduler (if supported)."
216 )
217 parser.add_argument(
218 "--lr_warmup_func",
219 type=str,
220 default="cos",
221 help='Choose between ["linear", "cos"]'
222 )
223 parser.add_argument(
224 "--lr_warmup_exp",
225 type=int,
226 default=1,
227 help='If lr_warmup_func is "cos", exponent to modify the function'
228 )
229 parser.add_argument(
230 "--lr_annealing_func",
231 type=str,
232 default="cos",
233 help='Choose between ["linear", "half_cos", "cos"]'
234 )
235 parser.add_argument(
236 "--lr_annealing_exp",
237 type=int,
238 default=3,
239 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
240 )
241 parser.add_argument(
242 "--lr_min_lr",
243 type=float,
244 default=0.04,
245 help="Minimum learning rate in the lr scheduler."
246 )
247 parser.add_argument(
248 "--use_8bit_adam",
249 action="store_true",
250 help="Whether or not to use 8-bit Adam from bitsandbytes."
251 )
252 parser.add_argument(
253 "--adam_beta1",
254 type=float,
255 default=0.9,
256 help="The beta1 parameter for the Adam optimizer."
257 )
258 parser.add_argument(
259 "--adam_beta2",
260 type=float,
261 default=0.999,
262 help="The beta2 parameter for the Adam optimizer."
263 )
264 parser.add_argument(
265 "--adam_weight_decay",
266 type=float,
267 default=1e-2,
268 help="Weight decay to use."
269 )
270 parser.add_argument(
271 "--adam_epsilon",
272 type=float,
273 default=1e-08,
274 help="Epsilon value for the Adam optimizer"
275 )
276 parser.add_argument(
277 "--adam_amsgrad",
278 type=bool,
279 default=False,
280 help="Amsgrad value for the Adam optimizer"
281 )
282 parser.add_argument(
283 "--mixed_precision",
284 type=str,
285 default="no",
286 choices=["no", "fp16", "bf16"],
287 help=(
288 "Whether to use mixed precision. Choose"
289 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
290 "and an Nvidia Ampere GPU."
291 ),
292 )
293 parser.add_argument(
294 "--sample_frequency",
295 type=int,
296 default=1,
297 help="How often to save a checkpoint and sample image",
298 )
299 parser.add_argument(
300 "--sample_image_size",
301 type=int,
302 default=768,
303 help="Size of sample images",
304 )
305 parser.add_argument(
306 "--sample_batches",
307 type=int,
308 default=1,
309 help="Number of sample batches to generate per checkpoint",
310 )
311 parser.add_argument(
312 "--sample_batch_size",
313 type=int,
314 default=1,
315 help="Number of samples to generate per batch",
316 )
317 parser.add_argument(
318 "--valid_set_size",
319 type=int,
320 default=None,
321 help="Number of images in the validation dataset."
322 )
323 parser.add_argument(
324 "--valid_set_repeat",
325 type=int,
326 default=1,
327 help="Times the images in the validation dataset are repeated."
328 )
329 parser.add_argument(
330 "--train_batch_size",
331 type=int,
332 default=1,
333 help="Batch size (per device) for the training dataloader."
334 )
335 parser.add_argument(
336 "--sample_steps",
337 type=int,
338 default=20,
339 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
340 )
341 parser.add_argument(
342 "--prior_loss_weight",
343 type=float,
344 default=1.0,
345 help="The weight of prior preservation loss."
346 )
347 parser.add_argument(
348 "--max_grad_norm",
349 default=1.0,
350 type=float,
351 help="Max gradient norm."
352 )
353 parser.add_argument(
354 "--noise_timesteps",
355 type=int,
356 default=1000,
357 )
358 parser.add_argument(
359 "--config",
360 type=str,
361 default=None,
362 help="Path to a JSON configuration file containing arguments for invoking this script."
363 )
364
365 args = parser.parse_args()
366 if args.config is not None:
367 args = load_config(args.config)
368 args = parser.parse_args(namespace=argparse.Namespace(**args))
369
370 if args.train_data_file is None:
371 raise ValueError("You must specify --train_data_file")
372
373 if args.pretrained_model_name_or_path is None:
374 raise ValueError("You must specify --pretrained_model_name_or_path")
375
376 if args.project is None:
377 raise ValueError("You must specify --project")
378
379 if isinstance(args.collection, str):
380 args.collection = [args.collection]
381
382 if isinstance(args.exclude_collections, str):
383 args.exclude_collections = [args.exclude_collections]
384
385 if args.output_dir is None:
386 raise ValueError("You must specify --output_dir")
387
388 return args
389
390
391def main():
392 args = parse_args()
393
394 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
395 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
396 output_dir.mkdir(parents=True, exist_ok=True)
397
398 accelerator = Accelerator(
399 log_with=LoggerType.TENSORBOARD,
400 logging_dir=f"{output_dir}",
401 gradient_accumulation_steps=args.gradient_accumulation_steps,
402 mixed_precision=args.mixed_precision
403 )
404
405 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
406
407 if args.seed is None:
408 args.seed = torch.random.seed() >> 32
409
410 set_seed(args.seed)
411
412 save_args(output_dir, args)
413
414 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
415 args.pretrained_model_name_or_path)
416
417 vae.enable_slicing()
418 vae.set_use_memory_efficient_attention_xformers(True)
419 unet.enable_xformers_memory_efficient_attention()
420
421 lora_attn_procs = {}
422 for name in unet.attn_processors.keys():
423 cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
424 if name.startswith("mid_block"):
425 hidden_size = unet.config.block_out_channels[-1]
426 elif name.startswith("up_blocks"):
427 block_id = int(name[len("up_blocks.")])
428 hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
429 elif name.startswith("down_blocks"):
430 block_id = int(name[len("down_blocks.")])
431 hidden_size = unet.config.block_out_channels[block_id]
432
433 lora_attn_procs[name] = LoRACrossAttnProcessor(
434 hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
435 )
436
437 unet.set_attn_processor(lora_attn_procs)
438 lora_layers = AttnProcsLayers(unet.attn_processors)
439
440 if args.embeddings_dir is not None:
441 embeddings_dir = Path(args.embeddings_dir)
442 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
443 raise ValueError("--embeddings_dir must point to an existing directory")
444
445 embeddings.persist()
446
447 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
448 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
449
450 if args.scale_lr:
451 args.learning_rate = (
452 args.learning_rate * args.gradient_accumulation_steps *
453 args.train_batch_size * accelerator.num_processes
454 )
455
456 if args.find_lr:
457 args.learning_rate = 1e-6
458 args.lr_scheduler = "exponential_growth"
459
460 if args.use_8bit_adam:
461 try:
462 import bitsandbytes as bnb
463 except ImportError:
464 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
465
466 optimizer_class = bnb.optim.AdamW8bit
467 else:
468 optimizer_class = torch.optim.AdamW
469
470 weight_dtype = torch.float32
471 if args.mixed_precision == "fp16":
472 weight_dtype = torch.float16
473 elif args.mixed_precision == "bf16":
474 weight_dtype = torch.bfloat16
475
476 trainer = partial(
477 train,
478 accelerator=accelerator,
479 unet=unet,
480 text_encoder=text_encoder,
481 vae=vae,
482 lora_layers=lora_layers,
483 noise_scheduler=noise_scheduler,
484 dtype=weight_dtype,
485 with_prior_preservation=args.num_class_images != 0,
486 prior_loss_weight=args.prior_loss_weight,
487 )
488
489 checkpoint_output_dir = output_dir.joinpath("model")
490 sample_output_dir = output_dir.joinpath(f"samples")
491
492 datamodule = VlpnDataModule(
493 data_file=args.train_data_file,
494 batch_size=args.train_batch_size,
495 tokenizer=tokenizer,
496 class_subdir=args.class_image_dir,
497 num_class_images=args.num_class_images,
498 size=args.resolution,
499 num_buckets=args.num_buckets,
500 progressive_buckets=args.progressive_buckets,
501 bucket_step_size=args.bucket_step_size,
502 bucket_max_pixels=args.bucket_max_pixels,
503 dropout=args.tag_dropout,
504 shuffle=not args.no_tag_shuffle,
505 template_key=args.train_data_template,
506 valid_set_size=args.valid_set_size,
507 train_set_pad=args.train_set_pad,
508 valid_set_pad=args.valid_set_pad,
509 seed=args.seed,
510 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
511 dtype=weight_dtype
512 )
513 datamodule.setup()
514
515 optimizer = optimizer_class(
516 lora_layers.parameters(),
517 lr=args.learning_rate,
518 betas=(args.adam_beta1, args.adam_beta2),
519 weight_decay=args.adam_weight_decay,
520 eps=args.adam_epsilon,
521 amsgrad=args.adam_amsgrad,
522 )
523
524 lr_scheduler = get_scheduler(
525 args.lr_scheduler,
526 optimizer=optimizer,
527 num_training_steps_per_epoch=len(datamodule.train_dataloader),
528 gradient_accumulation_steps=args.gradient_accumulation_steps,
529 min_lr=args.lr_min_lr,
530 warmup_func=args.lr_warmup_func,
531 annealing_func=args.lr_annealing_func,
532 warmup_exp=args.lr_warmup_exp,
533 annealing_exp=args.lr_annealing_exp,
534 cycles=args.lr_cycles,
535 end_lr=1e2,
536 train_epochs=args.num_train_epochs,
537 warmup_epochs=args.lr_warmup_epochs,
538 )
539
540 metrics = trainer(
541 strategy=lora_strategy,
542 project="lora",
543 train_dataloader=datamodule.train_dataloader,
544 val_dataloader=datamodule.val_dataloader,
545 seed=args.seed,
546 optimizer=optimizer,
547 lr_scheduler=lr_scheduler,
548 num_train_epochs=args.num_train_epochs,
549 sample_frequency=args.sample_frequency,
550 # --
551 tokenizer=tokenizer,
552 sample_scheduler=sample_scheduler,
553 sample_output_dir=sample_output_dir,
554 checkpoint_output_dir=checkpoint_output_dir,
555 max_grad_norm=args.max_grad_norm,
556 sample_batch_size=args.sample_batch_size,
557 sample_num_batches=args.sample_batches,
558 sample_num_steps=args.sample_steps,
559 sample_image_size=args.sample_image_size,
560 )
561
562 plot_metrics(metrics, output_dir.joinpath("lr.png"))
563
564
565if __name__ == "__main__":
566 main()
diff --git a/train_ti.py b/train_ti.py
index c118aab..56f9e97 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -166,7 +166,7 @@ def parse_args():
166 parser.add_argument( 166 parser.add_argument(
167 "--tag_dropout", 167 "--tag_dropout",
168 type=float, 168 type=float,
169 default=0, 169 default=0.1,
170 help="Tag dropout probability.", 170 help="Tag dropout probability.",
171 ) 171 )
172 parser.add_argument( 172 parser.add_argument(
@@ -414,7 +414,7 @@ def parse_args():
414 ) 414 )
415 parser.add_argument( 415 parser.add_argument(
416 "--emb_decay", 416 "--emb_decay",
417 default=1e0, 417 default=1e-2,
418 type=float, 418 type=float,
419 help="Embedding decay factor." 419 help="Embedding decay factor."
420 ) 420 )
@@ -530,7 +530,7 @@ def main():
530 530
531 vae.enable_slicing() 531 vae.enable_slicing()
532 vae.set_use_memory_efficient_attention_xformers(True) 532 vae.set_use_memory_efficient_attention_xformers(True)
533 unet.set_use_memory_efficient_attention_xformers(True) 533 unet.enable_xformers_memory_efficient_attention()
534 534
535 if args.gradient_checkpointing: 535 if args.gradient_checkpointing:
536 unet.enable_gradient_checkpointing() 536 unet.enable_gradient_checkpointing()
@@ -612,8 +612,10 @@ def main():
612 612
613 if len(placeholder_tokens) == 1: 613 if len(placeholder_tokens) == 1:
614 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") 614 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}")
615 metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png")
615 else: 616 else:
616 sample_output_dir = output_dir.joinpath("samples") 617 sample_output_dir = output_dir.joinpath("samples")
618 metrics_output_file = output_dir.joinpath(f"lr.png")
617 619
618 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 620 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
619 tokenizer=tokenizer, 621 tokenizer=tokenizer,
@@ -687,7 +689,7 @@ def main():
687 placeholder_token_ids=placeholder_token_ids, 689 placeholder_token_ids=placeholder_token_ids,
688 ) 690 )
689 691
690 plot_metrics(metrics, output_dir.joinpath("lr.png")) 692 plot_metrics(metrics, metrics_output_file)
691 693
692 if args.simultaneous: 694 if args.simultaneous:
693 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 695 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template)
diff --git a/training/functional.py b/training/functional.py
index c373ac9..8f47734 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -34,7 +34,7 @@ def const(result=None):
34@dataclass 34@dataclass
35class TrainingCallbacks(): 35class TrainingCallbacks():
36 on_prepare: Callable[[], None] = const() 36 on_prepare: Callable[[], None] = const()
37 on_model: Callable[[], torch.nn.Module] = const(None) 37 on_accum_model: Callable[[], torch.nn.Module] = const(None)
38 on_log: Callable[[], dict[str, Any]] = const({}) 38 on_log: Callable[[], dict[str, Any]] = const({})
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[float, int], None] = const() 40 on_before_optimize: Callable[[float, int], None] = const()
@@ -51,7 +51,11 @@ class TrainingStrategyPrepareCallable(Protocol):
51 accelerator: Accelerator, 51 accelerator: Accelerator,
52 text_encoder: CLIPTextModel, 52 text_encoder: CLIPTextModel,
53 unet: UNet2DConditionModel, 53 unet: UNet2DConditionModel,
54 *args 54 optimizer: torch.optim.Optimizer,
55 train_dataloader: DataLoader,
56 val_dataloader: Optional[DataLoader],
57 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
58 **kwargs
55 ) -> Tuple: ... 59 ) -> Tuple: ...
56 60
57 61
@@ -92,7 +96,6 @@ def save_samples(
92 sample_scheduler: DPMSolverMultistepScheduler, 96 sample_scheduler: DPMSolverMultistepScheduler,
93 train_dataloader: DataLoader, 97 train_dataloader: DataLoader,
94 val_dataloader: Optional[DataLoader], 98 val_dataloader: Optional[DataLoader],
95 dtype: torch.dtype,
96 output_dir: Path, 99 output_dir: Path,
97 seed: int, 100 seed: int,
98 step: int, 101 step: int,
@@ -107,15 +110,6 @@ def save_samples(
107 grid_cols = min(batch_size, 4) 110 grid_cols = min(batch_size, 4)
108 grid_rows = (num_batches * batch_size) // grid_cols 111 grid_rows = (num_batches * batch_size) // grid_cols
109 112
110 unet = accelerator.unwrap_model(unet)
111 text_encoder = accelerator.unwrap_model(text_encoder)
112
113 orig_unet_dtype = unet.dtype
114 orig_text_encoder_dtype = text_encoder.dtype
115
116 unet.to(dtype=dtype)
117 text_encoder.to(dtype=dtype)
118
119 pipeline = VlpnStableDiffusion( 113 pipeline = VlpnStableDiffusion(
120 text_encoder=text_encoder, 114 text_encoder=text_encoder,
121 vae=vae, 115 vae=vae,
@@ -172,11 +166,6 @@ def save_samples(
172 image_grid = make_grid(all_samples, grid_rows, grid_cols) 166 image_grid = make_grid(all_samples, grid_rows, grid_cols)
173 image_grid.save(file_path, quality=85) 167 image_grid.save(file_path, quality=85)
174 168
175 unet.to(dtype=orig_unet_dtype)
176 text_encoder.to(dtype=orig_text_encoder_dtype)
177
178 del unet
179 del text_encoder
180 del generator 169 del generator
181 del pipeline 170 del pipeline
182 171
@@ -393,7 +382,7 @@ def train_loop(
393 ) 382 )
394 global_progress_bar.set_description("Total progress") 383 global_progress_bar.set_description("Total progress")
395 384
396 model = callbacks.on_model() 385 model = callbacks.on_accum_model()
397 on_log = callbacks.on_log 386 on_log = callbacks.on_log
398 on_train = callbacks.on_train 387 on_train = callbacks.on_train
399 on_before_optimize = callbacks.on_before_optimize 388 on_before_optimize = callbacks.on_before_optimize
@@ -559,8 +548,10 @@ def train(
559 prior_loss_weight: float = 1.0, 548 prior_loss_weight: float = 1.0,
560 **kwargs, 549 **kwargs,
561): 550):
562 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( 551 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
563 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 552 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs)
553
554 kwargs.update(extra)
564 555
565 vae.to(accelerator.device, dtype=dtype) 556 vae.to(accelerator.device, dtype=dtype)
566 557
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index e88bf90..b4c77f3 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -61,14 +61,11 @@ def dreambooth_strategy_callbacks(
61 save_samples_ = partial( 61 save_samples_ = partial(
62 save_samples, 62 save_samples,
63 accelerator=accelerator, 63 accelerator=accelerator,
64 unet=unet,
65 text_encoder=text_encoder,
66 tokenizer=tokenizer, 64 tokenizer=tokenizer,
67 vae=vae, 65 vae=vae,
68 sample_scheduler=sample_scheduler, 66 sample_scheduler=sample_scheduler,
69 train_dataloader=train_dataloader, 67 train_dataloader=train_dataloader,
70 val_dataloader=val_dataloader, 68 val_dataloader=val_dataloader,
71 dtype=weight_dtype,
72 output_dir=sample_output_dir, 69 output_dir=sample_output_dir,
73 seed=seed, 70 seed=seed,
74 batch_size=sample_batch_size, 71 batch_size=sample_batch_size,
@@ -94,7 +91,7 @@ def dreambooth_strategy_callbacks(
94 else: 91 else:
95 return nullcontext() 92 return nullcontext()
96 93
97 def on_model(): 94 def on_accum_model():
98 return unet 95 return unet
99 96
100 def on_prepare(): 97 def on_prepare():
@@ -172,11 +169,29 @@ def dreambooth_strategy_callbacks(
172 @torch.no_grad() 169 @torch.no_grad()
173 def on_sample(step): 170 def on_sample(step):
174 with ema_context(): 171 with ema_context():
175 save_samples_(step=step) 172 unet_ = accelerator.unwrap_model(unet)
173 text_encoder_ = accelerator.unwrap_model(text_encoder)
174
175 orig_unet_dtype = unet_.dtype
176 orig_text_encoder_dtype = text_encoder_.dtype
177
178 unet_.to(dtype=weight_dtype)
179 text_encoder_.to(dtype=weight_dtype)
180
181 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
182
183 unet_.to(dtype=orig_unet_dtype)
184 text_encoder_.to(dtype=orig_text_encoder_dtype)
185
186 del unet_
187 del text_encoder_
188
189 if torch.cuda.is_available():
190 torch.cuda.empty_cache()
176 191
177 return TrainingCallbacks( 192 return TrainingCallbacks(
178 on_prepare=on_prepare, 193 on_prepare=on_prepare,
179 on_model=on_model, 194 on_accum_model=on_accum_model,
180 on_train=on_train, 195 on_train=on_train,
181 on_eval=on_eval, 196 on_eval=on_eval,
182 on_before_optimize=on_before_optimize, 197 on_before_optimize=on_before_optimize,
@@ -191,9 +206,13 @@ def dreambooth_prepare(
191 accelerator: Accelerator, 206 accelerator: Accelerator,
192 text_encoder: CLIPTextModel, 207 text_encoder: CLIPTextModel,
193 unet: UNet2DConditionModel, 208 unet: UNet2DConditionModel,
194 *args 209 optimizer: torch.optim.Optimizer,
210 train_dataloader: DataLoader,
211 val_dataloader: Optional[DataLoader],
212 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
213 **kwargs
195): 214):
196 return accelerator.prepare(text_encoder, unet, *args) 215 return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({})
197 216
198 217
199dreambooth_strategy = TrainingStrategy( 218dreambooth_strategy = TrainingStrategy(
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
new file mode 100644
index 0000000..88d1824
--- /dev/null
+++ b/training/strategy/lora.py
@@ -0,0 +1,147 @@
1from contextlib import nullcontext
2from typing import Optional
3from functools import partial
4from contextlib import contextmanager, nullcontext
5from pathlib import Path
6
7import torch
8import torch.nn as nn
9from torch.utils.data import DataLoader
10
11from accelerate import Accelerator
12from transformers import CLIPTextModel
13from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
14from diffusers.loaders import AttnProcsLayers
15
16from slugify import slugify
17
18from models.clip.tokenizer import MultiCLIPTokenizer
19from training.util import EMAModel
20from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
21
22
23def lora_strategy_callbacks(
24 accelerator: Accelerator,
25 unet: UNet2DConditionModel,
26 text_encoder: CLIPTextModel,
27 tokenizer: MultiCLIPTokenizer,
28 vae: AutoencoderKL,
29 sample_scheduler: DPMSolverMultistepScheduler,
30 train_dataloader: DataLoader,
31 val_dataloader: Optional[DataLoader],
32 sample_output_dir: Path,
33 checkpoint_output_dir: Path,
34 seed: int,
35 lora_layers: AttnProcsLayers,
36 max_grad_norm: float = 1.0,
37 sample_batch_size: int = 1,
38 sample_num_batches: int = 1,
39 sample_num_steps: int = 20,
40 sample_guidance_scale: float = 7.5,
41 sample_image_size: Optional[int] = None,
42):
43 sample_output_dir.mkdir(parents=True, exist_ok=True)
44 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
45
46 weight_dtype = torch.float32
47 if accelerator.state.mixed_precision == "fp16":
48 weight_dtype = torch.float16
49 elif accelerator.state.mixed_precision == "bf16":
50 weight_dtype = torch.bfloat16
51
52 save_samples_ = partial(
53 save_samples,
54 accelerator=accelerator,
55 unet=unet,
56 text_encoder=text_encoder,
57 tokenizer=tokenizer,
58 vae=vae,
59 sample_scheduler=sample_scheduler,
60 train_dataloader=train_dataloader,
61 val_dataloader=val_dataloader,
62 output_dir=sample_output_dir,
63 seed=seed,
64 batch_size=sample_batch_size,
65 num_batches=sample_num_batches,
66 num_steps=sample_num_steps,
67 guidance_scale=sample_guidance_scale,
68 image_size=sample_image_size,
69 )
70
71 def on_prepare():
72 lora_layers.requires_grad_(True)
73
74 def on_accum_model():
75 return unet
76
77 @contextmanager
78 def on_train(epoch: int):
79 tokenizer.train()
80 yield
81
82 @contextmanager
83 def on_eval():
84 tokenizer.eval()
85 yield
86
87 def on_before_optimize(lr: float, epoch: int):
88 if accelerator.sync_gradients:
89 accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm)
90
91 @torch.no_grad()
92 def on_checkpoint(step, postfix):
93 print(f"Saving checkpoint for step {step}...")
94 orig_unet_dtype = unet.dtype
95 unet.to(dtype=torch.float32)
96 unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}"))
97 unet.to(dtype=orig_unet_dtype)
98
99 @torch.no_grad()
100 def on_sample(step):
101 orig_unet_dtype = unet.dtype
102 unet.to(dtype=weight_dtype)
103 save_samples_(step=step)
104 unet.to(dtype=orig_unet_dtype)
105
106 if torch.cuda.is_available():
107 torch.cuda.empty_cache()
108
109 return TrainingCallbacks(
110 on_prepare=on_prepare,
111 on_accum_model=on_accum_model,
112 on_train=on_train,
113 on_eval=on_eval,
114 on_before_optimize=on_before_optimize,
115 on_checkpoint=on_checkpoint,
116 on_sample=on_sample,
117 )
118
119
120def lora_prepare(
121 accelerator: Accelerator,
122 text_encoder: CLIPTextModel,
123 unet: UNet2DConditionModel,
124 optimizer: torch.optim.Optimizer,
125 train_dataloader: DataLoader,
126 val_dataloader: Optional[DataLoader],
127 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
128 lora_layers: AttnProcsLayers,
129 **kwargs
130):
131 weight_dtype = torch.float32
132 if accelerator.state.mixed_precision == "fp16":
133 weight_dtype = torch.float16
134 elif accelerator.state.mixed_precision == "bf16":
135 weight_dtype = torch.bfloat16
136
137 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
138 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler)
139 unet.to(accelerator.device, dtype=weight_dtype)
140 text_encoder.to(accelerator.device, dtype=weight_dtype)
141 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers}
142
143
144lora_strategy = TrainingStrategy(
145 callbacks=lora_strategy_callbacks,
146 prepare=lora_prepare,
147)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 14bdafd..d306f18 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -59,14 +59,11 @@ def textual_inversion_strategy_callbacks(
59 save_samples_ = partial( 59 save_samples_ = partial(
60 save_samples, 60 save_samples,
61 accelerator=accelerator, 61 accelerator=accelerator,
62 unet=unet,
63 text_encoder=text_encoder,
64 tokenizer=tokenizer, 62 tokenizer=tokenizer,
65 vae=vae, 63 vae=vae,
66 sample_scheduler=sample_scheduler, 64 sample_scheduler=sample_scheduler,
67 train_dataloader=train_dataloader, 65 train_dataloader=train_dataloader,
68 val_dataloader=val_dataloader, 66 val_dataloader=val_dataloader,
69 dtype=weight_dtype,
70 output_dir=sample_output_dir, 67 output_dir=sample_output_dir,
71 seed=seed, 68 seed=seed,
72 batch_size=sample_batch_size, 69 batch_size=sample_batch_size,
@@ -94,7 +91,7 @@ def textual_inversion_strategy_callbacks(
94 else: 91 else:
95 return nullcontext() 92 return nullcontext()
96 93
97 def on_model(): 94 def on_accum_model():
98 return text_encoder.text_model.embeddings.temp_token_embedding 95 return text_encoder.text_model.embeddings.temp_token_embedding
99 96
100 def on_prepare(): 97 def on_prepare():
@@ -149,11 +146,29 @@ def textual_inversion_strategy_callbacks(
149 @torch.no_grad() 146 @torch.no_grad()
150 def on_sample(step): 147 def on_sample(step):
151 with ema_context(): 148 with ema_context():
152 save_samples_(step=step) 149 unet_ = accelerator.unwrap_model(unet)
150 text_encoder_ = accelerator.unwrap_model(text_encoder)
151
152 orig_unet_dtype = unet_.dtype
153 orig_text_encoder_dtype = text_encoder_.dtype
154
155 unet_.to(dtype=weight_dtype)
156 text_encoder_.to(dtype=weight_dtype)
157
158 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
159
160 unet_.to(dtype=orig_unet_dtype)
161 text_encoder_.to(dtype=orig_text_encoder_dtype)
162
163 del unet_
164 del text_encoder_
165
166 if torch.cuda.is_available():
167 torch.cuda.empty_cache()
153 168
154 return TrainingCallbacks( 169 return TrainingCallbacks(
155 on_prepare=on_prepare, 170 on_prepare=on_prepare,
156 on_model=on_model, 171 on_accum_model=on_accum_model,
157 on_train=on_train, 172 on_train=on_train,
158 on_eval=on_eval, 173 on_eval=on_eval,
159 on_before_optimize=on_before_optimize, 174 on_before_optimize=on_before_optimize,
@@ -168,7 +183,11 @@ def textual_inversion_prepare(
168 accelerator: Accelerator, 183 accelerator: Accelerator,
169 text_encoder: CLIPTextModel, 184 text_encoder: CLIPTextModel,
170 unet: UNet2DConditionModel, 185 unet: UNet2DConditionModel,
171 *args 186 optimizer: torch.optim.Optimizer,
187 train_dataloader: DataLoader,
188 val_dataloader: Optional[DataLoader],
189 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
190 **kwargs
172): 191):
173 weight_dtype = torch.float32 192 weight_dtype = torch.float32
174 if accelerator.state.mixed_precision == "fp16": 193 if accelerator.state.mixed_precision == "fp16":
@@ -176,9 +195,10 @@ def textual_inversion_prepare(
176 elif accelerator.state.mixed_precision == "bf16": 195 elif accelerator.state.mixed_precision == "bf16":
177 weight_dtype = torch.bfloat16 196 weight_dtype = torch.bfloat16
178 197
179 prepped = accelerator.prepare(text_encoder, *args) 198 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
199 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler)
180 unet.to(accelerator.device, dtype=weight_dtype) 200 unet.to(accelerator.device, dtype=weight_dtype)
181 return (prepped[0], unet) + prepped[1:] 201 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
182 202
183 203
184textual_inversion_strategy = TrainingStrategy( 204textual_inversion_strategy = TrainingStrategy(