From f4f996681ca340e940315ca0ebc162c655904a7d Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 5 Apr 2023 16:02:04 +0200 Subject: Add color jitter --- data/csv.py | 29 +++++++++++++++++++---------- environment.yaml | 1 + train_dreambooth.py | 15 ++++++++++++--- train_lora.py | 15 ++++++++++++--- train_ti.py | 13 ++++++++++++- 5 files changed, 56 insertions(+), 17 deletions(-) diff --git a/data/csv.py b/data/csv.py index e1b92c1..818fcd9 100644 --- a/data/csv.py +++ b/data/csv.py @@ -186,6 +186,7 @@ class VlpnDataModule(): dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", + color_jitter: bool = True, template_key: str = "template", placeholder_tokens: list[str] = [], valid_set_size: Optional[int] = None, @@ -219,6 +220,7 @@ class VlpnDataModule(): self.shuffle = shuffle self.template_key = template_key self.interpolation = interpolation + self.color_jitter = color_jitter self.valid_set_size = valid_set_size self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size @@ -323,7 +325,7 @@ class VlpnDataModule(): num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, fill_batch=True, generator=generator, - size=self.size, interpolation=self.interpolation, + size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) @@ -343,7 +345,7 @@ class VlpnDataModule(): num_buckets=self.num_buckets, progressive_buckets=True, bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, batch_size=self.batch_size, generator=generator, - size=self.size, interpolation=self.interpolation, + size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, ) self.val_dataloader = DataLoader( @@ -370,6 +372,7 @@ class VlpnDataset(IterableDataset): dropout: float = 0, shuffle: bool = False, interpolation: str = "bicubic", + color_jitter: bool = True, generator: Optional[torch.Generator] = None, ): self.items = items @@ -382,6 +385,7 @@ class VlpnDataset(IterableDataset): self.dropout = dropout self.shuffle = shuffle self.interpolation = interpolations[interpolation] + self.color_jitter = color_jitter self.generator = generator self.buckets, self.bucket_items, self.bucket_assignments = generate_buckets( @@ -446,15 +450,20 @@ class VlpnDataset(IterableDataset): width = int(self.size * ratio) if ratio > 1 else self.size height = int(self.size / ratio) if ratio < 1 else self.size - image_transforms = transforms.Compose( - [ - transforms.Resize(self.size, interpolation=self.interpolation), - transforms.RandomCrop((height, width)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), + image_transforms = [ + transforms.Resize(self.size, interpolation=self.interpolation), + transforms.RandomCrop((height, width)), + transforms.RandomHorizontalFlip(), + ] + if self.color_jitter: + image_transforms += [ + transforms.ColorJitter(0.2, 0.1), ] - ) + image_transforms += [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + image_transforms = transforms.Compose(image_transforms) continue diff --git a/environment.yaml b/environment.yaml index 1de76bd..418cb22 100644 --- a/environment.yaml +++ b/environment.yaml @@ -24,5 +24,6 @@ dependencies: - safetensors==0.3.0 - setuptools==65.6.3 - test-tube>=0.7.5 + - timm==0.8.17.dev0 - transformers==4.27.1 - triton==2.0.0.post1 diff --git a/train_dreambooth.py b/train_dreambooth.py index 4c36ae4..48921d4 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -306,7 +306,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], + choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], help='Optimizer to use' ) parser.add_argument( @@ -513,8 +513,6 @@ def main(): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - embeddings.persist() - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") @@ -549,6 +547,17 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adan': + try: + import timm.optim + except ImportError: + raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") + + create_optimizer = partial( + timm.optim.Adan, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) elif args.optimizer == 'lion': try: import lion_pytorch diff --git a/train_lora.py b/train_lora.py index 538a7f7..73b3e19 100644 --- a/train_lora.py +++ b/train_lora.py @@ -317,7 +317,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], + choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], help='Optimizer to use' ) parser.add_argument( @@ -544,8 +544,6 @@ def main(): if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - embeddings.persist() - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") @@ -580,6 +578,17 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adan': + try: + import timm.optim + except ImportError: + raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") + + create_optimizer = partial( + timm.optim.Adan, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) elif args.optimizer == 'lion': try: import lion_pytorch diff --git a/train_ti.py b/train_ti.py index 6757bde..fc0d68c 100644 --- a/train_ti.py +++ b/train_ti.py @@ -330,7 +330,7 @@ def parse_args(): "--optimizer", type=str, default="dadan", - choices=["adam", "adam8bit", "lion", "dadam", "dadan", "adafactor"], + choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], help='Optimizer to use' ) parser.add_argument( @@ -679,6 +679,17 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) + elif args.optimizer == 'adan': + try: + import timm.optim + except ImportError: + raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") + + create_optimizer = partial( + timm.optim.Adan, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) elif args.optimizer == 'lion': try: import lion_pytorch -- cgit v1.2.3-70-g09d2