From 6c8cffe28baeafac77d047ff3f8ded9418033e2f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 15:52:43 +0100 Subject: More training adjustments --- data/csv.py | 39 ++++++++++++++------------ train_dreambooth.py | 71 ++++++++++++++++++++++++++++++++++++++++-------- train_ti.py | 17 ++++++++---- training/functional.py | 5 ++-- training/optimization.py | 10 +++---- training/strategy/ti.py | 2 +- 6 files changed, 101 insertions(+), 43 deletions(-) diff --git a/data/csv.py b/data/csv.py index dec66d7..85b98f8 100644 --- a/data/csv.py +++ b/data/csv.py @@ -174,7 +174,8 @@ class VlpnDataModule(): interpolation: str = "bicubic", template_key: str = "template", valid_set_size: Optional[int] = None, - valid_set_repeat: int = 1, + train_set_pad: Optional[int] = None, + valid_set_pad: Optional[int] = None, seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, @@ -202,7 +203,8 @@ class VlpnDataModule(): self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size - self.valid_set_repeat = valid_set_repeat + self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size + self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size self.seed = seed self.filter = filter self.batch_size = batch_size @@ -267,9 +269,6 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) - if (len(items) < self.batch_size): - items = (items * self.batch_size)[:self.batch_size] - num_images = len(items) valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 @@ -283,14 +282,17 @@ class VlpnDataModule(): collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) if valid_set_size == 0: - data_train, data_val = items, [] + data_train, data_val = items, items[:1] else: data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) - self.data_train = self.pad_items(data_train, self.num_class_images) + data_train = self.pad_items(data_train, self.num_class_images) + + if len(data_train) < self.train_set_pad: + data_train *= math.ceil(self.train_set_pad / len(data_train)) - train_dataset = VlpnDataset( - self.data_train, self.tokenizer, + self.train_dataset = VlpnDataset( + data_train, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, @@ -299,24 +301,26 @@ class VlpnDataModule(): ) self.train_dataloader = DataLoader( - train_dataset, + self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) - if valid_set_size != 0: - self.data_val = self.pad_items(data_val) + if len(data_val) != 0: + data_val = self.pad_items(data_val) + + if len(data_val) < self.valid_set_pad: + data_val *= math.ceil(self.valid_set_pad / len(data_val)) - val_dataset = VlpnDataset( - self.data_val, self.tokenizer, + self.val_dataset = VlpnDataset( + data_val, self.tokenizer, num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - repeat=self.valid_set_repeat, batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, ) self.val_dataloader = DataLoader( - val_dataset, + self.val_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) else: @@ -332,7 +336,6 @@ class VlpnDataset(IterableDataset): bucket_step_size: int = 64, bucket_max_pixels: Optional[int] = None, progressive_buckets: bool = False, - repeat: int = 1, batch_size: int = 1, num_class_images: int = 0, size: int = 768, @@ -341,7 +344,7 @@ class VlpnDataset(IterableDataset): interpolation: str = "bicubic", generator: Optional[torch.Generator] = None, ): - self.items = items * repeat + self.items = items self.batch_size = batch_size self.tokenizer = tokenizer diff --git a/train_dreambooth.py b/train_dreambooth.py index a9fbbbd..1dc41b1 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -54,6 +54,18 @@ def parse_args(): type=str, default="template", ) + parser.add_argument( + "--train_set_pad", + type=int, + default=None, + help="The number to fill train dataset items up to." + ) + parser.add_argument( + "--valid_set_pad", + type=int, + default=None, + help="The number to fill validation dataset items up to." + ) parser.add_argument( "--project", type=str, @@ -187,11 +199,23 @@ def parse_args(): type=int, default=100 ) + parser.add_argument( + "--ti_data_template", + type=str, + nargs='*', + default=[], + ) parser.add_argument( "--ti_num_train_epochs", type=int, default=10 ) + parser.add_argument( + "--ti_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) parser.add_argument( "--max_train_steps", type=int, @@ -458,6 +482,12 @@ def parse_args(): if len(args.placeholder_tokens) != len(args.num_vectors): raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if isinstance(args.ti_data_template, str): + args.ti_data_template = [args.ti_data_template] + + if len(args.ti_data_template) == 0: + raise ValueError("You must specify --ti_data_template") + if isinstance(args.collection, str): args.collection = [args.collection] @@ -491,6 +521,8 @@ def main(): set_seed(args.seed) + seed_generator = torch.Generator().manual_seed(args.seed) + save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( @@ -512,6 +544,8 @@ def main(): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") + embeddings.persist() + added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") @@ -545,7 +579,6 @@ def main(): vae=vae, noise_scheduler=noise_scheduler, dtype=weight_dtype, - seed=args.seed, with_prior_preservation=args.num_class_images != 0, prior_loss_weight=args.prior_loss_weight, ) @@ -557,13 +590,17 @@ def main(): cur_dir = output_dir.joinpath("1-ti") cur_dir.mkdir(parents=True, exist_ok=True) - for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors): - print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})") - + for i, placeholder_token, initializer_token, num_vectors, data_template in zip( + range(len(args.placeholder_tokens)), + args.placeholder_tokens, + args.initializer_tokens, + args.num_vectors, + args.ti_data_template + ): cur_subdir = cur_dir.joinpath(placeholder_token) cur_subdir.mkdir(parents=True, exist_ok=True) - placeholder_token_ids, _ = add_placeholder_tokens( + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=[placeholder_token], @@ -571,17 +608,23 @@ def main(): num_vectors=[num_vectors] ) + print( + f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") + + args.seed = seed_generator.seed() + datamodule = VlpnDataModule( data_file=args.train_data_file, - batch_size=args.train_batch_size, + batch_size=args.ti_batch_size, tokenizer=tokenizer, class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, shuffle=not args.no_tag_shuffle, - template_key=args.train_data_template, + template_key=data_template, valid_set_size=1, - valid_set_repeat=args.valid_set_repeat, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, seed=args.seed, filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), dtype=weight_dtype @@ -591,7 +634,9 @@ def main(): optimizer = optimizer_class( text_encoder.text_model.embeddings.temp_token_embedding.parameters(), lr=args.ti_learning_rate, + betas=(args.adam_beta1, args.adam_beta2), weight_decay=0.0, + eps=args.adam_epsilon, ) lr_scheduler = get_scheduler( @@ -600,7 +645,6 @@ def main(): num_training_steps_per_epoch=len(datamodule.train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, train_epochs=args.ti_num_train_epochs, - warmup_epochs=args.ti_num_train_epochs // 4, ) trainer( @@ -608,10 +652,11 @@ def main(): project="textual_inversion", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, + seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=args.ti_num_train_epochs, - sample_frequency=2, + sample_frequency=args.ti_num_train_epochs // 5, checkpoint_frequency=9999999, # -- tokenizer=tokenizer, @@ -637,7 +682,7 @@ def main(): cur_dir = output_dir.joinpath("2-db") cur_dir.mkdir(parents=True, exist_ok=True) - args.seed = (args.seed + 28635) >> 32 + args.seed = seed_generator.seed() datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -654,7 +699,8 @@ def main(): shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, - valid_set_repeat=args.valid_set_repeat, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, seed=args.seed, filter=partial(keyword_filter, None, args.collection, args.exclude_collections), dtype=weight_dtype @@ -697,6 +743,7 @@ def main(): project="dreambooth", train_dataloader=datamodule.train_dataloader, val_dataloader=datamodule.val_dataloader, + seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=args.num_train_epochs, diff --git a/train_ti.py b/train_ti.py index a894ee7..7aecdef 100644 --- a/train_ti.py +++ b/train_ti.py @@ -360,10 +360,16 @@ def parse_args(): help="Number of images in the validation dataset." ) parser.add_argument( - "--valid_set_repeat", + "--train_set_pad", type=int, - default=1, - help="Times the images in the validation dataset are repeated." + default=None, + help="The number to fill train dataset items up to." + ) + parser.add_argument( + "--valid_set_pad", + type=int, + default=None, + help="The number to fill validation dataset items up to." ) parser.add_argument( "--train_batch_size", @@ -575,7 +581,8 @@ def main(): shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, valid_set_size=args.valid_set_size, - valid_set_repeat=args.valid_set_repeat, + train_set_pad=args.train_set_pad, + valid_set_pad=args.valid_set_pad, seed=args.seed, filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), dtype=weight_dtype @@ -590,7 +597,7 @@ def main(): unet, tokenizer, sample_scheduler, - datamodule.data_train, + datamodule.train_dataset, args.sample_batch_size, args.sample_image_size, args.sample_steps diff --git a/training/functional.py b/training/functional.py index c6b4dc3..b6b5d87 100644 --- a/training/functional.py +++ b/training/functional.py @@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol from tqdm.auto import tqdm from PIL import Image +from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings @@ -175,12 +176,12 @@ def generate_class_images( unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, sample_scheduler: DPMSolverMultistepScheduler, - data_train, + train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, sample_steps: int ): - missing_data = [item for item in data_train if not item.class_image_path.exists()] + missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] if len(missing_data) == 0: return diff --git a/training/optimization.py b/training/optimization.py index 5db7794..6dee4bc 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -49,8 +49,8 @@ def get_one_cycle_schedule( annealing: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, annealing_exp: int = 1, - min_lr: int = 0.04, - mid_point: int = 0.3, + min_lr: float = 0.04, + mid_point: float = 0.3, last_epoch: int = -1 ): if warmup == "linear": @@ -91,10 +91,10 @@ def get_scheduler( id: str, optimizer: torch.optim.Optimizer, num_training_steps_per_epoch: int, - gradient_accumulation_steps: int, + gradient_accumulation_steps: int = 1, min_lr: float = 0.04, - warmup_func: str = "cos", - annealing_func: str = "cos", + warmup_func: Literal["cos", "linear"] = "cos", + annealing_func: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, annealing_exp: int = 1, cycles: int = 1, diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 568f9eb..9d39e15 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -36,7 +36,7 @@ def textual_inversion_strategy( use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay_factor: float = 1, - emb_decay_start: float = 1e-4, + emb_decay_start: float = 0, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, -- cgit v1.2.3-54-g00ecf