diff options
| -rw-r--r-- | data/dreambooth/csv.py | 108 | ||||
| -rw-r--r-- | data/dreambooth/prompt.py | 4 | ||||
| -rw-r--r-- | data/textual_inversion/csv.py | 3 | ||||
| -rw-r--r-- | dreambooth.py | 168 | ||||
| -rw-r--r-- | textual_inversion.py | 6 |
5 files changed, 127 insertions, 162 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index c0b0067..4ebdc13 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -13,13 +13,11 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 13 | batch_size, | 13 | batch_size, |
| 14 | data_file, | 14 | data_file, |
| 15 | tokenizer, | 15 | tokenizer, |
| 16 | instance_prompt, | 16 | instance_identifier, |
| 17 | class_data_root=None, | 17 | class_identifier=None, |
| 18 | class_prompt=None, | ||
| 19 | size=512, | 18 | size=512, |
| 20 | repeats=100, | 19 | repeats=100, |
| 21 | interpolation="bicubic", | 20 | interpolation="bicubic", |
| 22 | identifier="*", | ||
| 23 | center_crop=False, | 21 | center_crop=False, |
| 24 | valid_set_size=None, | 22 | valid_set_size=None, |
| 25 | generator=None, | 23 | generator=None, |
| @@ -32,13 +30,14 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 32 | raise ValueError("data_file must be a file") | 30 | raise ValueError("data_file must be a file") |
| 33 | 31 | ||
| 34 | self.data_root = self.data_file.parent | 32 | self.data_root = self.data_file.parent |
| 33 | self.class_root = self.data_root.joinpath("db_cls") | ||
| 34 | self.class_root.mkdir(parents=True, exist_ok=True) | ||
| 35 | |||
| 35 | self.tokenizer = tokenizer | 36 | self.tokenizer = tokenizer |
| 36 | self.instance_prompt = instance_prompt | 37 | self.instance_identifier = instance_identifier |
| 37 | self.class_data_root = class_data_root | 38 | self.class_identifier = class_identifier |
| 38 | self.class_prompt = class_prompt | ||
| 39 | self.size = size | 39 | self.size = size |
| 40 | self.repeats = repeats | 40 | self.repeats = repeats |
| 41 | self.identifier = identifier | ||
| 42 | self.center_crop = center_crop | 41 | self.center_crop = center_crop |
| 43 | self.interpolation = interpolation | 42 | self.interpolation = interpolation |
| 44 | self.valid_set_size = valid_set_size | 43 | self.valid_set_size = valid_set_size |
| @@ -48,30 +47,36 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 48 | 47 | ||
| 49 | def prepare_data(self): | 48 | def prepare_data(self): |
| 50 | metadata = pd.read_csv(self.data_file) | 49 | metadata = pd.read_csv(self.data_file) |
| 51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 50 | instance_image_paths = [self.data_root.joinpath(f) for f in metadata['image'].values] |
| 51 | class_image_paths = [self.class_root.joinpath(Path(f).name) for f in metadata['image'].values] | ||
| 52 | prompts = metadata['prompt'].values | 52 | prompts = metadata['prompt'].values |
| 53 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) | 53 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(instance_image_paths) |
| 54 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) | 54 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(instance_image_paths) |
| 55 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | 55 | self.data = [(i, c, p, n) |
| 56 | for i, c, p, n, s | ||
| 57 | in zip(instance_image_paths, class_image_paths, prompts, nprompts, skips) | ||
| 58 | if s != "x"] | ||
| 56 | 59 | ||
| 57 | def setup(self, stage=None): | 60 | def setup(self, stage=None): |
| 58 | valid_set_size = int(len(self.data_full) * 0.2) | 61 | valid_set_size = int(len(self.data) * 0.2) |
| 59 | if self.valid_set_size: | 62 | if self.valid_set_size: |
| 60 | valid_set_size = min(valid_set_size, self.valid_set_size) | 63 | valid_set_size = min(valid_set_size, self.valid_set_size) |
| 61 | train_set_size = len(self.data_full) - valid_set_size | 64 | valid_set_size = max(valid_set_size, 1) |
| 65 | train_set_size = len(self.data) - valid_set_size | ||
| 62 | 66 | ||
| 63 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) | 67 | self.data_train, self.data_val = random_split(self.data, [train_set_size, valid_set_size], self.generator) |
| 64 | 68 | ||
| 65 | train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt, | 69 | train_dataset = CSVDataset(self.data_train, self.tokenizer, |
| 66 | class_data_root=self.class_data_root, class_prompt=self.class_prompt, | 70 | instance_identifier=self.instance_identifier, class_identifier=self.class_identifier, |
| 67 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 71 | size=self.size, interpolation=self.interpolation, |
| 68 | center_crop=self.center_crop, repeats=self.repeats, batch_size=self.batch_size) | 72 | center_crop=self.center_crop, repeats=self.repeats) |
| 69 | val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt, | 73 | val_dataset = CSVDataset(self.data_val, self.tokenizer, |
| 70 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 74 | instance_identifier=self.instance_identifier, |
| 71 | center_crop=self.center_crop, batch_size=self.batch_size) | 75 | size=self.size, interpolation=self.interpolation, |
| 72 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 76 | center_crop=self.center_crop, repeats=self.repeats) |
| 77 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, drop_last=True, | ||
| 73 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) | 78 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
| 74 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 79 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, drop_last=True, |
| 75 | pin_memory=True, collate_fn=self.collate_fn) | 80 | pin_memory=True, collate_fn=self.collate_fn) |
| 76 | 81 | ||
| 77 | def train_dataloader(self): | 82 | def train_dataloader(self): |
| @@ -85,39 +90,23 @@ class CSVDataset(Dataset): | |||
| 85 | def __init__(self, | 90 | def __init__(self, |
| 86 | data, | 91 | data, |
| 87 | tokenizer, | 92 | tokenizer, |
| 88 | instance_prompt, | 93 | instance_identifier, |
| 89 | class_data_root=None, | 94 | class_identifier=None, |
| 90 | class_prompt=None, | ||
| 91 | size=512, | 95 | size=512, |
| 92 | repeats=1, | 96 | repeats=1, |
| 93 | interpolation="bicubic", | 97 | interpolation="bicubic", |
| 94 | identifier="*", | ||
| 95 | center_crop=False, | 98 | center_crop=False, |
| 96 | batch_size=1, | ||
| 97 | ): | 99 | ): |
| 98 | 100 | ||
| 99 | self.data = data | 101 | self.data = data |
| 100 | self.tokenizer = tokenizer | 102 | self.tokenizer = tokenizer |
| 101 | self.instance_prompt = instance_prompt | 103 | self.instance_identifier = instance_identifier |
| 102 | self.identifier = identifier | 104 | self.class_identifier = class_identifier |
| 103 | self.batch_size = batch_size | ||
| 104 | self.cache = {} | 105 | self.cache = {} |
| 105 | 106 | ||
| 106 | self.num_instance_images = len(self.data) | 107 | self.num_instance_images = len(self.data) |
| 107 | self._length = self.num_instance_images * repeats | 108 | self._length = self.num_instance_images * repeats |
| 108 | 109 | ||
| 109 | if class_data_root is not None: | ||
| 110 | self.class_data_root = Path(class_data_root) | ||
| 111 | self.class_data_root.mkdir(parents=True, exist_ok=True) | ||
| 112 | |||
| 113 | self.class_images = list(self.class_data_root.iterdir()) | ||
| 114 | self.num_class_images = len(self.class_images) | ||
| 115 | self._length = max(self.num_class_images, self.num_instance_images) | ||
| 116 | |||
| 117 | self.class_prompt = class_prompt | ||
| 118 | else: | ||
| 119 | self.class_data_root = None | ||
| 120 | |||
| 121 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, | 110 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, |
| 122 | "bilinear": transforms.InterpolationMode.BILINEAR, | 111 | "bilinear": transforms.InterpolationMode.BILINEAR, |
| 123 | "bicubic": transforms.InterpolationMode.BICUBIC, | 112 | "bicubic": transforms.InterpolationMode.BICUBIC, |
| @@ -134,46 +123,49 @@ class CSVDataset(Dataset): | |||
| 134 | ) | 123 | ) |
| 135 | 124 | ||
| 136 | def __len__(self): | 125 | def __len__(self): |
| 137 | return math.ceil(self._length / self.batch_size) * self.batch_size | 126 | return self._length |
| 138 | 127 | ||
| 139 | def get_example(self, i): | 128 | def get_example(self, i): |
| 140 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] | 129 | instance_image_path, class_image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
| 141 | 130 | ||
| 142 | if image_path in self.cache: | 131 | if instance_image_path in self.cache: |
| 143 | return self.cache[image_path] | 132 | return self.cache[instance_image_path] |
| 144 | 133 | ||
| 145 | example = {} | 134 | example = {} |
| 146 | 135 | ||
| 147 | instance_image = Image.open(image_path) | 136 | example["prompts"] = prompt |
| 137 | example["nprompts"] = nprompt | ||
| 138 | |||
| 139 | instance_image = Image.open(instance_image_path) | ||
| 148 | if not instance_image.mode == "RGB": | 140 | if not instance_image.mode == "RGB": |
| 149 | instance_image = instance_image.convert("RGB") | 141 | instance_image = instance_image.convert("RGB") |
| 150 | 142 | ||
| 151 | prompt = prompt.format(self.identifier) | 143 | instance_prompt = prompt.format(self.instance_identifier) |
| 152 | 144 | ||
| 153 | example["prompts"] = prompt | ||
| 154 | example["nprompts"] = nprompt | ||
| 155 | example["instance_images"] = instance_image | 145 | example["instance_images"] = instance_image |
| 156 | example["instance_prompt_ids"] = self.tokenizer( | 146 | example["instance_prompt_ids"] = self.tokenizer( |
| 157 | self.instance_prompt, | 147 | instance_prompt, |
| 158 | padding="do_not_pad", | 148 | padding="do_not_pad", |
| 159 | truncation=True, | 149 | truncation=True, |
| 160 | max_length=self.tokenizer.model_max_length, | 150 | max_length=self.tokenizer.model_max_length, |
| 161 | ).input_ids | 151 | ).input_ids |
| 162 | 152 | ||
| 163 | if self.class_data_root: | 153 | if self.class_identifier: |
| 164 | class_image = Image.open(self.class_images[i % self.num_class_images]) | 154 | class_image = Image.open(class_image_path) |
| 165 | if not class_image.mode == "RGB": | 155 | if not class_image.mode == "RGB": |
| 166 | class_image = class_image.convert("RGB") | 156 | class_image = class_image.convert("RGB") |
| 167 | 157 | ||
| 158 | class_prompt = prompt.format(self.class_identifier) | ||
| 159 | |||
| 168 | example["class_images"] = class_image | 160 | example["class_images"] = class_image |
| 169 | example["class_prompt_ids"] = self.tokenizer( | 161 | example["class_prompt_ids"] = self.tokenizer( |
| 170 | self.class_prompt, | 162 | class_prompt, |
| 171 | padding="do_not_pad", | 163 | padding="do_not_pad", |
| 172 | truncation=True, | 164 | truncation=True, |
| 173 | max_length=self.tokenizer.model_max_length, | 165 | max_length=self.tokenizer.model_max_length, |
| 174 | ).input_ids | 166 | ).input_ids |
| 175 | 167 | ||
| 176 | self.cache[image_path] = example | 168 | self.cache[instance_image_path] = example |
| 177 | return example | 169 | return example |
| 178 | 170 | ||
| 179 | def __getitem__(self, i): | 171 | def __getitem__(self, i): |
| @@ -185,7 +177,7 @@ class CSVDataset(Dataset): | |||
| 185 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 177 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 186 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 178 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
| 187 | 179 | ||
| 188 | if self.class_data_root: | 180 | if self.class_identifier: |
| 189 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | 181 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) |
| 190 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] | 182 | example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"] |
| 191 | 183 | ||
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py index 34f510d..b3a83ce 100644 --- a/data/dreambooth/prompt.py +++ b/data/dreambooth/prompt.py | |||
| @@ -2,8 +2,9 @@ from torch.utils.data import Dataset | |||
| 2 | 2 | ||
| 3 | 3 | ||
| 4 | class PromptDataset(Dataset): | 4 | class PromptDataset(Dataset): |
| 5 | def __init__(self, prompt, num_samples): | 5 | def __init__(self, prompt, nprompt, num_samples): |
| 6 | self.prompt = prompt | 6 | self.prompt = prompt |
| 7 | self.nprompt = nprompt | ||
| 7 | self.num_samples = num_samples | 8 | self.num_samples = num_samples |
| 8 | 9 | ||
| 9 | def __len__(self): | 10 | def __len__(self): |
| @@ -12,5 +13,6 @@ class PromptDataset(Dataset): | |||
| 12 | def __getitem__(self, index): | 13 | def __getitem__(self, index): |
| 13 | example = {} | 14 | example = {} |
| 14 | example["prompt"] = self.prompt | 15 | example["prompt"] = self.prompt |
| 16 | example["nprompt"] = self.nprompt | ||
| 15 | example["index"] = index | 17 | example["index"] = index |
| 16 | return example | 18 | return example |
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 852b1cb..4c5e27e 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -52,13 +52,14 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 52 | valid_set_size = int(len(self.data_full) * 0.2) | 52 | valid_set_size = int(len(self.data_full) * 0.2) |
| 53 | if self.valid_set_size: | 53 | if self.valid_set_size: |
| 54 | valid_set_size = min(valid_set_size, self.valid_set_size) | 54 | valid_set_size = min(valid_set_size, self.valid_set_size) |
| 55 | valid_set_size = max(valid_set_size, 1) | ||
| 55 | train_set_size = len(self.data_full) - valid_set_size | 56 | train_set_size = len(self.data_full) - valid_set_size |
| 56 | 57 | ||
| 57 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) | 58 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator) |
| 58 | 59 | ||
| 59 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 60 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
| 60 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 61 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | 62 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
| 62 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 63 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 63 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) | 64 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) |
| 64 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) | 65 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) |
diff --git a/dreambooth.py b/dreambooth.py index 9d6b8d6..2fe89ec 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -13,13 +13,12 @@ import torch.utils.checkpoint | |||
| 13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
| 14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
| 15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
| 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
| 17 | from schedulers.scheduling_euler_a import EulerAScheduler | 17 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 18 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
| 19 | from pipelines.stable_diffusion.no_check import NoCheck | ||
| 20 | from PIL import Image | 19 | from PIL import Image |
| 21 | from tqdm.auto import tqdm | 20 | from tqdm.auto import tqdm |
| 22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 21 | from transformers import CLIPTextModel, CLIPTokenizer |
| 23 | from slugify import slugify | 22 | from slugify import slugify |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | import json | 24 | import json |
| @@ -56,7 +55,13 @@ def parse_args(): | |||
| 56 | help="A folder containing the training data." | 55 | help="A folder containing the training data." |
| 57 | ) | 56 | ) |
| 58 | parser.add_argument( | 57 | parser.add_argument( |
| 59 | "--identifier", | 58 | "--instance_identifier", |
| 59 | type=str, | ||
| 60 | default=None, | ||
| 61 | help="A token to use as a placeholder for the concept.", | ||
| 62 | ) | ||
| 63 | parser.add_argument( | ||
| 64 | "--class_identifier", | ||
| 60 | type=str, | 65 | type=str, |
| 61 | default=None, | 66 | default=None, |
| 62 | help="A token to use as a placeholder for the concept.", | 67 | help="A token to use as a placeholder for the concept.", |
| @@ -218,12 +223,6 @@ def parse_args(): | |||
| 218 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 223 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 219 | ) | 224 | ) |
| 220 | parser.add_argument( | 225 | parser.add_argument( |
| 221 | "--instance_prompt", | ||
| 222 | type=str, | ||
| 223 | default=None, | ||
| 224 | help="The prompt with identifier specifing the instance", | ||
| 225 | ) | ||
| 226 | parser.add_argument( | ||
| 227 | "--class_data_dir", | 226 | "--class_data_dir", |
| 228 | type=str, | 227 | type=str, |
| 229 | default=None, | 228 | default=None, |
| @@ -231,12 +230,6 @@ def parse_args(): | |||
| 231 | help="A folder containing the training data of class images.", | 230 | help="A folder containing the training data of class images.", |
| 232 | ) | 231 | ) |
| 233 | parser.add_argument( | 232 | parser.add_argument( |
| 234 | "--class_prompt", | ||
| 235 | type=str, | ||
| 236 | default=None, | ||
| 237 | help="The prompt to specify images in the same class as provided intance images.", | ||
| 238 | ) | ||
| 239 | parser.add_argument( | ||
| 240 | "--prior_loss_weight", | 233 | "--prior_loss_weight", |
| 241 | type=float, | 234 | type=float, |
| 242 | default=1.0, | 235 | default=1.0, |
| @@ -255,15 +248,6 @@ def parse_args(): | |||
| 255 | help="Max gradient norm." | 248 | help="Max gradient norm." |
| 256 | ) | 249 | ) |
| 257 | parser.add_argument( | 250 | parser.add_argument( |
| 258 | "--num_class_images", | ||
| 259 | type=int, | ||
| 260 | default=100, | ||
| 261 | help=( | ||
| 262 | "Minimal class images for prior perversation loss. If not have enough images, additional images will be" | ||
| 263 | " sampled with class_prompt." | ||
| 264 | ), | ||
| 265 | ) | ||
| 266 | parser.add_argument( | ||
| 267 | "--config", | 251 | "--config", |
| 268 | type=str, | 252 | type=str, |
| 269 | default=None, | 253 | default=None, |
| @@ -286,21 +270,12 @@ def parse_args(): | |||
| 286 | if args.pretrained_model_name_or_path is None: | 270 | if args.pretrained_model_name_or_path is None: |
| 287 | raise ValueError("You must specify --pretrained_model_name_or_path") | 271 | raise ValueError("You must specify --pretrained_model_name_or_path") |
| 288 | 272 | ||
| 289 | if args.instance_prompt is None: | 273 | if args.instance_identifier is None: |
| 290 | raise ValueError("You must specify --instance_prompt") | 274 | raise ValueError("You must specify --instance_identifier") |
| 291 | |||
| 292 | if args.identifier is None: | ||
| 293 | raise ValueError("You must specify --identifier") | ||
| 294 | 275 | ||
| 295 | if args.output_dir is None: | 276 | if args.output_dir is None: |
| 296 | raise ValueError("You must specify --output_dir") | 277 | raise ValueError("You must specify --output_dir") |
| 297 | 278 | ||
| 298 | if args.with_prior_preservation: | ||
| 299 | if args.class_data_dir is None: | ||
| 300 | raise ValueError("You must specify --class_data_dir") | ||
| 301 | if args.class_prompt is None: | ||
| 302 | raise ValueError("You must specify --class_prompt") | ||
| 303 | |||
| 304 | return args | 279 | return args |
| 305 | 280 | ||
| 306 | 281 | ||
| @@ -443,7 +418,7 @@ def main(): | |||
| 443 | args = parse_args() | 418 | args = parse_args() |
| 444 | 419 | ||
| 445 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 420 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 446 | basepath = Path(args.output_dir).joinpath(slugify(args.identifier), now) | 421 | basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now) |
| 447 | basepath.mkdir(parents=True, exist_ok=True) | 422 | basepath.mkdir(parents=True, exist_ok=True) |
| 448 | 423 | ||
| 449 | accelerator = Accelerator( | 424 | accelerator = Accelerator( |
| @@ -488,47 +463,6 @@ def main(): | |||
| 488 | freeze_params(vae.parameters()) | 463 | freeze_params(vae.parameters()) |
| 489 | freeze_params(text_encoder.parameters()) | 464 | freeze_params(text_encoder.parameters()) |
| 490 | 465 | ||
| 491 | # Generate class images, if necessary | ||
| 492 | if args.with_prior_preservation: | ||
| 493 | class_images_dir = Path(args.class_data_dir) | ||
| 494 | class_images_dir.mkdir(parents=True, exist_ok=True) | ||
| 495 | cur_class_images = len(list(class_images_dir.iterdir())) | ||
| 496 | |||
| 497 | if cur_class_images < args.num_class_images: | ||
| 498 | scheduler = EulerAScheduler( | ||
| 499 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 500 | ) | ||
| 501 | |||
| 502 | pipeline = VlpnStableDiffusion( | ||
| 503 | text_encoder=text_encoder, | ||
| 504 | vae=vae, | ||
| 505 | unet=unet, | ||
| 506 | tokenizer=tokenizer, | ||
| 507 | scheduler=scheduler, | ||
| 508 | ).to(accelerator.device) | ||
| 509 | pipeline.enable_attention_slicing() | ||
| 510 | pipeline.set_progress_bar_config(disable=True) | ||
| 511 | |||
| 512 | num_new_images = args.num_class_images - cur_class_images | ||
| 513 | logger.info(f"Number of class images to sample: {num_new_images}.") | ||
| 514 | |||
| 515 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) | ||
| 516 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | ||
| 517 | |||
| 518 | sample_dataloader = accelerator.prepare(sample_dataloader) | ||
| 519 | |||
| 520 | for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process): | ||
| 521 | with accelerator.autocast(): | ||
| 522 | images = pipeline(example["prompt"]).images | ||
| 523 | |||
| 524 | for i, image in enumerate(images): | ||
| 525 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") | ||
| 526 | |||
| 527 | del pipeline | ||
| 528 | |||
| 529 | if torch.cuda.is_available(): | ||
| 530 | torch.cuda.empty_cache() | ||
| 531 | |||
| 532 | if args.scale_lr: | 466 | if args.scale_lr: |
| 533 | args.learning_rate = ( | 467 | args.learning_rate = ( |
| 534 | args.learning_rate * args.gradient_accumulation_steps * | 468 | args.learning_rate * args.gradient_accumulation_steps * |
| @@ -564,6 +498,7 @@ def main(): | |||
| 564 | 498 | ||
| 565 | def collate_fn(examples): | 499 | def collate_fn(examples): |
| 566 | prompts = [example["prompts"] for example in examples] | 500 | prompts = [example["prompts"] for example in examples] |
| 501 | nprompts = [example["nprompts"] for example in examples] | ||
| 567 | input_ids = [example["instance_prompt_ids"] for example in examples] | 502 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 568 | pixel_values = [example["instance_images"] for example in examples] | 503 | pixel_values = [example["instance_images"] for example in examples] |
| 569 | 504 | ||
| @@ -579,6 +514,7 @@ def main(): | |||
| 579 | 514 | ||
| 580 | batch = { | 515 | batch = { |
| 581 | "prompts": prompts, | 516 | "prompts": prompts, |
| 517 | "nprompts": nprompts, | ||
| 582 | "input_ids": input_ids, | 518 | "input_ids": input_ids, |
| 583 | "pixel_values": pixel_values, | 519 | "pixel_values": pixel_values, |
| 584 | } | 520 | } |
| @@ -588,11 +524,9 @@ def main(): | |||
| 588 | data_file=args.train_data_file, | 524 | data_file=args.train_data_file, |
| 589 | batch_size=args.train_batch_size, | 525 | batch_size=args.train_batch_size, |
| 590 | tokenizer=tokenizer, | 526 | tokenizer=tokenizer, |
| 591 | instance_prompt=args.instance_prompt, | 527 | instance_identifier=args.instance_identifier, |
| 592 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, | 528 | class_identifier=args.class_identifier, |
| 593 | class_prompt=args.class_prompt, | ||
| 594 | size=args.resolution, | 529 | size=args.resolution, |
| 595 | identifier=args.identifier, | ||
| 596 | repeats=args.repeats, | 530 | repeats=args.repeats, |
| 597 | center_crop=args.center_crop, | 531 | center_crop=args.center_crop, |
| 598 | valid_set_size=args.sample_batch_size*args.sample_batches, | 532 | valid_set_size=args.sample_batch_size*args.sample_batches, |
| @@ -601,6 +535,46 @@ def main(): | |||
| 601 | datamodule.prepare_data() | 535 | datamodule.prepare_data() |
| 602 | datamodule.setup() | 536 | datamodule.setup() |
| 603 | 537 | ||
| 538 | if args.class_identifier: | ||
| 539 | missing_data = [item for item in datamodule.data if not item[1].exists()] | ||
| 540 | |||
| 541 | if len(missing_data) != 0: | ||
| 542 | batched_data = [missing_data[i:i+args.sample_batch_size] | ||
| 543 | for i in range(0, len(missing_data), args.sample_batch_size)] | ||
| 544 | |||
| 545 | scheduler = EulerAScheduler( | ||
| 546 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 547 | ) | ||
| 548 | |||
| 549 | pipeline = VlpnStableDiffusion( | ||
| 550 | text_encoder=text_encoder, | ||
| 551 | vae=vae, | ||
| 552 | unet=unet, | ||
| 553 | tokenizer=tokenizer, | ||
| 554 | scheduler=scheduler, | ||
| 555 | ).to(accelerator.device) | ||
| 556 | pipeline.enable_attention_slicing() | ||
| 557 | |||
| 558 | for batch in batched_data: | ||
| 559 | image_name = [p[1] for p in batch] | ||
| 560 | prompt = [p[2] for p in batch] | ||
| 561 | nprompt = [p[3] for p in batch] | ||
| 562 | |||
| 563 | with accelerator.autocast(): | ||
| 564 | images = pipeline( | ||
| 565 | prompt=prompt, | ||
| 566 | negative_prompt=nprompt, | ||
| 567 | num_inference_steps=args.sample_steps | ||
| 568 | ).images | ||
| 569 | |||
| 570 | for i, image in enumerate(images): | ||
| 571 | image.save(image_name[i]) | ||
| 572 | |||
| 573 | del pipeline | ||
| 574 | |||
| 575 | if torch.cuda.is_available(): | ||
| 576 | torch.cuda.empty_cache() | ||
| 577 | |||
| 604 | train_dataloader = datamodule.train_dataloader() | 578 | train_dataloader = datamodule.train_dataloader() |
| 605 | val_dataloader = datamodule.val_dataloader() | 579 | val_dataloader = datamodule.val_dataloader() |
| 606 | 580 | ||
| @@ -718,23 +692,22 @@ def main(): | |||
| 718 | # Predict the noise residual | 692 | # Predict the noise residual |
| 719 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 693 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 720 | 694 | ||
| 721 | with accelerator.autocast(): | 695 | if args.with_prior_preservation: |
| 722 | if args.with_prior_preservation: | 696 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 723 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 697 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) |
| 724 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 698 | noise, noise_prior = torch.chunk(noise, 2, dim=0) |
| 725 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 726 | 699 | ||
| 727 | # Compute instance loss | 700 | # Compute instance loss |
| 728 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 701 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 729 | 702 | ||
| 730 | # Compute prior loss | 703 | # Compute prior loss |
| 731 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, | 704 | prior_loss = F.mse_loss(noise_pred_prior, noise_prior, |
| 732 | reduction="none").mean([1, 2, 3]).mean() | 705 | reduction="none").mean([1, 2, 3]).mean() |
| 733 | 706 | ||
| 734 | # Add the prior loss to the instance loss. | 707 | # Add the prior loss to the instance loss. |
| 735 | loss = loss + args.prior_loss_weight * prior_loss | 708 | loss = loss + args.prior_loss_weight * prior_loss |
| 736 | else: | 709 | else: |
| 737 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 710 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 738 | 711 | ||
| 739 | accelerator.backward(loss) | 712 | accelerator.backward(loss) |
| 740 | if accelerator.sync_gradients: | 713 | if accelerator.sync_gradients: |
| @@ -786,8 +759,7 @@ def main(): | |||
| 786 | 759 | ||
| 787 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 760 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 788 | 761 | ||
| 789 | with accelerator.autocast(): | 762 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 790 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 791 | 763 | ||
| 792 | loss = loss.detach().item() | 764 | loss = loss.detach().item() |
| 793 | val_loss += loss | 765 | val_loss += loss |
diff --git a/textual_inversion.py b/textual_inversion.py index 5fc2338..4c4da29 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -694,8 +694,7 @@ def main(): | |||
| 694 | # Predict the noise residual | 694 | # Predict the noise residual |
| 695 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 695 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 696 | 696 | ||
| 697 | with accelerator.autocast(): | 697 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 698 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 699 | 698 | ||
| 700 | accelerator.backward(loss) | 699 | accelerator.backward(loss) |
| 701 | 700 | ||
| @@ -766,8 +765,7 @@ def main(): | |||
| 766 | 765 | ||
| 767 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 766 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 768 | 767 | ||
| 769 | with accelerator.autocast(): | 768 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 770 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 771 | 769 | ||
| 772 | loss = loss.detach().item() | 770 | loss = loss.detach().item() |
| 773 | val_loss += loss | 771 | val_loss += loss |
