diff options
| -rw-r--r-- | data/csv.py | 144 | ||||
| -rw-r--r-- | infer.py | 23 | ||||
| -rw-r--r-- | train_dreambooth.py | 12 | ||||
| -rw-r--r-- | train_ti.py | 94 | ||||
| -rw-r--r-- | training/util.py | 4 |
5 files changed, 151 insertions, 126 deletions
diff --git a/data/csv.py b/data/csv.py index 4986153..59d6d8d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -11,11 +11,26 @@ from models.clip.prompt import PromptProcessor | |||
| 11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 11 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
| 12 | 12 | ||
| 13 | 13 | ||
| 14 | image_cache: dict[str, Image.Image] = {} | ||
| 15 | |||
| 16 | |||
| 17 | def get_image(path): | ||
| 18 | if path in image_cache: | ||
| 19 | return image_cache[path] | ||
| 20 | |||
| 21 | image = Image.open(path) | ||
| 22 | if not image.mode == "RGB": | ||
| 23 | image = image.convert("RGB") | ||
| 24 | image_cache[path] = image | ||
| 25 | |||
| 26 | return image | ||
| 27 | |||
| 28 | |||
| 14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 29 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): |
| 15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 30 | return {"content": prompt} if isinstance(prompt, str) else prompt |
| 16 | 31 | ||
| 17 | 32 | ||
| 18 | class CSVDataItem(NamedTuple): | 33 | class VlpnDataItem(NamedTuple): |
| 19 | instance_image_path: Path | 34 | instance_image_path: Path |
| 20 | class_image_path: Path | 35 | class_image_path: Path |
| 21 | prompt: list[str] | 36 | prompt: list[str] |
| @@ -24,7 +39,15 @@ class CSVDataItem(NamedTuple): | |||
| 24 | collection: list[str] | 39 | collection: list[str] |
| 25 | 40 | ||
| 26 | 41 | ||
| 27 | class CSVDataModule(): | 42 | class VlpnDataBucket(): |
| 43 | def __init__(self, width: int, height: int): | ||
| 44 | self.width = width | ||
| 45 | self.height = height | ||
| 46 | self.ratio = width / height | ||
| 47 | self.items: list[VlpnDataItem] = [] | ||
| 48 | |||
| 49 | |||
| 50 | class VlpnDataModule(): | ||
| 28 | def __init__( | 51 | def __init__( |
| 29 | self, | 52 | self, |
| 30 | batch_size: int, | 53 | batch_size: int, |
| @@ -36,11 +59,10 @@ class CSVDataModule(): | |||
| 36 | repeats: int = 1, | 59 | repeats: int = 1, |
| 37 | dropout: float = 0, | 60 | dropout: float = 0, |
| 38 | interpolation: str = "bicubic", | 61 | interpolation: str = "bicubic", |
| 39 | center_crop: bool = False, | ||
| 40 | template_key: str = "template", | 62 | template_key: str = "template", |
| 41 | valid_set_size: Optional[int] = None, | 63 | valid_set_size: Optional[int] = None, |
| 42 | seed: Optional[int] = None, | 64 | seed: Optional[int] = None, |
| 43 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 65 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
| 44 | collate_fn=None, | 66 | collate_fn=None, |
| 45 | num_workers: int = 0 | 67 | num_workers: int = 0 |
| 46 | ): | 68 | ): |
| @@ -60,7 +82,6 @@ class CSVDataModule(): | |||
| 60 | self.size = size | 82 | self.size = size |
| 61 | self.repeats = repeats | 83 | self.repeats = repeats |
| 62 | self.dropout = dropout | 84 | self.dropout = dropout |
| 63 | self.center_crop = center_crop | ||
| 64 | self.template_key = template_key | 85 | self.template_key = template_key |
| 65 | self.interpolation = interpolation | 86 | self.interpolation = interpolation |
| 66 | self.valid_set_size = valid_set_size | 87 | self.valid_set_size = valid_set_size |
| @@ -70,14 +91,14 @@ class CSVDataModule(): | |||
| 70 | self.num_workers = num_workers | 91 | self.num_workers = num_workers |
| 71 | self.batch_size = batch_size | 92 | self.batch_size = batch_size |
| 72 | 93 | ||
| 73 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: | 94 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
| 74 | image = template["image"] if "image" in template else "{}" | 95 | image = template["image"] if "image" in template else "{}" |
| 75 | prompt = template["prompt"] if "prompt" in template else "{content}" | 96 | prompt = template["prompt"] if "prompt" in template else "{content}" |
| 76 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | 97 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" |
| 77 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 98 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
| 78 | 99 | ||
| 79 | return [ | 100 | return [ |
| 80 | CSVDataItem( | 101 | VlpnDataItem( |
| 81 | self.data_root.joinpath(image.format(item["image"])), | 102 | self.data_root.joinpath(image.format(item["image"])), |
| 82 | None, | 103 | None, |
| 83 | prompt_to_keywords( | 104 | prompt_to_keywords( |
| @@ -97,17 +118,17 @@ class CSVDataModule(): | |||
| 97 | for item in data | 118 | for item in data |
| 98 | ] | 119 | ] |
| 99 | 120 | ||
| 100 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | 121 | def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: |
| 101 | if self.filter is None: | 122 | if self.filter is None: |
| 102 | return items | 123 | return items |
| 103 | 124 | ||
| 104 | return [item for item in items if self.filter(item)] | 125 | return [item for item in items if self.filter(item)] |
| 105 | 126 | ||
| 106 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | 127 | def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: |
| 107 | image_multiplier = max(num_class_images, 1) | 128 | image_multiplier = max(num_class_images, 1) |
| 108 | 129 | ||
| 109 | return [ | 130 | return [ |
| 110 | CSVDataItem( | 131 | VlpnDataItem( |
| 111 | item.instance_image_path, | 132 | item.instance_image_path, |
| 112 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 133 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
| 113 | item.prompt, | 134 | item.prompt, |
| @@ -119,7 +140,30 @@ class CSVDataModule(): | |||
| 119 | for i in range(image_multiplier) | 140 | for i in range(image_multiplier) |
| 120 | ] | 141 | ] |
| 121 | 142 | ||
| 122 | def prepare_data(self): | 143 | def generate_buckets(self, items: list[VlpnDataItem]): |
| 144 | buckets = [VlpnDataBucket(self.size, self.size)] | ||
| 145 | |||
| 146 | for i in range(1, 5): | ||
| 147 | s = self.size + i * 64 | ||
| 148 | buckets.append(VlpnDataBucket(s, self.size)) | ||
| 149 | buckets.append(VlpnDataBucket(self.size, s)) | ||
| 150 | |||
| 151 | for item in items: | ||
| 152 | image = get_image(item.instance_image_path) | ||
| 153 | ratio = image.width / image.height | ||
| 154 | |||
| 155 | if ratio >= 1: | ||
| 156 | candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] | ||
| 157 | else: | ||
| 158 | candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] | ||
| 159 | |||
| 160 | for bucket in candidates: | ||
| 161 | bucket.items.append(item) | ||
| 162 | |||
| 163 | buckets = [bucket for bucket in buckets if len(bucket.items) != 0] | ||
| 164 | return buckets | ||
| 165 | |||
| 166 | def setup(self): | ||
| 123 | with open(self.data_file, 'rt') as f: | 167 | with open(self.data_file, 'rt') as f: |
| 124 | metadata = json.load(f) | 168 | metadata = json.load(f) |
| 125 | template = metadata[self.template_key] if self.template_key in metadata else {} | 169 | template = metadata[self.template_key] if self.template_key in metadata else {} |
| @@ -144,48 +188,48 @@ class CSVDataModule(): | |||
| 144 | self.data_train = self.pad_items(data_train, self.num_class_images) | 188 | self.data_train = self.pad_items(data_train, self.num_class_images) |
| 145 | self.data_val = self.pad_items(data_val) | 189 | self.data_val = self.pad_items(data_val) |
| 146 | 190 | ||
| 147 | def setup(self, stage=None): | 191 | buckets = self.generate_buckets(data_train) |
| 148 | train_dataset = CSVDataset( | 192 | |
| 149 | self.data_train, self.prompt_processor, batch_size=self.batch_size, | 193 | train_datasets = [ |
| 150 | num_class_images=self.num_class_images, | 194 | VlpnDataset( |
| 151 | size=self.size, interpolation=self.interpolation, | 195 | bucket.items, self.prompt_processor, batch_size=self.batch_size, |
| 152 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout | 196 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, |
| 153 | ) | 197 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, |
| 154 | val_dataset = CSVDataset( | 198 | ) |
| 155 | self.data_val, self.prompt_processor, batch_size=self.batch_size, | 199 | for bucket in buckets |
| 156 | size=self.size, interpolation=self.interpolation, | 200 | ] |
| 157 | center_crop=self.center_crop | 201 | |
| 158 | ) | 202 | val_dataset = VlpnDataset( |
| 159 | self.train_dataloader_ = DataLoader( | 203 | data_val, self.prompt_processor, batch_size=self.batch_size, |
| 160 | train_dataset, batch_size=self.batch_size, | 204 | width=self.size, height=self.size, interpolation=self.interpolation, |
| 161 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, | ||
| 162 | num_workers=self.num_workers | ||
| 163 | ) | ||
| 164 | self.val_dataloader_ = DataLoader( | ||
| 165 | val_dataset, batch_size=self.batch_size, | ||
| 166 | pin_memory=True, collate_fn=self.collate_fn, | ||
| 167 | num_workers=self.num_workers | ||
| 168 | ) | 205 | ) |
| 169 | 206 | ||
| 170 | def train_dataloader(self): | 207 | self.train_dataloaders = [ |
| 171 | return self.train_dataloader_ | 208 | DataLoader( |
| 209 | dataset, batch_size=self.batch_size, shuffle=True, | ||
| 210 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | ||
| 211 | ) | ||
| 212 | for dataset in train_datasets | ||
| 213 | ] | ||
| 172 | 214 | ||
| 173 | def val_dataloader(self): | 215 | self.val_dataloader = DataLoader( |
| 174 | return self.val_dataloader_ | 216 | val_dataset, batch_size=self.batch_size, |
| 217 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | ||
| 218 | ) | ||
| 175 | 219 | ||
| 176 | 220 | ||
| 177 | class CSVDataset(Dataset): | 221 | class VlpnDataset(Dataset): |
| 178 | def __init__( | 222 | def __init__( |
| 179 | self, | 223 | self, |
| 180 | data: List[CSVDataItem], | 224 | data: List[VlpnDataItem], |
| 181 | prompt_processor: PromptProcessor, | 225 | prompt_processor: PromptProcessor, |
| 182 | batch_size: int = 1, | 226 | batch_size: int = 1, |
| 183 | num_class_images: int = 0, | 227 | num_class_images: int = 0, |
| 184 | size: int = 768, | 228 | width: int = 768, |
| 229 | height: int = 768, | ||
| 185 | repeats: int = 1, | 230 | repeats: int = 1, |
| 186 | dropout: float = 0, | 231 | dropout: float = 0, |
| 187 | interpolation: str = "bicubic", | 232 | interpolation: str = "bicubic", |
| 188 | center_crop: bool = False, | ||
| 189 | ): | 233 | ): |
| 190 | 234 | ||
| 191 | self.data = data | 235 | self.data = data |
| @@ -193,7 +237,6 @@ class CSVDataset(Dataset): | |||
| 193 | self.batch_size = batch_size | 237 | self.batch_size = batch_size |
| 194 | self.num_class_images = num_class_images | 238 | self.num_class_images = num_class_images |
| 195 | self.dropout = dropout | 239 | self.dropout = dropout |
| 196 | self.image_cache = {} | ||
| 197 | 240 | ||
| 198 | self.num_instance_images = len(self.data) | 241 | self.num_instance_images = len(self.data) |
| 199 | self._length = self.num_instance_images * repeats | 242 | self._length = self.num_instance_images * repeats |
| @@ -206,8 +249,8 @@ class CSVDataset(Dataset): | |||
| 206 | }[interpolation] | 249 | }[interpolation] |
| 207 | self.image_transforms = transforms.Compose( | 250 | self.image_transforms = transforms.Compose( |
| 208 | [ | 251 | [ |
| 209 | transforms.Resize(size, interpolation=self.interpolation), | 252 | transforms.Resize(min(width, height), interpolation=self.interpolation), |
| 210 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | 253 | transforms.RandomCrop((height, width)), |
| 211 | transforms.RandomHorizontalFlip(), | 254 | transforms.RandomHorizontalFlip(), |
| 212 | transforms.ToTensor(), | 255 | transforms.ToTensor(), |
| 213 | transforms.Normalize([0.5], [0.5]), | 256 | transforms.Normalize([0.5], [0.5]), |
| @@ -217,17 +260,6 @@ class CSVDataset(Dataset): | |||
| 217 | def __len__(self): | 260 | def __len__(self): |
| 218 | return math.ceil(self._length / self.batch_size) * self.batch_size | 261 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 219 | 262 | ||
| 220 | def get_image(self, path): | ||
| 221 | if path in self.image_cache: | ||
| 222 | return self.image_cache[path] | ||
| 223 | |||
| 224 | image = Image.open(path) | ||
| 225 | if not image.mode == "RGB": | ||
| 226 | image = image.convert("RGB") | ||
| 227 | self.image_cache[path] = image | ||
| 228 | |||
| 229 | return image | ||
| 230 | |||
| 231 | def get_example(self, i): | 263 | def get_example(self, i): |
| 232 | item = self.data[i % self.num_instance_images] | 264 | item = self.data[i % self.num_instance_images] |
| 233 | 265 | ||
| @@ -235,9 +267,9 @@ class CSVDataset(Dataset): | |||
| 235 | example["prompts"] = item.prompt | 267 | example["prompts"] = item.prompt |
| 236 | example["cprompts"] = item.cprompt | 268 | example["cprompts"] = item.cprompt |
| 237 | example["nprompts"] = item.nprompt | 269 | example["nprompts"] = item.nprompt |
| 238 | example["instance_images"] = self.get_image(item.instance_image_path) | 270 | example["instance_images"] = get_image(item.instance_image_path) |
| 239 | if self.num_class_images != 0: | 271 | if self.num_class_images != 0: |
| 240 | example["class_images"] = self.get_image(item.class_image_path) | 272 | example["class_images"] = get_image(item.class_image_path) |
| 241 | 273 | ||
| 242 | return example | 274 | return example |
| 243 | 275 | ||
| @@ -238,16 +238,15 @@ def create_pipeline(model, dtype): | |||
| 238 | return pipeline | 238 | return pipeline |
| 239 | 239 | ||
| 240 | 240 | ||
| 241 | def shuffle_prompts(prompts: list[str]) -> list[str]: | ||
| 242 | return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts] | ||
| 243 | |||
| 244 | |||
| 241 | @torch.inference_mode() | 245 | @torch.inference_mode() |
| 242 | def generate(output_dir: Path, pipeline, args): | 246 | def generate(output_dir: Path, pipeline, args): |
| 243 | if isinstance(args.prompt, str): | 247 | if isinstance(args.prompt, str): |
| 244 | args.prompt = [args.prompt] | 248 | args.prompt = [args.prompt] |
| 245 | 249 | ||
| 246 | if args.shuffle: | ||
| 247 | args.prompt *= args.batch_size | ||
| 248 | args.batch_size = 1 | ||
| 249 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] | ||
| 250 | |||
| 251 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] | 250 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] |
| 252 | 251 | ||
| 253 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 252 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| @@ -263,9 +262,6 @@ def generate(output_dir: Path, pipeline, args): | |||
| 263 | dir = output_dir.joinpath(slugify(prompt)[:100]) | 262 | dir = output_dir.joinpath(slugify(prompt)[:100]) |
| 264 | dir.mkdir(parents=True, exist_ok=True) | 263 | dir.mkdir(parents=True, exist_ok=True) |
| 265 | image_dir.append(dir) | 264 | image_dir.append(dir) |
| 266 | |||
| 267 | with open(dir.joinpath('prompt.txt'), 'w') as f: | ||
| 268 | f.write(prompt) | ||
| 269 | else: | 265 | else: |
| 270 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 266 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
| 271 | output_dir.mkdir(parents=True, exist_ok=True) | 267 | output_dir.mkdir(parents=True, exist_ok=True) |
| @@ -306,9 +302,10 @@ def generate(output_dir: Path, pipeline, args): | |||
| 306 | ) | 302 | ) |
| 307 | 303 | ||
| 308 | seed = args.seed + i | 304 | seed = args.seed + i |
| 305 | prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt | ||
| 309 | generator = torch.Generator(device="cuda").manual_seed(seed) | 306 | generator = torch.Generator(device="cuda").manual_seed(seed) |
| 310 | images = pipeline( | 307 | images = pipeline( |
| 311 | prompt=args.prompt, | 308 | prompt=prompt, |
| 312 | negative_prompt=args.negative_prompt, | 309 | negative_prompt=args.negative_prompt, |
| 313 | height=args.height, | 310 | height=args.height, |
| 314 | width=args.width, | 311 | width=args.width, |
| @@ -321,9 +318,13 @@ def generate(output_dir: Path, pipeline, args): | |||
| 321 | ).images | 318 | ).images |
| 322 | 319 | ||
| 323 | for j, image in enumerate(images): | 320 | for j, image in enumerate(images): |
| 321 | basename = f"{seed}_{j // len(args.prompt)}" | ||
| 324 | dir = image_dir[j % len(args.prompt)] | 322 | dir = image_dir[j % len(args.prompt)] |
| 325 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) | 323 | |
| 326 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) | 324 | image.save(dir.joinpath(f"{basename}.png")) |
| 325 | image.save(dir.joinpath(f"{basename}.jpg"), quality=85) | ||
| 326 | with open(dir.joinpath(f"{basename}.txt"), 'w') as f: | ||
| 327 | f.write(prompt[j % len(args.prompt)]) | ||
| 327 | 328 | ||
| 328 | if torch.cuda.is_available(): | 329 | if torch.cuda.is_available(): |
| 329 | torch.cuda.empty_cache() | 330 | torch.cuda.empty_cache() |
diff --git a/train_dreambooth.py b/train_dreambooth.py index e8256be..d265bcc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -22,7 +22,7 @@ from slugify import slugify | |||
| 22 | 22 | ||
| 23 | from util import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import VlpnDataModule, VlpnDataItem |
| 26 | from training.common import run_model | 26 | from training.common import run_model |
| 27 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
| 28 | from training.lr import LRFinder | 28 | from training.lr import LRFinder |
| @@ -172,11 +172,6 @@ def parse_args(): | |||
| 172 | ), | 172 | ), |
| 173 | ) | 173 | ) |
| 174 | parser.add_argument( | 174 | parser.add_argument( |
| 175 | "--center_crop", | ||
| 176 | action="store_true", | ||
| 177 | help="Whether to center crop images before resizing to resolution" | ||
| 178 | ) | ||
| 179 | parser.add_argument( | ||
| 180 | "--dataloader_num_workers", | 175 | "--dataloader_num_workers", |
| 181 | type=int, | 176 | type=int, |
| 182 | default=0, | 177 | default=0, |
| @@ -698,7 +693,7 @@ def main(): | |||
| 698 | elif args.mixed_precision == "bf16": | 693 | elif args.mixed_precision == "bf16": |
| 699 | weight_dtype = torch.bfloat16 | 694 | weight_dtype = torch.bfloat16 |
| 700 | 695 | ||
| 701 | def keyword_filter(item: CSVDataItem): | 696 | def keyword_filter(item: VlpnDataItem): |
| 702 | cond3 = args.collection is None or args.collection in item.collection | 697 | cond3 = args.collection is None or args.collection in item.collection |
| 703 | cond4 = args.exclude_collections is None or not any( | 698 | cond4 = args.exclude_collections is None or not any( |
| 704 | collection in item.collection | 699 | collection in item.collection |
| @@ -733,7 +728,7 @@ def main(): | |||
| 733 | } | 728 | } |
| 734 | return batch | 729 | return batch |
| 735 | 730 | ||
| 736 | datamodule = CSVDataModule( | 731 | datamodule = VlpnDataModule( |
| 737 | data_file=args.train_data_file, | 732 | data_file=args.train_data_file, |
| 738 | batch_size=args.train_batch_size, | 733 | batch_size=args.train_batch_size, |
| 739 | prompt_processor=prompt_processor, | 734 | prompt_processor=prompt_processor, |
| @@ -742,7 +737,6 @@ def main(): | |||
| 742 | size=args.resolution, | 737 | size=args.resolution, |
| 743 | repeats=args.repeats, | 738 | repeats=args.repeats, |
| 744 | dropout=args.tag_dropout, | 739 | dropout=args.tag_dropout, |
| 745 | center_crop=args.center_crop, | ||
| 746 | template_key=args.train_data_template, | 740 | template_key=args.train_data_template, |
| 747 | valid_set_size=args.valid_set_size, | 741 | valid_set_size=args.valid_set_size, |
| 748 | num_workers=args.dataloader_num_workers, | 742 | num_workers=args.dataloader_num_workers, |
diff --git a/train_ti.py b/train_ti.py index 0ffc9e6..89c6672 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -21,7 +21,7 @@ from slugify import slugify | |||
| 21 | 21 | ||
| 22 | from util import load_config, load_embeddings_from_dir | 22 | from util import load_config, load_embeddings_from_dir |
| 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 24 | from data.csv import CSVDataModule, CSVDataItem | 24 | from data.csv import VlpnDataModule, VlpnDataItem |
| 25 | from training.common import run_model | 25 | from training.common import run_model |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
| @@ -146,11 +146,6 @@ def parse_args(): | |||
| 146 | ), | 146 | ), |
| 147 | ) | 147 | ) |
| 148 | parser.add_argument( | 148 | parser.add_argument( |
| 149 | "--center_crop", | ||
| 150 | action="store_true", | ||
| 151 | help="Whether to center crop images before resizing to resolution" | ||
| 152 | ) | ||
| 153 | parser.add_argument( | ||
| 154 | "--tag_dropout", | 149 | "--tag_dropout", |
| 155 | type=float, | 150 | type=float, |
| 156 | default=0.1, | 151 | default=0.1, |
| @@ -668,7 +663,7 @@ def main(): | |||
| 668 | elif args.mixed_precision == "bf16": | 663 | elif args.mixed_precision == "bf16": |
| 669 | weight_dtype = torch.bfloat16 | 664 | weight_dtype = torch.bfloat16 |
| 670 | 665 | ||
| 671 | def keyword_filter(item: CSVDataItem): | 666 | def keyword_filter(item: VlpnDataItem): |
| 672 | cond1 = any( | 667 | cond1 = any( |
| 673 | keyword in part | 668 | keyword in part |
| 674 | for keyword in args.placeholder_token | 669 | for keyword in args.placeholder_token |
| @@ -708,7 +703,7 @@ def main(): | |||
| 708 | } | 703 | } |
| 709 | return batch | 704 | return batch |
| 710 | 705 | ||
| 711 | datamodule = CSVDataModule( | 706 | datamodule = VlpnDataModule( |
| 712 | data_file=args.train_data_file, | 707 | data_file=args.train_data_file, |
| 713 | batch_size=args.train_batch_size, | 708 | batch_size=args.train_batch_size, |
| 714 | prompt_processor=prompt_processor, | 709 | prompt_processor=prompt_processor, |
| @@ -717,7 +712,6 @@ def main(): | |||
| 717 | size=args.resolution, | 712 | size=args.resolution, |
| 718 | repeats=args.repeats, | 713 | repeats=args.repeats, |
| 719 | dropout=args.tag_dropout, | 714 | dropout=args.tag_dropout, |
| 720 | center_crop=args.center_crop, | ||
| 721 | template_key=args.train_data_template, | 715 | template_key=args.train_data_template, |
| 722 | valid_set_size=args.valid_set_size, | 716 | valid_set_size=args.valid_set_size, |
| 723 | num_workers=args.dataloader_num_workers, | 717 | num_workers=args.dataloader_num_workers, |
| @@ -725,8 +719,6 @@ def main(): | |||
| 725 | filter=keyword_filter, | 719 | filter=keyword_filter, |
| 726 | collate_fn=collate_fn | 720 | collate_fn=collate_fn |
| 727 | ) | 721 | ) |
| 728 | |||
| 729 | datamodule.prepare_data() | ||
| 730 | datamodule.setup() | 722 | datamodule.setup() |
| 731 | 723 | ||
| 732 | if args.num_class_images != 0: | 724 | if args.num_class_images != 0: |
| @@ -769,12 +761,14 @@ def main(): | |||
| 769 | if torch.cuda.is_available(): | 761 | if torch.cuda.is_available(): |
| 770 | torch.cuda.empty_cache() | 762 | torch.cuda.empty_cache() |
| 771 | 763 | ||
| 772 | train_dataloader = datamodule.train_dataloader() | 764 | train_dataloaders = datamodule.train_dataloaders |
| 773 | val_dataloader = datamodule.val_dataloader() | 765 | default_train_dataloader = train_dataloaders[0] |
| 766 | val_dataloader = datamodule.val_dataloader | ||
| 774 | 767 | ||
| 775 | # Scheduler and math around the number of training steps. | 768 | # Scheduler and math around the number of training steps. |
| 776 | overrode_max_train_steps = False | 769 | overrode_max_train_steps = False |
| 777 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 770 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
| 771 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 778 | if args.max_train_steps is None: | 772 | if args.max_train_steps is None: |
| 779 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 773 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 780 | overrode_max_train_steps = True | 774 | overrode_max_train_steps = True |
| @@ -811,9 +805,10 @@ def main(): | |||
| 811 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 805 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 812 | ) | 806 | ) |
| 813 | 807 | ||
| 814 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 808 | text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( |
| 815 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 809 | text_encoder, optimizer, val_dataloader, lr_scheduler |
| 816 | ) | 810 | ) |
| 811 | train_dataloaders = accelerator.prepare(*train_dataloaders) | ||
| 817 | 812 | ||
| 818 | # Move vae and unet to device | 813 | # Move vae and unet to device |
| 819 | vae.to(accelerator.device, dtype=weight_dtype) | 814 | vae.to(accelerator.device, dtype=weight_dtype) |
| @@ -831,7 +826,8 @@ def main(): | |||
| 831 | unet.eval() | 826 | unet.eval() |
| 832 | 827 | ||
| 833 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 828 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
| 834 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 829 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
| 830 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
| 835 | if overrode_max_train_steps: | 831 | if overrode_max_train_steps: |
| 836 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 832 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 837 | 833 | ||
| @@ -889,7 +885,7 @@ def main(): | |||
| 889 | accelerator, | 885 | accelerator, |
| 890 | text_encoder, | 886 | text_encoder, |
| 891 | optimizer, | 887 | optimizer, |
| 892 | train_dataloader, | 888 | default_train_dataloader, |
| 893 | val_dataloader, | 889 | val_dataloader, |
| 894 | loop, | 890 | loop, |
| 895 | on_train=on_train, | 891 | on_train=on_train, |
| @@ -968,46 +964,48 @@ def main(): | |||
| 968 | text_encoder.train() | 964 | text_encoder.train() |
| 969 | 965 | ||
| 970 | with on_train(): | 966 | with on_train(): |
| 971 | for step, batch in enumerate(train_dataloader): | 967 | for train_dataloader in train_dataloaders: |
| 972 | with accelerator.accumulate(text_encoder): | 968 | for step, batch in enumerate(train_dataloader): |
| 973 | loss, acc, bsz = loop(step, batch) | 969 | with accelerator.accumulate(text_encoder): |
| 970 | loss, acc, bsz = loop(step, batch) | ||
| 974 | 971 | ||
| 975 | accelerator.backward(loss) | 972 | accelerator.backward(loss) |
| 976 | 973 | ||
| 977 | optimizer.step() | 974 | optimizer.step() |
| 978 | if not accelerator.optimizer_step_was_skipped: | 975 | if not accelerator.optimizer_step_was_skipped: |
| 979 | lr_scheduler.step() | 976 | lr_scheduler.step() |
| 980 | optimizer.zero_grad(set_to_none=True) | 977 | optimizer.zero_grad(set_to_none=True) |
| 981 | 978 | ||
| 982 | avg_loss.update(loss.detach_(), bsz) | 979 | avg_loss.update(loss.detach_(), bsz) |
| 983 | avg_acc.update(acc.detach_(), bsz) | 980 | avg_acc.update(acc.detach_(), bsz) |
| 984 | 981 | ||
| 985 | # Checks if the accelerator has performed an optimization step behind the scenes | 982 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 986 | if accelerator.sync_gradients: | 983 | if accelerator.sync_gradients: |
| 987 | if args.use_ema: | 984 | if args.use_ema: |
| 988 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 985 | ema_embeddings.step( |
| 986 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 989 | 987 | ||
| 990 | local_progress_bar.update(1) | 988 | local_progress_bar.update(1) |
| 991 | global_progress_bar.update(1) | 989 | global_progress_bar.update(1) |
| 992 | 990 | ||
| 993 | global_step += 1 | 991 | global_step += 1 |
| 994 | 992 | ||
| 995 | logs = { | 993 | logs = { |
| 996 | "train/loss": avg_loss.avg.item(), | 994 | "train/loss": avg_loss.avg.item(), |
| 997 | "train/acc": avg_acc.avg.item(), | 995 | "train/acc": avg_acc.avg.item(), |
| 998 | "train/cur_loss": loss.item(), | 996 | "train/cur_loss": loss.item(), |
| 999 | "train/cur_acc": acc.item(), | 997 | "train/cur_acc": acc.item(), |
| 1000 | "lr": lr_scheduler.get_last_lr()[0], | 998 | "lr": lr_scheduler.get_last_lr()[0], |
| 1001 | } | 999 | } |
| 1002 | if args.use_ema: | 1000 | if args.use_ema: |
| 1003 | logs["ema_decay"] = ema_embeddings.decay | 1001 | logs["ema_decay"] = ema_embeddings.decay |
| 1004 | 1002 | ||
| 1005 | accelerator.log(logs, step=global_step) | 1003 | accelerator.log(logs, step=global_step) |
| 1006 | 1004 | ||
| 1007 | local_progress_bar.set_postfix(**logs) | 1005 | local_progress_bar.set_postfix(**logs) |
| 1008 | 1006 | ||
| 1009 | if global_step >= args.max_train_steps: | 1007 | if global_step >= args.max_train_steps: |
| 1010 | break | 1008 | break |
| 1011 | 1009 | ||
| 1012 | accelerator.wait_for_everyone() | 1010 | accelerator.wait_for_everyone() |
| 1013 | 1011 | ||
diff --git a/training/util.py b/training/util.py index bc466e2..6f42228 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -58,8 +58,8 @@ class CheckpointerBase: | |||
| 58 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 58 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 59 | samples_path = Path(self.output_dir).joinpath("samples") | 59 | samples_path = Path(self.output_dir).joinpath("samples") |
| 60 | 60 | ||
| 61 | train_data = self.datamodule.train_dataloader() | 61 | train_data = self.datamodule.train_dataloaders[0] |
| 62 | val_data = self.datamodule.val_dataloader() | 62 | val_data = self.datamodule.val_dataloader |
| 63 | 63 | ||
| 64 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 64 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| 65 | 65 | ||
