From e2d3a62bce63fcde940395a1c5618c4eb43385a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:25:13 +0100 Subject: Cleanup --- train_ti.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) (limited to 'train_ti.py') diff --git a/train_ti.py b/train_ti.py index d2ca7eb..d752927 100644 --- a/train_ti.py +++ b/train_ti.py @@ -180,15 +180,6 @@ def parse_args(): default="auto", help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" - " process." - ), - ) parser.add_argument( "--num_train_epochs", type=int, @@ -575,24 +566,24 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.project), now) - basepath.mkdir(parents=True, exist_ok=True) + output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, - logging_dir=f"{basepath}", + logging_dir=f"{output_dir}", gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision ) - logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 set_seed(args.seed) - save_args(basepath, args) + save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) @@ -616,7 +607,7 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - placeholder_token_ids = add_placeholder_tokens( + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=args.placeholder_tokens, @@ -625,7 +616,9 @@ def main(): ) if len(placeholder_token_ids) != 0: - print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") + initializer_token_id_lens = [len(id) for id in initializer_token_ids] + placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) + print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") if args.use_ema: ema_embeddings = EMAModel( @@ -708,7 +701,6 @@ def main(): template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, - num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, dtype=weight_dtype @@ -807,7 +799,6 @@ def main(): noise_scheduler, unet, text_encoder, - args.num_class_images != 0, args.prior_loss_weight, args.seed, ) @@ -825,7 +816,8 @@ def main(): scheduler=sample_scheduler, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, - output_dir=basepath, + output_dir=output_dir, + sample_steps=args.sample_steps, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, @@ -849,7 +841,7 @@ def main(): ) lr_finder.run(num_epochs=100, end_lr=1e3) - plt.savefig(basepath.joinpath("lr.png"), dpi=300) + plt.savefig(output_dir.joinpath("lr.png"), dpi=300) plt.close() else: train_loop( @@ -862,7 +854,6 @@ def main(): val_dataloader=val_dataloader, loss_step=loss_step_, sample_frequency=args.sample_frequency, - sample_steps=args.sample_steps, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=global_step_offset, num_epochs=args.num_train_epochs, -- cgit v1.2.3-54-g00ecf