From 728dfcf57c30f40236b3a00d7380c4e0057cacb3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 17 Oct 2022 22:08:58 +0200 Subject: Implemented extended prompt limit --- dreambooth_plus.py | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) (limited to 'dreambooth_plus.py') diff --git a/dreambooth_plus.py b/dreambooth_plus.py index a98417f..ae31377 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -124,7 +124,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=1500, + default=1400, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -147,7 +147,7 @@ def parse_args(): parser.add_argument( "--learning_rate_text", type=float, - default=5e-6, + default=1e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -469,9 +469,16 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(self.instance_identifier) - for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] - nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] + prompt = [ + [p.format(self.instance_identifier) for p in prompt] + for batch in batches + for prompt in batch["prompts"] + ][:self.sample_batch_size] + nprompt = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ][:self.sample_batch_size] samples = pipeline( prompt=prompt, @@ -666,6 +673,17 @@ def main(): } return batch + def encode_input_ids(input_ids): + text_embeddings = [] + + for ids in input_ids: + embeddings = text_encoder(ids)[0] + embeddings = embeddings.reshape((1, -1, 768)) + text_embeddings.append(embeddings) + + text_embeddings = torch.cat(text_embeddings) + return text_embeddings + datamodule = CSVDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, @@ -688,8 +706,10 @@ def main(): missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] if len(missing_data) != 0: - batched_data = [missing_data[i:i+args.sample_batch_size] - for i in range(0, len(missing_data), args.sample_batch_size)] + batched_data = [ + missing_data[i:i+args.sample_batch_size] + for i in range(0, len(missing_data), args.sample_batch_size) + ] scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" @@ -706,9 +726,9 @@ def main(): with torch.inference_mode(): for batch in batched_data: - image_name = [p.class_image_path for p in batch] - prompt = [p.prompt.format(args.class_identifier) for p in batch] - nprompt = [p.nprompt for p in batch] + image_name = [item.class_image_path for item in batch] + prompt = [[p.format(args.class_identifier) for p in item.prompt] for item in batch] + nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, @@ -855,7 +875,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = encode_input_ids(batch["input_ids"]) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -954,7 +974,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = encode_input_ids(batch["input_ids"]) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample -- cgit v1.2.3-54-g00ecf