From 5571c4ebcb39813e2bd8585de30c64bb02f9d7fa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 09:43:22 +0100 Subject: Improved aspect ratio bucketing --- data/csv.py | 273 +++++++++++++++++++++++++++++----------------------- train_dreambooth.py | 100 +++++++++---------- train_ti.py | 85 +++++++--------- training/util.py | 2 +- 4 files changed, 237 insertions(+), 223 deletions(-) diff --git a/data/csv.py b/data/csv.py index 654aec1..9be36ba 100644 --- a/data/csv.py +++ b/data/csv.py @@ -2,20 +2,28 @@ import math import torch import json from pathlib import Path +from typing import NamedTuple, Optional, Union, Callable + from PIL import Image -from torch.utils.data import Dataset, DataLoader, random_split -from torchvision import transforms -from typing import Dict, NamedTuple, List, Optional, Union, Callable -import numpy as np +from torch.utils.data import IterableDataset, DataLoader, random_split +from torchvision import transforms -from models.clip.prompt import PromptProcessor from data.keywords import prompt_to_keywords, keywords_to_prompt +from models.clip.prompt import PromptProcessor image_cache: dict[str, Image.Image] = {} +interpolations = { + "linear": transforms.InterpolationMode.NEAREST, + "bilinear": transforms.InterpolationMode.BILINEAR, + "bicubic": transforms.InterpolationMode.BICUBIC, + "lanczos": transforms.InterpolationMode.LANCZOS, +} + + def get_image(path): if path in image_cache: return image_cache[path] @@ -28,10 +36,46 @@ def get_image(path): return image -def prepare_prompt(prompt: Union[str, Dict[str, str]]): +def prepare_prompt(prompt: Union[str, dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt +def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): + item_order: list[int] = [] + item_buckets: list[int] = [] + buckets = [1.0] + + for i in range(1, num_buckets + 1): + s = size + i * 64 + buckets.append(s / size) + buckets.append(size / s) + + buckets = torch.tensor(buckets) + bucket_indices = torch.arange(len(buckets)) + + for i, item in enumerate(items): + image = get_image(item) + ratio = image.width / image.height + + if ratio >= 1: + mask = torch.bitwise_and(buckets >= 1, buckets <= ratio) + else: + mask = torch.bitwise_and(buckets <= 1, buckets >= ratio) + + if not progressive_buckets: + mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() + + indices = bucket_indices[mask] + + if len(indices.shape) == 0: + indices = indices.unsqueeze(0) + + item_order += [i] * len(indices) + item_buckets += indices + + return buckets.tolist(), item_order, item_buckets + + class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path @@ -41,14 +85,6 @@ class VlpnDataItem(NamedTuple): collection: list[str] -class VlpnDataBucket(): - def __init__(self, width: int, height: int): - self.width = width - self.height = height - self.ratio = width / height - self.items: list[VlpnDataItem] = [] - - class VlpnDataModule(): def __init__( self, @@ -60,7 +96,6 @@ class VlpnDataModule(): size: int = 768, num_aspect_ratio_buckets: int = 0, progressive_aspect_ratio_buckets: bool = False, - repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", template_key: str = "template", @@ -86,7 +121,6 @@ class VlpnDataModule(): self.size = size self.num_aspect_ratio_buckets = num_aspect_ratio_buckets self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets - self.repeats = repeats self.dropout = dropout self.template_key = template_key self.interpolation = interpolation @@ -146,36 +180,6 @@ class VlpnDataModule(): for i in range(image_multiplier) ] - def generate_buckets(self, items: list[VlpnDataItem]): - buckets = [VlpnDataBucket(self.size, self.size)] - - for i in range(1, self.num_aspect_ratio_buckets + 1): - s = self.size + i * 64 - buckets.append(VlpnDataBucket(s, self.size)) - buckets.append(VlpnDataBucket(self.size, s)) - - buckets = np.array(buckets) - bucket_ratios = np.array([bucket.ratio for bucket in buckets]) - - for item in items: - image = get_image(item.instance_image_path) - ratio = image.width / image.height - - if ratio >= 1: - mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) - else: - mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) - - if not self.progressive_aspect_ratio_buckets: - ratios = bucket_ratios.copy() - ratios[~mask] = math.inf - mask = [np.argmin(np.abs(ratios - ratio))] - - for bucket in buckets[mask]: - bucket.items.append(item) - - return [bucket for bucket in buckets if len(bucket.items) != 0] - def setup(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) @@ -201,105 +205,136 @@ class VlpnDataModule(): self.data_train = self.pad_items(data_train, self.num_class_images) self.data_val = self.pad_items(data_val) - buckets = self.generate_buckets(data_train) - - train_datasets = [ - VlpnDataset( - bucket.items, self.prompt_processor, - width=bucket.width, height=bucket.height, interpolation=self.interpolation, - num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, - ) - for bucket in buckets - ] + train_dataset = VlpnDataset( + self.data_train, self.prompt_processor, + num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, + batch_size=self.batch_size, + size=self.size, interpolation=self.interpolation, + num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, + ) val_dataset = VlpnDataset( - data_val, self.prompt_processor, - width=self.size, height=self.size, interpolation=self.interpolation, + self.data_val, self.prompt_processor, + batch_size=self.batch_size, + size=self.size, interpolation=self.interpolation, ) - self.train_dataloaders = [ - DataLoader( - dataset, batch_size=self.batch_size, shuffle=True, - pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers - ) - for dataset in train_datasets - ] + self.train_dataloader = DataLoader( + train_dataset, + batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + ) self.val_dataloader = DataLoader( - val_dataset, batch_size=self.batch_size, - pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + val_dataset, + batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers ) -class VlpnDataset(Dataset): +class VlpnDataset(IterableDataset): def __init__( self, - data: List[VlpnDataItem], + items: list[VlpnDataItem], prompt_processor: PromptProcessor, + num_buckets: int = 1, + progressive_buckets: bool = False, + batch_size: int = 1, num_class_images: int = 0, - width: int = 768, - height: int = 768, - repeats: int = 1, + size: int = 768, dropout: float = 0, + shuffle: bool = False, interpolation: str = "bicubic", + generator: Optional[torch.Generator] = None, ): + self.items = items + self.batch_size = batch_size - self.data = data self.prompt_processor = prompt_processor self.num_class_images = num_class_images + self.size = size self.dropout = dropout - - self.num_instance_images = len(self.data) - self._length = self.num_instance_images * repeats - - self.interpolation = { - "linear": transforms.InterpolationMode.NEAREST, - "bilinear": transforms.InterpolationMode.BILINEAR, - "bicubic": transforms.InterpolationMode.BICUBIC, - "lanczos": transforms.InterpolationMode.LANCZOS, - }[interpolation] - self.image_transforms = transforms.Compose( - [ - transforms.Resize(min(width, height), interpolation=self.interpolation), - transforms.RandomCrop((height, width)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] + self.shuffle = shuffle + self.interpolation = interpolations[interpolation] + self.generator = generator + + buckets, item_order, item_buckets = generate_buckets( + [item.instance_image_path for item in items], + size, + num_buckets, + progressive_buckets ) - def __len__(self): - return self._length + self.buckets = torch.tensor(buckets) + self.item_order = torch.tensor(item_order) + self.item_buckets = torch.tensor(item_buckets) - def get_example(self, i): - item = self.data[i % self.num_instance_images] - - example = {} - example["prompts"] = item.prompt - example["cprompts"] = item.cprompt - example["nprompts"] = item.nprompt - example["instance_images"] = get_image(item.instance_image_path) - if self.num_class_images != 0: - example["class_images"] = get_image(item.class_image_path) - - return example + def __len__(self): + return len(self.item_buckets) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + + if self.shuffle: + perm = torch.randperm(len(self.item_buckets), generator=self.generator) + self.item_order = self.item_order[perm] + self.item_buckets = self.item_buckets[perm] + + item_mask = torch.ones_like(self.item_buckets, dtype=bool) + bucket = -1 + image_transforms = None + batch = [] + batch_size = self.batch_size + + if worker_info is not None: + batch_size = math.ceil(batch_size / worker_info.num_workers) + worker_batch = math.ceil(len(self) / worker_info.num_workers) + start = worker_info.id * worker_batch + end = start + worker_batch + item_mask[:start] = False + item_mask[end:] = False + + while item_mask.any(): + item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] + + if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): + yield batch + batch = [] + + if len(item_indices) == 0: + bucket = self.item_buckets[item_mask][0] + ratio = self.buckets[bucket] + width = self.size * ratio if ratio > 1 else self.size + height = self.size / ratio if ratio < 1 else self.size + + image_transforms = transforms.Compose( + [ + transforms.Resize(min(width, height), interpolation=self.interpolation), + transforms.RandomCrop((height, width)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + else: + item_index = item_indices[0] + item = self.items[item_index] + item_mask[item_index] = False - def __getitem__(self, i): - unprocessed_example = self.get_example(i) + example = {} - example = {} + example["prompts"] = keywords_to_prompt(item.prompt) + example["cprompts"] = item.cprompt + example["nprompts"] = item.nprompt - example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"]) - example["cprompts"] = unprocessed_example["cprompts"] - example["nprompts"] = unprocessed_example["nprompts"] + example["instance_images"] = image_transforms(get_image(item.instance_image_path)) + example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( + keywords_to_prompt(item.prompt, self.dropout, True) + ) - example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) - example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( - keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) - ) + if self.num_class_images != 0: + example["class_images"] = image_transforms(get_image(item.class_image_path)) + example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) - if self.num_class_images != 0: - example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) - example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) + batch.append(example) - return example + if len(batch) != 0: + yield batch diff --git a/train_dreambooth.py b/train_dreambooth.py index 589af59..42a7d0f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -133,12 +133,6 @@ def parse_args(): default="cls", help="The directory where class images will be saved.", ) - parser.add_argument( - "--repeats", - type=int, - default=1, - help="How many times to repeat the training data." - ) parser.add_argument( "--output_dir", type=str, @@ -738,7 +732,6 @@ def main(): class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, - repeats=args.repeats, dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, @@ -751,7 +744,7 @@ def main(): datamodule.prepare_data() datamodule.setup() - train_dataloaders = datamodule.train_dataloaders + train_dataloader = datamodule.train_dataloader val_dataloader = datamodule.val_dataloader if args.num_class_images != 0: @@ -770,8 +763,7 @@ def main(): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) - num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -820,8 +812,7 @@ def main(): ema_unet.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) - num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch @@ -877,7 +868,7 @@ def main(): accelerator, text_encoder, optimizer, - train_dataloaders[0], + train_dataloader, val_dataloader, loop, on_train=tokenizer.train, @@ -960,54 +951,53 @@ def main(): text_encoder.requires_grad_(False) with on_train(): - for train_dataloader in train_dataloaders: - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(unet): - loss, acc, bsz = loop(step, batch) - - accelerator.backward(loss) - - if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(unet.parameters(), text_encoder.parameters()) - if args.train_text_encoder and epoch < args.train_text_encoder_epochs - else unet.parameters() - ) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - if args.use_ema: - ema_unet.step(unet.parameters()) - optimizer.zero_grad(set_to_none=True) - - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - local_progress_bar.update(1) - global_progress_bar.update(1) + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + loss, acc, bsz = loop(step, batch) - global_step += 1 + accelerator.backward(loss) - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0] - } + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder and epoch < args.train_text_encoder_epochs + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() if args.use_ema: - logs["ema_decay"] = 1 - ema_unet.decay + ema_unet.step(unet.parameters()) + optimizer.zero_grad(set_to_none=True) - accelerator.log(logs, step=global_step) + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) - local_progress_bar.set_postfix(**logs) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + local_progress_bar.update(1) + global_progress_bar.update(1) + + global_step += 1 + + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0] + } + if args.use_ema: + logs["ema_decay"] = 1 - ema_unet.decay + + accelerator.log(logs, step=global_step) + + local_progress_bar.set_postfix(**logs) - if global_step >= args.max_train_steps: - break + if global_step >= args.max_train_steps: + break accelerator.wait_for_everyone() diff --git a/train_ti.py b/train_ti.py index b4b602b..727b591 100644 --- a/train_ti.py +++ b/train_ti.py @@ -106,12 +106,6 @@ def parse_args(): nargs='*', help="Exclude all items with a listed collection.", ) - parser.add_argument( - "--repeats", - type=int, - default=1, - help="How many times to repeat the training data." - ) parser.add_argument( "--output_dir", type=str, @@ -722,7 +716,6 @@ def main(): size=args.resolution, num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, - repeats=args.repeats, dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, @@ -733,7 +726,7 @@ def main(): ) datamodule.setup() - train_dataloaders = datamodule.train_dataloaders + train_dataloader = datamodule.train_dataloader val_dataloader = datamodule.val_dataloader if args.num_class_images != 0: @@ -752,8 +745,7 @@ def main(): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) - num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -790,10 +782,9 @@ def main(): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, val_dataloader, lr_scheduler + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - train_dataloaders = accelerator.prepare(*train_dataloaders) # Move vae and unet to device vae.to(accelerator.device, dtype=weight_dtype) @@ -811,8 +802,7 @@ def main(): unet.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) - num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch @@ -870,7 +860,7 @@ def main(): accelerator, text_encoder, optimizer, - train_dataloaders[0], + train_dataloader, val_dataloader, loop, on_train=on_train, @@ -949,48 +939,47 @@ def main(): text_encoder.train() with on_train(): - for train_dataloader in train_dataloaders: - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - loss, acc, bsz = loop(step, batch) + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + loss, acc, bsz = loop(step, batch) - accelerator.backward(loss) + accelerator.backward(loss) - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - if args.use_ema: - ema_embeddings.step( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_embeddings.step( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - local_progress_bar.update(1) - global_progress_bar.update(1) + local_progress_bar.update(1) + global_progress_bar.update(1) - global_step += 1 + global_step += 1 - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0], - } - if args.use_ema: - logs["ema_decay"] = ema_embeddings.decay + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0], + } + if args.use_ema: + logs["ema_decay"] = ema_embeddings.decay - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=global_step) - local_progress_bar.set_postfix(**logs) + local_progress_bar.set_postfix(**logs) - if global_step >= args.max_train_steps: - break + if global_step >= args.max_train_steps: + break accelerator.wait_for_everyone() diff --git a/training/util.py b/training/util.py index 2b7f71d..ae6bfc4 100644 --- a/training/util.py +++ b/training/util.py @@ -59,7 +59,7 @@ class CheckpointerBase: def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") - train_data = self.datamodule.train_dataloaders[0] + train_data = self.datamodule.train_dataloader val_data = self.datamodule.val_dataloader generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) -- cgit v1.2.3-54-g00ecf