diff options
| -rw-r--r-- | data/csv.py | 265 | ||||
| -rw-r--r-- | train_dreambooth.py | 92 | ||||
| -rw-r--r-- | train_ti.py | 85 | ||||
| -rw-r--r-- | training/util.py | 2 | 
4 files changed, 229 insertions, 215 deletions
| diff --git a/data/csv.py b/data/csv.py index 654aec1..9be36ba 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -2,20 +2,28 @@ import math | |||
| 2 | import torch | 2 | import torch | 
| 3 | import json | 3 | import json | 
| 4 | from pathlib import Path | 4 | from pathlib import Path | 
| 5 | from typing import NamedTuple, Optional, Union, Callable | ||
| 6 | |||
| 5 | from PIL import Image | 7 | from PIL import Image | 
| 6 | from torch.utils.data import Dataset, DataLoader, random_split | ||
| 7 | from torchvision import transforms | ||
| 8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable | ||
| 9 | 8 | ||
| 10 | import numpy as np | 9 | from torch.utils.data import IterableDataset, DataLoader, random_split | 
| 10 | from torchvision import transforms | ||
| 11 | 11 | ||
| 12 | from models.clip.prompt import PromptProcessor | ||
| 13 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 12 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 
| 13 | from models.clip.prompt import PromptProcessor | ||
| 14 | 14 | ||
| 15 | 15 | ||
| 16 | image_cache: dict[str, Image.Image] = {} | 16 | image_cache: dict[str, Image.Image] = {} | 
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | interpolations = { | ||
| 20 | "linear": transforms.InterpolationMode.NEAREST, | ||
| 21 | "bilinear": transforms.InterpolationMode.BILINEAR, | ||
| 22 | "bicubic": transforms.InterpolationMode.BICUBIC, | ||
| 23 | "lanczos": transforms.InterpolationMode.LANCZOS, | ||
| 24 | } | ||
| 25 | |||
| 26 | |||
| 19 | def get_image(path): | 27 | def get_image(path): | 
| 20 | if path in image_cache: | 28 | if path in image_cache: | 
| 21 | return image_cache[path] | 29 | return image_cache[path] | 
| @@ -28,10 +36,46 @@ def get_image(path): | |||
| 28 | return image | 36 | return image | 
| 29 | 37 | ||
| 30 | 38 | ||
| 31 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 39 | def prepare_prompt(prompt: Union[str, dict[str, str]]): | 
| 32 | return {"content": prompt} if isinstance(prompt, str) else prompt | 40 | return {"content": prompt} if isinstance(prompt, str) else prompt | 
| 33 | 41 | ||
| 34 | 42 | ||
| 43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): | ||
| 44 | item_order: list[int] = [] | ||
| 45 | item_buckets: list[int] = [] | ||
| 46 | buckets = [1.0] | ||
| 47 | |||
| 48 | for i in range(1, num_buckets + 1): | ||
| 49 | s = size + i * 64 | ||
| 50 | buckets.append(s / size) | ||
| 51 | buckets.append(size / s) | ||
| 52 | |||
| 53 | buckets = torch.tensor(buckets) | ||
| 54 | bucket_indices = torch.arange(len(buckets)) | ||
| 55 | |||
| 56 | for i, item in enumerate(items): | ||
| 57 | image = get_image(item) | ||
| 58 | ratio = image.width / image.height | ||
| 59 | |||
| 60 | if ratio >= 1: | ||
| 61 | mask = torch.bitwise_and(buckets >= 1, buckets <= ratio) | ||
| 62 | else: | ||
| 63 | mask = torch.bitwise_and(buckets <= 1, buckets >= ratio) | ||
| 64 | |||
| 65 | if not progressive_buckets: | ||
| 66 | mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() | ||
| 67 | |||
| 68 | indices = bucket_indices[mask] | ||
| 69 | |||
| 70 | if len(indices.shape) == 0: | ||
| 71 | indices = indices.unsqueeze(0) | ||
| 72 | |||
| 73 | item_order += [i] * len(indices) | ||
| 74 | item_buckets += indices | ||
| 75 | |||
| 76 | return buckets.tolist(), item_order, item_buckets | ||
| 77 | |||
| 78 | |||
| 35 | class VlpnDataItem(NamedTuple): | 79 | class VlpnDataItem(NamedTuple): | 
| 36 | instance_image_path: Path | 80 | instance_image_path: Path | 
| 37 | class_image_path: Path | 81 | class_image_path: Path | 
| @@ -41,14 +85,6 @@ class VlpnDataItem(NamedTuple): | |||
| 41 | collection: list[str] | 85 | collection: list[str] | 
| 42 | 86 | ||
| 43 | 87 | ||
| 44 | class VlpnDataBucket(): | ||
| 45 | def __init__(self, width: int, height: int): | ||
| 46 | self.width = width | ||
| 47 | self.height = height | ||
| 48 | self.ratio = width / height | ||
| 49 | self.items: list[VlpnDataItem] = [] | ||
| 50 | |||
| 51 | |||
| 52 | class VlpnDataModule(): | 88 | class VlpnDataModule(): | 
| 53 | def __init__( | 89 | def __init__( | 
| 54 | self, | 90 | self, | 
| @@ -60,7 +96,6 @@ class VlpnDataModule(): | |||
| 60 | size: int = 768, | 96 | size: int = 768, | 
| 61 | num_aspect_ratio_buckets: int = 0, | 97 | num_aspect_ratio_buckets: int = 0, | 
| 62 | progressive_aspect_ratio_buckets: bool = False, | 98 | progressive_aspect_ratio_buckets: bool = False, | 
| 63 | repeats: int = 1, | ||
| 64 | dropout: float = 0, | 99 | dropout: float = 0, | 
| 65 | interpolation: str = "bicubic", | 100 | interpolation: str = "bicubic", | 
| 66 | template_key: str = "template", | 101 | template_key: str = "template", | 
| @@ -86,7 +121,6 @@ class VlpnDataModule(): | |||
| 86 | self.size = size | 121 | self.size = size | 
| 87 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | 122 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | 
| 88 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | 123 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | 
| 89 | self.repeats = repeats | ||
| 90 | self.dropout = dropout | 124 | self.dropout = dropout | 
| 91 | self.template_key = template_key | 125 | self.template_key = template_key | 
| 92 | self.interpolation = interpolation | 126 | self.interpolation = interpolation | 
| @@ -146,36 +180,6 @@ class VlpnDataModule(): | |||
| 146 | for i in range(image_multiplier) | 180 | for i in range(image_multiplier) | 
| 147 | ] | 181 | ] | 
| 148 | 182 | ||
| 149 | def generate_buckets(self, items: list[VlpnDataItem]): | ||
| 150 | buckets = [VlpnDataBucket(self.size, self.size)] | ||
| 151 | |||
| 152 | for i in range(1, self.num_aspect_ratio_buckets + 1): | ||
| 153 | s = self.size + i * 64 | ||
| 154 | buckets.append(VlpnDataBucket(s, self.size)) | ||
| 155 | buckets.append(VlpnDataBucket(self.size, s)) | ||
| 156 | |||
| 157 | buckets = np.array(buckets) | ||
| 158 | bucket_ratios = np.array([bucket.ratio for bucket in buckets]) | ||
| 159 | |||
| 160 | for item in items: | ||
| 161 | image = get_image(item.instance_image_path) | ||
| 162 | ratio = image.width / image.height | ||
| 163 | |||
| 164 | if ratio >= 1: | ||
| 165 | mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) | ||
| 166 | else: | ||
| 167 | mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) | ||
| 168 | |||
| 169 | if not self.progressive_aspect_ratio_buckets: | ||
| 170 | ratios = bucket_ratios.copy() | ||
| 171 | ratios[~mask] = math.inf | ||
| 172 | mask = [np.argmin(np.abs(ratios - ratio))] | ||
| 173 | |||
| 174 | for bucket in buckets[mask]: | ||
| 175 | bucket.items.append(item) | ||
| 176 | |||
| 177 | return [bucket for bucket in buckets if len(bucket.items) != 0] | ||
| 178 | |||
| 179 | def setup(self): | 183 | def setup(self): | 
| 180 | with open(self.data_file, 'rt') as f: | 184 | with open(self.data_file, 'rt') as f: | 
| 181 | metadata = json.load(f) | 185 | metadata = json.load(f) | 
| @@ -201,105 +205,136 @@ class VlpnDataModule(): | |||
| 201 | self.data_train = self.pad_items(data_train, self.num_class_images) | 205 | self.data_train = self.pad_items(data_train, self.num_class_images) | 
| 202 | self.data_val = self.pad_items(data_val) | 206 | self.data_val = self.pad_items(data_val) | 
| 203 | 207 | ||
| 204 | buckets = self.generate_buckets(data_train) | 208 | train_dataset = VlpnDataset( | 
| 205 | 209 | self.data_train, self.prompt_processor, | |
| 206 | train_datasets = [ | 210 | num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, | 
| 207 | VlpnDataset( | 211 | batch_size=self.batch_size, | 
| 208 | bucket.items, self.prompt_processor, | 212 | size=self.size, interpolation=self.interpolation, | 
| 209 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, | 213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, | 
| 210 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, | 214 | ) | 
| 211 | ) | ||
| 212 | for bucket in buckets | ||
| 213 | ] | ||
| 214 | 215 | ||
| 215 | val_dataset = VlpnDataset( | 216 | val_dataset = VlpnDataset( | 
| 216 | data_val, self.prompt_processor, | 217 | self.data_val, self.prompt_processor, | 
| 217 | width=self.size, height=self.size, interpolation=self.interpolation, | 218 | batch_size=self.batch_size, | 
| 219 | size=self.size, interpolation=self.interpolation, | ||
| 218 | ) | 220 | ) | 
| 219 | 221 | ||
| 220 | self.train_dataloaders = [ | 222 | self.train_dataloader = DataLoader( | 
| 221 | DataLoader( | 223 | train_dataset, | 
| 222 | dataset, batch_size=self.batch_size, shuffle=True, | 224 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 
| 223 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 225 | ) | 
| 224 | ) | ||
| 225 | for dataset in train_datasets | ||
| 226 | ] | ||
| 227 | 226 | ||
| 228 | self.val_dataloader = DataLoader( | 227 | self.val_dataloader = DataLoader( | 
| 229 | val_dataset, batch_size=self.batch_size, | 228 | val_dataset, | 
| 230 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 229 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 
| 231 | ) | 230 | ) | 
| 232 | 231 | ||
| 233 | 232 | ||
| 234 | class VlpnDataset(Dataset): | 233 | class VlpnDataset(IterableDataset): | 
| 235 | def __init__( | 234 | def __init__( | 
| 236 | self, | 235 | self, | 
| 237 | data: List[VlpnDataItem], | 236 | items: list[VlpnDataItem], | 
| 238 | prompt_processor: PromptProcessor, | 237 | prompt_processor: PromptProcessor, | 
| 238 | num_buckets: int = 1, | ||
| 239 | progressive_buckets: bool = False, | ||
| 240 | batch_size: int = 1, | ||
| 239 | num_class_images: int = 0, | 241 | num_class_images: int = 0, | 
| 240 | width: int = 768, | 242 | size: int = 768, | 
| 241 | height: int = 768, | ||
| 242 | repeats: int = 1, | ||
| 243 | dropout: float = 0, | 243 | dropout: float = 0, | 
| 244 | shuffle: bool = False, | ||
| 244 | interpolation: str = "bicubic", | 245 | interpolation: str = "bicubic", | 
| 246 | generator: Optional[torch.Generator] = None, | ||
| 245 | ): | 247 | ): | 
| 248 | self.items = items | ||
| 249 | self.batch_size = batch_size | ||
| 246 | 250 | ||
| 247 | self.data = data | ||
| 248 | self.prompt_processor = prompt_processor | 251 | self.prompt_processor = prompt_processor | 
| 249 | self.num_class_images = num_class_images | 252 | self.num_class_images = num_class_images | 
| 253 | self.size = size | ||
| 250 | self.dropout = dropout | 254 | self.dropout = dropout | 
| 255 | self.shuffle = shuffle | ||
| 256 | self.interpolation = interpolations[interpolation] | ||
| 257 | self.generator = generator | ||
| 251 | 258 | ||
| 252 | self.num_instance_images = len(self.data) | 259 | buckets, item_order, item_buckets = generate_buckets( | 
| 253 | self._length = self.num_instance_images * repeats | 260 | [item.instance_image_path for item in items], | 
| 254 | 261 | size, | |
| 255 | self.interpolation = { | 262 | num_buckets, | 
| 256 | "linear": transforms.InterpolationMode.NEAREST, | 263 | progressive_buckets | 
| 257 | "bilinear": transforms.InterpolationMode.BILINEAR, | ||
| 258 | "bicubic": transforms.InterpolationMode.BICUBIC, | ||
| 259 | "lanczos": transforms.InterpolationMode.LANCZOS, | ||
| 260 | }[interpolation] | ||
| 261 | self.image_transforms = transforms.Compose( | ||
| 262 | [ | ||
| 263 | transforms.Resize(min(width, height), interpolation=self.interpolation), | ||
| 264 | transforms.RandomCrop((height, width)), | ||
| 265 | transforms.RandomHorizontalFlip(), | ||
| 266 | transforms.ToTensor(), | ||
| 267 | transforms.Normalize([0.5], [0.5]), | ||
| 268 | ] | ||
| 269 | ) | 264 | ) | 
| 270 | 265 | ||
| 266 | self.buckets = torch.tensor(buckets) | ||
| 267 | self.item_order = torch.tensor(item_order) | ||
| 268 | self.item_buckets = torch.tensor(item_buckets) | ||
| 269 | |||
| 271 | def __len__(self): | 270 | def __len__(self): | 
| 272 | return self._length | 271 | return len(self.item_buckets) | 
| 273 | 272 | ||
| 274 | def get_example(self, i): | 273 | def __iter__(self): | 
| 275 | item = self.data[i % self.num_instance_images] | 274 | worker_info = torch.utils.data.get_worker_info() | 
| 276 | 275 | ||
| 277 | example = {} | 276 | if self.shuffle: | 
| 278 | example["prompts"] = item.prompt | 277 | perm = torch.randperm(len(self.item_buckets), generator=self.generator) | 
| 279 | example["cprompts"] = item.cprompt | 278 | self.item_order = self.item_order[perm] | 
| 280 | example["nprompts"] = item.nprompt | 279 | self.item_buckets = self.item_buckets[perm] | 
| 281 | example["instance_images"] = get_image(item.instance_image_path) | ||
| 282 | if self.num_class_images != 0: | ||
| 283 | example["class_images"] = get_image(item.class_image_path) | ||
| 284 | 280 | ||
| 285 | return example | 281 | item_mask = torch.ones_like(self.item_buckets, dtype=bool) | 
| 282 | bucket = -1 | ||
| 283 | image_transforms = None | ||
| 284 | batch = [] | ||
| 285 | batch_size = self.batch_size | ||
| 286 | 286 | ||
| 287 | def __getitem__(self, i): | 287 | if worker_info is not None: | 
| 288 | unprocessed_example = self.get_example(i) | 288 | batch_size = math.ceil(batch_size / worker_info.num_workers) | 
| 289 | worker_batch = math.ceil(len(self) / worker_info.num_workers) | ||
| 290 | start = worker_info.id * worker_batch | ||
| 291 | end = start + worker_batch | ||
| 292 | item_mask[:start] = False | ||
| 293 | item_mask[end:] = False | ||
| 289 | 294 | ||
| 290 | example = {} | 295 | while item_mask.any(): | 
| 296 | item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] | ||
| 291 | 297 | ||
| 292 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"]) | 298 | if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): | 
| 293 | example["cprompts"] = unprocessed_example["cprompts"] | 299 | yield batch | 
| 294 | example["nprompts"] = unprocessed_example["nprompts"] | 300 | batch = [] | 
| 295 | 301 | ||
| 296 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 302 | if len(item_indices) == 0: | 
| 297 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 303 | bucket = self.item_buckets[item_mask][0] | 
| 298 | keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) | 304 | ratio = self.buckets[bucket] | 
| 299 | ) | 305 | width = self.size * ratio if ratio > 1 else self.size | 
| 306 | height = self.size / ratio if ratio < 1 else self.size | ||
| 307 | |||
| 308 | image_transforms = transforms.Compose( | ||
| 309 | [ | ||
| 310 | transforms.Resize(min(width, height), interpolation=self.interpolation), | ||
| 311 | transforms.RandomCrop((height, width)), | ||
| 312 | transforms.RandomHorizontalFlip(), | ||
| 313 | transforms.ToTensor(), | ||
| 314 | transforms.Normalize([0.5], [0.5]), | ||
| 315 | ] | ||
| 316 | ) | ||
| 317 | else: | ||
| 318 | item_index = item_indices[0] | ||
| 319 | item = self.items[item_index] | ||
| 320 | item_mask[item_index] = False | ||
| 321 | |||
| 322 | example = {} | ||
| 323 | |||
| 324 | example["prompts"] = keywords_to_prompt(item.prompt) | ||
| 325 | example["cprompts"] = item.cprompt | ||
| 326 | example["nprompts"] = item.nprompt | ||
| 327 | |||
| 328 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) | ||
| 329 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | ||
| 330 | keywords_to_prompt(item.prompt, self.dropout, True) | ||
| 331 | ) | ||
| 332 | |||
| 333 | if self.num_class_images != 0: | ||
| 334 | example["class_images"] = image_transforms(get_image(item.class_image_path)) | ||
| 335 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) | ||
| 300 | 336 | ||
| 301 | if self.num_class_images != 0: | 337 | batch.append(example) | 
| 302 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | ||
| 303 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) | ||
| 304 | 338 | ||
| 305 | return example | 339 | if len(batch) != 0: | 
| 340 | yield batch | ||
| diff --git a/train_dreambooth.py b/train_dreambooth.py index 589af59..42a7d0f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -134,12 +134,6 @@ def parse_args(): | |||
| 134 | help="The directory where class images will be saved.", | 134 | help="The directory where class images will be saved.", | 
| 135 | ) | 135 | ) | 
| 136 | parser.add_argument( | 136 | parser.add_argument( | 
| 137 | "--repeats", | ||
| 138 | type=int, | ||
| 139 | default=1, | ||
| 140 | help="How many times to repeat the training data." | ||
| 141 | ) | ||
| 142 | parser.add_argument( | ||
| 143 | "--output_dir", | 137 | "--output_dir", | 
| 144 | type=str, | 138 | type=str, | 
| 145 | default="output/dreambooth", | 139 | default="output/dreambooth", | 
| @@ -738,7 +732,6 @@ def main(): | |||
| 738 | class_subdir=args.class_image_dir, | 732 | class_subdir=args.class_image_dir, | 
| 739 | num_class_images=args.num_class_images, | 733 | num_class_images=args.num_class_images, | 
| 740 | size=args.resolution, | 734 | size=args.resolution, | 
| 741 | repeats=args.repeats, | ||
| 742 | dropout=args.tag_dropout, | 735 | dropout=args.tag_dropout, | 
| 743 | template_key=args.train_data_template, | 736 | template_key=args.train_data_template, | 
| 744 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, | 
| @@ -751,7 +744,7 @@ def main(): | |||
| 751 | datamodule.prepare_data() | 744 | datamodule.prepare_data() | 
| 752 | datamodule.setup() | 745 | datamodule.setup() | 
| 753 | 746 | ||
| 754 | train_dataloaders = datamodule.train_dataloaders | 747 | train_dataloader = datamodule.train_dataloader | 
| 755 | val_dataloader = datamodule.val_dataloader | 748 | val_dataloader = datamodule.val_dataloader | 
| 756 | 749 | ||
| 757 | if args.num_class_images != 0: | 750 | if args.num_class_images != 0: | 
| @@ -770,8 +763,7 @@ def main(): | |||
| 770 | 763 | ||
| 771 | # Scheduler and math around the number of training steps. | 764 | # Scheduler and math around the number of training steps. | 
| 772 | overrode_max_train_steps = False | 765 | overrode_max_train_steps = False | 
| 773 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 766 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 
| 774 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 775 | if args.max_train_steps is None: | 767 | if args.max_train_steps is None: | 
| 776 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 768 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 
| 777 | overrode_max_train_steps = True | 769 | overrode_max_train_steps = True | 
| @@ -820,8 +812,7 @@ def main(): | |||
| 820 | ema_unet.to(accelerator.device) | 812 | ema_unet.to(accelerator.device) | 
| 821 | 813 | ||
| 822 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 814 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 
| 823 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 815 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 
| 824 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 825 | if overrode_max_train_steps: | 816 | if overrode_max_train_steps: | 
| 826 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 817 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 
| 827 | 818 | ||
| @@ -877,7 +868,7 @@ def main(): | |||
| 877 | accelerator, | 868 | accelerator, | 
| 878 | text_encoder, | 869 | text_encoder, | 
| 879 | optimizer, | 870 | optimizer, | 
| 880 | train_dataloaders[0], | 871 | train_dataloader, | 
| 881 | val_dataloader, | 872 | val_dataloader, | 
| 882 | loop, | 873 | loop, | 
| 883 | on_train=tokenizer.train, | 874 | on_train=tokenizer.train, | 
| @@ -960,54 +951,53 @@ def main(): | |||
| 960 | text_encoder.requires_grad_(False) | 951 | text_encoder.requires_grad_(False) | 
| 961 | 952 | ||
| 962 | with on_train(): | 953 | with on_train(): | 
| 963 | for train_dataloader in train_dataloaders: | 954 | for step, batch in enumerate(train_dataloader): | 
| 964 | for step, batch in enumerate(train_dataloader): | 955 | with accelerator.accumulate(unet): | 
| 965 | with accelerator.accumulate(unet): | 956 | loss, acc, bsz = loop(step, batch) | 
| 966 | loss, acc, bsz = loop(step, batch) | ||
| 967 | 957 | ||
| 968 | accelerator.backward(loss) | 958 | accelerator.backward(loss) | 
| 969 | 959 | ||
| 970 | if accelerator.sync_gradients: | 960 | if accelerator.sync_gradients: | 
| 971 | params_to_clip = ( | 961 | params_to_clip = ( | 
| 972 | itertools.chain(unet.parameters(), text_encoder.parameters()) | 962 | itertools.chain(unet.parameters(), text_encoder.parameters()) | 
| 973 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | 963 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | 
| 974 | else unet.parameters() | 964 | else unet.parameters() | 
| 975 | ) | 965 | ) | 
| 976 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | 966 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | 
| 977 | 967 | ||
| 978 | optimizer.step() | 968 | optimizer.step() | 
| 979 | if not accelerator.optimizer_step_was_skipped: | 969 | if not accelerator.optimizer_step_was_skipped: | 
| 980 | lr_scheduler.step() | 970 | lr_scheduler.step() | 
| 981 | if args.use_ema: | 971 | if args.use_ema: | 
| 982 | ema_unet.step(unet.parameters()) | 972 | ema_unet.step(unet.parameters()) | 
| 983 | optimizer.zero_grad(set_to_none=True) | 973 | optimizer.zero_grad(set_to_none=True) | 
| 984 | 974 | ||
| 985 | avg_loss.update(loss.detach_(), bsz) | 975 | avg_loss.update(loss.detach_(), bsz) | 
| 986 | avg_acc.update(acc.detach_(), bsz) | 976 | avg_acc.update(acc.detach_(), bsz) | 
| 987 | 977 | ||
| 988 | # Checks if the accelerator has performed an optimization step behind the scenes | 978 | # Checks if the accelerator has performed an optimization step behind the scenes | 
| 989 | if accelerator.sync_gradients: | 979 | if accelerator.sync_gradients: | 
| 990 | local_progress_bar.update(1) | 980 | local_progress_bar.update(1) | 
| 991 | global_progress_bar.update(1) | 981 | global_progress_bar.update(1) | 
| 992 | 982 | ||
| 993 | global_step += 1 | 983 | global_step += 1 | 
| 994 | 984 | ||
| 995 | logs = { | 985 | logs = { | 
| 996 | "train/loss": avg_loss.avg.item(), | 986 | "train/loss": avg_loss.avg.item(), | 
| 997 | "train/acc": avg_acc.avg.item(), | 987 | "train/acc": avg_acc.avg.item(), | 
| 998 | "train/cur_loss": loss.item(), | 988 | "train/cur_loss": loss.item(), | 
| 999 | "train/cur_acc": acc.item(), | 989 | "train/cur_acc": acc.item(), | 
| 1000 | "lr": lr_scheduler.get_last_lr()[0] | 990 | "lr": lr_scheduler.get_last_lr()[0] | 
| 1001 | } | 991 | } | 
| 1002 | if args.use_ema: | 992 | if args.use_ema: | 
| 1003 | logs["ema_decay"] = 1 - ema_unet.decay | 993 | logs["ema_decay"] = 1 - ema_unet.decay | 
| 1004 | 994 | ||
| 1005 | accelerator.log(logs, step=global_step) | 995 | accelerator.log(logs, step=global_step) | 
| 1006 | 996 | ||
| 1007 | local_progress_bar.set_postfix(**logs) | 997 | local_progress_bar.set_postfix(**logs) | 
| 1008 | 998 | ||
| 1009 | if global_step >= args.max_train_steps: | 999 | if global_step >= args.max_train_steps: | 
| 1010 | break | 1000 | break | 
| 1011 | 1001 | ||
| 1012 | accelerator.wait_for_everyone() | 1002 | accelerator.wait_for_everyone() | 
| 1013 | 1003 | ||
| diff --git a/train_ti.py b/train_ti.py index b4b602b..727b591 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -107,12 +107,6 @@ def parse_args(): | |||
| 107 | help="Exclude all items with a listed collection.", | 107 | help="Exclude all items with a listed collection.", | 
| 108 | ) | 108 | ) | 
| 109 | parser.add_argument( | 109 | parser.add_argument( | 
| 110 | "--repeats", | ||
| 111 | type=int, | ||
| 112 | default=1, | ||
| 113 | help="How many times to repeat the training data." | ||
| 114 | ) | ||
| 115 | parser.add_argument( | ||
| 116 | "--output_dir", | 110 | "--output_dir", | 
| 117 | type=str, | 111 | type=str, | 
| 118 | default="output/text-inversion", | 112 | default="output/text-inversion", | 
| @@ -722,7 +716,6 @@ def main(): | |||
| 722 | size=args.resolution, | 716 | size=args.resolution, | 
| 723 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | 717 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | 
| 724 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | 718 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | 
| 725 | repeats=args.repeats, | ||
| 726 | dropout=args.tag_dropout, | 719 | dropout=args.tag_dropout, | 
| 727 | template_key=args.train_data_template, | 720 | template_key=args.train_data_template, | 
| 728 | valid_set_size=args.valid_set_size, | 721 | valid_set_size=args.valid_set_size, | 
| @@ -733,7 +726,7 @@ def main(): | |||
| 733 | ) | 726 | ) | 
| 734 | datamodule.setup() | 727 | datamodule.setup() | 
| 735 | 728 | ||
| 736 | train_dataloaders = datamodule.train_dataloaders | 729 | train_dataloader = datamodule.train_dataloader | 
| 737 | val_dataloader = datamodule.val_dataloader | 730 | val_dataloader = datamodule.val_dataloader | 
| 738 | 731 | ||
| 739 | if args.num_class_images != 0: | 732 | if args.num_class_images != 0: | 
| @@ -752,8 +745,7 @@ def main(): | |||
| 752 | 745 | ||
| 753 | # Scheduler and math around the number of training steps. | 746 | # Scheduler and math around the number of training steps. | 
| 754 | overrode_max_train_steps = False | 747 | overrode_max_train_steps = False | 
| 755 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 748 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 
| 756 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 757 | if args.max_train_steps is None: | 749 | if args.max_train_steps is None: | 
| 758 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 750 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 
| 759 | overrode_max_train_steps = True | 751 | overrode_max_train_steps = True | 
| @@ -790,10 +782,9 @@ def main(): | |||
| 790 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 782 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 
| 791 | ) | 783 | ) | 
| 792 | 784 | ||
| 793 | text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( | 785 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 
| 794 | text_encoder, optimizer, val_dataloader, lr_scheduler | 786 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 
| 795 | ) | 787 | ) | 
| 796 | train_dataloaders = accelerator.prepare(*train_dataloaders) | ||
| 797 | 788 | ||
| 798 | # Move vae and unet to device | 789 | # Move vae and unet to device | 
| 799 | vae.to(accelerator.device, dtype=weight_dtype) | 790 | vae.to(accelerator.device, dtype=weight_dtype) | 
| @@ -811,8 +802,7 @@ def main(): | |||
| 811 | unet.eval() | 802 | unet.eval() | 
| 812 | 803 | ||
| 813 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 804 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 
| 814 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 805 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 
| 815 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 816 | if overrode_max_train_steps: | 806 | if overrode_max_train_steps: | 
| 817 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 807 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 
| 818 | 808 | ||
| @@ -870,7 +860,7 @@ def main(): | |||
| 870 | accelerator, | 860 | accelerator, | 
| 871 | text_encoder, | 861 | text_encoder, | 
| 872 | optimizer, | 862 | optimizer, | 
| 873 | train_dataloaders[0], | 863 | train_dataloader, | 
| 874 | val_dataloader, | 864 | val_dataloader, | 
| 875 | loop, | 865 | loop, | 
| 876 | on_train=on_train, | 866 | on_train=on_train, | 
| @@ -949,48 +939,47 @@ def main(): | |||
| 949 | text_encoder.train() | 939 | text_encoder.train() | 
| 950 | 940 | ||
| 951 | with on_train(): | 941 | with on_train(): | 
| 952 | for train_dataloader in train_dataloaders: | 942 | for step, batch in enumerate(train_dataloader): | 
| 953 | for step, batch in enumerate(train_dataloader): | 943 | with accelerator.accumulate(text_encoder): | 
| 954 | with accelerator.accumulate(text_encoder): | 944 | loss, acc, bsz = loop(step, batch) | 
| 955 | loss, acc, bsz = loop(step, batch) | ||
| 956 | 945 | ||
| 957 | accelerator.backward(loss) | 946 | accelerator.backward(loss) | 
| 958 | 947 | ||
| 959 | optimizer.step() | 948 | optimizer.step() | 
| 960 | if not accelerator.optimizer_step_was_skipped: | 949 | if not accelerator.optimizer_step_was_skipped: | 
| 961 | lr_scheduler.step() | 950 | lr_scheduler.step() | 
| 962 | optimizer.zero_grad(set_to_none=True) | 951 | optimizer.zero_grad(set_to_none=True) | 
| 963 | 952 | ||
| 964 | avg_loss.update(loss.detach_(), bsz) | 953 | avg_loss.update(loss.detach_(), bsz) | 
| 965 | avg_acc.update(acc.detach_(), bsz) | 954 | avg_acc.update(acc.detach_(), bsz) | 
| 966 | 955 | ||
| 967 | # Checks if the accelerator has performed an optimization step behind the scenes | 956 | # Checks if the accelerator has performed an optimization step behind the scenes | 
| 968 | if accelerator.sync_gradients: | 957 | if accelerator.sync_gradients: | 
| 969 | if args.use_ema: | 958 | if args.use_ema: | 
| 970 | ema_embeddings.step( | 959 | ema_embeddings.step( | 
| 971 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 960 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 
| 972 | 961 | ||
| 973 | local_progress_bar.update(1) | 962 | local_progress_bar.update(1) | 
| 974 | global_progress_bar.update(1) | 963 | global_progress_bar.update(1) | 
| 975 | 964 | ||
| 976 | global_step += 1 | 965 | global_step += 1 | 
| 977 | 966 | ||
| 978 | logs = { | 967 | logs = { | 
| 979 | "train/loss": avg_loss.avg.item(), | 968 | "train/loss": avg_loss.avg.item(), | 
| 980 | "train/acc": avg_acc.avg.item(), | 969 | "train/acc": avg_acc.avg.item(), | 
| 981 | "train/cur_loss": loss.item(), | 970 | "train/cur_loss": loss.item(), | 
| 982 | "train/cur_acc": acc.item(), | 971 | "train/cur_acc": acc.item(), | 
| 983 | "lr": lr_scheduler.get_last_lr()[0], | 972 | "lr": lr_scheduler.get_last_lr()[0], | 
| 984 | } | 973 | } | 
| 985 | if args.use_ema: | 974 | if args.use_ema: | 
| 986 | logs["ema_decay"] = ema_embeddings.decay | 975 | logs["ema_decay"] = ema_embeddings.decay | 
| 987 | 976 | ||
| 988 | accelerator.log(logs, step=global_step) | 977 | accelerator.log(logs, step=global_step) | 
| 989 | 978 | ||
| 990 | local_progress_bar.set_postfix(**logs) | 979 | local_progress_bar.set_postfix(**logs) | 
| 991 | 980 | ||
| 992 | if global_step >= args.max_train_steps: | 981 | if global_step >= args.max_train_steps: | 
| 993 | break | 982 | break | 
| 994 | 983 | ||
| 995 | accelerator.wait_for_everyone() | 984 | accelerator.wait_for_everyone() | 
| 996 | 985 | ||
| diff --git a/training/util.py b/training/util.py index 2b7f71d..ae6bfc4 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -59,7 +59,7 @@ class CheckpointerBase: | |||
| 59 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 59 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 
| 60 | samples_path = Path(self.output_dir).joinpath("samples") | 60 | samples_path = Path(self.output_dir).joinpath("samples") | 
| 61 | 61 | ||
| 62 | train_data = self.datamodule.train_dataloaders[0] | 62 | train_data = self.datamodule.train_dataloader | 
| 63 | val_data = self.datamodule.val_dataloader | 63 | val_data = self.datamodule.val_dataloader | 
| 64 | 64 | ||
| 65 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 65 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 
