summaryrefslogtreecommitdiffstats
path: root/dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-04 19:22:22 +0200
committerVolpeon <git@volpeon.ink>2022-10-04 19:22:22 +0200
commit300deaa789a0321f32d5e7f04d9860eaa258110e (patch)
tree892e89753e5c4d86d787131595751bc03c610be8 /dreambooth.py
parentDefault sample steps 30 -> 40 (diff)
downloadtextual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.gz
textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.tar.bz2
textual-inversion-diff-300deaa789a0321f32d5e7f04d9860eaa258110e.zip
Add Textual Inversion with class dataset (a la Dreambooth)
Diffstat (limited to 'dreambooth.py')
-rw-r--r--dreambooth.py25
1 files changed, 10 insertions, 15 deletions
diff --git a/dreambooth.py b/dreambooth.py
index aedf25c..0c5c42a 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -24,7 +24,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24import json 24import json
25 25
26from data.dreambooth.csv import CSVDataModule 26from data.dreambooth.csv import CSVDataModule
27from data.dreambooth.prompt import PromptDataset
28 27
29logger = get_logger(__name__) 28logger = get_logger(__name__)
30 29
@@ -122,7 +121,7 @@ def parse_args():
122 parser.add_argument( 121 parser.add_argument(
123 "--learning_rate", 122 "--learning_rate",
124 type=float, 123 type=float,
125 default=3e-6, 124 default=1e-6,
126 help="Initial learning rate (after the potential warmup period) to use.", 125 help="Initial learning rate (after the potential warmup period) to use.",
127 ) 126 )
128 parser.add_argument( 127 parser.add_argument(
@@ -219,17 +218,10 @@ def parse_args():
219 parser.add_argument( 218 parser.add_argument(
220 "--sample_steps", 219 "--sample_steps",
221 type=int, 220 type=int,
222 default=40, 221 default=30,
223 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 222 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
224 ) 223 )
225 parser.add_argument( 224 parser.add_argument(
226 "--class_data_dir",
227 type=str,
228 default=None,
229 required=False,
230 help="A folder containing the training data of class images.",
231 )
232 parser.add_argument(
233 "--prior_loss_weight", 225 "--prior_loss_weight",
234 type=float, 226 type=float,
235 default=1.0, 227 default=1.0,
@@ -311,7 +303,7 @@ class Checkpointer:
311 self.output_dir = output_dir 303 self.output_dir = output_dir
312 self.instance_identifier = instance_identifier 304 self.instance_identifier = instance_identifier
313 self.sample_image_size = sample_image_size 305 self.sample_image_size = sample_image_size
314 self.seed = seed 306 self.seed = seed or torch.random.seed()
315 self.sample_batches = sample_batches 307 self.sample_batches = sample_batches
316 self.sample_batch_size = sample_batch_size 308 self.sample_batch_size = sample_batch_size
317 309
@@ -406,6 +398,8 @@ class Checkpointer:
406 del unwrapped 398 del unwrapped
407 del scheduler 399 del scheduler
408 del pipeline 400 del pipeline
401 del generator
402 del stable_latents
409 403
410 if torch.cuda.is_available(): 404 if torch.cuda.is_available():
411 torch.cuda.empty_cache() 405 torch.cuda.empty_cache()
@@ -523,11 +517,13 @@ def main():
523 tokenizer=tokenizer, 517 tokenizer=tokenizer,
524 instance_identifier=args.instance_identifier, 518 instance_identifier=args.instance_identifier,
525 class_identifier=args.class_identifier, 519 class_identifier=args.class_identifier,
520 class_subdir="db_cls",
526 size=args.resolution, 521 size=args.resolution,
527 repeats=args.repeats, 522 repeats=args.repeats,
528 center_crop=args.center_crop, 523 center_crop=args.center_crop,
529 valid_set_size=args.sample_batch_size*args.sample_batches, 524 valid_set_size=args.sample_batch_size*args.sample_batches,
530 collate_fn=collate_fn) 525 collate_fn=collate_fn
526 )
531 527
532 datamodule.prepare_data() 528 datamodule.prepare_data()
533 datamodule.setup() 529 datamodule.setup()
@@ -587,7 +583,7 @@ def main():
587 sample_image_size=args.sample_image_size, 583 sample_image_size=args.sample_image_size,
588 sample_batch_size=args.sample_batch_size, 584 sample_batch_size=args.sample_batch_size,
589 sample_batches=args.sample_batches, 585 sample_batches=args.sample_batches,
590 seed=args.seed or torch.random.seed() 586 seed=args.seed
591 ) 587 )
592 588
593 # Scheduler and math around the number of training steps. 589 # Scheduler and math around the number of training steps.
@@ -699,8 +695,7 @@ def main():
699 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 695 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
700 696
701 # Compute prior loss 697 # Compute prior loss
702 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, 698 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
703 reduction="none").mean([1, 2, 3]).mean()
704 699
705 # Add the prior loss to the instance loss. 700 # Add the prior loss to the instance loss.
706 loss = loss + args.prior_loss_weight * prior_loss 701 loss = loss + args.prior_loss_weight * prior_loss