From 3ee13893f9a4973ac75f45fe9318c35760dd4b1f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 7 Jan 2023 13:57:46 +0100 Subject: Added progressive aspect ratio bucketing --- data/csv.py | 144 ++++++++++++++++++++++++++++++++-------------------- infer.py | 23 +++++---- train_dreambooth.py | 12 ++--- train_ti.py | 94 +++++++++++++++++----------------- training/util.py | 4 +- 5 files changed, 151 insertions(+), 126 deletions(-) diff --git a/data/csv.py b/data/csv.py index 4986153..59d6d8d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -11,11 +11,26 @@ from models.clip.prompt import PromptProcessor from data.keywords import prompt_to_keywords, keywords_to_prompt +image_cache: dict[str, Image.Image] = {} + + +def get_image(path): + if path in image_cache: + return image_cache[path] + + image = Image.open(path) + if not image.mode == "RGB": + image = image.convert("RGB") + image_cache[path] = image + + return image + + def prepare_prompt(prompt: Union[str, Dict[str, str]]): return {"content": prompt} if isinstance(prompt, str) else prompt -class CSVDataItem(NamedTuple): +class VlpnDataItem(NamedTuple): instance_image_path: Path class_image_path: Path prompt: list[str] @@ -24,7 +39,15 @@ class CSVDataItem(NamedTuple): collection: list[str] -class CSVDataModule(): +class VlpnDataBucket(): + def __init__(self, width: int, height: int): + self.width = width + self.height = height + self.ratio = width / height + self.items: list[VlpnDataItem] = [] + + +class VlpnDataModule(): def __init__( self, batch_size: int, @@ -36,11 +59,10 @@ class CSVDataModule(): repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", - center_crop: bool = False, template_key: str = "template", valid_set_size: Optional[int] = None, seed: Optional[int] = None, - filter: Optional[Callable[[CSVDataItem], bool]] = None, + filter: Optional[Callable[[VlpnDataItem], bool]] = None, collate_fn=None, num_workers: int = 0 ): @@ -60,7 +82,6 @@ class CSVDataModule(): self.size = size self.repeats = repeats self.dropout = dropout - self.center_crop = center_crop self.template_key = template_key self.interpolation = interpolation self.valid_set_size = valid_set_size @@ -70,14 +91,14 @@ class CSVDataModule(): self.num_workers = num_workers self.batch_size = batch_size - def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: + def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: image = template["image"] if "image" in template else "{}" prompt = template["prompt"] if "prompt" in template else "{content}" cprompt = template["cprompt"] if "cprompt" in template else "{content}" nprompt = template["nprompt"] if "nprompt" in template else "{content}" return [ - CSVDataItem( + VlpnDataItem( self.data_root.joinpath(image.format(item["image"])), None, prompt_to_keywords( @@ -97,17 +118,17 @@ class CSVDataModule(): for item in data ] - def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: + def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: if self.filter is None: return items return [item for item in items if self.filter(item)] - def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: + def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: image_multiplier = max(num_class_images, 1) return [ - CSVDataItem( + VlpnDataItem( item.instance_image_path, self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), item.prompt, @@ -119,7 +140,30 @@ class CSVDataModule(): for i in range(image_multiplier) ] - def prepare_data(self): + def generate_buckets(self, items: list[VlpnDataItem]): + buckets = [VlpnDataBucket(self.size, self.size)] + + for i in range(1, 5): + s = self.size + i * 64 + buckets.append(VlpnDataBucket(s, self.size)) + buckets.append(VlpnDataBucket(self.size, s)) + + for item in items: + image = get_image(item.instance_image_path) + ratio = image.width / image.height + + if ratio >= 1: + candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] + else: + candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] + + for bucket in candidates: + bucket.items.append(item) + + buckets = [bucket for bucket in buckets if len(bucket.items) != 0] + return buckets + + def setup(self): with open(self.data_file, 'rt') as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} @@ -144,48 +188,48 @@ class CSVDataModule(): self.data_train = self.pad_items(data_train, self.num_class_images) self.data_val = self.pad_items(data_val) - def setup(self, stage=None): - train_dataset = CSVDataset( - self.data_train, self.prompt_processor, batch_size=self.batch_size, - num_class_images=self.num_class_images, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout - ) - val_dataset = CSVDataset( - self.data_val, self.prompt_processor, batch_size=self.batch_size, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop - ) - self.train_dataloader_ = DataLoader( - train_dataset, batch_size=self.batch_size, - shuffle=True, pin_memory=True, collate_fn=self.collate_fn, - num_workers=self.num_workers - ) - self.val_dataloader_ = DataLoader( - val_dataset, batch_size=self.batch_size, - pin_memory=True, collate_fn=self.collate_fn, - num_workers=self.num_workers + buckets = self.generate_buckets(data_train) + + train_datasets = [ + VlpnDataset( + bucket.items, self.prompt_processor, batch_size=self.batch_size, + width=bucket.width, height=bucket.height, interpolation=self.interpolation, + num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, + ) + for bucket in buckets + ] + + val_dataset = VlpnDataset( + data_val, self.prompt_processor, batch_size=self.batch_size, + width=self.size, height=self.size, interpolation=self.interpolation, ) - def train_dataloader(self): - return self.train_dataloader_ + self.train_dataloaders = [ + DataLoader( + dataset, batch_size=self.batch_size, shuffle=True, + pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + ) + for dataset in train_datasets + ] - def val_dataloader(self): - return self.val_dataloader_ + self.val_dataloader = DataLoader( + val_dataset, batch_size=self.batch_size, + pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers + ) -class CSVDataset(Dataset): +class VlpnDataset(Dataset): def __init__( self, - data: List[CSVDataItem], + data: List[VlpnDataItem], prompt_processor: PromptProcessor, batch_size: int = 1, num_class_images: int = 0, - size: int = 768, + width: int = 768, + height: int = 768, repeats: int = 1, dropout: float = 0, interpolation: str = "bicubic", - center_crop: bool = False, ): self.data = data @@ -193,7 +237,6 @@ class CSVDataset(Dataset): self.batch_size = batch_size self.num_class_images = num_class_images self.dropout = dropout - self.image_cache = {} self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats @@ -206,8 +249,8 @@ class CSVDataset(Dataset): }[interpolation] self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=self.interpolation), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize(min(width, height), interpolation=self.interpolation), + transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), @@ -217,17 +260,6 @@ class CSVDataset(Dataset): def __len__(self): return math.ceil(self._length / self.batch_size) * self.batch_size - def get_image(self, path): - if path in self.image_cache: - return self.image_cache[path] - - image = Image.open(path) - if not image.mode == "RGB": - image = image.convert("RGB") - self.image_cache[path] = image - - return image - def get_example(self, i): item = self.data[i % self.num_instance_images] @@ -235,9 +267,9 @@ class CSVDataset(Dataset): example["prompts"] = item.prompt example["cprompts"] = item.cprompt example["nprompts"] = item.nprompt - example["instance_images"] = self.get_image(item.instance_image_path) + example["instance_images"] = get_image(item.instance_image_path) if self.num_class_images != 0: - example["class_images"] = self.get_image(item.class_image_path) + example["class_images"] = get_image(item.class_image_path) return example diff --git a/infer.py b/infer.py index d3d5f1b..2b07b21 100644 --- a/infer.py +++ b/infer.py @@ -238,16 +238,15 @@ def create_pipeline(model, dtype): return pipeline +def shuffle_prompts(prompts: list[str]) -> list[str]: + return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts] + + @torch.inference_mode() def generate(output_dir: Path, pipeline, args): if isinstance(args.prompt, str): args.prompt = [args.prompt] - if args.shuffle: - args.prompt *= args.batch_size - args.batch_size = 1 - args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] - args.prompt = [args.template.format(prompt) for prompt in args.prompt] now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") @@ -263,9 +262,6 @@ def generate(output_dir: Path, pipeline, args): dir = output_dir.joinpath(slugify(prompt)[:100]) dir.mkdir(parents=True, exist_ok=True) image_dir.append(dir) - - with open(dir.joinpath('prompt.txt'), 'w') as f: - f.write(prompt) else: output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") output_dir.mkdir(parents=True, exist_ok=True) @@ -306,9 +302,10 @@ def generate(output_dir: Path, pipeline, args): ) seed = args.seed + i + prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt generator = torch.Generator(device="cuda").manual_seed(seed) images = pipeline( - prompt=args.prompt, + prompt=prompt, negative_prompt=args.negative_prompt, height=args.height, width=args.width, @@ -321,9 +318,13 @@ def generate(output_dir: Path, pipeline, args): ).images for j, image in enumerate(images): + basename = f"{seed}_{j // len(args.prompt)}" dir = image_dir[j % len(args.prompt)] - image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) - image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) + + image.save(dir.joinpath(f"{basename}.png")) + image.save(dir.joinpath(f"{basename}.jpg"), quality=85) + with open(dir.joinpath(f"{basename}.txt"), 'w') as f: + f.write(prompt[j % len(args.prompt)]) if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/train_dreambooth.py b/train_dreambooth.py index e8256be..d265bcc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -22,7 +22,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from data.csv import CSVDataModule, CSVDataItem +from data.csv import VlpnDataModule, VlpnDataItem from training.common import run_model from training.optimization import get_one_cycle_schedule from training.lr import LRFinder @@ -171,11 +171,6 @@ def parse_args(): " resolution" ), ) - parser.add_argument( - "--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution" - ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -698,7 +693,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - def keyword_filter(item: CSVDataItem): + def keyword_filter(item: VlpnDataItem): cond3 = args.collection is None or args.collection in item.collection cond4 = args.exclude_collections is None or not any( collection in item.collection @@ -733,7 +728,7 @@ def main(): } return batch - datamodule = CSVDataModule( + datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, @@ -742,7 +737,6 @@ def main(): size=args.resolution, repeats=args.repeats, dropout=args.tag_dropout, - center_crop=args.center_crop, template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, diff --git a/train_ti.py b/train_ti.py index 0ffc9e6..89c6672 100644 --- a/train_ti.py +++ b/train_ti.py @@ -21,7 +21,7 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion -from data.csv import CSVDataModule, CSVDataItem +from data.csv import VlpnDataModule, VlpnDataItem from training.common import run_model from training.optimization import get_one_cycle_schedule from training.lr import LRFinder @@ -145,11 +145,6 @@ def parse_args(): " resolution" ), ) - parser.add_argument( - "--center_crop", - action="store_true", - help="Whether to center crop images before resizing to resolution" - ) parser.add_argument( "--tag_dropout", type=float, @@ -668,7 +663,7 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - def keyword_filter(item: CSVDataItem): + def keyword_filter(item: VlpnDataItem): cond1 = any( keyword in part for keyword in args.placeholder_token @@ -708,7 +703,7 @@ def main(): } return batch - datamodule = CSVDataModule( + datamodule = VlpnDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, prompt_processor=prompt_processor, @@ -717,7 +712,6 @@ def main(): size=args.resolution, repeats=args.repeats, dropout=args.tag_dropout, - center_crop=args.center_crop, template_key=args.train_data_template, valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, @@ -725,8 +719,6 @@ def main(): filter=keyword_filter, collate_fn=collate_fn ) - - datamodule.prepare_data() datamodule.setup() if args.num_class_images != 0: @@ -769,12 +761,14 @@ def main(): if torch.cuda.is_available(): torch.cuda.empty_cache() - train_dataloader = datamodule.train_dataloader() - val_dataloader = datamodule.val_dataloader() + train_dataloaders = datamodule.train_dataloaders + default_train_dataloader = train_dataloaders[0] + val_dataloader = datamodule.val_dataloader # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) + num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -811,9 +805,10 @@ def main(): num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, ) - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, val_dataloader, lr_scheduler ) + train_dataloaders = accelerator.prepare(*train_dataloaders) # Move vae and unet to device vae.to(accelerator.device, dtype=weight_dtype) @@ -831,7 +826,8 @@ def main(): unet.eval() # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) + num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch @@ -889,7 +885,7 @@ def main(): accelerator, text_encoder, optimizer, - train_dataloader, + default_train_dataloader, val_dataloader, loop, on_train=on_train, @@ -968,46 +964,48 @@ def main(): text_encoder.train() with on_train(): - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder): - loss, acc, bsz = loop(step, batch) + for train_dataloader in train_dataloaders: + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + loss, acc, bsz = loop(step, batch) - accelerator.backward(loss) + accelerator.backward(loss) - optimizer.step() - if not accelerator.optimizer_step_was_skipped: - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + optimizer.step() + if not accelerator.optimizer_step_was_skipped: + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) - avg_loss.update(loss.detach_(), bsz) - avg_acc.update(acc.detach_(), bsz) + avg_loss.update(loss.detach_(), bsz) + avg_acc.update(acc.detach_(), bsz) - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - if args.use_ema: - ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_embeddings.step( + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) - local_progress_bar.update(1) - global_progress_bar.update(1) + local_progress_bar.update(1) + global_progress_bar.update(1) - global_step += 1 + global_step += 1 - logs = { - "train/loss": avg_loss.avg.item(), - "train/acc": avg_acc.avg.item(), - "train/cur_loss": loss.item(), - "train/cur_acc": acc.item(), - "lr": lr_scheduler.get_last_lr()[0], - } - if args.use_ema: - logs["ema_decay"] = ema_embeddings.decay + logs = { + "train/loss": avg_loss.avg.item(), + "train/acc": avg_acc.avg.item(), + "train/cur_loss": loss.item(), + "train/cur_acc": acc.item(), + "lr": lr_scheduler.get_last_lr()[0], + } + if args.use_ema: + logs["ema_decay"] = ema_embeddings.decay - accelerator.log(logs, step=global_step) + accelerator.log(logs, step=global_step) - local_progress_bar.set_postfix(**logs) + local_progress_bar.set_postfix(**logs) - if global_step >= args.max_train_steps: - break + if global_step >= args.max_train_steps: + break accelerator.wait_for_everyone() diff --git a/training/util.py b/training/util.py index bc466e2..6f42228 100644 --- a/training/util.py +++ b/training/util.py @@ -58,8 +58,8 @@ class CheckpointerBase: def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): samples_path = Path(self.output_dir).joinpath("samples") - train_data = self.datamodule.train_dataloader() - val_data = self.datamodule.val_dataloader() + train_data = self.datamodule.train_dataloaders[0] + val_data = self.datamodule.val_dataloader generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) -- cgit v1.2.3-54-g00ecf