diff options
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 770 |
1 files changed, 568 insertions, 202 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2aca1e7..659b84c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -5,34 +5,70 @@ import itertools | |||
5 | from pathlib import Path | 5 | from pathlib import Path |
6 | from functools import partial | 6 | from functools import partial |
7 | import math | 7 | import math |
8 | import warnings | ||
8 | 9 | ||
9 | import torch | 10 | import torch |
11 | import torch._dynamo | ||
10 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
13 | import hidet | ||
11 | 14 | ||
12 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
13 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
15 | from slugify import slugify | 18 | |
19 | # from diffusers.models.attention_processor import AttnProcessor | ||
20 | from diffusers.utils.import_utils import is_xformers_available | ||
16 | import transformers | 21 | import transformers |
17 | 22 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 23 | import numpy as np |
24 | from slugify import slugify | ||
25 | |||
19 | from data.csv import VlpnDataModule, keyword_filter | 26 | from data.csv import VlpnDataModule, keyword_filter |
20 | from training.functional import train, get_models | 27 | from models.clip.embeddings import patch_managed_embeddings |
28 | from training.functional import train, add_placeholder_tokens, get_models | ||
21 | from training.strategy.dreambooth import dreambooth_strategy | 29 | from training.strategy.dreambooth import dreambooth_strategy |
22 | from training.optimization import get_scheduler | 30 | from training.optimization import get_scheduler |
23 | from training.util import save_args | 31 | from training.sampler import create_named_schedule_sampler |
32 | from training.util import AverageMeter, save_args | ||
33 | from util.files import load_config, load_embeddings_from_dir | ||
34 | |||
24 | 35 | ||
25 | logger = get_logger(__name__) | 36 | logger = get_logger(__name__) |
26 | 37 | ||
38 | warnings.filterwarnings("ignore") | ||
39 | |||
27 | 40 | ||
28 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
29 | torch.backends.cudnn.benchmark = True | 42 | torch.backends.cudnn.benchmark = True |
30 | 43 | ||
44 | # torch._dynamo.config.log_level = logging.WARNING | ||
45 | torch._dynamo.config.suppress_errors = True | ||
46 | |||
47 | hidet.torch.dynamo_config.use_tensor_core(True) | ||
48 | hidet.torch.dynamo_config.search_space(0) | ||
49 | |||
50 | |||
51 | def patch_xformers(dtype): | ||
52 | if is_xformers_available(): | ||
53 | import xformers | ||
54 | import xformers.ops | ||
55 | |||
56 | orig_xformers_memory_efficient_attention = ( | ||
57 | xformers.ops.memory_efficient_attention | ||
58 | ) | ||
59 | |||
60 | def xformers_memory_efficient_attention( | ||
61 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
62 | ): | ||
63 | return orig_xformers_memory_efficient_attention( | ||
64 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
65 | ) | ||
66 | |||
67 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
68 | |||
31 | 69 | ||
32 | def parse_args(): | 70 | def parse_args(): |
33 | parser = argparse.ArgumentParser( | 71 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
34 | description="Simple example of a training script." | ||
35 | ) | ||
36 | parser.add_argument( | 72 | parser.add_argument( |
37 | "--pretrained_model_name_or_path", | 73 | "--pretrained_model_name_or_path", |
38 | type=str, | 74 | type=str, |
@@ -49,7 +85,7 @@ def parse_args(): | |||
49 | "--train_data_file", | 85 | "--train_data_file", |
50 | type=str, | 86 | type=str, |
51 | default=None, | 87 | default=None, |
52 | help="A folder containing the training data." | 88 | help="A folder containing the training data.", |
53 | ) | 89 | ) |
54 | parser.add_argument( | 90 | parser.add_argument( |
55 | "--train_data_template", | 91 | "--train_data_template", |
@@ -60,13 +96,13 @@ def parse_args(): | |||
60 | "--train_set_pad", | 96 | "--train_set_pad", |
61 | type=int, | 97 | type=int, |
62 | default=None, | 98 | default=None, |
63 | help="The number to fill train dataset items up to." | 99 | help="The number to fill train dataset items up to.", |
64 | ) | 100 | ) |
65 | parser.add_argument( | 101 | parser.add_argument( |
66 | "--valid_set_pad", | 102 | "--valid_set_pad", |
67 | type=int, | 103 | type=int, |
68 | default=None, | 104 | default=None, |
69 | help="The number to fill validation dataset items up to." | 105 | help="The number to fill validation dataset items up to.", |
70 | ) | 106 | ) |
71 | parser.add_argument( | 107 | parser.add_argument( |
72 | "--project", | 108 | "--project", |
@@ -75,20 +111,58 @@ def parse_args(): | |||
75 | help="The name of the current project.", | 111 | help="The name of the current project.", |
76 | ) | 112 | ) |
77 | parser.add_argument( | 113 | parser.add_argument( |
78 | "--exclude_collections", | 114 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
115 | ) | ||
116 | parser.add_argument( | ||
117 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." | ||
118 | ) | ||
119 | parser.add_argument( | ||
120 | "--placeholder_tokens", | ||
79 | type=str, | 121 | type=str, |
80 | nargs='*', | 122 | nargs="*", |
81 | help="Exclude all items with a listed collection.", | 123 | help="A token to use as a placeholder for the concept.", |
82 | ) | 124 | ) |
83 | parser.add_argument( | 125 | parser.add_argument( |
84 | "--train_text_encoder_epochs", | 126 | "--initializer_tokens", |
85 | default=999999, | 127 | type=str, |
86 | help="Number of epochs the text encoder will be trained." | 128 | nargs="*", |
129 | help="A token to use as initializer word.", | ||
130 | ) | ||
131 | parser.add_argument( | ||
132 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." | ||
133 | ) | ||
134 | parser.add_argument( | ||
135 | "--initializer_noise", | ||
136 | type=float, | ||
137 | default=0, | ||
138 | help="Noise to apply to the initializer word", | ||
139 | ) | ||
140 | parser.add_argument( | ||
141 | "--alias_tokens", | ||
142 | type=str, | ||
143 | nargs="*", | ||
144 | default=[], | ||
145 | help="Tokens to create an alias for.", | ||
146 | ) | ||
147 | parser.add_argument( | ||
148 | "--inverted_initializer_tokens", | ||
149 | type=str, | ||
150 | nargs="*", | ||
151 | help="A token to use as initializer word.", | ||
152 | ) | ||
153 | parser.add_argument( | ||
154 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | ||
155 | ) | ||
156 | parser.add_argument( | ||
157 | "--exclude_collections", | ||
158 | type=str, | ||
159 | nargs="*", | ||
160 | help="Exclude all items with a listed collection.", | ||
87 | ) | 161 | ) |
88 | parser.add_argument( | 162 | parser.add_argument( |
89 | "--num_buckets", | 163 | "--num_buckets", |
90 | type=int, | 164 | type=int, |
91 | default=0, | 165 | default=2, |
92 | help="Number of aspect ratio buckets in either direction.", | 166 | help="Number of aspect ratio buckets in either direction.", |
93 | ) | 167 | ) |
94 | parser.add_argument( | 168 | parser.add_argument( |
@@ -120,19 +194,6 @@ def parse_args(): | |||
120 | help="Shuffle tags.", | 194 | help="Shuffle tags.", |
121 | ) | 195 | ) |
122 | parser.add_argument( | 196 | parser.add_argument( |
123 | "--vector_dropout", | ||
124 | type=int, | ||
125 | default=0, | ||
126 | help="Vector dropout probability.", | ||
127 | ) | ||
128 | parser.add_argument( | ||
129 | "--vector_shuffle", | ||
130 | type=str, | ||
131 | default="auto", | ||
132 | choices=["all", "trailing", "leading", "between", "auto", "off"], | ||
133 | help='Vector shuffling algorithm.', | ||
134 | ) | ||
135 | parser.add_argument( | ||
136 | "--guidance_scale", | 197 | "--guidance_scale", |
137 | type=float, | 198 | type=float, |
138 | default=0, | 199 | default=0, |
@@ -141,7 +202,7 @@ def parse_args(): | |||
141 | "--num_class_images", | 202 | "--num_class_images", |
142 | type=int, | 203 | type=int, |
143 | default=0, | 204 | default=0, |
144 | help="How many class images to generate." | 205 | help="How many class images to generate.", |
145 | ) | 206 | ) |
146 | parser.add_argument( | 207 | parser.add_argument( |
147 | "--class_image_dir", | 208 | "--class_image_dir", |
@@ -162,16 +223,18 @@ def parse_args(): | |||
162 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 223 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
163 | ) | 224 | ) |
164 | parser.add_argument( | 225 | parser.add_argument( |
226 | "--train_dir_embeddings", | ||
227 | action="store_true", | ||
228 | help="Train embeddings loaded from embeddings directory.", | ||
229 | ) | ||
230 | parser.add_argument( | ||
165 | "--collection", | 231 | "--collection", |
166 | type=str, | 232 | type=str, |
167 | nargs='*', | 233 | nargs="*", |
168 | help="A collection to filter the dataset.", | 234 | help="A collection to filter the dataset.", |
169 | ) | 235 | ) |
170 | parser.add_argument( | 236 | parser.add_argument( |
171 | "--seed", | 237 | "--seed", type=int, default=None, help="A seed for reproducible training." |
172 | type=int, | ||
173 | default=None, | ||
174 | help="A seed for reproducible training." | ||
175 | ) | 238 | ) |
176 | parser.add_argument( | 239 | parser.add_argument( |
177 | "--resolution", | 240 | "--resolution", |
@@ -189,15 +252,13 @@ def parse_args(): | |||
189 | help="Perlin offset noise strength.", | 252 | help="Perlin offset noise strength.", |
190 | ) | 253 | ) |
191 | parser.add_argument( | 254 | parser.add_argument( |
192 | "--num_train_epochs", | 255 | "--input_pertubation", |
193 | type=int, | 256 | type=float, |
194 | default=None | 257 | default=0, |
195 | ) | 258 | help="The scale of input pretubation. Recommended 0.1.", |
196 | parser.add_argument( | ||
197 | "--num_train_steps", | ||
198 | type=int, | ||
199 | default=2000 | ||
200 | ) | 259 | ) |
260 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
261 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
201 | parser.add_argument( | 262 | parser.add_argument( |
202 | "--gradient_accumulation_steps", | 263 | "--gradient_accumulation_steps", |
203 | type=int, | 264 | type=int, |
@@ -205,9 +266,9 @@ def parse_args(): | |||
205 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 266 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
206 | ) | 267 | ) |
207 | parser.add_argument( | 268 | parser.add_argument( |
208 | "--gradient_checkpointing", | 269 | "--train_text_encoder_cycles", |
209 | action="store_true", | 270 | default=999999, |
210 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 271 | help="Number of epochs the text encoder will be trained.", |
211 | ) | 272 | ) |
212 | parser.add_argument( | 273 | parser.add_argument( |
213 | "--find_lr", | 274 | "--find_lr", |
@@ -215,9 +276,15 @@ def parse_args(): | |||
215 | help="Automatically find a learning rate (no training).", | 276 | help="Automatically find a learning rate (no training).", |
216 | ) | 277 | ) |
217 | parser.add_argument( | 278 | parser.add_argument( |
218 | "--learning_rate", | 279 | "--learning_rate_unet", |
280 | type=float, | ||
281 | default=1e-4, | ||
282 | help="Initial learning rate (after the potential warmup period) to use.", | ||
283 | ) | ||
284 | parser.add_argument( | ||
285 | "--learning_rate_text", | ||
219 | type=float, | 286 | type=float, |
220 | default=2e-6, | 287 | default=5e-5, |
221 | help="Initial learning rate (after the potential warmup period) to use.", | 288 | help="Initial learning rate (after the potential warmup period) to use.", |
222 | ) | 289 | ) |
223 | parser.add_argument( | 290 | parser.add_argument( |
@@ -229,27 +296,31 @@ def parse_args(): | |||
229 | "--lr_scheduler", | 296 | "--lr_scheduler", |
230 | type=str, | 297 | type=str, |
231 | default="one_cycle", | 298 | default="one_cycle", |
232 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 299 | choices=[ |
233 | "constant", "constant_with_warmup", "one_cycle"], | 300 | "linear", |
234 | help='The scheduler type to use.', | 301 | "cosine", |
302 | "cosine_with_restarts", | ||
303 | "polynomial", | ||
304 | "constant", | ||
305 | "constant_with_warmup", | ||
306 | "one_cycle", | ||
307 | ], | ||
308 | help="The scheduler type to use.", | ||
235 | ) | 309 | ) |
236 | parser.add_argument( | 310 | parser.add_argument( |
237 | "--lr_warmup_epochs", | 311 | "--lr_warmup_epochs", |
238 | type=int, | 312 | type=int, |
239 | default=10, | 313 | default=10, |
240 | help="Number of steps for the warmup in the lr scheduler." | 314 | help="Number of steps for the warmup in the lr scheduler.", |
241 | ) | 315 | ) |
242 | parser.add_argument( | 316 | parser.add_argument( |
243 | "--lr_mid_point", | 317 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
244 | type=float, | ||
245 | default=0.3, | ||
246 | help="OneCycle schedule mid point." | ||
247 | ) | 318 | ) |
248 | parser.add_argument( | 319 | parser.add_argument( |
249 | "--lr_cycles", | 320 | "--lr_cycles", |
250 | type=int, | 321 | type=int, |
251 | default=None, | 322 | default=None, |
252 | help="Number of restart cycles in the lr scheduler (if supported)." | 323 | help="Number of restart cycles in the lr scheduler (if supported).", |
253 | ) | 324 | ) |
254 | parser.add_argument( | 325 | parser.add_argument( |
255 | "--lr_warmup_func", | 326 | "--lr_warmup_func", |
@@ -261,7 +332,7 @@ def parse_args(): | |||
261 | "--lr_warmup_exp", | 332 | "--lr_warmup_exp", |
262 | type=int, | 333 | type=int, |
263 | default=1, | 334 | default=1, |
264 | help='If lr_warmup_func is "cos", exponent to modify the function' | 335 | help='If lr_warmup_func is "cos", exponent to modify the function', |
265 | ) | 336 | ) |
266 | parser.add_argument( | 337 | parser.add_argument( |
267 | "--lr_annealing_func", | 338 | "--lr_annealing_func", |
@@ -273,76 +344,76 @@ def parse_args(): | |||
273 | "--lr_annealing_exp", | 344 | "--lr_annealing_exp", |
274 | type=int, | 345 | type=int, |
275 | default=3, | 346 | default=3, |
276 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 347 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
277 | ) | 348 | ) |
278 | parser.add_argument( | 349 | parser.add_argument( |
279 | "--lr_min_lr", | 350 | "--lr_min_lr", |
280 | type=float, | 351 | type=float, |
281 | default=0.04, | 352 | default=0.04, |
282 | help="Minimum learning rate in the lr scheduler." | 353 | help="Minimum learning rate in the lr scheduler.", |
283 | ) | ||
284 | parser.add_argument( | ||
285 | "--use_ema", | ||
286 | action="store_true", | ||
287 | help="Whether to use EMA model." | ||
288 | ) | ||
289 | parser.add_argument( | ||
290 | "--ema_inv_gamma", | ||
291 | type=float, | ||
292 | default=1.0 | ||
293 | ) | 354 | ) |
355 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
294 | parser.add_argument( | 356 | parser.add_argument( |
295 | "--ema_power", | 357 | "--schedule_sampler", |
296 | type=float, | 358 | type=str, |
297 | default=6/7 | 359 | default="uniform", |
298 | ) | 360 | choices=["uniform", "loss-second-moment"], |
299 | parser.add_argument( | 361 | help="Noise schedule sampler.", |
300 | "--ema_max_decay", | ||
301 | type=float, | ||
302 | default=0.9999 | ||
303 | ) | 362 | ) |
304 | parser.add_argument( | 363 | parser.add_argument( |
305 | "--optimizer", | 364 | "--optimizer", |
306 | type=str, | 365 | type=str, |
307 | default="dadan", | 366 | default="adan", |
308 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 367 | choices=[ |
309 | help='Optimizer to use' | 368 | "adam", |
369 | "adam8bit", | ||
370 | "adan", | ||
371 | "lion", | ||
372 | "dadam", | ||
373 | "dadan", | ||
374 | "dlion", | ||
375 | "adafactor", | ||
376 | ], | ||
377 | help="Optimizer to use", | ||
310 | ) | 378 | ) |
311 | parser.add_argument( | 379 | parser.add_argument( |
312 | "--dadaptation_d0", | 380 | "--dadaptation_d0", |
313 | type=float, | 381 | type=float, |
314 | default=1e-6, | 382 | default=1e-6, |
315 | help="The d0 parameter for Dadaptation optimizers." | 383 | help="The d0 parameter for Dadaptation optimizers.", |
384 | ) | ||
385 | parser.add_argument( | ||
386 | "--dadaptation_growth_rate", | ||
387 | type=float, | ||
388 | default=math.inf, | ||
389 | help="The growth_rate parameter for Dadaptation optimizers.", | ||
316 | ) | 390 | ) |
317 | parser.add_argument( | 391 | parser.add_argument( |
318 | "--adam_beta1", | 392 | "--adam_beta1", |
319 | type=float, | 393 | type=float, |
320 | default=None, | 394 | default=None, |
321 | help="The beta1 parameter for the Adam optimizer." | 395 | help="The beta1 parameter for the Adam optimizer.", |
322 | ) | 396 | ) |
323 | parser.add_argument( | 397 | parser.add_argument( |
324 | "--adam_beta2", | 398 | "--adam_beta2", |
325 | type=float, | 399 | type=float, |
326 | default=None, | 400 | default=None, |
327 | help="The beta2 parameter for the Adam optimizer." | 401 | help="The beta2 parameter for the Adam optimizer.", |
328 | ) | 402 | ) |
329 | parser.add_argument( | 403 | parser.add_argument( |
330 | "--adam_weight_decay", | 404 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
331 | type=float, | ||
332 | default=1e-2, | ||
333 | help="Weight decay to use." | ||
334 | ) | 405 | ) |
335 | parser.add_argument( | 406 | parser.add_argument( |
336 | "--adam_epsilon", | 407 | "--adam_epsilon", |
337 | type=float, | 408 | type=float, |
338 | default=1e-08, | 409 | default=1e-08, |
339 | help="Epsilon value for the Adam optimizer" | 410 | help="Epsilon value for the Adam optimizer", |
340 | ) | 411 | ) |
341 | parser.add_argument( | 412 | parser.add_argument( |
342 | "--adam_amsgrad", | 413 | "--adam_amsgrad", |
343 | type=bool, | 414 | type=bool, |
344 | default=False, | 415 | default=False, |
345 | help="Amsgrad value for the Adam optimizer" | 416 | help="Amsgrad value for the Adam optimizer", |
346 | ) | 417 | ) |
347 | parser.add_argument( | 418 | parser.add_argument( |
348 | "--mixed_precision", | 419 | "--mixed_precision", |
@@ -356,12 +427,28 @@ def parse_args(): | |||
356 | ), | 427 | ), |
357 | ) | 428 | ) |
358 | parser.add_argument( | 429 | parser.add_argument( |
430 | "--compile_unet", | ||
431 | action="store_true", | ||
432 | help="Compile UNet with Torch Dynamo.", | ||
433 | ) | ||
434 | parser.add_argument( | ||
435 | "--use_xformers", | ||
436 | action="store_true", | ||
437 | help="Use xformers.", | ||
438 | ) | ||
439 | parser.add_argument( | ||
359 | "--sample_frequency", | 440 | "--sample_frequency", |
360 | type=int, | 441 | type=int, |
361 | default=1, | 442 | default=1, |
362 | help="How often to save a checkpoint and sample image", | 443 | help="How often to save a checkpoint and sample image", |
363 | ) | 444 | ) |
364 | parser.add_argument( | 445 | parser.add_argument( |
446 | "--sample_num", | ||
447 | type=int, | ||
448 | default=None, | ||
449 | help="How often to save a checkpoint and sample image (in number of samples)", | ||
450 | ) | ||
451 | parser.add_argument( | ||
365 | "--sample_image_size", | 452 | "--sample_image_size", |
366 | type=int, | 453 | type=int, |
367 | default=768, | 454 | default=768, |
@@ -383,19 +470,19 @@ def parse_args(): | |||
383 | "--valid_set_size", | 470 | "--valid_set_size", |
384 | type=int, | 471 | type=int, |
385 | default=None, | 472 | default=None, |
386 | help="Number of images in the validation dataset." | 473 | help="Number of images in the validation dataset.", |
387 | ) | 474 | ) |
388 | parser.add_argument( | 475 | parser.add_argument( |
389 | "--valid_set_repeat", | 476 | "--valid_set_repeat", |
390 | type=int, | 477 | type=int, |
391 | default=1, | 478 | default=1, |
392 | help="Times the images in the validation dataset are repeated." | 479 | help="Times the images in the validation dataset are repeated.", |
393 | ) | 480 | ) |
394 | parser.add_argument( | 481 | parser.add_argument( |
395 | "--train_batch_size", | 482 | "--train_batch_size", |
396 | type=int, | 483 | type=int, |
397 | default=1, | 484 | default=1, |
398 | help="Batch size (per device) for the training dataloader." | 485 | help="Batch size (per device) for the training dataloader.", |
399 | ) | 486 | ) |
400 | parser.add_argument( | 487 | parser.add_argument( |
401 | "--sample_steps", | 488 | "--sample_steps", |
@@ -407,13 +494,18 @@ def parse_args(): | |||
407 | "--prior_loss_weight", | 494 | "--prior_loss_weight", |
408 | type=float, | 495 | type=float, |
409 | default=1.0, | 496 | default=1.0, |
410 | help="The weight of prior preservation loss." | 497 | help="The weight of prior preservation loss.", |
411 | ) | 498 | ) |
499 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
500 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
412 | parser.add_argument( | 501 | parser.add_argument( |
413 | "--max_grad_norm", | 502 | "--emb_dropout", |
414 | default=1.0, | ||
415 | type=float, | 503 | type=float, |
416 | help="Max gradient norm." | 504 | default=0, |
505 | help="Embedding dropout probability.", | ||
506 | ) | ||
507 | parser.add_argument( | ||
508 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." | ||
417 | ) | 509 | ) |
418 | parser.add_argument( | 510 | parser.add_argument( |
419 | "--noise_timesteps", | 511 | "--noise_timesteps", |
@@ -424,7 +516,7 @@ def parse_args(): | |||
424 | "--config", | 516 | "--config", |
425 | type=str, | 517 | type=str, |
426 | default=None, | 518 | default=None, |
427 | help="Path to a JSON configuration file containing arguments for invoking this script." | 519 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
428 | ) | 520 | ) |
429 | 521 | ||
430 | args = parser.parse_args() | 522 | args = parser.parse_args() |
@@ -441,6 +533,67 @@ def parse_args(): | |||
441 | if args.project is None: | 533 | if args.project is None: |
442 | raise ValueError("You must specify --project") | 534 | raise ValueError("You must specify --project") |
443 | 535 | ||
536 | if args.initializer_tokens is None: | ||
537 | args.initializer_tokens = [] | ||
538 | |||
539 | if args.placeholder_tokens is None: | ||
540 | args.placeholder_tokens = [] | ||
541 | |||
542 | if isinstance(args.placeholder_tokens, str): | ||
543 | args.placeholder_tokens = [args.placeholder_tokens] | ||
544 | |||
545 | if isinstance(args.initializer_tokens, str): | ||
546 | args.initializer_tokens = [args.initializer_tokens] * len( | ||
547 | args.placeholder_tokens | ||
548 | ) | ||
549 | |||
550 | if len(args.placeholder_tokens) == 0: | ||
551 | args.placeholder_tokens = [ | ||
552 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
553 | ] | ||
554 | |||
555 | if len(args.initializer_tokens) == 0: | ||
556 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
557 | |||
558 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
559 | raise ValueError( | ||
560 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
561 | ) | ||
562 | |||
563 | if isinstance(args.inverted_initializer_tokens, str): | ||
564 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
565 | args.placeholder_tokens | ||
566 | ) | ||
567 | |||
568 | if ( | ||
569 | isinstance(args.inverted_initializer_tokens, list) | ||
570 | and len(args.inverted_initializer_tokens) != 0 | ||
571 | ): | ||
572 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
573 | args.initializer_tokens += args.inverted_initializer_tokens | ||
574 | |||
575 | if isinstance(args.num_vectors, int): | ||
576 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | ||
577 | |||
578 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( | ||
579 | args.num_vectors | ||
580 | ): | ||
581 | raise ValueError( | ||
582 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
583 | ) | ||
584 | |||
585 | if args.alias_tokens is None: | ||
586 | args.alias_tokens = [] | ||
587 | |||
588 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
589 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
590 | |||
591 | if args.filter_tokens is None: | ||
592 | args.filter_tokens = args.placeholder_tokens.copy() | ||
593 | |||
594 | if isinstance(args.filter_tokens, str): | ||
595 | args.filter_tokens = [args.filter_tokens] | ||
596 | |||
444 | if isinstance(args.collection, str): | 597 | if isinstance(args.collection, str): |
445 | args.collection = [args.collection] | 598 | args.collection = [args.collection] |
446 | 599 | ||
@@ -451,15 +604,15 @@ def parse_args(): | |||
451 | raise ValueError("You must specify --output_dir") | 604 | raise ValueError("You must specify --output_dir") |
452 | 605 | ||
453 | if args.adam_beta1 is None: | 606 | if args.adam_beta1 is None: |
454 | if args.optimizer in ('adam', 'adam8bit'): | 607 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
455 | args.adam_beta1 = 0.9 | 608 | args.adam_beta1 = 0.9 |
456 | elif args.optimizer == 'lion': | 609 | elif args.optimizer in ("lion", "dlion"): |
457 | args.adam_beta1 = 0.95 | 610 | args.adam_beta1 = 0.95 |
458 | 611 | ||
459 | if args.adam_beta2 is None: | 612 | if args.adam_beta2 is None: |
460 | if args.optimizer in ('adam', 'adam8bit'): | 613 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
461 | args.adam_beta2 = 0.999 | 614 | args.adam_beta2 = 0.999 |
462 | elif args.optimizer == 'lion': | 615 | elif args.optimizer in ("lion", "dlion"): |
463 | args.adam_beta2 = 0.98 | 616 | args.adam_beta2 = 0.98 |
464 | 617 | ||
465 | return args | 618 | return args |
@@ -475,7 +628,7 @@ def main(): | |||
475 | accelerator = Accelerator( | 628 | accelerator = Accelerator( |
476 | log_with=LoggerType.TENSORBOARD, | 629 | log_with=LoggerType.TENSORBOARD, |
477 | project_dir=f"{output_dir}", | 630 | project_dir=f"{output_dir}", |
478 | mixed_precision=args.mixed_precision | 631 | mixed_precision=args.mixed_precision, |
479 | ) | 632 | ) |
480 | 633 | ||
481 | weight_dtype = torch.float32 | 634 | weight_dtype = torch.float32 |
@@ -484,6 +637,8 @@ def main(): | |||
484 | elif args.mixed_precision == "bf16": | 637 | elif args.mixed_precision == "bf16": |
485 | weight_dtype = torch.bfloat16 | 638 | weight_dtype = torch.bfloat16 |
486 | 639 | ||
640 | patch_xformers(weight_dtype) | ||
641 | |||
487 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 642 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
488 | 643 | ||
489 | if args.seed is None: | 644 | if args.seed is None: |
@@ -493,44 +648,125 @@ def main(): | |||
493 | 648 | ||
494 | save_args(output_dir, args) | 649 | save_args(output_dir, args) |
495 | 650 | ||
496 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 651 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
497 | args.pretrained_model_name_or_path) | 652 | args.pretrained_model_name_or_path |
498 | 653 | ) | |
499 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 654 | embeddings = patch_managed_embeddings( |
500 | tokenizer.set_dropout(args.vector_dropout) | 655 | text_encoder, args.emb_alpha, args.emb_dropout |
656 | ) | ||
657 | schedule_sampler = create_named_schedule_sampler( | ||
658 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
659 | ) | ||
501 | 660 | ||
502 | vae.enable_slicing() | 661 | vae.enable_slicing() |
503 | vae.set_use_memory_efficient_attention_xformers(True) | 662 | |
504 | unet.enable_xformers_memory_efficient_attention() | 663 | if args.use_xformers: |
664 | vae.set_use_memory_efficient_attention_xformers(True) | ||
665 | unet.enable_xformers_memory_efficient_attention() | ||
666 | # elif args.compile_unet: | ||
667 | # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
668 | # | ||
669 | # proc = AttnProcessor() | ||
670 | # | ||
671 | # def fn_recursive_set_proc(module: torch.nn.Module): | ||
672 | # if hasattr(module, "processor"): | ||
673 | # module.processor = proc | ||
674 | # | ||
675 | # for child in module.children(): | ||
676 | # fn_recursive_set_proc(child) | ||
677 | # | ||
678 | # fn_recursive_set_proc(unet) | ||
505 | 679 | ||
506 | if args.gradient_checkpointing: | 680 | if args.gradient_checkpointing: |
507 | unet.enable_gradient_checkpointing() | 681 | unet.enable_gradient_checkpointing() |
508 | text_encoder.gradient_checkpointing_enable() | 682 | |
683 | if len(args.alias_tokens) != 0: | ||
684 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
685 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
686 | |||
687 | added_tokens, added_ids = add_placeholder_tokens( | ||
688 | tokenizer=tokenizer, | ||
689 | embeddings=embeddings, | ||
690 | placeholder_tokens=alias_placeholder_tokens, | ||
691 | initializer_tokens=alias_initializer_tokens, | ||
692 | ) | ||
693 | embeddings.persist() | ||
694 | print( | ||
695 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
696 | ) | ||
697 | |||
698 | placeholder_tokens = [] | ||
699 | placeholder_token_ids = [] | ||
509 | 700 | ||
510 | if args.embeddings_dir is not None: | 701 | if args.embeddings_dir is not None: |
511 | embeddings_dir = Path(args.embeddings_dir) | 702 | embeddings_dir = Path(args.embeddings_dir) |
512 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 703 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
513 | raise ValueError("--embeddings_dir must point to an existing directory") | 704 | raise ValueError("--embeddings_dir must point to an existing directory") |
514 | 705 | ||
515 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 706 | added_tokens, added_ids = load_embeddings_from_dir( |
516 | embeddings.persist() | 707 | tokenizer, embeddings, embeddings_dir |
517 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 708 | ) |
709 | |||
710 | placeholder_tokens = added_tokens | ||
711 | placeholder_token_ids = added_ids | ||
712 | |||
713 | print( | ||
714 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
715 | ) | ||
716 | |||
717 | if args.train_dir_embeddings: | ||
718 | print("Training embeddings from embeddings dir") | ||
719 | else: | ||
720 | embeddings.persist() | ||
721 | |||
722 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | ||
723 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
724 | tokenizer=tokenizer, | ||
725 | embeddings=embeddings, | ||
726 | placeholder_tokens=args.placeholder_tokens, | ||
727 | initializer_tokens=args.initializer_tokens, | ||
728 | num_vectors=args.num_vectors, | ||
729 | initializer_noise=args.initializer_noise, | ||
730 | ) | ||
731 | |||
732 | placeholder_tokens = args.placeholder_tokens | ||
733 | |||
734 | stats = list( | ||
735 | zip( | ||
736 | placeholder_tokens, | ||
737 | placeholder_token_ids, | ||
738 | args.initializer_tokens, | ||
739 | initializer_token_ids, | ||
740 | ) | ||
741 | ) | ||
742 | print(f"Training embeddings: {stats}") | ||
518 | 743 | ||
519 | if args.scale_lr: | 744 | if args.scale_lr: |
520 | args.learning_rate = ( | 745 | args.learning_rate_unet = ( |
521 | args.learning_rate * args.gradient_accumulation_steps * | 746 | args.learning_rate_unet |
522 | args.train_batch_size * accelerator.num_processes | 747 | * args.gradient_accumulation_steps |
748 | * args.train_batch_size | ||
749 | * accelerator.num_processes | ||
750 | ) | ||
751 | args.learning_rate_text = ( | ||
752 | args.learning_rate_text | ||
753 | * args.gradient_accumulation_steps | ||
754 | * args.train_batch_size | ||
755 | * accelerator.num_processes | ||
523 | ) | 756 | ) |
524 | 757 | ||
525 | if args.find_lr: | 758 | if args.find_lr: |
526 | args.learning_rate = 1e-6 | 759 | args.learning_rate_unet = 1e-6 |
760 | args.learning_rate_text = 1e-6 | ||
527 | args.lr_scheduler = "exponential_growth" | 761 | args.lr_scheduler = "exponential_growth" |
528 | 762 | ||
529 | if args.optimizer == 'adam8bit': | 763 | if args.optimizer == "adam8bit": |
530 | try: | 764 | try: |
531 | import bitsandbytes as bnb | 765 | import bitsandbytes as bnb |
532 | except ImportError: | 766 | except ImportError: |
533 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 767 | raise ImportError( |
768 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
769 | ) | ||
534 | 770 | ||
535 | create_optimizer = partial( | 771 | create_optimizer = partial( |
536 | bnb.optim.AdamW8bit, | 772 | bnb.optim.AdamW8bit, |
@@ -539,7 +775,7 @@ def main(): | |||
539 | eps=args.adam_epsilon, | 775 | eps=args.adam_epsilon, |
540 | amsgrad=args.adam_amsgrad, | 776 | amsgrad=args.adam_amsgrad, |
541 | ) | 777 | ) |
542 | elif args.optimizer == 'adam': | 778 | elif args.optimizer == "adam": |
543 | create_optimizer = partial( | 779 | create_optimizer = partial( |
544 | torch.optim.AdamW, | 780 | torch.optim.AdamW, |
545 | betas=(args.adam_beta1, args.adam_beta2), | 781 | betas=(args.adam_beta1, args.adam_beta2), |
@@ -547,22 +783,27 @@ def main(): | |||
547 | eps=args.adam_epsilon, | 783 | eps=args.adam_epsilon, |
548 | amsgrad=args.adam_amsgrad, | 784 | amsgrad=args.adam_amsgrad, |
549 | ) | 785 | ) |
550 | elif args.optimizer == 'adan': | 786 | elif args.optimizer == "adan": |
551 | try: | 787 | try: |
552 | import timm.optim | 788 | import timm.optim |
553 | except ImportError: | 789 | except ImportError: |
554 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 790 | raise ImportError( |
791 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
792 | ) | ||
555 | 793 | ||
556 | create_optimizer = partial( | 794 | create_optimizer = partial( |
557 | timm.optim.Adan, | 795 | timm.optim.Adan, |
558 | weight_decay=args.adam_weight_decay, | 796 | weight_decay=args.adam_weight_decay, |
559 | eps=args.adam_epsilon, | 797 | eps=args.adam_epsilon, |
798 | no_prox=True, | ||
560 | ) | 799 | ) |
561 | elif args.optimizer == 'lion': | 800 | elif args.optimizer == "lion": |
562 | try: | 801 | try: |
563 | import lion_pytorch | 802 | import lion_pytorch |
564 | except ImportError: | 803 | except ImportError: |
565 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 804 | raise ImportError( |
805 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
806 | ) | ||
566 | 807 | ||
567 | create_optimizer = partial( | 808 | create_optimizer = partial( |
568 | lion_pytorch.Lion, | 809 | lion_pytorch.Lion, |
@@ -570,7 +811,7 @@ def main(): | |||
570 | weight_decay=args.adam_weight_decay, | 811 | weight_decay=args.adam_weight_decay, |
571 | use_triton=True, | 812 | use_triton=True, |
572 | ) | 813 | ) |
573 | elif args.optimizer == 'adafactor': | 814 | elif args.optimizer == "adafactor": |
574 | create_optimizer = partial( | 815 | create_optimizer = partial( |
575 | transformers.optimization.Adafactor, | 816 | transformers.optimization.Adafactor, |
576 | weight_decay=args.adam_weight_decay, | 817 | weight_decay=args.adam_weight_decay, |
@@ -580,13 +821,16 @@ def main(): | |||
580 | ) | 821 | ) |
581 | 822 | ||
582 | args.lr_scheduler = "adafactor" | 823 | args.lr_scheduler = "adafactor" |
583 | args.lr_min_lr = args.learning_rate | 824 | args.lr_min_lr = args.learning_rate_unet |
584 | args.learning_rate = None | 825 | args.learning_rate_unet = None |
585 | elif args.optimizer == 'dadam': | 826 | args.learning_rate_text = None |
827 | elif args.optimizer == "dadam": | ||
586 | try: | 828 | try: |
587 | import dadaptation | 829 | import dadaptation |
588 | except ImportError: | 830 | except ImportError: |
589 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 831 | raise ImportError( |
832 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
833 | ) | ||
590 | 834 | ||
591 | create_optimizer = partial( | 835 | create_optimizer = partial( |
592 | dadaptation.DAdaptAdam, | 836 | dadaptation.DAdaptAdam, |
@@ -595,46 +839,65 @@ def main(): | |||
595 | eps=args.adam_epsilon, | 839 | eps=args.adam_epsilon, |
596 | decouple=True, | 840 | decouple=True, |
597 | d0=args.dadaptation_d0, | 841 | d0=args.dadaptation_d0, |
842 | growth_rate=args.dadaptation_growth_rate, | ||
598 | ) | 843 | ) |
599 | 844 | ||
600 | args.learning_rate = 1.0 | 845 | args.learning_rate_unet = 1.0 |
601 | elif args.optimizer == 'dadan': | 846 | args.learning_rate_text = 1.0 |
847 | elif args.optimizer == "dadan": | ||
602 | try: | 848 | try: |
603 | import dadaptation | 849 | import dadaptation |
604 | except ImportError: | 850 | except ImportError: |
605 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 851 | raise ImportError( |
852 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
853 | ) | ||
606 | 854 | ||
607 | create_optimizer = partial( | 855 | create_optimizer = partial( |
608 | dadaptation.DAdaptAdan, | 856 | dadaptation.DAdaptAdan, |
609 | weight_decay=args.adam_weight_decay, | 857 | weight_decay=args.adam_weight_decay, |
610 | eps=args.adam_epsilon, | 858 | eps=args.adam_epsilon, |
611 | d0=args.dadaptation_d0, | 859 | d0=args.dadaptation_d0, |
860 | growth_rate=args.dadaptation_growth_rate, | ||
612 | ) | 861 | ) |
613 | 862 | ||
614 | args.learning_rate = 1.0 | 863 | args.learning_rate_unet = 1.0 |
864 | args.learning_rate_text = 1.0 | ||
865 | elif args.optimizer == "dlion": | ||
866 | raise ImportError("DLion has not been merged into dadaptation yet") | ||
615 | else: | 867 | else: |
616 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 868 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
617 | 869 | ||
618 | trainer = partial( | 870 | trainer = partial( |
619 | train, | 871 | train, |
620 | accelerator=accelerator, | 872 | accelerator=accelerator, |
621 | unet=unet, | 873 | unet=unet, |
622 | text_encoder=text_encoder, | 874 | text_encoder=text_encoder, |
875 | tokenizer=tokenizer, | ||
623 | vae=vae, | 876 | vae=vae, |
624 | noise_scheduler=noise_scheduler, | 877 | noise_scheduler=noise_scheduler, |
878 | schedule_sampler=schedule_sampler, | ||
879 | min_snr_gamma=args.min_snr_gamma, | ||
625 | dtype=weight_dtype, | 880 | dtype=weight_dtype, |
881 | seed=args.seed, | ||
882 | compile_unet=args.compile_unet, | ||
626 | guidance_scale=args.guidance_scale, | 883 | guidance_scale=args.guidance_scale, |
627 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 884 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
628 | no_val=args.valid_set_size == 0, | 885 | sample_scheduler=sample_scheduler, |
886 | sample_batch_size=args.sample_batch_size, | ||
887 | sample_num_batches=args.sample_batches, | ||
888 | sample_num_steps=args.sample_steps, | ||
889 | sample_image_size=args.sample_image_size, | ||
890 | max_grad_norm=args.max_grad_norm, | ||
629 | ) | 891 | ) |
630 | 892 | ||
631 | checkpoint_output_dir = output_dir / "model" | 893 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
632 | sample_output_dir = output_dir / "samples" | 894 | data_npgenerator = np.random.default_rng(args.seed) |
633 | 895 | ||
634 | datamodule = VlpnDataModule( | 896 | create_datamodule = partial( |
897 | VlpnDataModule, | ||
635 | data_file=args.train_data_file, | 898 | data_file=args.train_data_file, |
636 | batch_size=args.train_batch_size, | ||
637 | tokenizer=tokenizer, | 899 | tokenizer=tokenizer, |
900 | constant_prompt_length=args.compile_unet, | ||
638 | class_subdir=args.class_image_dir, | 901 | class_subdir=args.class_image_dir, |
639 | with_guidance=args.guidance_scale != 0, | 902 | with_guidance=args.guidance_scale != 0, |
640 | num_class_images=args.num_class_images, | 903 | num_class_images=args.num_class_images, |
@@ -643,83 +906,186 @@ def main(): | |||
643 | progressive_buckets=args.progressive_buckets, | 906 | progressive_buckets=args.progressive_buckets, |
644 | bucket_step_size=args.bucket_step_size, | 907 | bucket_step_size=args.bucket_step_size, |
645 | bucket_max_pixels=args.bucket_max_pixels, | 908 | bucket_max_pixels=args.bucket_max_pixels, |
646 | dropout=args.tag_dropout, | ||
647 | shuffle=not args.no_tag_shuffle, | 909 | shuffle=not args.no_tag_shuffle, |
648 | template_key=args.train_data_template, | 910 | template_key=args.train_data_template, |
649 | valid_set_size=args.valid_set_size, | ||
650 | train_set_pad=args.train_set_pad, | 911 | train_set_pad=args.train_set_pad, |
651 | valid_set_pad=args.valid_set_pad, | 912 | valid_set_pad=args.valid_set_pad, |
652 | seed=args.seed, | 913 | dtype=weight_dtype, |
653 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 914 | generator=data_generator, |
654 | dtype=weight_dtype | 915 | npgenerator=data_npgenerator, |
655 | ) | ||
656 | datamodule.setup() | ||
657 | |||
658 | num_train_epochs = args.num_train_epochs | ||
659 | sample_frequency = args.sample_frequency | ||
660 | if num_train_epochs is None: | ||
661 | num_train_epochs = math.ceil( | ||
662 | args.num_train_steps / len(datamodule.train_dataset) | ||
663 | ) * args.gradient_accumulation_steps | ||
664 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | ||
665 | |||
666 | params_to_optimize = (unet.parameters(), ) | ||
667 | if args.train_text_encoder_epochs != 0: | ||
668 | params_to_optimize += ( | ||
669 | text_encoder.text_model.encoder.parameters(), | ||
670 | text_encoder.text_model.final_layer_norm.parameters(), | ||
671 | ) | ||
672 | |||
673 | optimizer = create_optimizer( | ||
674 | itertools.chain(*params_to_optimize), | ||
675 | lr=args.learning_rate, | ||
676 | ) | 916 | ) |
677 | 917 | ||
678 | lr_scheduler = get_scheduler( | 918 | create_lr_scheduler = partial( |
679 | args.lr_scheduler, | 919 | get_scheduler, |
680 | optimizer=optimizer, | ||
681 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
682 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
683 | min_lr=args.lr_min_lr, | 920 | min_lr=args.lr_min_lr, |
684 | warmup_func=args.lr_warmup_func, | 921 | warmup_func=args.lr_warmup_func, |
685 | annealing_func=args.lr_annealing_func, | 922 | annealing_func=args.lr_annealing_func, |
686 | warmup_exp=args.lr_warmup_exp, | 923 | warmup_exp=args.lr_warmup_exp, |
687 | annealing_exp=args.lr_annealing_exp, | 924 | annealing_exp=args.lr_annealing_exp, |
688 | cycles=args.lr_cycles, | ||
689 | end_lr=1e2, | 925 | end_lr=1e2, |
690 | train_epochs=num_train_epochs, | ||
691 | warmup_epochs=args.lr_warmup_epochs, | ||
692 | mid_point=args.lr_mid_point, | 926 | mid_point=args.lr_mid_point, |
693 | ) | 927 | ) |
694 | 928 | ||
695 | trainer( | 929 | # Dreambooth |
696 | strategy=dreambooth_strategy, | 930 | # -------------------------------------------------------------------------------- |
697 | project="dreambooth", | 931 | |
698 | train_dataloader=datamodule.train_dataloader, | 932 | dreambooth_datamodule = create_datamodule( |
699 | val_dataloader=datamodule.val_dataloader, | 933 | valid_set_size=args.valid_set_size, |
700 | seed=args.seed, | 934 | batch_size=args.train_batch_size, |
701 | optimizer=optimizer, | 935 | dropout=args.tag_dropout, |
702 | lr_scheduler=lr_scheduler, | 936 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
703 | num_train_epochs=num_train_epochs, | 937 | ) |
704 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 938 | dreambooth_datamodule.setup() |
705 | sample_frequency=sample_frequency, | 939 | |
706 | offset_noise_strength=args.offset_noise_strength, | 940 | num_train_epochs = args.num_train_epochs |
707 | # -- | 941 | dreambooth_sample_frequency = args.sample_frequency |
708 | tokenizer=tokenizer, | 942 | if num_train_epochs is None: |
709 | sample_scheduler=sample_scheduler, | 943 | num_train_epochs = ( |
710 | sample_output_dir=sample_output_dir, | 944 | math.ceil(args.num_train_steps / len(dreambooth_datamodule.train_dataset)) |
711 | checkpoint_output_dir=checkpoint_output_dir, | 945 | * args.gradient_accumulation_steps |
712 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 946 | ) |
713 | max_grad_norm=args.max_grad_norm, | 947 | dreambooth_sample_frequency = math.ceil( |
714 | use_ema=args.use_ema, | 948 | num_train_epochs * (dreambooth_sample_frequency / args.num_train_steps) |
715 | ema_inv_gamma=args.ema_inv_gamma, | 949 | ) |
716 | ema_power=args.ema_power, | 950 | num_training_steps_per_epoch = math.ceil( |
717 | ema_max_decay=args.ema_max_decay, | 951 | len(dreambooth_datamodule.train_dataset) / args.gradient_accumulation_steps |
718 | sample_batch_size=args.sample_batch_size, | ||
719 | sample_num_batches=args.sample_batches, | ||
720 | sample_num_steps=args.sample_steps, | ||
721 | sample_image_size=args.sample_image_size, | ||
722 | ) | 952 | ) |
953 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
954 | if args.sample_num is not None: | ||
955 | dreambooth_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | ||
956 | |||
957 | dreambooth_project = "dreambooth" | ||
958 | |||
959 | if accelerator.is_main_process: | ||
960 | accelerator.init_trackers(dreambooth_project) | ||
961 | |||
962 | dreambooth_sample_output_dir = output_dir / dreambooth_project / "samples" | ||
963 | |||
964 | training_iter = 0 | ||
965 | auto_cycles = list(args.auto_cycles) | ||
966 | learning_rate_unet = args.learning_rate_unet | ||
967 | learning_rate_text = args.learning_rate_text | ||
968 | lr_scheduler = args.lr_scheduler | ||
969 | lr_warmup_epochs = args.lr_warmup_epochs | ||
970 | lr_cycles = args.lr_cycles | ||
971 | |||
972 | avg_loss = AverageMeter() | ||
973 | avg_acc = AverageMeter() | ||
974 | avg_loss_val = AverageMeter() | ||
975 | avg_acc_val = AverageMeter() | ||
976 | |||
977 | params_to_optimize = [ | ||
978 | { | ||
979 | "params": (param for param in unet.parameters() if param.requires_grad), | ||
980 | "lr": learning_rate_unet, | ||
981 | }, | ||
982 | { | ||
983 | "params": ( | ||
984 | param for param in text_encoder.parameters() if param.requires_grad | ||
985 | ), | ||
986 | "lr": learning_rate_text, | ||
987 | }, | ||
988 | ] | ||
989 | group_labels = ["unet", "text"] | ||
990 | |||
991 | dreambooth_optimizer = create_optimizer(params_to_optimize) | ||
992 | |||
993 | while True: | ||
994 | if len(auto_cycles) != 0: | ||
995 | response = auto_cycles.pop(0) | ||
996 | else: | ||
997 | response = input( | ||
998 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " | ||
999 | ) | ||
1000 | |||
1001 | if response.lower().strip() == "o": | ||
1002 | if args.learning_rate_unet is not None: | ||
1003 | learning_rate_unet = ( | ||
1004 | args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) | ||
1005 | ) | ||
1006 | if args.learning_rate_text is not None: | ||
1007 | learning_rate_text = ( | ||
1008 | args.learning_rate_text * 2 * (args.cycle_decay**training_iter) | ||
1009 | ) | ||
1010 | else: | ||
1011 | learning_rate_unet = args.learning_rate_unet * ( | ||
1012 | args.cycle_decay**training_iter | ||
1013 | ) | ||
1014 | learning_rate_text = args.learning_rate_text * ( | ||
1015 | args.cycle_decay**training_iter | ||
1016 | ) | ||
1017 | |||
1018 | if response.lower().strip() == "o": | ||
1019 | lr_scheduler = "one_cycle" | ||
1020 | lr_warmup_epochs = args.lr_warmup_epochs | ||
1021 | lr_cycles = args.lr_cycles | ||
1022 | elif response.lower().strip() == "w": | ||
1023 | lr_scheduler = "constant_with_warmup" | ||
1024 | lr_warmup_epochs = num_train_epochs | ||
1025 | elif response.lower().strip() == "c": | ||
1026 | lr_scheduler = "constant" | ||
1027 | elif response.lower().strip() == "d": | ||
1028 | lr_scheduler = "cosine" | ||
1029 | lr_warmup_epochs = 0 | ||
1030 | lr_cycles = 1 | ||
1031 | elif response.lower().strip() == "s": | ||
1032 | break | ||
1033 | else: | ||
1034 | continue | ||
1035 | |||
1036 | print("") | ||
1037 | print( | ||
1038 | f"============ Dreambooth cycle {training_iter + 1}: {response} ============" | ||
1039 | ) | ||
1040 | print("") | ||
1041 | |||
1042 | for group, lr in zip( | ||
1043 | dreambooth_optimizer.param_groups, [learning_rate_unet, learning_rate_text] | ||
1044 | ): | ||
1045 | group["lr"] = lr | ||
1046 | |||
1047 | dreambooth_lr_scheduler = create_lr_scheduler( | ||
1048 | lr_scheduler, | ||
1049 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
1050 | optimizer=dreambooth_optimizer, | ||
1051 | num_training_steps_per_epoch=len(dreambooth_datamodule.train_dataloader), | ||
1052 | train_epochs=num_train_epochs, | ||
1053 | cycles=lr_cycles, | ||
1054 | warmup_epochs=lr_warmup_epochs, | ||
1055 | ) | ||
1056 | |||
1057 | dreambooth_checkpoint_output_dir = ( | ||
1058 | output_dir / dreambooth_project / f"model_{training_iter}" | ||
1059 | ) | ||
1060 | |||
1061 | trainer( | ||
1062 | strategy=dreambooth_strategy, | ||
1063 | train_dataloader=dreambooth_datamodule.train_dataloader, | ||
1064 | val_dataloader=dreambooth_datamodule.val_dataloader, | ||
1065 | optimizer=dreambooth_optimizer, | ||
1066 | lr_scheduler=dreambooth_lr_scheduler, | ||
1067 | num_train_epochs=num_train_epochs, | ||
1068 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
1069 | global_step_offset=training_iter * num_train_steps, | ||
1070 | cycle=training_iter, | ||
1071 | train_text_encoder_cycles=args.train_text_encoder_cycles, | ||
1072 | # -- | ||
1073 | group_labels=group_labels, | ||
1074 | sample_output_dir=dreambooth_sample_output_dir, | ||
1075 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | ||
1076 | sample_frequency=dreambooth_sample_frequency, | ||
1077 | offset_noise_strength=args.offset_noise_strength, | ||
1078 | input_pertubation=args.input_pertubation, | ||
1079 | no_val=args.valid_set_size == 0, | ||
1080 | avg_loss=avg_loss, | ||
1081 | avg_acc=avg_acc, | ||
1082 | avg_loss_val=avg_loss_val, | ||
1083 | avg_acc_val=avg_acc_val, | ||
1084 | ) | ||
1085 | |||
1086 | training_iter += 1 | ||
1087 | |||
1088 | accelerator.end_training() | ||
723 | 1089 | ||
724 | 1090 | ||
725 | if __name__ == "__main__": | 1091 | if __name__ == "__main__": |