From 5ef5ff5aece1a29995f11943c3ca1d6fe2fabbfa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 19 Oct 2022 17:45:56 +0200 Subject: Dreambooth: Added option to insert a new input token; removed Dreambooth Plus --- dreambooth.py | 48 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 6 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index d1cf535..da8399f 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -68,6 +68,18 @@ def parse_args(): default=None, help="A token to use as a placeholder for the concept.", ) + parser.add_argument( + "--placeholder_token", + type=str, + default="<*>", + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_token", + type=str, + default=None, + help="A token to use as initializer word." + ) parser.add_argument( "--num_class_images", type=int, @@ -114,7 +126,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=3600, + default=6000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -131,13 +143,13 @@ def parse_args(): parser.add_argument( "--learning_rate_unet", type=float, - default=3e-6, + default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--learning_rate_text", type=float, - default=3e-6, + default=2e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -476,8 +488,13 @@ class Checkpointer: def main(): args = parse_args() + instance_identifier = args.instance_identifier + + if args.placeholder_token is not None: + instance_identifier = instance_identifier.format(args.placeholder_token) + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now) + basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -514,6 +531,25 @@ def main(): device=accelerator.device ) if args.use_ema else None + if args.initializer_token is not None: + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.") + initializer_token_ids = torch.tensor(initializer_token_ids[:1]) + + # Add the placeholder token in tokenizer + num_added_tokens = tokenizer.add_tokens(args.placeholder_token) + if num_added_tokens == 0: + print(f"Re-using existing token {args.placeholder_token}.") + else: + print(f"Training new token {args.placeholder_token}.") + placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings() + initializer_token_embeddings = token_embeds(initializer_token_ids) + token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings + prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.gradient_checkpointing: @@ -605,7 +641,7 @@ def main(): data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, - instance_identifier=args.instance_identifier, + instance_identifier=instance_identifier, class_identifier=args.class_identifier, class_subdir="cls", num_class_images=args.num_class_images, @@ -735,7 +771,7 @@ def main(): tokenizer=tokenizer, text_encoder=text_encoder, output_dir=basepath, - instance_identifier=args.instance_identifier, + instance_identifier=instance_identifier, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, -- cgit v1.2.3-54-g00ecf