From 7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 8 Jan 2023 13:38:43 +0100 Subject: Fixed aspect ratio bucketing; allow passing token IDs to pipeline --- data/csv.py | 78 ++++++++++++---------- .../stable_diffusion/vlpn_stable_diffusion.py | 40 ++++++++--- train_dreambooth.py | 14 ++-- train_ti.py | 24 ++++--- training/util.py | 14 ++-- 5 files changed, 102 insertions(+), 68 deletions(-) diff --git a/data/csv.py b/data/csv.py index 9be36ba..289a64d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -41,8 +41,8 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): - item_order: list[int] = [] - item_buckets: list[int] = [] + bucket_items: list[int] = [] + bucket_assignments: list[int] = [] buckets = [1.0] for i in range(1, num_buckets + 1): @@ -70,10 +70,10 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_ if len(indices.shape) == 0: indices = indices.unsqueeze(0) - item_order += [i] * len(indices) - item_buckets += indices + bucket_items += [i] * len(indices) + bucket_assignments += indices - return buckets.tolist(), item_order, item_buckets + return buckets.tolist(), bucket_items, bucket_assignments class VlpnDataItem(NamedTuple): @@ -94,8 +94,8 @@ class VlpnDataModule(): class_subdir: str = "cls", num_class_images: int = 1, size: int = 768, - num_aspect_ratio_buckets: int = 0, - progressive_aspect_ratio_buckets: bool = False, + num_buckets: int = 0, + progressive_buckets: bool = False, dropout: float = 0, interpolation: str = "bicubic", template_key: str = "template", @@ -119,8 +119,8 @@ class VlpnDataModule(): self.prompt_processor = prompt_processor self.size = size - self.num_aspect_ratio_buckets = num_aspect_ratio_buckets - self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets + self.num_buckets = num_buckets + self.progressive_buckets = progressive_buckets self.dropout = dropout self.template_key = template_key self.interpolation = interpolation @@ -207,15 +207,15 @@ class VlpnDataModule(): train_dataset = VlpnDataset( self.data_train, self.prompt_processor, - num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, - batch_size=self.batch_size, + num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, + batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, ) val_dataset = VlpnDataset( self.data_val, self.prompt_processor, - batch_size=self.batch_size, + batch_size=self.batch_size, generator=generator, size=self.size, interpolation=self.interpolation, ) @@ -256,7 +256,7 @@ class VlpnDataset(IterableDataset): self.interpolation = interpolations[interpolation] self.generator = generator - buckets, item_order, item_buckets = generate_buckets( + buckets, bucket_items, bucket_assignments = generate_buckets( [item.instance_image_path for item in items], size, num_buckets, @@ -264,23 +264,27 @@ class VlpnDataset(IterableDataset): ) self.buckets = torch.tensor(buckets) - self.item_order = torch.tensor(item_order) - self.item_buckets = torch.tensor(item_buckets) + self.bucket_items = torch.tensor(bucket_items) + self.bucket_assignments = torch.tensor(bucket_assignments) + self.bucket_item_range = torch.arange(len(bucket_items)) + + self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() def __len__(self): - return len(self.item_buckets) + return self.length_ def __iter__(self): worker_info = torch.utils.data.get_worker_info() if self.shuffle: - perm = torch.randperm(len(self.item_buckets), generator=self.generator) - self.item_order = self.item_order[perm] - self.item_buckets = self.item_buckets[perm] + perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) + self.bucket_items = self.bucket_items[perm] + self.bucket_assignments = self.bucket_assignments[perm] - item_mask = torch.ones_like(self.item_buckets, dtype=bool) - bucket = -1 image_transforms = None + + mask = torch.ones_like(self.bucket_assignments, dtype=bool) + bucket = -1 batch = [] batch_size = self.batch_size @@ -289,25 +293,30 @@ class VlpnDataset(IterableDataset): worker_batch = math.ceil(len(self) / worker_info.num_workers) start = worker_info.id * worker_batch end = start + worker_batch - item_mask[:start] = False - item_mask[end:] = False + mask[:start] = False + mask[end:] = False - while item_mask.any(): - item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] + while mask.any(): + bucket_mask = mask.logical_and(self.bucket_assignments == bucket) + bucket_items = self.bucket_items[bucket_mask] - if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): + if len(batch) >= batch_size: yield batch batch = [] - if len(item_indices) == 0: - bucket = self.item_buckets[item_mask][0] + if len(bucket_items) == 0: + if len(batch) != 0: + yield batch + batch = [] + + bucket = self.bucket_assignments[mask][0] ratio = self.buckets[bucket] width = self.size * ratio if ratio > 1 else self.size height = self.size / ratio if ratio < 1 else self.size image_transforms = transforms.Compose( [ - transforms.Resize(min(width, height), interpolation=self.interpolation), + transforms.Resize(self.size, interpolation=self.interpolation), transforms.RandomCrop((height, width)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), @@ -315,15 +324,14 @@ class VlpnDataset(IterableDataset): ] ) else: - item_index = item_indices[0] + item_index = bucket_items[0] item = self.items[item_index] - item_mask[item_index] = False + mask[self.bucket_item_range[bucket_mask][0]] = False example = {} - example["prompts"] = keywords_to_prompt(item.prompt) - example["cprompts"] = item.cprompt - example["nprompts"] = item.nprompt + example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt)) + example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt) example["instance_images"] = image_transforms(get_image(item.instance_image_path)) example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( @@ -332,7 +340,7 @@ class VlpnDataset(IterableDataset): if self.num_class_images != 0: example["class_images"] = image_transforms(get_image(item.class_image_path)) - example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) + example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt) batch.append(example) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 53b5eea..cb300d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -79,6 +79,7 @@ class VlpnStableDiffusion(DiffusionPipeline): unet=unet, scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -160,14 +161,22 @@ class VlpnStableDiffusion(DiffusionPipeline): return torch.device(module._hf_hook.execution_device) return self.device - def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): - if isinstance(prompt, str): + def check_inputs( + self, + prompt: Union[str, List[str], List[int], List[List[int]]], + negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], + width: Optional[int], + height: Optional[int], + strength: float, + callback_steps: Optional[int] + ): + if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): prompt = [prompt] if negative_prompt is None: negative_prompt = "" - if isinstance(negative_prompt, str): + if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): negative_prompt = [negative_prompt] * len(prompt) if not isinstance(prompt, list): @@ -196,12 +205,20 @@ class VlpnStableDiffusion(DiffusionPipeline): return prompt, negative_prompt - def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): - text_input_ids = self.prompt_processor.get_input_ids(prompt) + def encode_prompt( + self, + prompt: Union[List[str], List[List[int]]], + negative_prompt: Union[List[str], List[List[int]]], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + device + ): + text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt text_input_ids *= num_images_per_prompt if do_classifier_free_guidance: - unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) + unconditional_input_ids = self.prompt_processor.get_input_ids( + negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt unconditional_input_ids *= num_images_per_prompt text_input_ids = unconditional_input_ids + text_input_ids @@ -314,12 +331,12 @@ class VlpnStableDiffusion(DiffusionPipeline): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str], List[List[str]]], - negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, + prompt: Union[str, List[str], List[int], List[List[int]]], + negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, num_images_per_prompt: Optional[int] = 1, strength: float = 0.8, - height: Optional[int] = 768, - width: Optional[int] = 768, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, @@ -379,6 +396,9 @@ class VlpnStableDiffusion(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) diff --git a/train_dreambooth.py b/train_dreambooth.py index 42a7d0f..79eede6 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -699,9 +699,9 @@ def main(): return cond3 and cond4 def collate_fn(examples): - prompts = [example["prompts"] for example in examples] - cprompts = [example["cprompts"] for example in examples] - nprompts = [example["nprompts"] for example in examples] + prompt_ids = [example["prompt_ids"] for example in examples] + nprompt_ids = [example["nprompt_ids"] for example in examples] + input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -713,16 +713,18 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) + prompts = prompt_processor.unify_input_ids(prompt_ids) + nprompts = prompt_processor.unify_input_ids(nprompt_ids) inputs = prompt_processor.unify_input_ids(input_ids) batch = { - "prompts": prompts, - "cprompts": cprompts, - "nprompts": nprompts, + "prompt_ids": prompts.input_ids, + "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, "pixel_values": pixel_values, "attention_mask": inputs.attention_mask, } + return batch datamodule = VlpnDataModule( diff --git a/train_ti.py b/train_ti.py index 727b591..323ef10 100644 --- a/train_ti.py +++ b/train_ti.py @@ -140,13 +140,13 @@ def parse_args(): ), ) parser.add_argument( - "--num_aspect_ratio_buckets", + "--num_buckets", type=int, default=4, - help="Number of buckets in either direction (adds 64 pixels per step).", + help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).", ) parser.add_argument( - "--progressive_aspect_ratio_buckets", + "--progressive_buckets", action="store_true", help="Include images in smaller buckets as well.", ) @@ -681,9 +681,9 @@ def main(): return cond1 and cond3 and cond4 def collate_fn(examples): - prompts = [example["prompts"] for example in examples] - cprompts = [example["cprompts"] for example in examples] - nprompts = [example["nprompts"] for example in examples] + prompt_ids = [example["prompt_ids"] for example in examples] + nprompt_ids = [example["nprompt_ids"] for example in examples] + input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] @@ -695,16 +695,18 @@ def main(): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) + prompts = prompt_processor.unify_input_ids(prompt_ids) + nprompts = prompt_processor.unify_input_ids(nprompt_ids) inputs = prompt_processor.unify_input_ids(input_ids) batch = { - "prompts": prompts, - "cprompts": cprompts, - "nprompts": nprompts, + "prompt_ids": prompts.input_ids, + "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, "pixel_values": pixel_values, "attention_mask": inputs.attention_mask, } + return batch datamodule = VlpnDataModule( @@ -714,8 +716,8 @@ def main(): class_subdir=args.class_image_dir, num_class_images=args.num_class_images, size=args.resolution, - num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, - progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, + num_buckets=args.num_buckets, + progressive_buckets=args.progressive_buckets, dropout=args.tag_dropout, template_key=args.train_data_template, valid_set_size=args.valid_set_size, diff --git a/training/util.py b/training/util.py index ae6bfc4..60d64f0 100644 --- a/training/util.py +++ b/training/util.py @@ -73,20 +73,22 @@ class CheckpointerBase: file_path.parent.mkdir(parents=True, exist_ok=True) batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) - prompts = [ + prompt_ids = [ prompt for batch in batches - for prompt in batch["prompts"] + for prompt in batch["prompt_ids"] ] - nprompts = [ + nprompt_ids = [ prompt for batch in batches - for prompt in batch["nprompts"] + for prompt in batch["nprompt_ids"] ] for i in range(self.sample_batches): - prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] - nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + start = i * self.sample_batch_size + end = (i + 1) * self.sample_batch_size + prompt = prompt_ids[start:end] + nprompt = nprompt_ids[start:end] samples = pipeline( prompt=prompt, -- cgit v1.2.3-70-g09d2