diff options
Diffstat (limited to 'dreambooth.py')
| -rw-r--r-- | dreambooth.py | 25 |
1 files changed, 10 insertions, 15 deletions
diff --git a/dreambooth.py b/dreambooth.py index aedf25c..0c5c42a 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -24,7 +24,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 24 | import json | 24 | import json |
| 25 | 25 | ||
| 26 | from data.dreambooth.csv import CSVDataModule | 26 | from data.dreambooth.csv import CSVDataModule |
| 27 | from data.dreambooth.prompt import PromptDataset | ||
| 28 | 27 | ||
| 29 | logger = get_logger(__name__) | 28 | logger = get_logger(__name__) |
| 30 | 29 | ||
| @@ -122,7 +121,7 @@ def parse_args(): | |||
| 122 | parser.add_argument( | 121 | parser.add_argument( |
| 123 | "--learning_rate", | 122 | "--learning_rate", |
| 124 | type=float, | 123 | type=float, |
| 125 | default=3e-6, | 124 | default=1e-6, |
| 126 | help="Initial learning rate (after the potential warmup period) to use.", | 125 | help="Initial learning rate (after the potential warmup period) to use.", |
| 127 | ) | 126 | ) |
| 128 | parser.add_argument( | 127 | parser.add_argument( |
| @@ -219,17 +218,10 @@ def parse_args(): | |||
| 219 | parser.add_argument( | 218 | parser.add_argument( |
| 220 | "--sample_steps", | 219 | "--sample_steps", |
| 221 | type=int, | 220 | type=int, |
| 222 | default=40, | 221 | default=30, |
| 223 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 222 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 224 | ) | 223 | ) |
| 225 | parser.add_argument( | 224 | parser.add_argument( |
| 226 | "--class_data_dir", | ||
| 227 | type=str, | ||
| 228 | default=None, | ||
| 229 | required=False, | ||
| 230 | help="A folder containing the training data of class images.", | ||
| 231 | ) | ||
| 232 | parser.add_argument( | ||
| 233 | "--prior_loss_weight", | 225 | "--prior_loss_weight", |
| 234 | type=float, | 226 | type=float, |
| 235 | default=1.0, | 227 | default=1.0, |
| @@ -311,7 +303,7 @@ class Checkpointer: | |||
| 311 | self.output_dir = output_dir | 303 | self.output_dir = output_dir |
| 312 | self.instance_identifier = instance_identifier | 304 | self.instance_identifier = instance_identifier |
| 313 | self.sample_image_size = sample_image_size | 305 | self.sample_image_size = sample_image_size |
| 314 | self.seed = seed | 306 | self.seed = seed or torch.random.seed() |
| 315 | self.sample_batches = sample_batches | 307 | self.sample_batches = sample_batches |
| 316 | self.sample_batch_size = sample_batch_size | 308 | self.sample_batch_size = sample_batch_size |
| 317 | 309 | ||
| @@ -406,6 +398,8 @@ class Checkpointer: | |||
| 406 | del unwrapped | 398 | del unwrapped |
| 407 | del scheduler | 399 | del scheduler |
| 408 | del pipeline | 400 | del pipeline |
| 401 | del generator | ||
| 402 | del stable_latents | ||
| 409 | 403 | ||
| 410 | if torch.cuda.is_available(): | 404 | if torch.cuda.is_available(): |
| 411 | torch.cuda.empty_cache() | 405 | torch.cuda.empty_cache() |
| @@ -523,11 +517,13 @@ def main(): | |||
| 523 | tokenizer=tokenizer, | 517 | tokenizer=tokenizer, |
| 524 | instance_identifier=args.instance_identifier, | 518 | instance_identifier=args.instance_identifier, |
| 525 | class_identifier=args.class_identifier, | 519 | class_identifier=args.class_identifier, |
| 520 | class_subdir="db_cls", | ||
| 526 | size=args.resolution, | 521 | size=args.resolution, |
| 527 | repeats=args.repeats, | 522 | repeats=args.repeats, |
| 528 | center_crop=args.center_crop, | 523 | center_crop=args.center_crop, |
| 529 | valid_set_size=args.sample_batch_size*args.sample_batches, | 524 | valid_set_size=args.sample_batch_size*args.sample_batches, |
| 530 | collate_fn=collate_fn) | 525 | collate_fn=collate_fn |
| 526 | ) | ||
| 531 | 527 | ||
| 532 | datamodule.prepare_data() | 528 | datamodule.prepare_data() |
| 533 | datamodule.setup() | 529 | datamodule.setup() |
| @@ -587,7 +583,7 @@ def main(): | |||
| 587 | sample_image_size=args.sample_image_size, | 583 | sample_image_size=args.sample_image_size, |
| 588 | sample_batch_size=args.sample_batch_size, | 584 | sample_batch_size=args.sample_batch_size, |
| 589 | sample_batches=args.sample_batches, | 585 | sample_batches=args.sample_batches, |
| 590 | seed=args.seed or torch.random.seed() | 586 | seed=args.seed |
| 591 | ) | 587 | ) |
| 592 | 588 | ||
| 593 | # Scheduler and math around the number of training steps. | 589 | # Scheduler and math around the number of training steps. |
| @@ -699,8 +695,7 @@ def main(): | |||
| 699 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 695 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 700 | 696 | ||
| 701 | # Compute prior loss | 697 | # Compute prior loss |
| 702 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | 698 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() |
| 703 | reduction="none").mean([1, 2, 3]).mean() | ||
| 704 | 699 | ||
| 705 | # Add the prior loss to the instance loss. | 700 | # Add the prior loss to the instance loss. |
| 706 | loss = loss + args.prior_loss_weight * prior_loss | 701 | loss = loss + args.prior_loss_weight * prior_loss |
