summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 13:28:49 +0200
commit8364ce697ddf6117fdd4f7222832d546d63880de (patch)
tree152c99815bbd8b2659d0dabe63c98f63151c97c2 /train_dreambooth.py
parentFix LoRA training with DAdan (diff)
downloadtextual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2
textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip
Update
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py770
1 files changed, 568 insertions, 202 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2aca1e7..659b84c 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -5,34 +5,70 @@ import itertools
5from pathlib import Path 5from pathlib import Path
6from functools import partial 6from functools import partial
7import math 7import math
8import warnings
8 9
9import torch 10import torch
11import torch._dynamo
10import torch.utils.checkpoint 12import torch.utils.checkpoint
13import hidet
11 14
12from accelerate import Accelerator 15from accelerate import Accelerator
13from accelerate.logging import get_logger 16from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
15from slugify import slugify 18
19# from diffusers.models.attention_processor import AttnProcessor
20from diffusers.utils.import_utils import is_xformers_available
16import transformers 21import transformers
17 22
18from util.files import load_config, load_embeddings_from_dir 23import numpy as np
24from slugify import slugify
25
19from data.csv import VlpnDataModule, keyword_filter 26from data.csv import VlpnDataModule, keyword_filter
20from training.functional import train, get_models 27from models.clip.embeddings import patch_managed_embeddings
28from training.functional import train, add_placeholder_tokens, get_models
21from training.strategy.dreambooth import dreambooth_strategy 29from training.strategy.dreambooth import dreambooth_strategy
22from training.optimization import get_scheduler 30from training.optimization import get_scheduler
23from training.util import save_args 31from training.sampler import create_named_schedule_sampler
32from training.util import AverageMeter, save_args
33from util.files import load_config, load_embeddings_from_dir
34
24 35
25logger = get_logger(__name__) 36logger = get_logger(__name__)
26 37
38warnings.filterwarnings("ignore")
39
27 40
28torch.backends.cuda.matmul.allow_tf32 = True 41torch.backends.cuda.matmul.allow_tf32 = True
29torch.backends.cudnn.benchmark = True 42torch.backends.cudnn.benchmark = True
30 43
44# torch._dynamo.config.log_level = logging.WARNING
45torch._dynamo.config.suppress_errors = True
46
47hidet.torch.dynamo_config.use_tensor_core(True)
48hidet.torch.dynamo_config.search_space(0)
49
50
51def patch_xformers(dtype):
52 if is_xformers_available():
53 import xformers
54 import xformers.ops
55
56 orig_xformers_memory_efficient_attention = (
57 xformers.ops.memory_efficient_attention
58 )
59
60 def xformers_memory_efficient_attention(
61 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs
62 ):
63 return orig_xformers_memory_efficient_attention(
64 query.to(dtype), key.to(dtype), value.to(dtype), **kwargs
65 )
66
67 xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention
68
31 69
32def parse_args(): 70def parse_args():
33 parser = argparse.ArgumentParser( 71 parser = argparse.ArgumentParser(description="Simple example of a training script.")
34 description="Simple example of a training script."
35 )
36 parser.add_argument( 72 parser.add_argument(
37 "--pretrained_model_name_or_path", 73 "--pretrained_model_name_or_path",
38 type=str, 74 type=str,
@@ -49,7 +85,7 @@ def parse_args():
49 "--train_data_file", 85 "--train_data_file",
50 type=str, 86 type=str,
51 default=None, 87 default=None,
52 help="A folder containing the training data." 88 help="A folder containing the training data.",
53 ) 89 )
54 parser.add_argument( 90 parser.add_argument(
55 "--train_data_template", 91 "--train_data_template",
@@ -60,13 +96,13 @@ def parse_args():
60 "--train_set_pad", 96 "--train_set_pad",
61 type=int, 97 type=int,
62 default=None, 98 default=None,
63 help="The number to fill train dataset items up to." 99 help="The number to fill train dataset items up to.",
64 ) 100 )
65 parser.add_argument( 101 parser.add_argument(
66 "--valid_set_pad", 102 "--valid_set_pad",
67 type=int, 103 type=int,
68 default=None, 104 default=None,
69 help="The number to fill validation dataset items up to." 105 help="The number to fill validation dataset items up to.",
70 ) 106 )
71 parser.add_argument( 107 parser.add_argument(
72 "--project", 108 "--project",
@@ -75,20 +111,58 @@ def parse_args():
75 help="The name of the current project.", 111 help="The name of the current project.",
76 ) 112 )
77 parser.add_argument( 113 parser.add_argument(
78 "--exclude_collections", 114 "--auto_cycles", type=str, default="o", help="Cycles to run automatically."
115 )
116 parser.add_argument(
117 "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle."
118 )
119 parser.add_argument(
120 "--placeholder_tokens",
79 type=str, 121 type=str,
80 nargs='*', 122 nargs="*",
81 help="Exclude all items with a listed collection.", 123 help="A token to use as a placeholder for the concept.",
82 ) 124 )
83 parser.add_argument( 125 parser.add_argument(
84 "--train_text_encoder_epochs", 126 "--initializer_tokens",
85 default=999999, 127 type=str,
86 help="Number of epochs the text encoder will be trained." 128 nargs="*",
129 help="A token to use as initializer word.",
130 )
131 parser.add_argument(
132 "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by."
133 )
134 parser.add_argument(
135 "--initializer_noise",
136 type=float,
137 default=0,
138 help="Noise to apply to the initializer word",
139 )
140 parser.add_argument(
141 "--alias_tokens",
142 type=str,
143 nargs="*",
144 default=[],
145 help="Tokens to create an alias for.",
146 )
147 parser.add_argument(
148 "--inverted_initializer_tokens",
149 type=str,
150 nargs="*",
151 help="A token to use as initializer word.",
152 )
153 parser.add_argument(
154 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
155 )
156 parser.add_argument(
157 "--exclude_collections",
158 type=str,
159 nargs="*",
160 help="Exclude all items with a listed collection.",
87 ) 161 )
88 parser.add_argument( 162 parser.add_argument(
89 "--num_buckets", 163 "--num_buckets",
90 type=int, 164 type=int,
91 default=0, 165 default=2,
92 help="Number of aspect ratio buckets in either direction.", 166 help="Number of aspect ratio buckets in either direction.",
93 ) 167 )
94 parser.add_argument( 168 parser.add_argument(
@@ -120,19 +194,6 @@ def parse_args():
120 help="Shuffle tags.", 194 help="Shuffle tags.",
121 ) 195 )
122 parser.add_argument( 196 parser.add_argument(
123 "--vector_dropout",
124 type=int,
125 default=0,
126 help="Vector dropout probability.",
127 )
128 parser.add_argument(
129 "--vector_shuffle",
130 type=str,
131 default="auto",
132 choices=["all", "trailing", "leading", "between", "auto", "off"],
133 help='Vector shuffling algorithm.',
134 )
135 parser.add_argument(
136 "--guidance_scale", 197 "--guidance_scale",
137 type=float, 198 type=float,
138 default=0, 199 default=0,
@@ -141,7 +202,7 @@ def parse_args():
141 "--num_class_images", 202 "--num_class_images",
142 type=int, 203 type=int,
143 default=0, 204 default=0,
144 help="How many class images to generate." 205 help="How many class images to generate.",
145 ) 206 )
146 parser.add_argument( 207 parser.add_argument(
147 "--class_image_dir", 208 "--class_image_dir",
@@ -162,16 +223,18 @@ def parse_args():
162 help="The embeddings directory where Textual Inversion embeddings are stored.", 223 help="The embeddings directory where Textual Inversion embeddings are stored.",
163 ) 224 )
164 parser.add_argument( 225 parser.add_argument(
226 "--train_dir_embeddings",
227 action="store_true",
228 help="Train embeddings loaded from embeddings directory.",
229 )
230 parser.add_argument(
165 "--collection", 231 "--collection",
166 type=str, 232 type=str,
167 nargs='*', 233 nargs="*",
168 help="A collection to filter the dataset.", 234 help="A collection to filter the dataset.",
169 ) 235 )
170 parser.add_argument( 236 parser.add_argument(
171 "--seed", 237 "--seed", type=int, default=None, help="A seed for reproducible training."
172 type=int,
173 default=None,
174 help="A seed for reproducible training."
175 ) 238 )
176 parser.add_argument( 239 parser.add_argument(
177 "--resolution", 240 "--resolution",
@@ -189,15 +252,13 @@ def parse_args():
189 help="Perlin offset noise strength.", 252 help="Perlin offset noise strength.",
190 ) 253 )
191 parser.add_argument( 254 parser.add_argument(
192 "--num_train_epochs", 255 "--input_pertubation",
193 type=int, 256 type=float,
194 default=None 257 default=0,
195 ) 258 help="The scale of input pretubation. Recommended 0.1.",
196 parser.add_argument(
197 "--num_train_steps",
198 type=int,
199 default=2000
200 ) 259 )
260 parser.add_argument("--num_train_epochs", type=int, default=None)
261 parser.add_argument("--num_train_steps", type=int, default=2000)
201 parser.add_argument( 262 parser.add_argument(
202 "--gradient_accumulation_steps", 263 "--gradient_accumulation_steps",
203 type=int, 264 type=int,
@@ -205,9 +266,9 @@ def parse_args():
205 help="Number of updates steps to accumulate before performing a backward/update pass.", 266 help="Number of updates steps to accumulate before performing a backward/update pass.",
206 ) 267 )
207 parser.add_argument( 268 parser.add_argument(
208 "--gradient_checkpointing", 269 "--train_text_encoder_cycles",
209 action="store_true", 270 default=999999,
210 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 271 help="Number of epochs the text encoder will be trained.",
211 ) 272 )
212 parser.add_argument( 273 parser.add_argument(
213 "--find_lr", 274 "--find_lr",
@@ -215,9 +276,15 @@ def parse_args():
215 help="Automatically find a learning rate (no training).", 276 help="Automatically find a learning rate (no training).",
216 ) 277 )
217 parser.add_argument( 278 parser.add_argument(
218 "--learning_rate", 279 "--learning_rate_unet",
280 type=float,
281 default=1e-4,
282 help="Initial learning rate (after the potential warmup period) to use.",
283 )
284 parser.add_argument(
285 "--learning_rate_text",
219 type=float, 286 type=float,
220 default=2e-6, 287 default=5e-5,
221 help="Initial learning rate (after the potential warmup period) to use.", 288 help="Initial learning rate (after the potential warmup period) to use.",
222 ) 289 )
223 parser.add_argument( 290 parser.add_argument(
@@ -229,27 +296,31 @@ def parse_args():
229 "--lr_scheduler", 296 "--lr_scheduler",
230 type=str, 297 type=str,
231 default="one_cycle", 298 default="one_cycle",
232 choices=["linear", "cosine", "cosine_with_restarts", "polynomial", 299 choices=[
233 "constant", "constant_with_warmup", "one_cycle"], 300 "linear",
234 help='The scheduler type to use.', 301 "cosine",
302 "cosine_with_restarts",
303 "polynomial",
304 "constant",
305 "constant_with_warmup",
306 "one_cycle",
307 ],
308 help="The scheduler type to use.",
235 ) 309 )
236 parser.add_argument( 310 parser.add_argument(
237 "--lr_warmup_epochs", 311 "--lr_warmup_epochs",
238 type=int, 312 type=int,
239 default=10, 313 default=10,
240 help="Number of steps for the warmup in the lr scheduler." 314 help="Number of steps for the warmup in the lr scheduler.",
241 ) 315 )
242 parser.add_argument( 316 parser.add_argument(
243 "--lr_mid_point", 317 "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point."
244 type=float,
245 default=0.3,
246 help="OneCycle schedule mid point."
247 ) 318 )
248 parser.add_argument( 319 parser.add_argument(
249 "--lr_cycles", 320 "--lr_cycles",
250 type=int, 321 type=int,
251 default=None, 322 default=None,
252 help="Number of restart cycles in the lr scheduler (if supported)." 323 help="Number of restart cycles in the lr scheduler (if supported).",
253 ) 324 )
254 parser.add_argument( 325 parser.add_argument(
255 "--lr_warmup_func", 326 "--lr_warmup_func",
@@ -261,7 +332,7 @@ def parse_args():
261 "--lr_warmup_exp", 332 "--lr_warmup_exp",
262 type=int, 333 type=int,
263 default=1, 334 default=1,
264 help='If lr_warmup_func is "cos", exponent to modify the function' 335 help='If lr_warmup_func is "cos", exponent to modify the function',
265 ) 336 )
266 parser.add_argument( 337 parser.add_argument(
267 "--lr_annealing_func", 338 "--lr_annealing_func",
@@ -273,76 +344,76 @@ def parse_args():
273 "--lr_annealing_exp", 344 "--lr_annealing_exp",
274 type=int, 345 type=int,
275 default=3, 346 default=3,
276 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 347 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function',
277 ) 348 )
278 parser.add_argument( 349 parser.add_argument(
279 "--lr_min_lr", 350 "--lr_min_lr",
280 type=float, 351 type=float,
281 default=0.04, 352 default=0.04,
282 help="Minimum learning rate in the lr scheduler." 353 help="Minimum learning rate in the lr scheduler.",
283 )
284 parser.add_argument(
285 "--use_ema",
286 action="store_true",
287 help="Whether to use EMA model."
288 )
289 parser.add_argument(
290 "--ema_inv_gamma",
291 type=float,
292 default=1.0
293 ) 354 )
355 parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.")
294 parser.add_argument( 356 parser.add_argument(
295 "--ema_power", 357 "--schedule_sampler",
296 type=float, 358 type=str,
297 default=6/7 359 default="uniform",
298 ) 360 choices=["uniform", "loss-second-moment"],
299 parser.add_argument( 361 help="Noise schedule sampler.",
300 "--ema_max_decay",
301 type=float,
302 default=0.9999
303 ) 362 )
304 parser.add_argument( 363 parser.add_argument(
305 "--optimizer", 364 "--optimizer",
306 type=str, 365 type=str,
307 default="dadan", 366 default="adan",
308 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 367 choices=[
309 help='Optimizer to use' 368 "adam",
369 "adam8bit",
370 "adan",
371 "lion",
372 "dadam",
373 "dadan",
374 "dlion",
375 "adafactor",
376 ],
377 help="Optimizer to use",
310 ) 378 )
311 parser.add_argument( 379 parser.add_argument(
312 "--dadaptation_d0", 380 "--dadaptation_d0",
313 type=float, 381 type=float,
314 default=1e-6, 382 default=1e-6,
315 help="The d0 parameter for Dadaptation optimizers." 383 help="The d0 parameter for Dadaptation optimizers.",
384 )
385 parser.add_argument(
386 "--dadaptation_growth_rate",
387 type=float,
388 default=math.inf,
389 help="The growth_rate parameter for Dadaptation optimizers.",
316 ) 390 )
317 parser.add_argument( 391 parser.add_argument(
318 "--adam_beta1", 392 "--adam_beta1",
319 type=float, 393 type=float,
320 default=None, 394 default=None,
321 help="The beta1 parameter for the Adam optimizer." 395 help="The beta1 parameter for the Adam optimizer.",
322 ) 396 )
323 parser.add_argument( 397 parser.add_argument(
324 "--adam_beta2", 398 "--adam_beta2",
325 type=float, 399 type=float,
326 default=None, 400 default=None,
327 help="The beta2 parameter for the Adam optimizer." 401 help="The beta2 parameter for the Adam optimizer.",
328 ) 402 )
329 parser.add_argument( 403 parser.add_argument(
330 "--adam_weight_decay", 404 "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use."
331 type=float,
332 default=1e-2,
333 help="Weight decay to use."
334 ) 405 )
335 parser.add_argument( 406 parser.add_argument(
336 "--adam_epsilon", 407 "--adam_epsilon",
337 type=float, 408 type=float,
338 default=1e-08, 409 default=1e-08,
339 help="Epsilon value for the Adam optimizer" 410 help="Epsilon value for the Adam optimizer",
340 ) 411 )
341 parser.add_argument( 412 parser.add_argument(
342 "--adam_amsgrad", 413 "--adam_amsgrad",
343 type=bool, 414 type=bool,
344 default=False, 415 default=False,
345 help="Amsgrad value for the Adam optimizer" 416 help="Amsgrad value for the Adam optimizer",
346 ) 417 )
347 parser.add_argument( 418 parser.add_argument(
348 "--mixed_precision", 419 "--mixed_precision",
@@ -356,12 +427,28 @@ def parse_args():
356 ), 427 ),
357 ) 428 )
358 parser.add_argument( 429 parser.add_argument(
430 "--compile_unet",
431 action="store_true",
432 help="Compile UNet with Torch Dynamo.",
433 )
434 parser.add_argument(
435 "--use_xformers",
436 action="store_true",
437 help="Use xformers.",
438 )
439 parser.add_argument(
359 "--sample_frequency", 440 "--sample_frequency",
360 type=int, 441 type=int,
361 default=1, 442 default=1,
362 help="How often to save a checkpoint and sample image", 443 help="How often to save a checkpoint and sample image",
363 ) 444 )
364 parser.add_argument( 445 parser.add_argument(
446 "--sample_num",
447 type=int,
448 default=None,
449 help="How often to save a checkpoint and sample image (in number of samples)",
450 )
451 parser.add_argument(
365 "--sample_image_size", 452 "--sample_image_size",
366 type=int, 453 type=int,
367 default=768, 454 default=768,
@@ -383,19 +470,19 @@ def parse_args():
383 "--valid_set_size", 470 "--valid_set_size",
384 type=int, 471 type=int,
385 default=None, 472 default=None,
386 help="Number of images in the validation dataset." 473 help="Number of images in the validation dataset.",
387 ) 474 )
388 parser.add_argument( 475 parser.add_argument(
389 "--valid_set_repeat", 476 "--valid_set_repeat",
390 type=int, 477 type=int,
391 default=1, 478 default=1,
392 help="Times the images in the validation dataset are repeated." 479 help="Times the images in the validation dataset are repeated.",
393 ) 480 )
394 parser.add_argument( 481 parser.add_argument(
395 "--train_batch_size", 482 "--train_batch_size",
396 type=int, 483 type=int,
397 default=1, 484 default=1,
398 help="Batch size (per device) for the training dataloader." 485 help="Batch size (per device) for the training dataloader.",
399 ) 486 )
400 parser.add_argument( 487 parser.add_argument(
401 "--sample_steps", 488 "--sample_steps",
@@ -407,13 +494,18 @@ def parse_args():
407 "--prior_loss_weight", 494 "--prior_loss_weight",
408 type=float, 495 type=float,
409 default=1.0, 496 default=1.0,
410 help="The weight of prior preservation loss." 497 help="The weight of prior preservation loss.",
411 ) 498 )
499 parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.")
500 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha")
412 parser.add_argument( 501 parser.add_argument(
413 "--max_grad_norm", 502 "--emb_dropout",
414 default=1.0,
415 type=float, 503 type=float,
416 help="Max gradient norm." 504 default=0,
505 help="Embedding dropout probability.",
506 )
507 parser.add_argument(
508 "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
417 ) 509 )
418 parser.add_argument( 510 parser.add_argument(
419 "--noise_timesteps", 511 "--noise_timesteps",
@@ -424,7 +516,7 @@ def parse_args():
424 "--config", 516 "--config",
425 type=str, 517 type=str,
426 default=None, 518 default=None,
427 help="Path to a JSON configuration file containing arguments for invoking this script." 519 help="Path to a JSON configuration file containing arguments for invoking this script.",
428 ) 520 )
429 521
430 args = parser.parse_args() 522 args = parser.parse_args()
@@ -441,6 +533,67 @@ def parse_args():
441 if args.project is None: 533 if args.project is None:
442 raise ValueError("You must specify --project") 534 raise ValueError("You must specify --project")
443 535
536 if args.initializer_tokens is None:
537 args.initializer_tokens = []
538
539 if args.placeholder_tokens is None:
540 args.placeholder_tokens = []
541
542 if isinstance(args.placeholder_tokens, str):
543 args.placeholder_tokens = [args.placeholder_tokens]
544
545 if isinstance(args.initializer_tokens, str):
546 args.initializer_tokens = [args.initializer_tokens] * len(
547 args.placeholder_tokens
548 )
549
550 if len(args.placeholder_tokens) == 0:
551 args.placeholder_tokens = [
552 f"<*{i}>" for i in range(len(args.initializer_tokens))
553 ]
554
555 if len(args.initializer_tokens) == 0:
556 args.initializer_tokens = args.placeholder_tokens.copy()
557
558 if len(args.placeholder_tokens) != len(args.initializer_tokens):
559 raise ValueError(
560 "--placeholder_tokens and --initializer_tokens must have the same number of items"
561 )
562
563 if isinstance(args.inverted_initializer_tokens, str):
564 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
565 args.placeholder_tokens
566 )
567
568 if (
569 isinstance(args.inverted_initializer_tokens, list)
570 and len(args.inverted_initializer_tokens) != 0
571 ):
572 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
573 args.initializer_tokens += args.inverted_initializer_tokens
574
575 if isinstance(args.num_vectors, int):
576 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
577
578 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(
579 args.num_vectors
580 ):
581 raise ValueError(
582 "--placeholder_tokens and --num_vectors must have the same number of items"
583 )
584
585 if args.alias_tokens is None:
586 args.alias_tokens = []
587
588 if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0:
589 raise ValueError("--alias_tokens must be a list with an even number of items")
590
591 if args.filter_tokens is None:
592 args.filter_tokens = args.placeholder_tokens.copy()
593
594 if isinstance(args.filter_tokens, str):
595 args.filter_tokens = [args.filter_tokens]
596
444 if isinstance(args.collection, str): 597 if isinstance(args.collection, str):
445 args.collection = [args.collection] 598 args.collection = [args.collection]
446 599
@@ -451,15 +604,15 @@ def parse_args():
451 raise ValueError("You must specify --output_dir") 604 raise ValueError("You must specify --output_dir")
452 605
453 if args.adam_beta1 is None: 606 if args.adam_beta1 is None:
454 if args.optimizer in ('adam', 'adam8bit'): 607 if args.optimizer in ("adam", "adam8bit", "dadam"):
455 args.adam_beta1 = 0.9 608 args.adam_beta1 = 0.9
456 elif args.optimizer == 'lion': 609 elif args.optimizer in ("lion", "dlion"):
457 args.adam_beta1 = 0.95 610 args.adam_beta1 = 0.95
458 611
459 if args.adam_beta2 is None: 612 if args.adam_beta2 is None:
460 if args.optimizer in ('adam', 'adam8bit'): 613 if args.optimizer in ("adam", "adam8bit", "dadam"):
461 args.adam_beta2 = 0.999 614 args.adam_beta2 = 0.999
462 elif args.optimizer == 'lion': 615 elif args.optimizer in ("lion", "dlion"):
463 args.adam_beta2 = 0.98 616 args.adam_beta2 = 0.98
464 617
465 return args 618 return args
@@ -475,7 +628,7 @@ def main():
475 accelerator = Accelerator( 628 accelerator = Accelerator(
476 log_with=LoggerType.TENSORBOARD, 629 log_with=LoggerType.TENSORBOARD,
477 project_dir=f"{output_dir}", 630 project_dir=f"{output_dir}",
478 mixed_precision=args.mixed_precision 631 mixed_precision=args.mixed_precision,
479 ) 632 )
480 633
481 weight_dtype = torch.float32 634 weight_dtype = torch.float32
@@ -484,6 +637,8 @@ def main():
484 elif args.mixed_precision == "bf16": 637 elif args.mixed_precision == "bf16":
485 weight_dtype = torch.bfloat16 638 weight_dtype = torch.bfloat16
486 639
640 patch_xformers(weight_dtype)
641
487 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) 642 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
488 643
489 if args.seed is None: 644 if args.seed is None:
@@ -493,44 +648,125 @@ def main():
493 648
494 save_args(output_dir, args) 649 save_args(output_dir, args)
495 650
496 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 651 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(
497 args.pretrained_model_name_or_path) 652 args.pretrained_model_name_or_path
498 653 )
499 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 654 embeddings = patch_managed_embeddings(
500 tokenizer.set_dropout(args.vector_dropout) 655 text_encoder, args.emb_alpha, args.emb_dropout
656 )
657 schedule_sampler = create_named_schedule_sampler(
658 args.schedule_sampler, noise_scheduler.config.num_train_timesteps
659 )
501 660
502 vae.enable_slicing() 661 vae.enable_slicing()
503 vae.set_use_memory_efficient_attention_xformers(True) 662
504 unet.enable_xformers_memory_efficient_attention() 663 if args.use_xformers:
664 vae.set_use_memory_efficient_attention_xformers(True)
665 unet.enable_xformers_memory_efficient_attention()
666 # elif args.compile_unet:
667 # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False
668 #
669 # proc = AttnProcessor()
670 #
671 # def fn_recursive_set_proc(module: torch.nn.Module):
672 # if hasattr(module, "processor"):
673 # module.processor = proc
674 #
675 # for child in module.children():
676 # fn_recursive_set_proc(child)
677 #
678 # fn_recursive_set_proc(unet)
505 679
506 if args.gradient_checkpointing: 680 if args.gradient_checkpointing:
507 unet.enable_gradient_checkpointing() 681 unet.enable_gradient_checkpointing()
508 text_encoder.gradient_checkpointing_enable() 682
683 if len(args.alias_tokens) != 0:
684 alias_placeholder_tokens = args.alias_tokens[::2]
685 alias_initializer_tokens = args.alias_tokens[1::2]
686
687 added_tokens, added_ids = add_placeholder_tokens(
688 tokenizer=tokenizer,
689 embeddings=embeddings,
690 placeholder_tokens=alias_placeholder_tokens,
691 initializer_tokens=alias_initializer_tokens,
692 )
693 embeddings.persist()
694 print(
695 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
696 )
697
698 placeholder_tokens = []
699 placeholder_token_ids = []
509 700
510 if args.embeddings_dir is not None: 701 if args.embeddings_dir is not None:
511 embeddings_dir = Path(args.embeddings_dir) 702 embeddings_dir = Path(args.embeddings_dir)
512 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 703 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
513 raise ValueError("--embeddings_dir must point to an existing directory") 704 raise ValueError("--embeddings_dir must point to an existing directory")
514 705
515 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 706 added_tokens, added_ids = load_embeddings_from_dir(
516 embeddings.persist() 707 tokenizer, embeddings, embeddings_dir
517 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 708 )
709
710 placeholder_tokens = added_tokens
711 placeholder_token_ids = added_ids
712
713 print(
714 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
715 )
716
717 if args.train_dir_embeddings:
718 print("Training embeddings from embeddings dir")
719 else:
720 embeddings.persist()
721
722 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings:
723 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
724 tokenizer=tokenizer,
725 embeddings=embeddings,
726 placeholder_tokens=args.placeholder_tokens,
727 initializer_tokens=args.initializer_tokens,
728 num_vectors=args.num_vectors,
729 initializer_noise=args.initializer_noise,
730 )
731
732 placeholder_tokens = args.placeholder_tokens
733
734 stats = list(
735 zip(
736 placeholder_tokens,
737 placeholder_token_ids,
738 args.initializer_tokens,
739 initializer_token_ids,
740 )
741 )
742 print(f"Training embeddings: {stats}")
518 743
519 if args.scale_lr: 744 if args.scale_lr:
520 args.learning_rate = ( 745 args.learning_rate_unet = (
521 args.learning_rate * args.gradient_accumulation_steps * 746 args.learning_rate_unet
522 args.train_batch_size * accelerator.num_processes 747 * args.gradient_accumulation_steps
748 * args.train_batch_size
749 * accelerator.num_processes
750 )
751 args.learning_rate_text = (
752 args.learning_rate_text
753 * args.gradient_accumulation_steps
754 * args.train_batch_size
755 * accelerator.num_processes
523 ) 756 )
524 757
525 if args.find_lr: 758 if args.find_lr:
526 args.learning_rate = 1e-6 759 args.learning_rate_unet = 1e-6
760 args.learning_rate_text = 1e-6
527 args.lr_scheduler = "exponential_growth" 761 args.lr_scheduler = "exponential_growth"
528 762
529 if args.optimizer == 'adam8bit': 763 if args.optimizer == "adam8bit":
530 try: 764 try:
531 import bitsandbytes as bnb 765 import bitsandbytes as bnb
532 except ImportError: 766 except ImportError:
533 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 767 raise ImportError(
768 "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
769 )
534 770
535 create_optimizer = partial( 771 create_optimizer = partial(
536 bnb.optim.AdamW8bit, 772 bnb.optim.AdamW8bit,
@@ -539,7 +775,7 @@ def main():
539 eps=args.adam_epsilon, 775 eps=args.adam_epsilon,
540 amsgrad=args.adam_amsgrad, 776 amsgrad=args.adam_amsgrad,
541 ) 777 )
542 elif args.optimizer == 'adam': 778 elif args.optimizer == "adam":
543 create_optimizer = partial( 779 create_optimizer = partial(
544 torch.optim.AdamW, 780 torch.optim.AdamW,
545 betas=(args.adam_beta1, args.adam_beta2), 781 betas=(args.adam_beta1, args.adam_beta2),
@@ -547,22 +783,27 @@ def main():
547 eps=args.adam_epsilon, 783 eps=args.adam_epsilon,
548 amsgrad=args.adam_amsgrad, 784 amsgrad=args.adam_amsgrad,
549 ) 785 )
550 elif args.optimizer == 'adan': 786 elif args.optimizer == "adan":
551 try: 787 try:
552 import timm.optim 788 import timm.optim
553 except ImportError: 789 except ImportError:
554 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") 790 raise ImportError(
791 "To use Adan, please install the PyTorch Image Models library: `pip install timm`."
792 )
555 793
556 create_optimizer = partial( 794 create_optimizer = partial(
557 timm.optim.Adan, 795 timm.optim.Adan,
558 weight_decay=args.adam_weight_decay, 796 weight_decay=args.adam_weight_decay,
559 eps=args.adam_epsilon, 797 eps=args.adam_epsilon,
798 no_prox=True,
560 ) 799 )
561 elif args.optimizer == 'lion': 800 elif args.optimizer == "lion":
562 try: 801 try:
563 import lion_pytorch 802 import lion_pytorch
564 except ImportError: 803 except ImportError:
565 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") 804 raise ImportError(
805 "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`."
806 )
566 807
567 create_optimizer = partial( 808 create_optimizer = partial(
568 lion_pytorch.Lion, 809 lion_pytorch.Lion,
@@ -570,7 +811,7 @@ def main():
570 weight_decay=args.adam_weight_decay, 811 weight_decay=args.adam_weight_decay,
571 use_triton=True, 812 use_triton=True,
572 ) 813 )
573 elif args.optimizer == 'adafactor': 814 elif args.optimizer == "adafactor":
574 create_optimizer = partial( 815 create_optimizer = partial(
575 transformers.optimization.Adafactor, 816 transformers.optimization.Adafactor,
576 weight_decay=args.adam_weight_decay, 817 weight_decay=args.adam_weight_decay,
@@ -580,13 +821,16 @@ def main():
580 ) 821 )
581 822
582 args.lr_scheduler = "adafactor" 823 args.lr_scheduler = "adafactor"
583 args.lr_min_lr = args.learning_rate 824 args.lr_min_lr = args.learning_rate_unet
584 args.learning_rate = None 825 args.learning_rate_unet = None
585 elif args.optimizer == 'dadam': 826 args.learning_rate_text = None
827 elif args.optimizer == "dadam":
586 try: 828 try:
587 import dadaptation 829 import dadaptation
588 except ImportError: 830 except ImportError:
589 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") 831 raise ImportError(
832 "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`."
833 )
590 834
591 create_optimizer = partial( 835 create_optimizer = partial(
592 dadaptation.DAdaptAdam, 836 dadaptation.DAdaptAdam,
@@ -595,46 +839,65 @@ def main():
595 eps=args.adam_epsilon, 839 eps=args.adam_epsilon,
596 decouple=True, 840 decouple=True,
597 d0=args.dadaptation_d0, 841 d0=args.dadaptation_d0,
842 growth_rate=args.dadaptation_growth_rate,
598 ) 843 )
599 844
600 args.learning_rate = 1.0 845 args.learning_rate_unet = 1.0
601 elif args.optimizer == 'dadan': 846 args.learning_rate_text = 1.0
847 elif args.optimizer == "dadan":
602 try: 848 try:
603 import dadaptation 849 import dadaptation
604 except ImportError: 850 except ImportError:
605 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") 851 raise ImportError(
852 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
853 )
606 854
607 create_optimizer = partial( 855 create_optimizer = partial(
608 dadaptation.DAdaptAdan, 856 dadaptation.DAdaptAdan,
609 weight_decay=args.adam_weight_decay, 857 weight_decay=args.adam_weight_decay,
610 eps=args.adam_epsilon, 858 eps=args.adam_epsilon,
611 d0=args.dadaptation_d0, 859 d0=args.dadaptation_d0,
860 growth_rate=args.dadaptation_growth_rate,
612 ) 861 )
613 862
614 args.learning_rate = 1.0 863 args.learning_rate_unet = 1.0
864 args.learning_rate_text = 1.0
865 elif args.optimizer == "dlion":
866 raise ImportError("DLion has not been merged into dadaptation yet")
615 else: 867 else:
616 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 868 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
617 869
618 trainer = partial( 870 trainer = partial(
619 train, 871 train,
620 accelerator=accelerator, 872 accelerator=accelerator,
621 unet=unet, 873 unet=unet,
622 text_encoder=text_encoder, 874 text_encoder=text_encoder,
875 tokenizer=tokenizer,
623 vae=vae, 876 vae=vae,
624 noise_scheduler=noise_scheduler, 877 noise_scheduler=noise_scheduler,
878 schedule_sampler=schedule_sampler,
879 min_snr_gamma=args.min_snr_gamma,
625 dtype=weight_dtype, 880 dtype=weight_dtype,
881 seed=args.seed,
882 compile_unet=args.compile_unet,
626 guidance_scale=args.guidance_scale, 883 guidance_scale=args.guidance_scale,
627 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 884 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
628 no_val=args.valid_set_size == 0, 885 sample_scheduler=sample_scheduler,
886 sample_batch_size=args.sample_batch_size,
887 sample_num_batches=args.sample_batches,
888 sample_num_steps=args.sample_steps,
889 sample_image_size=args.sample_image_size,
890 max_grad_norm=args.max_grad_norm,
629 ) 891 )
630 892
631 checkpoint_output_dir = output_dir / "model" 893 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
632 sample_output_dir = output_dir / "samples" 894 data_npgenerator = np.random.default_rng(args.seed)
633 895
634 datamodule = VlpnDataModule( 896 create_datamodule = partial(
897 VlpnDataModule,
635 data_file=args.train_data_file, 898 data_file=args.train_data_file,
636 batch_size=args.train_batch_size,
637 tokenizer=tokenizer, 899 tokenizer=tokenizer,
900 constant_prompt_length=args.compile_unet,
638 class_subdir=args.class_image_dir, 901 class_subdir=args.class_image_dir,
639 with_guidance=args.guidance_scale != 0, 902 with_guidance=args.guidance_scale != 0,
640 num_class_images=args.num_class_images, 903 num_class_images=args.num_class_images,
@@ -643,83 +906,186 @@ def main():
643 progressive_buckets=args.progressive_buckets, 906 progressive_buckets=args.progressive_buckets,
644 bucket_step_size=args.bucket_step_size, 907 bucket_step_size=args.bucket_step_size,
645 bucket_max_pixels=args.bucket_max_pixels, 908 bucket_max_pixels=args.bucket_max_pixels,
646 dropout=args.tag_dropout,
647 shuffle=not args.no_tag_shuffle, 909 shuffle=not args.no_tag_shuffle,
648 template_key=args.train_data_template, 910 template_key=args.train_data_template,
649 valid_set_size=args.valid_set_size,
650 train_set_pad=args.train_set_pad, 911 train_set_pad=args.train_set_pad,
651 valid_set_pad=args.valid_set_pad, 912 valid_set_pad=args.valid_set_pad,
652 seed=args.seed, 913 dtype=weight_dtype,
653 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 914 generator=data_generator,
654 dtype=weight_dtype 915 npgenerator=data_npgenerator,
655 )
656 datamodule.setup()
657
658 num_train_epochs = args.num_train_epochs
659 sample_frequency = args.sample_frequency
660 if num_train_epochs is None:
661 num_train_epochs = math.ceil(
662 args.num_train_steps / len(datamodule.train_dataset)
663 ) * args.gradient_accumulation_steps
664 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps))
665
666 params_to_optimize = (unet.parameters(), )
667 if args.train_text_encoder_epochs != 0:
668 params_to_optimize += (
669 text_encoder.text_model.encoder.parameters(),
670 text_encoder.text_model.final_layer_norm.parameters(),
671 )
672
673 optimizer = create_optimizer(
674 itertools.chain(*params_to_optimize),
675 lr=args.learning_rate,
676 ) 916 )
677 917
678 lr_scheduler = get_scheduler( 918 create_lr_scheduler = partial(
679 args.lr_scheduler, 919 get_scheduler,
680 optimizer=optimizer,
681 num_training_steps_per_epoch=len(datamodule.train_dataloader),
682 gradient_accumulation_steps=args.gradient_accumulation_steps,
683 min_lr=args.lr_min_lr, 920 min_lr=args.lr_min_lr,
684 warmup_func=args.lr_warmup_func, 921 warmup_func=args.lr_warmup_func,
685 annealing_func=args.lr_annealing_func, 922 annealing_func=args.lr_annealing_func,
686 warmup_exp=args.lr_warmup_exp, 923 warmup_exp=args.lr_warmup_exp,
687 annealing_exp=args.lr_annealing_exp, 924 annealing_exp=args.lr_annealing_exp,
688 cycles=args.lr_cycles,
689 end_lr=1e2, 925 end_lr=1e2,
690 train_epochs=num_train_epochs,
691 warmup_epochs=args.lr_warmup_epochs,
692 mid_point=args.lr_mid_point, 926 mid_point=args.lr_mid_point,
693 ) 927 )
694 928
695 trainer( 929 # Dreambooth
696 strategy=dreambooth_strategy, 930 # --------------------------------------------------------------------------------
697 project="dreambooth", 931
698 train_dataloader=datamodule.train_dataloader, 932 dreambooth_datamodule = create_datamodule(
699 val_dataloader=datamodule.val_dataloader, 933 valid_set_size=args.valid_set_size,
700 seed=args.seed, 934 batch_size=args.train_batch_size,
701 optimizer=optimizer, 935 dropout=args.tag_dropout,
702 lr_scheduler=lr_scheduler, 936 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
703 num_train_epochs=num_train_epochs, 937 )
704 gradient_accumulation_steps=args.gradient_accumulation_steps, 938 dreambooth_datamodule.setup()
705 sample_frequency=sample_frequency, 939
706 offset_noise_strength=args.offset_noise_strength, 940 num_train_epochs = args.num_train_epochs
707 # -- 941 dreambooth_sample_frequency = args.sample_frequency
708 tokenizer=tokenizer, 942 if num_train_epochs is None:
709 sample_scheduler=sample_scheduler, 943 num_train_epochs = (
710 sample_output_dir=sample_output_dir, 944 math.ceil(args.num_train_steps / len(dreambooth_datamodule.train_dataset))
711 checkpoint_output_dir=checkpoint_output_dir, 945 * args.gradient_accumulation_steps
712 train_text_encoder_epochs=args.train_text_encoder_epochs, 946 )
713 max_grad_norm=args.max_grad_norm, 947 dreambooth_sample_frequency = math.ceil(
714 use_ema=args.use_ema, 948 num_train_epochs * (dreambooth_sample_frequency / args.num_train_steps)
715 ema_inv_gamma=args.ema_inv_gamma, 949 )
716 ema_power=args.ema_power, 950 num_training_steps_per_epoch = math.ceil(
717 ema_max_decay=args.ema_max_decay, 951 len(dreambooth_datamodule.train_dataset) / args.gradient_accumulation_steps
718 sample_batch_size=args.sample_batch_size,
719 sample_num_batches=args.sample_batches,
720 sample_num_steps=args.sample_steps,
721 sample_image_size=args.sample_image_size,
722 ) 952 )
953 num_train_steps = num_training_steps_per_epoch * num_train_epochs
954 if args.sample_num is not None:
955 dreambooth_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
956
957 dreambooth_project = "dreambooth"
958
959 if accelerator.is_main_process:
960 accelerator.init_trackers(dreambooth_project)
961
962 dreambooth_sample_output_dir = output_dir / dreambooth_project / "samples"
963
964 training_iter = 0
965 auto_cycles = list(args.auto_cycles)
966 learning_rate_unet = args.learning_rate_unet
967 learning_rate_text = args.learning_rate_text
968 lr_scheduler = args.lr_scheduler
969 lr_warmup_epochs = args.lr_warmup_epochs
970 lr_cycles = args.lr_cycles
971
972 avg_loss = AverageMeter()
973 avg_acc = AverageMeter()
974 avg_loss_val = AverageMeter()
975 avg_acc_val = AverageMeter()
976
977 params_to_optimize = [
978 {
979 "params": (param for param in unet.parameters() if param.requires_grad),
980 "lr": learning_rate_unet,
981 },
982 {
983 "params": (
984 param for param in text_encoder.parameters() if param.requires_grad
985 ),
986 "lr": learning_rate_text,
987 },
988 ]
989 group_labels = ["unet", "text"]
990
991 dreambooth_optimizer = create_optimizer(params_to_optimize)
992
993 while True:
994 if len(auto_cycles) != 0:
995 response = auto_cycles.pop(0)
996 else:
997 response = input(
998 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> "
999 )
1000
1001 if response.lower().strip() == "o":
1002 if args.learning_rate_unet is not None:
1003 learning_rate_unet = (
1004 args.learning_rate_unet * 2 * (args.cycle_decay**training_iter)
1005 )
1006 if args.learning_rate_text is not None:
1007 learning_rate_text = (
1008 args.learning_rate_text * 2 * (args.cycle_decay**training_iter)
1009 )
1010 else:
1011 learning_rate_unet = args.learning_rate_unet * (
1012 args.cycle_decay**training_iter
1013 )
1014 learning_rate_text = args.learning_rate_text * (
1015 args.cycle_decay**training_iter
1016 )
1017
1018 if response.lower().strip() == "o":
1019 lr_scheduler = "one_cycle"
1020 lr_warmup_epochs = args.lr_warmup_epochs
1021 lr_cycles = args.lr_cycles
1022 elif response.lower().strip() == "w":
1023 lr_scheduler = "constant_with_warmup"
1024 lr_warmup_epochs = num_train_epochs
1025 elif response.lower().strip() == "c":
1026 lr_scheduler = "constant"
1027 elif response.lower().strip() == "d":
1028 lr_scheduler = "cosine"
1029 lr_warmup_epochs = 0
1030 lr_cycles = 1
1031 elif response.lower().strip() == "s":
1032 break
1033 else:
1034 continue
1035
1036 print("")
1037 print(
1038 f"============ Dreambooth cycle {training_iter + 1}: {response} ============"
1039 )
1040 print("")
1041
1042 for group, lr in zip(
1043 dreambooth_optimizer.param_groups, [learning_rate_unet, learning_rate_text]
1044 ):
1045 group["lr"] = lr
1046
1047 dreambooth_lr_scheduler = create_lr_scheduler(
1048 lr_scheduler,
1049 gradient_accumulation_steps=args.gradient_accumulation_steps,
1050 optimizer=dreambooth_optimizer,
1051 num_training_steps_per_epoch=len(dreambooth_datamodule.train_dataloader),
1052 train_epochs=num_train_epochs,
1053 cycles=lr_cycles,
1054 warmup_epochs=lr_warmup_epochs,
1055 )
1056
1057 dreambooth_checkpoint_output_dir = (
1058 output_dir / dreambooth_project / f"model_{training_iter}"
1059 )
1060
1061 trainer(
1062 strategy=dreambooth_strategy,
1063 train_dataloader=dreambooth_datamodule.train_dataloader,
1064 val_dataloader=dreambooth_datamodule.val_dataloader,
1065 optimizer=dreambooth_optimizer,
1066 lr_scheduler=dreambooth_lr_scheduler,
1067 num_train_epochs=num_train_epochs,
1068 gradient_accumulation_steps=args.gradient_accumulation_steps,
1069 global_step_offset=training_iter * num_train_steps,
1070 cycle=training_iter,
1071 train_text_encoder_cycles=args.train_text_encoder_cycles,
1072 # --
1073 group_labels=group_labels,
1074 sample_output_dir=dreambooth_sample_output_dir,
1075 checkpoint_output_dir=dreambooth_checkpoint_output_dir,
1076 sample_frequency=dreambooth_sample_frequency,
1077 offset_noise_strength=args.offset_noise_strength,
1078 input_pertubation=args.input_pertubation,
1079 no_val=args.valid_set_size == 0,
1080 avg_loss=avg_loss,
1081 avg_acc=avg_acc,
1082 avg_loss_val=avg_loss_val,
1083 avg_acc_val=avg_acc_val,
1084 )
1085
1086 training_iter += 1
1087
1088 accelerator.end_training()
723 1089
724 1090
725if __name__ == "__main__": 1091if __name__ == "__main__":