From 300deaa789a0321f32d5e7f04d9860eaa258110e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 4 Oct 2022 19:22:22 +0200 Subject: Add Textual Inversion with class dataset (a la Dreambooth) --- dreambooth.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index aedf25c..0c5c42a 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -24,7 +24,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json from data.dreambooth.csv import CSVDataModule -from data.dreambooth.prompt import PromptDataset logger = get_logger(__name__) @@ -122,7 +121,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=3e-6, + default=1e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -219,16 +218,9 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=40, + default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) - parser.add_argument( - "--class_data_dir", - type=str, - default=None, - required=False, - help="A folder containing the training data of class images.", - ) parser.add_argument( "--prior_loss_weight", type=float, @@ -311,7 +303,7 @@ class Checkpointer: self.output_dir = output_dir self.instance_identifier = instance_identifier self.sample_image_size = sample_image_size - self.seed = seed + self.seed = seed or torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size @@ -406,6 +398,8 @@ class Checkpointer: del unwrapped del scheduler del pipeline + del generator + del stable_latents if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -523,11 +517,13 @@ def main(): tokenizer=tokenizer, instance_identifier=args.instance_identifier, class_identifier=args.class_identifier, + class_subdir="db_cls", size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, valid_set_size=args.sample_batch_size*args.sample_batches, - collate_fn=collate_fn) + collate_fn=collate_fn + ) datamodule.prepare_data() datamodule.setup() @@ -587,7 +583,7 @@ def main(): sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, - seed=args.seed or torch.random.seed() + seed=args.seed ) # Scheduler and math around the number of training steps. @@ -699,8 +695,7 @@ def main(): loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, - reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss -- cgit v1.2.3-54-g00ecf