summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py29
-rw-r--r--environment.yaml1
-rw-r--r--train_dreambooth.py15
-rw-r--r--train_lora.py15
-rw-r--r--train_ti.py13
5 files changed, 56 insertions, 17 deletions
diff --git a/data/csv.py b/data/csv.py
index e1b92c1..818fcd9 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -186,6 +186,7 @@ class VlpnDataModule():
186 dropout: float = 0, 186 dropout: float = 0,
187 shuffle: bool = False, 187 shuffle: bool = False,
188 interpolation: str = "bicubic", 188 interpolation: str = "bicubic",
189 color_jitter: bool = True,
189 template_key: str = "template", 190 template_key: str = "template",
190 placeholder_tokens: list[str] = [], 191 placeholder_tokens: list[str] = [],
191 valid_set_size: Optional[int] = None, 192 valid_set_size: Optional[int] = None,
@@ -219,6 +220,7 @@ class VlpnDataModule():
219 self.shuffle = shuffle 220 self.shuffle = shuffle
220 self.template_key = template_key 221 self.template_key = template_key
221 self.interpolation = interpolation 222 self.interpolation = interpolation
223 self.color_jitter = color_jitter
222 self.valid_set_size = valid_set_size 224 self.valid_set_size = valid_set_size
223 self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size 225 self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size
224 self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size 226 self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size
@@ -323,7 +325,7 @@ class VlpnDataModule():
323 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 325 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
324 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 326 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
325 batch_size=self.batch_size, fill_batch=True, generator=generator, 327 batch_size=self.batch_size, fill_batch=True, generator=generator,
326 size=self.size, interpolation=self.interpolation, 328 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter,
327 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, 329 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle,
328 ) 330 )
329 331
@@ -343,7 +345,7 @@ class VlpnDataModule():
343 num_buckets=self.num_buckets, progressive_buckets=True, 345 num_buckets=self.num_buckets, progressive_buckets=True,
344 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 346 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
345 batch_size=self.batch_size, generator=generator, 347 batch_size=self.batch_size, generator=generator,
346 size=self.size, interpolation=self.interpolation, 348 size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter,
347 ) 349 )
348 350
349 self.val_dataloader = DataLoader( 351 self.val_dataloader = DataLoader(
@@ -370,6 +372,7 @@ class VlpnDataset(IterableDataset):
370 dropout: float = 0, 372 dropout: float = 0,
371 shuffle: bool = False, 373 shuffle: bool = False,
372 interpolation: str = "bicubic", 374 interpolation: str = "bicubic",
375 color_jitter: bool = True,
373 generator: Optional[torch.Generator] = None, 376 generator: Optional[torch.Generator] = None,
374 ): 377 ):
375 self.items = items 378 self.items = items
@@ -382,6 +385,7 @@ class VlpnDataset(IterableDataset):
382 self.dropout = dropout 385 self.dropout = dropout
383 self.shuffle = shuffle 386 self.shuffle = shuffle
384 self.interpolation = interpolations[interpolation] 387 self.interpolation = interpolations[interpolation]
388 self.color_jitter = color_jitter
385 self.generator = generator 389 self.generator = generator
386 390
387 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( 391 self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets(
@@ -446,15 +450,20 @@ class VlpnDataset(IterableDataset):
446 width = int(self.size * ratio) if ratio > 1 else self.size 450 width = int(self.size * ratio) if ratio > 1 else self.size
447 height = int(self.size / ratio) if ratio < 1 else self.size 451 height = int(self.size / ratio) if ratio < 1 else self.size
448 452
449 image_transforms = transforms.Compose( 453 image_transforms = [
450 [ 454 transforms.Resize(self.size, interpolation=self.interpolation),
451 transforms.Resize(self.size, interpolation=self.interpolation), 455 transforms.RandomCrop((height, width)),
452 transforms.RandomCrop((height, width)), 456 transforms.RandomHorizontalFlip(),
453 transforms.RandomHorizontalFlip(), 457 ]
454 transforms.ToTensor(), 458 if self.color_jitter:
455 transforms.Normalize([0.5], [0.5]), 459 image_transforms += [
460 transforms.ColorJitter(0.2, 0.1),
456 ] 461 ]
457 ) 462 image_transforms += [
463 transforms.ToTensor(),
464 transforms.Normalize([0.5], [0.5]),
465 ]
466 image_transforms = transforms.Compose(image_transforms)
458 467
459 continue 468 continue
460 469
diff --git a/environment.yaml b/environment.yaml
index 1de76bd..418cb22 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -24,5 +24,6 @@ dependencies:
24 - safetensors==0.3.0 24 - safetensors==0.3.0
25 - setuptools==65.6.3 25 - setuptools==65.6.3
26 - test-tube>=0.7.5 26 - test-tube>=0.7.5
27 - timm==0.8.17.dev0
27 - transformers==4.27.1 28 - transformers==4.27.1
28 - triton==2.0.0.post1 29 - triton==2.0.0.post1
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 4c36ae4..48921d4 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -306,7 +306,7 @@ def parse_args():
306 "--optimizer", 306 "--optimizer",
307 type=str, 307 type=str,
308 default="dadan", 308 default="dadan",
309 choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], 309 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
310 help='Optimizer to use' 310 help='Optimizer to use'
311 ) 311 )
312 parser.add_argument( 312 parser.add_argument(
@@ -513,8 +513,6 @@ def main():
513 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 513 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
514 raise ValueError("--embeddings_dir must point to an existing directory") 514 raise ValueError("--embeddings_dir must point to an existing directory")
515 515
516 embeddings.persist()
517
518 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 516 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
519 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 517 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
520 518
@@ -549,6 +547,17 @@ def main():
549 eps=args.adam_epsilon, 547 eps=args.adam_epsilon,
550 amsgrad=args.adam_amsgrad, 548 amsgrad=args.adam_amsgrad,
551 ) 549 )
550 elif args.optimizer == 'adan':
551 try:
552 import timm.optim
553 except ImportError:
554 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.")
555
556 create_optimizer = partial(
557 timm.optim.Adan,
558 weight_decay=args.adam_weight_decay,
559 eps=args.adam_epsilon,
560 )
552 elif args.optimizer == 'lion': 561 elif args.optimizer == 'lion':
553 try: 562 try:
554 import lion_pytorch 563 import lion_pytorch
diff --git a/train_lora.py b/train_lora.py
index 538a7f7..73b3e19 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -317,7 +317,7 @@ def parse_args():
317 "--optimizer", 317 "--optimizer",
318 type=str, 318 type=str,
319 default="dadan", 319 default="dadan",
320 choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], 320 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
321 help='Optimizer to use' 321 help='Optimizer to use'
322 ) 322 )
323 parser.add_argument( 323 parser.add_argument(
@@ -544,8 +544,6 @@ def main():
544 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 544 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
545 raise ValueError("--embeddings_dir must point to an existing directory") 545 raise ValueError("--embeddings_dir must point to an existing directory")
546 546
547 embeddings.persist()
548
549 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 547 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
550 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 548 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
551 549
@@ -580,6 +578,17 @@ def main():
580 eps=args.adam_epsilon, 578 eps=args.adam_epsilon,
581 amsgrad=args.adam_amsgrad, 579 amsgrad=args.adam_amsgrad,
582 ) 580 )
581 elif args.optimizer == 'adan':
582 try:
583 import timm.optim
584 except ImportError:
585 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.")
586
587 create_optimizer = partial(
588 timm.optim.Adan,
589 weight_decay=args.adam_weight_decay,
590 eps=args.adam_epsilon,
591 )
583 elif args.optimizer == 'lion': 592 elif args.optimizer == 'lion':
584 try: 593 try:
585 import lion_pytorch 594 import lion_pytorch
diff --git a/train_ti.py b/train_ti.py
index 6757bde..fc0d68c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -330,7 +330,7 @@ def parse_args():
330 "--optimizer", 330 "--optimizer",
331 type=str, 331 type=str,
332 default="dadan", 332 default="dadan",
333 choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], 333 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
334 help='Optimizer to use' 334 help='Optimizer to use'
335 ) 335 )
336 parser.add_argument( 336 parser.add_argument(
@@ -679,6 +679,17 @@ def main():
679 eps=args.adam_epsilon, 679 eps=args.adam_epsilon,
680 amsgrad=args.adam_amsgrad, 680 amsgrad=args.adam_amsgrad,
681 ) 681 )
682 elif args.optimizer == 'adan':
683 try:
684 import timm.optim
685 except ImportError:
686 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.")
687
688 create_optimizer = partial(
689 timm.optim.Adan,
690 weight_decay=args.adam_weight_decay,
691 eps=args.adam_epsilon,
692 )
682 elif args.optimizer == 'lion': 693 elif args.optimizer == 'lion':
683 try: 694 try:
684 import lion_pytorch 695 import lion_pytorch