diff options
author | Volpeon <git@volpeon.ink> | 2022-10-04 19:22:22 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-04 19:22:22 +0200 |
commit | 300deaa789a0321f32d5e7f04d9860eaa258110e (patch) | |
tree | 892e89753e5c4d86d787131595751bc03c610be8 /dreambooth.py | |
parent | Default sample steps 30 -> 40 (diff) | |
download | textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.gz textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.bz2 textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.zip |
Add Textual Inversion with class dataset (a la Dreambooth)
Diffstat (limited to 'dreambooth.py')
-rw-r--r-- | dreambooth.py | 25 |
1 files changed, 10 insertions, 15 deletions
diff --git a/dreambooth.py b/dreambooth.py index aedf25c..0c5c42a 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -24,7 +24,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
24 | import json | 24 | import json |
25 | 25 | ||
26 | from data.dreambooth.csv import CSVDataModule | 26 | from data.dreambooth.csv import CSVDataModule |
27 | from data.dreambooth.prompt import PromptDataset | ||
28 | 27 | ||
29 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |
30 | 29 | ||
@@ -122,7 +121,7 @@ def parse_args(): | |||
122 | parser.add_argument( | 121 | parser.add_argument( |
123 | "--learning_rate", | 122 | "--learning_rate", |
124 | type=float, | 123 | type=float, |
125 | default=3e-6, | 124 | default=1e-6, |
126 | help="Initial learning rate (after the potential warmup period) to use.", | 125 | help="Initial learning rate (after the potential warmup period) to use.", |
127 | ) | 126 | ) |
128 | parser.add_argument( | 127 | parser.add_argument( |
@@ -219,17 +218,10 @@ def parse_args(): | |||
219 | parser.add_argument( | 218 | parser.add_argument( |
220 | "--sample_steps", | 219 | "--sample_steps", |
221 | type=int, | 220 | type=int, |
222 | default=40, | 221 | default=30, |
223 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 222 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
224 | ) | 223 | ) |
225 | parser.add_argument( | 224 | parser.add_argument( |
226 | "--class_data_dir", | ||
227 | type=str, | ||
228 | default=None, | ||
229 | required=False, | ||
230 | help="A folder containing the training data of class images.", | ||
231 | ) | ||
232 | parser.add_argument( | ||
233 | "--prior_loss_weight", | 225 | "--prior_loss_weight", |
234 | type=float, | 226 | type=float, |
235 | default=1.0, | 227 | default=1.0, |
@@ -311,7 +303,7 @@ class Checkpointer: | |||
311 | self.output_dir = output_dir | 303 | self.output_dir = output_dir |
312 | self.instance_identifier = instance_identifier | 304 | self.instance_identifier = instance_identifier |
313 | self.sample_image_size = sample_image_size | 305 | self.sample_image_size = sample_image_size |
314 | self.seed = seed | 306 | self.seed = seed or torch.random.seed() |
315 | self.sample_batches = sample_batches | 307 | self.sample_batches = sample_batches |
316 | self.sample_batch_size = sample_batch_size | 308 | self.sample_batch_size = sample_batch_size |
317 | 309 | ||
@@ -406,6 +398,8 @@ class Checkpointer: | |||
406 | del unwrapped | 398 | del unwrapped |
407 | del scheduler | 399 | del scheduler |
408 | del pipeline | 400 | del pipeline |
401 | del generator | ||
402 | del stable_latents | ||
409 | 403 | ||
410 | if torch.cuda.is_available(): | 404 | if torch.cuda.is_available(): |
411 | torch.cuda.empty_cache() | 405 | torch.cuda.empty_cache() |
@@ -523,11 +517,13 @@ def main(): | |||
523 | tokenizer=tokenizer, | 517 | tokenizer=tokenizer, |
524 | instance_identifier=args.instance_identifier, | 518 | instance_identifier=args.instance_identifier, |
525 | class_identifier=args.class_identifier, | 519 | class_identifier=args.class_identifier, |
520 | class_subdir="db_cls", | ||
526 | size=args.resolution, | 521 | size=args.resolution, |
527 | repeats=args.repeats, | 522 | repeats=args.repeats, |
528 | center_crop=args.center_crop, | 523 | center_crop=args.center_crop, |
529 | valid_set_size=args.sample_batch_size*args.sample_batches, | 524 | valid_set_size=args.sample_batch_size*args.sample_batches, |
530 | collate_fn=collate_fn) | 525 | collate_fn=collate_fn |
526 | ) | ||
531 | 527 | ||
532 | datamodule.prepare_data() | 528 | datamodule.prepare_data() |
533 | datamodule.setup() | 529 | datamodule.setup() |
@@ -587,7 +583,7 @@ def main(): | |||
587 | sample_image_size=args.sample_image_size, | 583 | sample_image_size=args.sample_image_size, |
588 | sample_batch_size=args.sample_batch_size, | 584 | sample_batch_size=args.sample_batch_size, |
589 | sample_batches=args.sample_batches, | 585 | sample_batches=args.sample_batches, |
590 | seed=args.seed or torch.random.seed() | 586 | seed=args.seed |
591 | ) | 587 | ) |
592 | 588 | ||
593 | # Scheduler and math around the number of training steps. | 589 | # Scheduler and math around the number of training steps. |
@@ -699,8 +695,7 @@ def main(): | |||
699 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 695 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
700 | 696 | ||
701 | # Compute prior loss | 697 | # Compute prior loss |
702 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | 698 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() |
703 | reduction="none").mean([1, 2, 3]).mean() | ||
704 | 699 | ||
705 | # Add the prior loss to the instance loss. | 700 | # Add the prior loss to the instance loss. |
706 | loss = loss + args.prior_loss_weight * prior_loss | 701 | loss = loss + args.prior_loss_weight * prior_loss |