diff options
-rw-r--r-- | data/csv.py | 29 | ||||
-rw-r--r-- | environment.yaml | 1 | ||||
-rw-r--r-- | train_dreambooth.py | 15 | ||||
-rw-r--r-- | train_lora.py | 15 | ||||
-rw-r--r-- | train_ti.py | 13 |
5 files changed, 56 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py index e1b92c1..818fcd9 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -186,6 +186,7 @@ class VlpnDataModule(): | |||
186 | dropout: float = 0, | 186 | dropout: float = 0, |
187 | shuffle: bool = False, | 187 | shuffle: bool = False, |
188 | interpolation: str = "bicubic", | 188 | interpolation: str = "bicubic", |
189 | color_jitter: bool = True, | ||
189 | template_key: str = "template", | 190 | template_key: str = "template", |
190 | placeholder_tokens: list[str] = [], | 191 | placeholder_tokens: list[str] = [], |
191 | valid_set_size: Optional[int] = None, | 192 | valid_set_size: Optional[int] = None, |
@@ -219,6 +220,7 @@ class VlpnDataModule(): | |||
219 | self.shuffle = shuffle | 220 | self.shuffle = shuffle |
220 | self.template_key = template_key | 221 | self.template_key = template_key |
221 | self.interpolation = interpolation | 222 | self.interpolation = interpolation |
223 | self.color_jitter = color_jitter | ||
222 | self.valid_set_size = valid_set_size | 224 | self.valid_set_size = valid_set_size |
223 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size | 225 | self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size |
224 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size | 226 | self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size |
@@ -323,7 +325,7 @@ class VlpnDataModule(): | |||
323 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, | 325 | num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, |
324 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 326 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
325 | batch_size=self.batch_size, fill_batch=True, generator=generator, | 327 | batch_size=self.batch_size, fill_batch=True, generator=generator, |
326 | size=self.size, interpolation=self.interpolation, | 328 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, |
327 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 329 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
328 | ) | 330 | ) |
329 | 331 | ||
@@ -343,7 +345,7 @@ class VlpnDataModule(): | |||
343 | num_buckets=self.num_buckets, progressive_buckets=True, | 345 | num_buckets=self.num_buckets, progressive_buckets=True, |
344 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | 346 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, |
345 | batch_size=self.batch_size, generator=generator, | 347 | batch_size=self.batch_size, generator=generator, |
346 | size=self.size, interpolation=self.interpolation, | 348 | size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, |
347 | ) | 349 | ) |
348 | 350 | ||
349 | self.val_dataloader = DataLoader( | 351 | self.val_dataloader = DataLoader( |
@@ -370,6 +372,7 @@ class VlpnDataset(IterableDataset): | |||
370 | dropout: float = 0, | 372 | dropout: float = 0, |
371 | shuffle: bool = False, | 373 | shuffle: bool = False, |
372 | interpolation: str = "bicubic", | 374 | interpolation: str = "bicubic", |
375 | color_jitter: bool = True, | ||
373 | generator: Optional[torch.Generator] = None, | 376 | generator: Optional[torch.Generator] = None, |
374 | ): | 377 | ): |
375 | self.items = items | 378 | self.items = items |
@@ -382,6 +385,7 @@ class VlpnDataset(IterableDataset): | |||
382 | self.dropout = dropout | 385 | self.dropout = dropout |
383 | self.shuffle = shuffle | 386 | self.shuffle = shuffle |
384 | self.interpolation = interpolations[interpolation] | 387 | self.interpolation = interpolations[interpolation] |
388 | self.color_jitter = color_jitter | ||
385 | self.generator = generator | 389 | self.generator = generator |
386 | 390 | ||
387 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( | 391 | self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( |
@@ -446,15 +450,20 @@ class VlpnDataset(IterableDataset): | |||
446 | width = int(self.size * ratio) if ratio > 1 else self.size | 450 | width = int(self.size * ratio) if ratio > 1 else self.size |
447 | height = int(self.size / ratio) if ratio < 1 else self.size | 451 | height = int(self.size / ratio) if ratio < 1 else self.size |
448 | 452 | ||
449 | image_transforms = transforms.Compose( | 453 | image_transforms = [ |
450 | [ | 454 | transforms.Resize(self.size, interpolation=self.interpolation), |
451 | transforms.Resize(self.size, interpolation=self.interpolation), | 455 | transforms.RandomCrop((height, width)), |
452 | transforms.RandomCrop((height, width)), | 456 | transforms.RandomHorizontalFlip(), |
453 | transforms.RandomHorizontalFlip(), | 457 | ] |
454 | transforms.ToTensor(), | 458 | if self.color_jitter: |
455 | transforms.Normalize([0.5], [0.5]), | 459 | image_transforms += [ |
460 | transforms.ColorJitter(0.2, 0.1), | ||
456 | ] | 461 | ] |
457 | ) | 462 | image_transforms += [ |
463 | transforms.ToTensor(), | ||
464 | transforms.Normalize([0.5], [0.5]), | ||
465 | ] | ||
466 | image_transforms = transforms.Compose(image_transforms) | ||
458 | 467 | ||
459 | continue | 468 | continue |
460 | 469 | ||
diff --git a/environment.yaml b/environment.yaml index 1de76bd..418cb22 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -24,5 +24,6 @@ dependencies: | |||
24 | - safetensors==0.3.0 | 24 | - safetensors==0.3.0 |
25 | - setuptools==65.6.3 | 25 | - setuptools==65.6.3 |
26 | - test-tube>=0.7.5 | 26 | - test-tube>=0.7.5 |
27 | - timm==0.8.17.dev0 | ||
27 | - transformers==4.27.1 | 28 | - transformers==4.27.1 |
28 | - triton==2.0.0.post1 | 29 | - triton==2.0.0.post1 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 4c36ae4..48921d4 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -306,7 +306,7 @@ def parse_args(): | |||
306 | "--optimizer", | 306 | "--optimizer", |
307 | type=str, | 307 | type=str, |
308 | default="dadan", | 308 | default="dadan", |
309 | choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], | 309 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
310 | help='Optimizer to use' | 310 | help='Optimizer to use' |
311 | ) | 311 | ) |
312 | parser.add_argument( | 312 | parser.add_argument( |
@@ -513,8 +513,6 @@ def main(): | |||
513 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 513 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
514 | raise ValueError("--embeddings_dir must point to an existing directory") | 514 | raise ValueError("--embeddings_dir must point to an existing directory") |
515 | 515 | ||
516 | embeddings.persist() | ||
517 | |||
518 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 516 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
519 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 517 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
520 | 518 | ||
@@ -549,6 +547,17 @@ def main(): | |||
549 | eps=args.adam_epsilon, | 547 | eps=args.adam_epsilon, |
550 | amsgrad=args.adam_amsgrad, | 548 | amsgrad=args.adam_amsgrad, |
551 | ) | 549 | ) |
550 | elif args.optimizer == 'adan': | ||
551 | try: | ||
552 | import timm.optim | ||
553 | except ImportError: | ||
554 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | ||
555 | |||
556 | create_optimizer = partial( | ||
557 | timm.optim.Adan, | ||
558 | weight_decay=args.adam_weight_decay, | ||
559 | eps=args.adam_epsilon, | ||
560 | ) | ||
552 | elif args.optimizer == 'lion': | 561 | elif args.optimizer == 'lion': |
553 | try: | 562 | try: |
554 | import lion_pytorch | 563 | import lion_pytorch |
diff --git a/train_lora.py b/train_lora.py index 538a7f7..73b3e19 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -317,7 +317,7 @@ def parse_args(): | |||
317 | "--optimizer", | 317 | "--optimizer", |
318 | type=str, | 318 | type=str, |
319 | default="dadan", | 319 | default="dadan", |
320 | choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], | 320 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
321 | help='Optimizer to use' | 321 | help='Optimizer to use' |
322 | ) | 322 | ) |
323 | parser.add_argument( | 323 | parser.add_argument( |
@@ -544,8 +544,6 @@ def main(): | |||
544 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 544 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
545 | raise ValueError("--embeddings_dir must point to an existing directory") | 545 | raise ValueError("--embeddings_dir must point to an existing directory") |
546 | 546 | ||
547 | embeddings.persist() | ||
548 | |||
549 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 547 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
550 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 548 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
551 | 549 | ||
@@ -580,6 +578,17 @@ def main(): | |||
580 | eps=args.adam_epsilon, | 578 | eps=args.adam_epsilon, |
581 | amsgrad=args.adam_amsgrad, | 579 | amsgrad=args.adam_amsgrad, |
582 | ) | 580 | ) |
581 | elif args.optimizer == 'adan': | ||
582 | try: | ||
583 | import timm.optim | ||
584 | except ImportError: | ||
585 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | ||
586 | |||
587 | create_optimizer = partial( | ||
588 | timm.optim.Adan, | ||
589 | weight_decay=args.adam_weight_decay, | ||
590 | eps=args.adam_epsilon, | ||
591 | ) | ||
583 | elif args.optimizer == 'lion': | 592 | elif args.optimizer == 'lion': |
584 | try: | 593 | try: |
585 | import lion_pytorch | 594 | import lion_pytorch |
diff --git a/train_ti.py b/train_ti.py index 6757bde..fc0d68c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -330,7 +330,7 @@ def parse_args(): | |||
330 | "--optimizer", | 330 | "--optimizer", |
331 | type=str, | 331 | type=str, |
332 | default="dadan", | 332 | default="dadan", |
333 | choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], | 333 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
334 | help='Optimizer to use' | 334 | help='Optimizer to use' |
335 | ) | 335 | ) |
336 | parser.add_argument( | 336 | parser.add_argument( |
@@ -679,6 +679,17 @@ def main(): | |||
679 | eps=args.adam_epsilon, | 679 | eps=args.adam_epsilon, |
680 | amsgrad=args.adam_amsgrad, | 680 | amsgrad=args.adam_amsgrad, |
681 | ) | 681 | ) |
682 | elif args.optimizer == 'adan': | ||
683 | try: | ||
684 | import timm.optim | ||
685 | except ImportError: | ||
686 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | ||
687 | |||
688 | create_optimizer = partial( | ||
689 | timm.optim.Adan, | ||
690 | weight_decay=args.adam_weight_decay, | ||
691 | eps=args.adam_epsilon, | ||
692 | ) | ||
682 | elif args.optimizer == 'lion': | 693 | elif args.optimizer == 'lion': |
683 | try: | 694 | try: |
684 | import lion_pytorch | 695 | import lion_pytorch |