From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- train_ti.py | 379 ++++++++++++++++++++++++++++++++---------------------------- 1 file changed, 201 insertions(+), 178 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index f60e3e5..c6f0b3a 100644 --- a/train_ti.py +++ b/train_ti.py @@ -32,7 +32,7 @@ from util.files import load_config, load_embeddings_from_dir logger = get_logger(__name__) -warnings.filterwarnings('ignore') +warnings.filterwarnings("ignore") torch.backends.cuda.matmul.allow_tf32 = True @@ -46,9 +46,7 @@ hidet.torch.dynamo_config.search_space(0) def parse_args(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -65,12 +63,12 @@ def parse_args(): "--train_data_file", type=str, default=None, - help="A CSV file containing the training data." + help="A CSV file containing the training data.", ) parser.add_argument( "--train_data_template", type=str, - nargs='*', + nargs="*", default="template", ) parser.add_argument( @@ -80,59 +78,47 @@ def parse_args(): help="The name of the current project.", ) parser.add_argument( - "--auto_cycles", - type=str, - default="o", - help="Cycles to run automatically." + "--auto_cycles", type=str, default="o", help="Cycles to run automatically." ) parser.add_argument( - "--cycle_decay", - type=float, - default=1.0, - help="Learning rate decay per cycle." + "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." ) parser.add_argument( "--placeholder_tokens", type=str, - nargs='*', + nargs="*", help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_tokens", type=str, - nargs='*', - help="A token to use as initializer word." + nargs="*", + help="A token to use as initializer word.", ) parser.add_argument( - "--filter_tokens", - type=str, - nargs='*', - help="Tokens to filter the dataset by." + "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." ) parser.add_argument( "--initializer_noise", type=float, default=0, - help="Noise to apply to the initializer word" + help="Noise to apply to the initializer word", ) parser.add_argument( "--alias_tokens", type=str, - nargs='*', + nargs="*", default=[], - help="Tokens to create an alias for." + help="Tokens to create an alias for.", ) parser.add_argument( "--inverted_initializer_tokens", type=str, - nargs='*', - help="A token to use as initializer word." + nargs="*", + help="A token to use as initializer word.", ) parser.add_argument( - "--num_vectors", - type=int, - nargs='*', - help="Number of vectors per embedding." + "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) parser.add_argument( "--sequential", @@ -147,7 +133,7 @@ def parse_args(): "--num_class_images", type=int, default=0, - help="How many class images to generate." + help="How many class images to generate.", ) parser.add_argument( "--class_image_dir", @@ -158,7 +144,7 @@ def parse_args(): parser.add_argument( "--exclude_collections", type=str, - nargs='*', + nargs="*", help="Exclude all items with a listed collection.", ) parser.add_argument( @@ -181,14 +167,11 @@ def parse_args(): parser.add_argument( "--collection", type=str, - nargs='*', + nargs="*", help="A collection to filter the dataset.", ) parser.add_argument( - "--seed", - type=int, - default=None, - help="A seed for reproducible training." + "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", @@ -244,7 +227,7 @@ def parse_args(): type=str, default="auto", choices=["all", "trailing", "leading", "between", "auto", "off"], - help='Vector shuffling algorithm.', + help="Vector shuffling algorithm.", ) parser.add_argument( "--offset_noise_strength", @@ -256,18 +239,10 @@ def parse_args(): "--input_pertubation", type=float, default=0, - help="The scale of input pretubation. Recommended 0.1." - ) - parser.add_argument( - "--num_train_epochs", - type=int, - default=None - ) - parser.add_argument( - "--num_train_steps", - type=int, - default=2000 + help="The scale of input pretubation. Recommended 0.1.", ) + parser.add_argument("--num_train_epochs", type=int, default=None) + parser.add_argument("--num_train_steps", type=int, default=2000) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -299,27 +274,31 @@ def parse_args(): "--lr_scheduler", type=str, default="one_cycle", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", - "constant", "constant_with_warmup", "one_cycle"], - help='The scheduler type to use.', + choices=[ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + "one_cycle", + ], + help="The scheduler type to use.", ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, - help="Number of steps for the warmup in the lr scheduler." + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( - "--lr_mid_point", - type=float, - default=0.3, - help="OneCycle schedule mid point." + "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." ) parser.add_argument( "--lr_cycles", type=int, default=None, - help="Number of restart cycles in the lr scheduler." + help="Number of restart cycles in the lr scheduler.", ) parser.add_argument( "--lr_warmup_func", @@ -331,7 +310,7 @@ def parse_args(): "--lr_warmup_exp", type=int, default=1, - help='If lr_warmup_func is "cos", exponent to modify the function' + help='If lr_warmup_func is "cos", exponent to modify the function', ) parser.add_argument( "--lr_annealing_func", @@ -343,89 +322,67 @@ def parse_args(): "--lr_annealing_exp", type=int, default=1, - help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, - help="Minimum learning rate in the lr scheduler." + help="Minimum learning rate in the lr scheduler.", ) parser.add_argument( - "--use_ema", - action="store_true", - help="Whether to use EMA model." - ) - parser.add_argument( - "--ema_inv_gamma", - type=float, - default=1.0 - ) - parser.add_argument( - "--ema_power", - type=float, - default=4/5 - ) - parser.add_argument( - "--ema_max_decay", - type=float, - default=0.9999 - ) - parser.add_argument( - "--min_snr_gamma", - type=int, - default=5, - help="MinSNR gamma." + "--use_ema", action="store_true", help="Whether to use EMA model." ) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) + parser.add_argument("--ema_power", type=float, default=4 / 5) + parser.add_argument("--ema_max_decay", type=float, default=0.9999) + parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") parser.add_argument( "--schedule_sampler", type=str, default="uniform", choices=["uniform", "loss-second-moment"], - help="Noise schedule sampler." + help="Noise schedule sampler.", ) parser.add_argument( "--optimizer", type=str, default="adan", choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], - help='Optimizer to use' + help="Optimizer to use", ) parser.add_argument( "--dadaptation_d0", type=float, default=1e-6, - help="The d0 parameter for Dadaptation optimizers." + help="The d0 parameter for Dadaptation optimizers.", ) parser.add_argument( "--adam_beta1", type=float, default=None, - help="The beta1 parameter for the Adam optimizer." + help="The beta1 parameter for the Adam optimizer.", ) parser.add_argument( "--adam_beta2", type=float, default=None, - help="The beta2 parameter for the Adam optimizer." + help="The beta2 parameter for the Adam optimizer.", ) parser.add_argument( - "--adam_weight_decay", - type=float, - default=2e-2, - help="Weight decay to use." + "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, - help="Epsilon value for the Adam optimizer" + help="Epsilon value for the Adam optimizer", ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, - help="Amsgrad value for the Adam optimizer" + help="Amsgrad value for the Adam optimizer", ) parser.add_argument( "--mixed_precision", @@ -456,7 +413,7 @@ def parse_args(): ) parser.add_argument( "--no_milestone_checkpoints", - action='store_true', + action="store_true", help="If checkpoints are saved on maximum accuracy", ) parser.add_argument( @@ -493,25 +450,25 @@ def parse_args(): "--valid_set_size", type=int, default=None, - help="Number of images in the validation dataset." + help="Number of images in the validation dataset.", ) parser.add_argument( "--train_set_pad", type=int, default=None, - help="The number to fill train dataset items up to." + help="The number to fill train dataset items up to.", ) parser.add_argument( "--valid_set_pad", type=int, default=None, - help="The number to fill validation dataset items up to." + help="The number to fill validation dataset items up to.", ) parser.add_argument( "--train_batch_size", type=int, default=1, - help="Batch size (per device) for the training dataloader." + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--sample_steps", @@ -523,14 +480,9 @@ def parse_args(): "--prior_loss_weight", type=float, default=1.0, - help="The weight of prior preservation loss." - ) - parser.add_argument( - "--emb_alpha", - type=float, - default=1.0, - help="Embedding alpha" + help="The weight of prior preservation loss.", ) + parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") parser.add_argument( "--emb_dropout", type=float, @@ -538,21 +490,13 @@ def parse_args(): help="Embedding dropout probability.", ) parser.add_argument( - "--use_emb_decay", - action="store_true", - help="Whether to use embedding decay." + "--use_emb_decay", action="store_true", help="Whether to use embedding decay." ) parser.add_argument( - "--emb_decay_target", - default=0.4, - type=float, - help="Embedding decay target." + "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." ) parser.add_argument( - "--emb_decay", - default=1e+2, - type=float, - help="Embedding decay factor." + "--emb_decay", default=1e2, type=float, help="Embedding decay factor." ) parser.add_argument( "--noise_timesteps", @@ -563,7 +507,7 @@ def parse_args(): "--resume_from", type=str, default=None, - help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" + help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)", ) parser.add_argument( "--global_step", @@ -574,7 +518,7 @@ def parse_args(): "--config", type=str, default=None, - help="Path to a JSON configuration file containing arguments for invoking this script." + help="Path to a JSON configuration file containing arguments for invoking this script.", ) args = parser.parse_args() @@ -595,29 +539,44 @@ def parse_args(): args.placeholder_tokens = [args.placeholder_tokens] if isinstance(args.initializer_tokens, str): - args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) + args.initializer_tokens = [args.initializer_tokens] * len( + args.placeholder_tokens + ) if len(args.placeholder_tokens) == 0: - args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] + args.placeholder_tokens = [ + f"<*{i}>" for i in range(len(args.initializer_tokens)) + ] if len(args.initializer_tokens) == 0: args.initializer_tokens = args.placeholder_tokens.copy() if len(args.placeholder_tokens) != len(args.initializer_tokens): - raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") + raise ValueError( + "--placeholder_tokens and --initializer_tokens must have the same number of items" + ) if isinstance(args.inverted_initializer_tokens, str): - args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) + args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( + args.placeholder_tokens + ) - if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: + if ( + isinstance(args.inverted_initializer_tokens, list) + and len(args.inverted_initializer_tokens) != 0 + ): args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] args.initializer_tokens += args.inverted_initializer_tokens if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) - if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): - raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( + args.num_vectors + ): + raise ValueError( + "--placeholder_tokens and --num_vectors must have the same number of items" + ) if args.alias_tokens is None: args.alias_tokens = [] @@ -639,16 +598,22 @@ def parse_args(): ] if isinstance(args.train_data_template, str): - args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) + args.train_data_template = [args.train_data_template] * len( + args.placeholder_tokens + ) if len(args.placeholder_tokens) != len(args.train_data_template): - raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + raise ValueError( + "--placeholder_tokens and --train_data_template must have the same number of items" + ) if args.num_vectors is None: args.num_vectors = [None] * len(args.placeholder_tokens) else: if isinstance(args.train_data_template, list): - raise ValueError("--train_data_template can't be a list in simultaneous mode") + raise ValueError( + "--train_data_template can't be a list in simultaneous mode" + ) if isinstance(args.collection, str): args.collection = [args.collection] @@ -660,13 +625,13 @@ def parse_args(): raise ValueError("You must specify --output_dir") if args.adam_beta1 is None: - if args.optimizer == 'lion': + if args.optimizer == "lion": args.adam_beta1 = 0.95 else: args.adam_beta1 = 0.9 if args.adam_beta2 is None: - if args.optimizer == 'lion': + if args.optimizer == "lion": args.adam_beta2 = 0.98 else: args.adam_beta2 = 0.999 @@ -679,13 +644,13 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir)/slugify(args.project)/now + output_dir = Path(args.output_dir) / slugify(args.project) / now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, project_dir=f"{output_dir}", - mixed_precision=args.mixed_precision + mixed_precision=args.mixed_precision, ) weight_dtype = torch.float32 @@ -703,9 +668,15 @@ def main(): save_args(output_dir, args) - tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) - embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) - schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( + args.pretrained_model_name_or_path + ) + embeddings = patch_managed_embeddings( + text_encoder, args.emb_alpha, args.emb_dropout + ) + schedule_sampler = create_named_schedule_sampler( + args.schedule_sampler, noise_scheduler.config.num_train_timesteps + ) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -717,16 +688,16 @@ def main(): unet.enable_xformers_memory_efficient_attention() elif args.compile_unet: unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False - + proc = AttnProcessor() - + def fn_recursive_set_proc(module: torch.nn.Module): if hasattr(module, "processor"): module.processor = proc - + for child in module.children(): fn_recursive_set_proc(child) - + fn_recursive_set_proc(unet) if args.gradient_checkpointing: @@ -751,18 +722,24 @@ def main(): tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=alias_placeholder_tokens, - initializer_tokens=alias_initializer_tokens + initializer_tokens=alias_initializer_tokens, ) embeddings.persist() - print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + print( + f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" + ) if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + added_tokens, added_ids = load_embeddings_from_dir( + tokenizer, embeddings, embeddings_dir + ) + print( + f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" + ) if args.train_dir_embeddings: args.placeholder_tokens = added_tokens @@ -772,19 +749,23 @@ def main(): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) if args.find_lr: args.learning_rate = 1e-5 args.lr_scheduler = "exponential_growth" - if args.optimizer == 'adam8bit': + if args.optimizer == "adam8bit": try: import bitsandbytes as bnb except ImportError: - raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) create_optimizer = partial( bnb.optim.AdamW8bit, @@ -793,7 +774,7 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adam': + elif args.optimizer == "adam": create_optimizer = partial( torch.optim.AdamW, betas=(args.adam_beta1, args.adam_beta2), @@ -801,11 +782,13 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adan': + elif args.optimizer == "adan": try: import timm.optim except ImportError: - raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") + raise ImportError( + "To use Adan, please install the PyTorch Image Models library: `pip install timm`." + ) create_optimizer = partial( timm.optim.Adan, @@ -813,11 +796,13 @@ def main(): eps=args.adam_epsilon, no_prox=True, ) - elif args.optimizer == 'lion': + elif args.optimizer == "lion": try: import lion_pytorch except ImportError: - raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") + raise ImportError( + "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." + ) create_optimizer = partial( lion_pytorch.Lion, @@ -825,7 +810,7 @@ def main(): weight_decay=args.adam_weight_decay, use_triton=True, ) - elif args.optimizer == 'adafactor': + elif args.optimizer == "adafactor": create_optimizer = partial( transformers.optimization.Adafactor, weight_decay=args.adam_weight_decay, @@ -837,11 +822,13 @@ def main(): args.lr_scheduler = "adafactor" args.lr_min_lr = args.learning_rate args.learning_rate = None - elif args.optimizer == 'dadam': + elif args.optimizer == "dadam": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdam, @@ -851,11 +838,13 @@ def main(): decouple=True, d0=args.dadaptation_d0, ) - elif args.optimizer == 'dadan': + elif args.optimizer == "dadan": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdan, @@ -864,7 +853,7 @@ def main(): d0=args.dadaptation_d0, ) else: - raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") + raise ValueError(f'Unknown --optimizer "{args.optimizer}"') trainer = partial( train, @@ -904,10 +893,21 @@ def main(): sample_image_size=args.sample_image_size, ) + optimizer = create_optimizer( + text_encoder.text_model.embeddings.token_embedding.parameters(), + lr=learning_rate, + ) + data_generator = torch.Generator(device="cpu").manual_seed(args.seed) data_npgenerator = np.random.default_rng(args.seed) - def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): + def run( + i: int, + placeholder_tokens: list[str], + initializer_tokens: list[str], + num_vectors: Union[int, list[int]], + data_template: str, + ): placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, @@ -917,14 +917,23 @@ def main(): initializer_noise=args.initializer_noise, ) - stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) + stats = list( + zip( + placeholder_tokens, + placeholder_token_ids, + initializer_tokens, + initializer_token_ids, + ) + ) print("") print(f"============ TI batch {i + 1} ============") print("") print(stats) - filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] + filter_tokens = [ + token for token in args.filter_tokens if token in placeholder_tokens + ] datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -945,7 +954,9 @@ def main(): valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, - filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), + filter=partial( + keyword_filter, filter_tokens, args.collection, args.exclude_collections + ), dtype=weight_dtype, generator=data_generator, npgenerator=data_npgenerator, @@ -955,11 +966,16 @@ def main(): num_train_epochs = args.num_train_epochs sample_frequency = args.sample_frequency if num_train_epochs is None: - num_train_epochs = math.ceil( - args.num_train_steps / len(datamodule.train_dataset) - ) * args.gradient_accumulation_steps - sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) - num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) + num_train_epochs = ( + math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + * args.gradient_accumulation_steps + ) + sample_frequency = math.ceil( + num_train_epochs * (sample_frequency / args.num_train_steps) + ) + num_training_steps_per_epoch = math.ceil( + len(datamodule.train_dataset) / args.gradient_accumulation_steps + ) num_train_steps = num_training_steps_per_epoch * num_train_epochs if args.sample_num is not None: sample_frequency = math.ceil(num_train_epochs / args.sample_num) @@ -988,7 +1004,8 @@ def main(): response = auto_cycles.pop(0) else: response = input( - "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") + "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " + ) if response.lower().strip() == "o": if args.learning_rate is not None: @@ -1018,10 +1035,8 @@ def main(): print(f"------------ TI cycle {training_iter + 1}: {response} ------------") print("") - optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_embedding.parameters(), - lr=learning_rate, - ) + for group, lr in zip(optimizer.param_groups, [learning_rate]): + group["lr"] = lr lr_scheduler = get_scheduler( lr_scheduler, @@ -1040,7 +1055,9 @@ def main(): mid_point=args.lr_mid_point, ) - checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" + checkpoint_output_dir = ( + output_dir / project / f"checkpoints_{training_iter}" + ) trainer( train_dataloader=datamodule.train_dataloader, @@ -1070,14 +1087,20 @@ def main(): accelerator.end_training() if not args.sequential: - run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) + run( + 0, + args.placeholder_tokens, + args.initializer_tokens, + args.num_vectors, + args.train_data_template, + ) else: for i, placeholder_token, initializer_token, num_vectors, data_template in zip( range(len(args.placeholder_tokens)), args.placeholder_tokens, args.initializer_tokens, args.num_vectors, - args.train_data_template + args.train_data_template, ): run(i, [placeholder_token], [initializer_token], num_vectors, data_template) embeddings.persist() -- cgit v1.2.3-54-g00ecf