summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-08 13:38:43 +0100
committerVolpeon <git@volpeon.ink>2023-01-08 13:38:43 +0100
commit7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41 (patch)
treed275e13506ca737efef18dc6dffa05f4e0d6759f
parentImproved aspect ratio bucketing (diff)
downloadtextual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.gz
textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.tar.bz2
textual-inversion-diff-7cd9f00f5f9c1c5679e64b3db8d0fd6d83813f41.zip
Fixed aspect ratio bucketing; allow passing token IDs to pipeline
-rw-r--r--data/csv.py78
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py40
-rw-r--r--train_dreambooth.py14
-rw-r--r--train_ti.py24
-rw-r--r--training/util.py14
5 files changed, 102 insertions, 68 deletions
diff --git a/data/csv.py b/data/csv.py
index 9be36ba..289a64d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -41,8 +41,8 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]):
41 41
42 42
43def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): 43def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool):
44 item_order: list[int] = [] 44 bucket_items: list[int] = []
45 item_buckets: list[int] = [] 45 bucket_assignments: list[int] = []
46 buckets = [1.0] 46 buckets = [1.0]
47 47
48 for i in range(1, num_buckets + 1): 48 for i in range(1, num_buckets + 1):
@@ -70,10 +70,10 @@ def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_
70 if len(indices.shape) == 0: 70 if len(indices.shape) == 0:
71 indices = indices.unsqueeze(0) 71 indices = indices.unsqueeze(0)
72 72
73 item_order += [i] * len(indices) 73 bucket_items += [i] * len(indices)
74 item_buckets += indices 74 bucket_assignments += indices
75 75
76 return buckets.tolist(), item_order, item_buckets 76 return buckets.tolist(), bucket_items, bucket_assignments
77 77
78 78
79class VlpnDataItem(NamedTuple): 79class VlpnDataItem(NamedTuple):
@@ -94,8 +94,8 @@ class VlpnDataModule():
94 class_subdir: str = "cls", 94 class_subdir: str = "cls",
95 num_class_images: int = 1, 95 num_class_images: int = 1,
96 size: int = 768, 96 size: int = 768,
97 num_aspect_ratio_buckets: int = 0, 97 num_buckets: int = 0,
98 progressive_aspect_ratio_buckets: bool = False, 98 progressive_buckets: bool = False,
99 dropout: float = 0, 99 dropout: float = 0,
100 interpolation: str = "bicubic", 100 interpolation: str = "bicubic",
101 template_key: str = "template", 101 template_key: str = "template",
@@ -119,8 +119,8 @@ class VlpnDataModule():
119 119
120 self.prompt_processor = prompt_processor 120 self.prompt_processor = prompt_processor
121 self.size = size 121 self.size = size
122 self.num_aspect_ratio_buckets = num_aspect_ratio_buckets 122 self.num_buckets = num_buckets
123 self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets 123 self.progressive_buckets = progressive_buckets
124 self.dropout = dropout 124 self.dropout = dropout
125 self.template_key = template_key 125 self.template_key = template_key
126 self.interpolation = interpolation 126 self.interpolation = interpolation
@@ -207,15 +207,15 @@ class VlpnDataModule():
207 207
208 train_dataset = VlpnDataset( 208 train_dataset = VlpnDataset(
209 self.data_train, self.prompt_processor, 209 self.data_train, self.prompt_processor,
210 num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, 210 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
211 batch_size=self.batch_size, 211 batch_size=self.batch_size, generator=generator,
212 size=self.size, interpolation=self.interpolation, 212 size=self.size, interpolation=self.interpolation,
213 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, 213 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True,
214 ) 214 )
215 215
216 val_dataset = VlpnDataset( 216 val_dataset = VlpnDataset(
217 self.data_val, self.prompt_processor, 217 self.data_val, self.prompt_processor,
218 batch_size=self.batch_size, 218 batch_size=self.batch_size, generator=generator,
219 size=self.size, interpolation=self.interpolation, 219 size=self.size, interpolation=self.interpolation,
220 ) 220 )
221 221
@@ -256,7 +256,7 @@ class VlpnDataset(IterableDataset):
256 self.interpolation = interpolations[interpolation] 256 self.interpolation = interpolations[interpolation]
257 self.generator = generator 257 self.generator = generator
258 258
259 buckets, item_order, item_buckets = generate_buckets( 259 buckets, bucket_items, bucket_assignments = generate_buckets(
260 [item.instance_image_path for item in items], 260 [item.instance_image_path for item in items],
261 size, 261 size,
262 num_buckets, 262 num_buckets,
@@ -264,23 +264,27 @@ class VlpnDataset(IterableDataset):
264 ) 264 )
265 265
266 self.buckets = torch.tensor(buckets) 266 self.buckets = torch.tensor(buckets)
267 self.item_order = torch.tensor(item_order) 267 self.bucket_items = torch.tensor(bucket_items)
268 self.item_buckets = torch.tensor(item_buckets) 268 self.bucket_assignments = torch.tensor(bucket_assignments)
269 self.bucket_item_range = torch.arange(len(bucket_items))
270
271 self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item()
269 272
270 def __len__(self): 273 def __len__(self):
271 return len(self.item_buckets) 274 return self.length_
272 275
273 def __iter__(self): 276 def __iter__(self):
274 worker_info = torch.utils.data.get_worker_info() 277 worker_info = torch.utils.data.get_worker_info()
275 278
276 if self.shuffle: 279 if self.shuffle:
277 perm = torch.randperm(len(self.item_buckets), generator=self.generator) 280 perm = torch.randperm(len(self.bucket_assignments), generator=self.generator)
278 self.item_order = self.item_order[perm] 281 self.bucket_items = self.bucket_items[perm]
279 self.item_buckets = self.item_buckets[perm] 282 self.bucket_assignments = self.bucket_assignments[perm]
280 283
281 item_mask = torch.ones_like(self.item_buckets, dtype=bool)
282 bucket = -1
283 image_transforms = None 284 image_transforms = None
285
286 mask = torch.ones_like(self.bucket_assignments, dtype=bool)
287 bucket = -1
284 batch = [] 288 batch = []
285 batch_size = self.batch_size 289 batch_size = self.batch_size
286 290
@@ -289,25 +293,30 @@ class VlpnDataset(IterableDataset):
289 worker_batch = math.ceil(len(self) / worker_info.num_workers) 293 worker_batch = math.ceil(len(self) / worker_info.num_workers)
290 start = worker_info.id * worker_batch 294 start = worker_info.id * worker_batch
291 end = start + worker_batch 295 end = start + worker_batch
292 item_mask[:start] = False 296 mask[:start] = False
293 item_mask[end:] = False 297 mask[end:] = False
294 298
295 while item_mask.any(): 299 while mask.any():
296 item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] 300 bucket_mask = mask.logical_and(self.bucket_assignments == bucket)
301 bucket_items = self.bucket_items[bucket_mask]
297 302
298 if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): 303 if len(batch) >= batch_size:
299 yield batch 304 yield batch
300 batch = [] 305 batch = []
301 306
302 if len(item_indices) == 0: 307 if len(bucket_items) == 0:
303 bucket = self.item_buckets[item_mask][0] 308 if len(batch) != 0:
309 yield batch
310 batch = []
311
312 bucket = self.bucket_assignments[mask][0]
304 ratio = self.buckets[bucket] 313 ratio = self.buckets[bucket]
305 width = self.size * ratio if ratio > 1 else self.size 314 width = self.size * ratio if ratio > 1 else self.size
306 height = self.size / ratio if ratio < 1 else self.size 315 height = self.size / ratio if ratio < 1 else self.size
307 316
308 image_transforms = transforms.Compose( 317 image_transforms = transforms.Compose(
309 [ 318 [
310 transforms.Resize(min(width, height), interpolation=self.interpolation), 319 transforms.Resize(self.size, interpolation=self.interpolation),
311 transforms.RandomCrop((height, width)), 320 transforms.RandomCrop((height, width)),
312 transforms.RandomHorizontalFlip(), 321 transforms.RandomHorizontalFlip(),
313 transforms.ToTensor(), 322 transforms.ToTensor(),
@@ -315,15 +324,14 @@ class VlpnDataset(IterableDataset):
315 ] 324 ]
316 ) 325 )
317 else: 326 else:
318 item_index = item_indices[0] 327 item_index = bucket_items[0]
319 item = self.items[item_index] 328 item = self.items[item_index]
320 item_mask[item_index] = False 329 mask[self.bucket_item_range[bucket_mask][0]] = False
321 330
322 example = {} 331 example = {}
323 332
324 example["prompts"] = keywords_to_prompt(item.prompt) 333 example["prompt_ids"] = self.prompt_processor.get_input_ids(keywords_to_prompt(item.prompt))
325 example["cprompts"] = item.cprompt 334 example["nprompt_ids"] = self.prompt_processor.get_input_ids(item.nprompt)
326 example["nprompts"] = item.nprompt
327 335
328 example["instance_images"] = image_transforms(get_image(item.instance_image_path)) 336 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
329 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 337 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
@@ -332,7 +340,7 @@ class VlpnDataset(IterableDataset):
332 340
333 if self.num_class_images != 0: 341 if self.num_class_images != 0:
334 example["class_images"] = image_transforms(get_image(item.class_image_path)) 342 example["class_images"] = image_transforms(get_image(item.class_image_path))
335 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) 343 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(item.cprompt)
336 344
337 batch.append(example) 345 batch.append(example)
338 346
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 53b5eea..cb300d1 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -79,6 +79,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
79 unet=unet, 79 unet=unet,
80 scheduler=scheduler, 80 scheduler=scheduler,
81 ) 81 )
82 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
82 83
83 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 84 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
84 r""" 85 r"""
@@ -160,14 +161,22 @@ class VlpnStableDiffusion(DiffusionPipeline):
160 return torch.device(module._hf_hook.execution_device) 161 return torch.device(module._hf_hook.execution_device)
161 return self.device 162 return self.device
162 163
163 def check_inputs(self, prompt, negative_prompt, width, height, strength, callback_steps): 164 def check_inputs(
164 if isinstance(prompt, str): 165 self,
166 prompt: Union[str, List[str], List[int], List[List[int]]],
167 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]],
168 width: Optional[int],
169 height: Optional[int],
170 strength: float,
171 callback_steps: Optional[int]
172 ):
173 if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)):
165 prompt = [prompt] 174 prompt = [prompt]
166 175
167 if negative_prompt is None: 176 if negative_prompt is None:
168 negative_prompt = "" 177 negative_prompt = ""
169 178
170 if isinstance(negative_prompt, str): 179 if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)):
171 negative_prompt = [negative_prompt] * len(prompt) 180 negative_prompt = [negative_prompt] * len(prompt)
172 181
173 if not isinstance(prompt, list): 182 if not isinstance(prompt, list):
@@ -196,12 +205,20 @@ class VlpnStableDiffusion(DiffusionPipeline):
196 205
197 return prompt, negative_prompt 206 return prompt, negative_prompt
198 207
199 def encode_prompt(self, prompt, negative_prompt, num_images_per_prompt, do_classifier_free_guidance, device): 208 def encode_prompt(
200 text_input_ids = self.prompt_processor.get_input_ids(prompt) 209 self,
210 prompt: Union[List[str], List[List[int]]],
211 negative_prompt: Union[List[str], List[List[int]]],
212 num_images_per_prompt: int,
213 do_classifier_free_guidance: bool,
214 device
215 ):
216 text_input_ids = self.prompt_processor.get_input_ids(prompt) if isinstance(prompt[0], str) else prompt
201 text_input_ids *= num_images_per_prompt 217 text_input_ids *= num_images_per_prompt
202 218
203 if do_classifier_free_guidance: 219 if do_classifier_free_guidance:
204 unconditional_input_ids = self.prompt_processor.get_input_ids(negative_prompt) 220 unconditional_input_ids = self.prompt_processor.get_input_ids(
221 negative_prompt) if isinstance(negative_prompt[0], str) else negative_prompt
205 unconditional_input_ids *= num_images_per_prompt 222 unconditional_input_ids *= num_images_per_prompt
206 text_input_ids = unconditional_input_ids + text_input_ids 223 text_input_ids = unconditional_input_ids + text_input_ids
207 224
@@ -314,12 +331,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
314 @torch.no_grad() 331 @torch.no_grad()
315 def __call__( 332 def __call__(
316 self, 333 self,
317 prompt: Union[str, List[str], List[List[str]]], 334 prompt: Union[str, List[str], List[int], List[List[int]]],
318 negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, 335 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None,
319 num_images_per_prompt: Optional[int] = 1, 336 num_images_per_prompt: Optional[int] = 1,
320 strength: float = 0.8, 337 strength: float = 0.8,
321 height: Optional[int] = 768, 338 height: Optional[int] = None,
322 width: Optional[int] = 768, 339 width: Optional[int] = None,
323 num_inference_steps: Optional[int] = 50, 340 num_inference_steps: Optional[int] = 50,
324 guidance_scale: Optional[float] = 7.5, 341 guidance_scale: Optional[float] = 7.5,
325 eta: Optional[float] = 0.0, 342 eta: Optional[float] = 0.0,
@@ -379,6 +396,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
379 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 396 list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
380 (nsfw) content, according to the `safety_checker`. 397 (nsfw) content, according to the `safety_checker`.
381 """ 398 """
399 # 0. Default height and width to unet
400 height = height or self.unet.config.sample_size * self.vae_scale_factor
401 width = width or self.unet.config.sample_size * self.vae_scale_factor
382 402
383 # 1. Check inputs. Raise error if not correct 403 # 1. Check inputs. Raise error if not correct
384 prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) 404 prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps)
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 42a7d0f..79eede6 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -699,9 +699,9 @@ def main():
699 return cond3 and cond4 699 return cond3 and cond4
700 700
701 def collate_fn(examples): 701 def collate_fn(examples):
702 prompts = [example["prompts"] for example in examples] 702 prompt_ids = [example["prompt_ids"] for example in examples]
703 cprompts = [example["cprompts"] for example in examples] 703 nprompt_ids = [example["nprompt_ids"] for example in examples]
704 nprompts = [example["nprompts"] for example in examples] 704
705 input_ids = [example["instance_prompt_ids"] for example in examples] 705 input_ids = [example["instance_prompt_ids"] for example in examples]
706 pixel_values = [example["instance_images"] for example in examples] 706 pixel_values = [example["instance_images"] for example in examples]
707 707
@@ -713,16 +713,18 @@ def main():
713 pixel_values = torch.stack(pixel_values) 713 pixel_values = torch.stack(pixel_values)
714 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 714 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
715 715
716 prompts = prompt_processor.unify_input_ids(prompt_ids)
717 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
716 inputs = prompt_processor.unify_input_ids(input_ids) 718 inputs = prompt_processor.unify_input_ids(input_ids)
717 719
718 batch = { 720 batch = {
719 "prompts": prompts, 721 "prompt_ids": prompts.input_ids,
720 "cprompts": cprompts, 722 "nprompt_ids": nprompts.input_ids,
721 "nprompts": nprompts,
722 "input_ids": inputs.input_ids, 723 "input_ids": inputs.input_ids,
723 "pixel_values": pixel_values, 724 "pixel_values": pixel_values,
724 "attention_mask": inputs.attention_mask, 725 "attention_mask": inputs.attention_mask,
725 } 726 }
727
726 return batch 728 return batch
727 729
728 datamodule = VlpnDataModule( 730 datamodule = VlpnDataModule(
diff --git a/train_ti.py b/train_ti.py
index 727b591..323ef10 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -140,13 +140,13 @@ def parse_args():
140 ), 140 ),
141 ) 141 )
142 parser.add_argument( 142 parser.add_argument(
143 "--num_aspect_ratio_buckets", 143 "--num_buckets",
144 type=int, 144 type=int,
145 default=4, 145 default=4,
146 help="Number of buckets in either direction (adds 64 pixels per step).", 146 help="Number of aspect ratio buckets in either direction (adds 64 pixels per step).",
147 ) 147 )
148 parser.add_argument( 148 parser.add_argument(
149 "--progressive_aspect_ratio_buckets", 149 "--progressive_buckets",
150 action="store_true", 150 action="store_true",
151 help="Include images in smaller buckets as well.", 151 help="Include images in smaller buckets as well.",
152 ) 152 )
@@ -681,9 +681,9 @@ def main():
681 return cond1 and cond3 and cond4 681 return cond1 and cond3 and cond4
682 682
683 def collate_fn(examples): 683 def collate_fn(examples):
684 prompts = [example["prompts"] for example in examples] 684 prompt_ids = [example["prompt_ids"] for example in examples]
685 cprompts = [example["cprompts"] for example in examples] 685 nprompt_ids = [example["nprompt_ids"] for example in examples]
686 nprompts = [example["nprompts"] for example in examples] 686
687 input_ids = [example["instance_prompt_ids"] for example in examples] 687 input_ids = [example["instance_prompt_ids"] for example in examples]
688 pixel_values = [example["instance_images"] for example in examples] 688 pixel_values = [example["instance_images"] for example in examples]
689 689
@@ -695,16 +695,18 @@ def main():
695 pixel_values = torch.stack(pixel_values) 695 pixel_values = torch.stack(pixel_values)
696 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) 696 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
697 697
698 prompts = prompt_processor.unify_input_ids(prompt_ids)
699 nprompts = prompt_processor.unify_input_ids(nprompt_ids)
698 inputs = prompt_processor.unify_input_ids(input_ids) 700 inputs = prompt_processor.unify_input_ids(input_ids)
699 701
700 batch = { 702 batch = {
701 "prompts": prompts, 703 "prompt_ids": prompts.input_ids,
702 "cprompts": cprompts, 704 "nprompt_ids": nprompts.input_ids,
703 "nprompts": nprompts,
704 "input_ids": inputs.input_ids, 705 "input_ids": inputs.input_ids,
705 "pixel_values": pixel_values, 706 "pixel_values": pixel_values,
706 "attention_mask": inputs.attention_mask, 707 "attention_mask": inputs.attention_mask,
707 } 708 }
709
708 return batch 710 return batch
709 711
710 datamodule = VlpnDataModule( 712 datamodule = VlpnDataModule(
@@ -714,8 +716,8 @@ def main():
714 class_subdir=args.class_image_dir, 716 class_subdir=args.class_image_dir,
715 num_class_images=args.num_class_images, 717 num_class_images=args.num_class_images,
716 size=args.resolution, 718 size=args.resolution,
717 num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, 719 num_buckets=args.num_buckets,
718 progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, 720 progressive_buckets=args.progressive_buckets,
719 dropout=args.tag_dropout, 721 dropout=args.tag_dropout,
720 template_key=args.train_data_template, 722 template_key=args.train_data_template,
721 valid_set_size=args.valid_set_size, 723 valid_set_size=args.valid_set_size,
diff --git a/training/util.py b/training/util.py
index ae6bfc4..60d64f0 100644
--- a/training/util.py
+++ b/training/util.py
@@ -73,20 +73,22 @@ class CheckpointerBase:
73 file_path.parent.mkdir(parents=True, exist_ok=True) 73 file_path.parent.mkdir(parents=True, exist_ok=True)
74 74
75 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches)) 75 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
76 prompts = [ 76 prompt_ids = [
77 prompt 77 prompt
78 for batch in batches 78 for batch in batches
79 for prompt in batch["prompts"] 79 for prompt in batch["prompt_ids"]
80 ] 80 ]
81 nprompts = [ 81 nprompt_ids = [
82 prompt 82 prompt
83 for batch in batches 83 for batch in batches
84 for prompt in batch["nprompts"] 84 for prompt in batch["nprompt_ids"]
85 ] 85 ]
86 86
87 for i in range(self.sample_batches): 87 for i in range(self.sample_batches):
88 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 88 start = i * self.sample_batch_size
89 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] 89 end = (i + 1) * self.sample_batch_size
90 prompt = prompt_ids[start:end]
91 nprompt = nprompt_ids[start:end]
90 92
91 samples = pipeline( 93 samples = pipeline(
92 prompt=prompt, 94 prompt=prompt,