summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py (renamed from data/dreambooth/csv.py)0
-rw-r--r--data/textual_inversion/csv.py150
-rw-r--r--dreambooth.py2
-rw-r--r--infer.py14
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py35
-rw-r--r--schedulers/scheduling_euler_a.py45
-rw-r--r--textual_dreambooth.py917
-rw-r--r--textual_inversion.py112
8 files changed, 159 insertions, 1116 deletions
diff --git a/data/dreambooth/csv.py b/data/csv.py
index abd329d..abd329d 100644
--- a/data/dreambooth/csv.py
+++ b/data/csv.py
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
deleted file mode 100644
index 4c5e27e..0000000
--- a/data/textual_inversion/csv.py
+++ /dev/null
@@ -1,150 +0,0 @@
1import os
2import numpy as np
3import pandas as pd
4from pathlib import Path
5import math
6import pytorch_lightning as pl
7from PIL import Image
8from torch.utils.data import Dataset, DataLoader, random_split
9from torchvision import transforms
10
11
12class CSVDataModule(pl.LightningDataModule):
13 def __init__(self,
14 batch_size,
15 data_file,
16 tokenizer,
17 size=512,
18 repeats=100,
19 interpolation="bicubic",
20 placeholder_token="*",
21 center_crop=False,
22 valid_set_size=None,
23 generator=None):
24 super().__init__()
25
26 self.data_file = Path(data_file)
27
28 if not self.data_file.is_file():
29 raise ValueError("data_file must be a file")
30
31 self.data_root = self.data_file.parent
32 self.tokenizer = tokenizer
33 self.size = size
34 self.repeats = repeats
35 self.placeholder_token = placeholder_token
36 self.center_crop = center_crop
37 self.interpolation = interpolation
38 self.valid_set_size = valid_set_size
39 self.generator = generator
40
41 self.batch_size = batch_size
42
43 def prepare_data(self):
44 metadata = pd.read_csv(self.data_file)
45 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
46 prompts = metadata['prompt'].values
47 nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths)
48 skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths)
49 self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"]
50
51 def setup(self, stage=None):
52 valid_set_size = int(len(self.data_full) * 0.2)
53 if self.valid_set_size:
54 valid_set_size = min(valid_set_size, self.valid_set_size)
55 valid_set_size = max(valid_set_size, 1)
56 train_set_size = len(self.data_full) - valid_set_size
57
58 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size], self.generator)
59
60 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
61 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
62 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
63 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
64 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True)
65 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True)
66
67 def train_dataloader(self):
68 return self.train_dataloader_
69
70 def val_dataloader(self):
71 return self.val_dataloader_
72
73
74class CSVDataset(Dataset):
75 def __init__(self,
76 data,
77 tokenizer,
78 size=512,
79 repeats=1,
80 interpolation="bicubic",
81 placeholder_token="*",
82 center_crop=False,
83 batch_size=1,
84 ):
85
86 self.data = data
87 self.tokenizer = tokenizer
88 self.placeholder_token = placeholder_token
89 self.batch_size = batch_size
90 self.cache = {}
91
92 self.num_instance_images = len(self.data)
93 self._length = self.num_instance_images * repeats
94
95 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST,
96 "bilinear": transforms.InterpolationMode.BILINEAR,
97 "bicubic": transforms.InterpolationMode.BICUBIC,
98 "lanczos": transforms.InterpolationMode.LANCZOS,
99 }[interpolation]
100 self.image_transforms = transforms.Compose(
101 [
102 transforms.Resize(size, interpolation=self.interpolation),
103 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
104 transforms.RandomHorizontalFlip(),
105 transforms.ToTensor(),
106 transforms.Normalize([0.5], [0.5]),
107 ]
108 )
109
110 def __len__(self):
111 return math.ceil(self._length / self.batch_size) * self.batch_size
112
113 def get_example(self, i):
114 image_path, prompt, nprompt = self.data[i % self.num_instance_images]
115
116 if image_path in self.cache:
117 return self.cache[image_path]
118
119 example = {}
120
121 instance_image = Image.open(image_path)
122 if not instance_image.mode == "RGB":
123 instance_image = instance_image.convert("RGB")
124
125 prompt = prompt.format(self.placeholder_token)
126
127 example["prompts"] = prompt
128 example["nprompts"] = nprompt
129 example["pixel_values"] = instance_image
130 example["input_ids"] = self.tokenizer(
131 prompt,
132 padding="max_length",
133 truncation=True,
134 max_length=self.tokenizer.model_max_length,
135 return_tensors="pt",
136 ).input_ids[0]
137
138 self.cache[image_path] = example
139 return example
140
141 def __getitem__(self, i):
142 example = {}
143 unprocessed_example = self.get_example(i)
144
145 example["prompts"] = unprocessed_example["prompts"]
146 example["nprompts"] = unprocessed_example["nprompts"]
147 example["input_ids"] = unprocessed_example["input_ids"]
148 example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"])
149
150 return example
diff --git a/dreambooth.py b/dreambooth.py
index 0c5c42a..0e69d79 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -23,7 +23,7 @@ from slugify import slugify
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24import json 24import json
25 25
26from data.dreambooth.csv import CSVDataModule 26from data.csv import CSVDataModule
27 27
28logger = get_logger(__name__) 28logger = get_logger(__name__)
29 29
diff --git a/infer.py b/infer.py
index 3487e5a..34e570a 100644
--- a/infer.py
+++ b/infer.py
@@ -171,6 +171,18 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir):
171 embeddings_dir = Path(embeddings_dir) 171 embeddings_dir = Path(embeddings_dir)
172 embeddings_dir.mkdir(parents=True, exist_ok=True) 172 embeddings_dir.mkdir(parents=True, exist_ok=True)
173 173
174 for file in embeddings_dir.iterdir():
175 placeholder_token = file.stem
176
177 num_added_tokens = tokenizer.add_tokens(placeholder_token)
178 if num_added_tokens == 0:
179 raise ValueError(
180 f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
181 " `placeholder_token` that is not already in the tokenizer."
182 )
183
184 text_encoder.resize_token_embeddings(len(tokenizer))
185
174 token_embeds = text_encoder.get_input_embeddings().weight.data 186 token_embeds = text_encoder.get_input_embeddings().weight.data
175 187
176 for file in embeddings_dir.iterdir(): 188 for file in embeddings_dir.iterdir():
@@ -187,6 +199,8 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir):
187 199
188 token_embeds[placeholder_token_id] = emb 200 token_embeds[placeholder_token_id] = emb
189 201
202 print(f"Loaded embedding: {placeholder_token}")
203
190 204
191def create_pipeline(model, scheduler, embeddings_dir, dtype): 205def create_pipeline(model, scheduler, embeddings_dir, dtype):
192 print("Loading Stable Diffusion pipeline...") 206 print("Loading Stable Diffusion pipeline...")
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 8fbe5f9..a198cf6 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -216,7 +216,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
216 216
217 offset = self.scheduler.config.get("steps_offset", 0) 217 offset = self.scheduler.config.get("steps_offset", 0)
218 init_timestep = num_inference_steps + offset 218 init_timestep = num_inference_steps + offset
219 ensure_sigma = not isinstance(latents, PIL.Image.Image)
220 219
221 # get the initial random noise unless the user supplied it 220 # get the initial random noise unless the user supplied it
222 221
@@ -246,13 +245,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
246 init_timestep = int(num_inference_steps * strength) + offset 245 init_timestep = int(num_inference_steps * strength) + offset
247 init_timestep = min(init_timestep, num_inference_steps) 246 init_timestep = min(init_timestep, num_inference_steps)
248 247
249 if isinstance(self.scheduler, LMSDiscreteScheduler): 248 timesteps = self.scheduler.timesteps[-init_timestep]
250 timesteps = torch.tensor( 249 timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
251 [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
252 )
253 else:
254 timesteps = self.scheduler.timesteps[-init_timestep]
255 timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
256 250
257 # add noise to latents using the timesteps 251 # add noise to latents using the timesteps
258 noise = torch.randn(latents.shape, generator=generator, device=self.device) 252 noise = torch.randn(latents.shape, generator=generator, device=self.device)
@@ -263,13 +257,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
263 if latents.device != self.device: 257 if latents.device != self.device:
264 raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}") 258 raise ValueError(f"Unexpected latents device, got {latents.device}, expected {self.device}")
265 259
266 # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
267 if ensure_sigma:
268 if isinstance(self.scheduler, LMSDiscreteScheduler):
269 latents = latents * self.scheduler.sigmas[0]
270 elif isinstance(self.scheduler, EulerAScheduler):
271 latents = latents * self.scheduler.sigmas[0]
272
273 t_start = max(num_inference_steps - init_timestep + offset, 0) 260 t_start = max(num_inference_steps - init_timestep + offset, 0)
274 261
275 # Some schedulers like PNDM have timesteps as arrays 262 # Some schedulers like PNDM have timesteps as arrays
@@ -290,19 +277,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
290 extra_step_kwargs["generator"] = generator 277 extra_step_kwargs["generator"] = generator
291 278
292 for i, t in enumerate(self.progress_bar(timesteps_tensor)): 279 for i, t in enumerate(self.progress_bar(timesteps_tensor)):
293 t_index = t_start + i
294
295 # expand the latents if we are doing classifier free guidance 280 # expand the latents if we are doing classifier free guidance
296 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 281 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
297 282 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
298 if isinstance(self.scheduler, LMSDiscreteScheduler):
299 sigma = self.scheduler.sigmas[t_index]
300 # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
301 latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
302 283
303 noise_pred = None 284 noise_pred = None
304 if isinstance(self.scheduler, EulerAScheduler): 285 if isinstance(self.scheduler, EulerAScheduler):
305 sigma = self.scheduler.sigmas[t].reshape(1) 286 sigma = t.reshape(1)
306 sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) 287 sigma_in = torch.cat([sigma] * latent_model_input.shape[0])
307 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, 288 noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in,
308 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) 289 text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas)
@@ -316,13 +297,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
316 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 297 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
317 298
318 # compute the previous noisy sample x_t -> x_t-1 299 # compute the previous noisy sample x_t -> x_t-1
319 if isinstance(self.scheduler, LMSDiscreteScheduler): 300 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
320 latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
321 elif isinstance(self.scheduler, EulerAScheduler):
322 latents = self.scheduler.step(noise_pred, t_index, t_index + 1,
323 latents, **extra_step_kwargs).prev_sample
324 else:
325 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
326 301
327 # scale and decode the image latents with vae 302 # scale and decode the image latents with vae
328 latents = 1 / 0.18215 * latents 303 latents = 1 / 0.18215 * latents
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py
index c6436d8..13ea6b3 100644
--- a/schedulers/scheduling_euler_a.py
+++ b/schedulers/scheduling_euler_a.py
@@ -171,6 +171,9 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
171 self.alphas = 1.0 - self.betas 171 self.alphas = 1.0 - self.betas
172 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 172 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
173 173
174 # standard deviation of the initial noise distribution
175 self.init_noise_sigma = 1.0
176
174 # setable values 177 # setable values
175 self.num_inference_steps = None 178 self.num_inference_steps = None
176 self.timesteps = np.arange(0, num_train_timesteps)[::-1] 179 self.timesteps = np.arange(0, num_train_timesteps)[::-1]
@@ -190,13 +193,33 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
190 self.num_inference_steps = num_inference_steps 193 self.num_inference_steps = num_inference_steps
191 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 194 self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
192 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) 195 self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device)
193 self.timesteps = np.arange(0, self.num_inference_steps) 196 self.timesteps = self.sigmas[:-1]
197 self.is_scale_input_called = False
198
199 def scale_model_input(self, sample: torch.FloatTensor, timestep: int) -> torch.FloatTensor:
200 """
201 Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
202 current timestep.
203 Args:
204 sample (`torch.FloatTensor`): input sample
205 timestep (`int`, optional): current timestep
206 Returns:
207 `torch.FloatTensor`: scaled input sample
208 """
209 if isinstance(timestep, torch.Tensor):
210 timestep = timestep.to(self.timesteps.device)
211 if self.is_scale_input_called:
212 return sample
213 step_index = (self.timesteps == timestep).nonzero().item()
214 sigma = self.sigmas[step_index]
215 sample = sample * sigma
216 self.is_scale_input_called = True
217 return sample
194 218
195 def step( 219 def step(
196 self, 220 self,
197 model_output: torch.FloatTensor, 221 model_output: torch.FloatTensor,
198 timestep: int, 222 timestep: Union[float, torch.FloatTensor],
199 timestep_prev: int,
200 sample: torch.FloatTensor, 223 sample: torch.FloatTensor,
201 generator: torch.Generator = None, 224 generator: torch.Generator = None,
202 return_dict: bool = True, 225 return_dict: bool = True,
@@ -219,8 +242,13 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
219 returning a tuple, the first element is the sample tensor. 242 returning a tuple, the first element is the sample tensor.
220 243
221 """ 244 """
222 s = self.sigmas[timestep] 245 if isinstance(timestep, torch.Tensor):
223 s_prev = self.sigmas[timestep_prev] 246 timestep = timestep.to(self.timesteps.device)
247 step_index = (self.timesteps == timestep).nonzero().item()
248 step_prev_index = step_index + 1
249
250 s = self.sigmas[step_index]
251 s_prev = self.sigmas[step_prev_index]
224 latents = sample 252 latents = sample
225 253
226 sigma_down, sigma_up = get_ancestral_step(s, s_prev) 254 sigma_down, sigma_up = get_ancestral_step(s, s_prev)
@@ -271,14 +299,17 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin):
271 self, 299 self,
272 original_samples: torch.FloatTensor, 300 original_samples: torch.FloatTensor,
273 noise: torch.FloatTensor, 301 noise: torch.FloatTensor,
274 timesteps: torch.IntTensor, 302 timesteps: torch.FloatTensor,
275 ) -> torch.FloatTensor: 303 ) -> torch.FloatTensor:
276 sigmas = self.sigmas.to(original_samples.device) 304 sigmas = self.sigmas.to(original_samples.device)
305 schedule_timesteps = self.timesteps.to(original_samples.device)
277 timesteps = timesteps.to(original_samples.device) 306 timesteps = timesteps.to(original_samples.device)
307 step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
278 308
279 sigma = sigmas[timesteps].flatten() 309 sigma = sigmas[step_indices].flatten()
280 while len(sigma.shape) < len(original_samples.shape): 310 while len(sigma.shape) < len(original_samples.shape):
281 sigma = sigma.unsqueeze(-1) 311 sigma = sigma.unsqueeze(-1)
282 312
283 noisy_samples = original_samples + noise * sigma 313 noisy_samples = original_samples + noise * sigma
314 self.is_scale_input_called = True
284 return noisy_samples 315 return noisy_samples
diff --git a/textual_dreambooth.py b/textual_dreambooth.py
deleted file mode 100644
index c07d98b..0000000
--- a/textual_dreambooth.py
+++ /dev/null
@@ -1,917 +0,0 @@
1import argparse
2import itertools
3import math
4import os
5import datetime
6import logging
7from pathlib import Path
8
9import numpy as np
10import torch
11import torch.nn.functional as F
12import torch.utils.checkpoint
13
14from accelerate import Accelerator
15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler
20from PIL import Image
21from tqdm.auto import tqdm
22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json
26import os
27
28from data.dreambooth.csv import CSVDataModule
29
30logger = get_logger(__name__)
31
32
33torch.backends.cuda.matmul.allow_tf32 = True
34
35
36def parse_args():
37 parser = argparse.ArgumentParser(
38 description="Simple example of a training script."
39 )
40 parser.add_argument(
41 "--pretrained_model_name_or_path",
42 type=str,
43 default=None,
44 help="Path to pretrained model or model identifier from huggingface.co/models.",
45 )
46 parser.add_argument(
47 "--tokenizer_name",
48 type=str,
49 default=None,
50 help="Pretrained tokenizer name or path if not the same as model_name",
51 )
52 parser.add_argument(
53 "--train_data_file",
54 type=str,
55 default=None,
56 help="A CSV file containing the training data."
57 )
58 parser.add_argument(
59 "--placeholder_token",
60 type=str,
61 default=None,
62 help="A token to use as a placeholder for the concept.",
63 )
64 parser.add_argument(
65 "--initializer_token",
66 type=str,
67 default=None,
68 help="A token to use as initializer word."
69 )
70 parser.add_argument(
71 "--use_class_images",
72 action="store_true",
73 default=True,
74 help="Include class images in the loss calculation a la Dreambooth.",
75 )
76 parser.add_argument(
77 "--repeats",
78 type=int,
79 default=100,
80 help="How many times to repeat the training data.")
81 parser.add_argument(
82 "--output_dir",
83 type=str,
84 default="output/text-inversion",
85 help="The output directory where the model predictions and checkpoints will be written.",
86 )
87 parser.add_argument(
88 "--seed",
89 type=int,
90 default=None,
91 help="A seed for reproducible training.")
92 parser.add_argument(
93 "--resolution",
94 type=int,
95 default=512,
96 help=(
97 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
98 " resolution"
99 ),
100 )
101 parser.add_argument(
102 "--center_crop",
103 action="store_true",
104 help="Whether to center crop images before resizing to resolution"
105 )
106 parser.add_argument(
107 "--num_train_epochs",
108 type=int,
109 default=100)
110 parser.add_argument(
111 "--max_train_steps",
112 type=int,
113 default=5000,
114 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
115 )
116 parser.add_argument(
117 "--gradient_accumulation_steps",
118 type=int,
119 default=1,
120 help="Number of updates steps to accumulate before performing a backward/update pass.",
121 )
122 parser.add_argument(
123 "--gradient_checkpointing",
124 action="store_true",
125 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
126 )
127 parser.add_argument(
128 "--learning_rate",
129 type=float,
130 default=1e-4,
131 help="Initial learning rate (after the potential warmup period) to use.",
132 )
133 parser.add_argument(
134 "--scale_lr",
135 action="store_true",
136 default=True,
137 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
138 )
139 parser.add_argument(
140 "--lr_scheduler",
141 type=str,
142 default="constant",
143 help=(
144 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
145 ' "constant", "constant_with_warmup"]'
146 ),
147 )
148 parser.add_argument(
149 "--lr_warmup_steps",
150 type=int,
151 default=500,
152 help="Number of steps for the warmup in the lr scheduler."
153 )
154 parser.add_argument(
155 "--use_8bit_adam",
156 action="store_true",
157 help="Whether or not to use 8-bit Adam from bitsandbytes."
158 )
159 parser.add_argument(
160 "--adam_beta1",
161 type=float,
162 default=0.9,
163 help="The beta1 parameter for the Adam optimizer."
164 )
165 parser.add_argument(
166 "--adam_beta2",
167 type=float,
168 default=0.999,
169 help="The beta2 parameter for the Adam optimizer."
170 )
171 parser.add_argument(
172 "--adam_weight_decay",
173 type=float,
174 default=1e-2,
175 help="Weight decay to use."
176 )
177 parser.add_argument(
178 "--adam_epsilon",
179 type=float,
180 default=1e-08,
181 help="Epsilon value for the Adam optimizer"
182 )
183 parser.add_argument(
184 "--mixed_precision",
185 type=str,
186 default="no",
187 choices=["no", "fp16", "bf16"],
188 help=(
189 "Whether to use mixed precision. Choose"
190 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
191 "and an Nvidia Ampere GPU."
192 ),
193 )
194 parser.add_argument(
195 "--local_rank",
196 type=int,
197 default=-1,
198 help="For distributed training: local_rank"
199 )
200 parser.add_argument(
201 "--checkpoint_frequency",
202 type=int,
203 default=500,
204 help="How often to save a checkpoint and sample image",
205 )
206 parser.add_argument(
207 "--sample_image_size",
208 type=int,
209 default=512,
210 help="Size of sample images",
211 )
212 parser.add_argument(
213 "--sample_batches",
214 type=int,
215 default=1,
216 help="Number of sample batches to generate per checkpoint",
217 )
218 parser.add_argument(
219 "--sample_batch_size",
220 type=int,
221 default=1,
222 help="Number of samples to generate per batch",
223 )
224 parser.add_argument(
225 "--train_batch_size",
226 type=int,
227 default=1,
228 help="Batch size (per device) for the training dataloader."
229 )
230 parser.add_argument(
231 "--sample_steps",
232 type=int,
233 default=30,
234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
235 )
236 parser.add_argument(
237 "--prior_loss_weight",
238 type=float,
239 default=1.0,
240 help="The weight of prior preservation loss."
241 )
242 parser.add_argument(
243 "--resume_from",
244 type=str,
245 default=None,
246 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
247 )
248 parser.add_argument(
249 "--resume_checkpoint",
250 type=str,
251 default=None,
252 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
253 )
254 parser.add_argument(
255 "--config",
256 type=str,
257 default=None,
258 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
259 )
260
261 args = parser.parse_args()
262 if args.resume_from is not None:
263 with open(f"{args.resume_from}/resume.json", 'rt') as f:
264 args = parser.parse_args(
265 namespace=argparse.Namespace(**json.load(f)["args"]))
266 elif args.config is not None:
267 with open(args.config, 'rt') as f:
268 args = parser.parse_args(
269 namespace=argparse.Namespace(**json.load(f)["args"]))
270
271 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
272 if env_local_rank != -1 and env_local_rank != args.local_rank:
273 args.local_rank = env_local_rank
274
275 if args.train_data_file is None:
276 raise ValueError("You must specify --train_data_file")
277
278 if args.pretrained_model_name_or_path is None:
279 raise ValueError("You must specify --pretrained_model_name_or_path")
280
281 if args.placeholder_token is None:
282 raise ValueError("You must specify --placeholder_token")
283
284 if args.initializer_token is None:
285 raise ValueError("You must specify --initializer_token")
286
287 if args.output_dir is None:
288 raise ValueError("You must specify --output_dir")
289
290 return args
291
292
293def freeze_params(params):
294 for param in params:
295 param.requires_grad = False
296
297
298def save_resume_file(basepath, args, extra={}):
299 info = {"args": vars(args)}
300 info["args"].update(extra)
301 with open(f"{basepath}/resume.json", "w") as f:
302 json.dump(info, f, indent=4)
303
304
305def make_grid(images, rows, cols):
306 w, h = images[0].size
307 grid = Image.new('RGB', size=(cols*w, rows*h))
308 for i, image in enumerate(images):
309 grid.paste(image, box=(i % cols*w, i//cols*h))
310 return grid
311
312
313class Checkpointer:
314 def __init__(
315 self,
316 datamodule,
317 accelerator,
318 vae,
319 unet,
320 tokenizer,
321 placeholder_token,
322 placeholder_token_id,
323 output_dir,
324 sample_image_size,
325 sample_batches,
326 sample_batch_size,
327 seed
328 ):
329 self.datamodule = datamodule
330 self.accelerator = accelerator
331 self.vae = vae
332 self.unet = unet
333 self.tokenizer = tokenizer
334 self.placeholder_token = placeholder_token
335 self.placeholder_token_id = placeholder_token_id
336 self.output_dir = output_dir
337 self.sample_image_size = sample_image_size
338 self.seed = seed or torch.random.seed()
339 self.sample_batches = sample_batches
340 self.sample_batch_size = sample_batch_size
341
342 @torch.no_grad()
343 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
344 print("Saving checkpoint for step %d..." % step)
345
346 if path is None:
347 checkpoints_path = f"{self.output_dir}/checkpoints"
348 os.makedirs(checkpoints_path, exist_ok=True)
349
350 unwrapped = self.accelerator.unwrap_model(text_encoder)
351
352 # Save a checkpoint
353 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
354 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
355
356 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
357 if path is not None:
358 torch.save(learned_embeds_dict, path)
359 else:
360 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
361 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
362
363 del unwrapped
364 del learned_embeds
365
366 @torch.no_grad()
367 def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
368 samples_path = Path(self.output_dir).joinpath("samples")
369
370 unwrapped = self.accelerator.unwrap_model(text_encoder)
371 scheduler = EulerAScheduler(
372 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
373 )
374
375 # Save a sample image
376 pipeline = VlpnStableDiffusion(
377 text_encoder=unwrapped,
378 vae=self.vae,
379 unet=self.unet,
380 tokenizer=self.tokenizer,
381 scheduler=scheduler,
382 ).to(self.accelerator.device)
383 pipeline.enable_attention_slicing()
384
385 train_data = self.datamodule.train_dataloader()
386 val_data = self.datamodule.val_dataloader()
387
388 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
389 stable_latents = torch.randn(
390 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
391 device=pipeline.device,
392 generator=generator,
393 )
394
395 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
396 all_samples = []
397 file_path = samples_path.joinpath(pool, f"step_{step}.png")
398 file_path.parent.mkdir(parents=True, exist_ok=True)
399
400 data_enum = enumerate(data)
401
402 for i in range(self.sample_batches):
403 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
404 prompt = [prompt.format(self.placeholder_token)
405 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
406 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
407
408 with self.accelerator.autocast():
409 samples = pipeline(
410 prompt=prompt,
411 negative_prompt=nprompt,
412 height=self.sample_image_size,
413 width=self.sample_image_size,
414 latents=latents[:len(prompt)] if latents is not None else None,
415 generator=generator if latents is not None else None,
416 guidance_scale=guidance_scale,
417 eta=eta,
418 num_inference_steps=num_inference_steps,
419 output_type='pil'
420 )["sample"]
421
422 all_samples += samples
423
424 del samples
425
426 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
427 image_grid.save(file_path)
428
429 del all_samples
430 del image_grid
431
432 del unwrapped
433 del scheduler
434 del pipeline
435 del generator
436 del stable_latents
437
438 if torch.cuda.is_available():
439 torch.cuda.empty_cache()
440
441
442def main():
443 args = parse_args()
444
445 global_step_offset = 0
446 if args.resume_from is not None:
447 basepath = Path(args.resume_from)
448 print("Resuming state from %s" % args.resume_from)
449 with open(basepath.joinpath("resume.json"), 'r') as f:
450 state = json.load(f)
451 global_step_offset = state["args"].get("global_step", 0)
452
453 print("We've trained %d steps so far" % global_step_offset)
454 else:
455 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
456 basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now)
457 basepath.mkdir(parents=True, exist_ok=True)
458
459 accelerator = Accelerator(
460 log_with=LoggerType.TENSORBOARD,
461 logging_dir=f"{basepath}",
462 gradient_accumulation_steps=args.gradient_accumulation_steps,
463 mixed_precision=args.mixed_precision
464 )
465
466 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
467
468 # If passed along, set the training seed now.
469 if args.seed is not None:
470 set_seed(args.seed)
471
472 # Load the tokenizer and add the placeholder token as a additional special token
473 if args.tokenizer_name:
474 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
475 elif args.pretrained_model_name_or_path:
476 tokenizer = CLIPTokenizer.from_pretrained(
477 args.pretrained_model_name_or_path + '/tokenizer'
478 )
479
480 # Add the placeholder token in tokenizer
481 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
482 if num_added_tokens == 0:
483 raise ValueError(
484 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
485 " `placeholder_token` that is not already in the tokenizer."
486 )
487
488 # Convert the initializer_token, placeholder_token to ids
489 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
490 # Check if initializer_token is a single token or a sequence of tokens
491 if len(initializer_token_ids) > 1:
492 raise ValueError(
493 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
494
495 initializer_token_ids = torch.tensor(initializer_token_ids)
496 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
497
498 # Load models and create wrapper for stable diffusion
499 text_encoder = CLIPTextModel.from_pretrained(
500 args.pretrained_model_name_or_path + '/text_encoder',
501 )
502 vae = AutoencoderKL.from_pretrained(
503 args.pretrained_model_name_or_path + '/vae',
504 )
505 unet = UNet2DConditionModel.from_pretrained(
506 args.pretrained_model_name_or_path + '/unet',
507 )
508
509 if args.gradient_checkpointing:
510 unet.enable_gradient_checkpointing()
511
512 slice_size = unet.config.attention_head_dim // 2
513 unet.set_attention_slice(slice_size)
514
515 # Resize the token embeddings as we are adding new special tokens to the tokenizer
516 text_encoder.resize_token_embeddings(len(tokenizer))
517
518 # Initialise the newly added placeholder token with the embeddings of the initializer token
519 token_embeds = text_encoder.get_input_embeddings().weight.data
520
521 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
522
523 if args.resume_checkpoint is not None:
524 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
525 args.placeholder_token]
526 else:
527 token_embeds[placeholder_token_id] = initializer_token_embeddings
528
529 # Freeze vae and unet
530 freeze_params(vae.parameters())
531 freeze_params(unet.parameters())
532 # Freeze all parameters except for the token embeddings in text encoder
533 params_to_freeze = itertools.chain(
534 text_encoder.text_model.encoder.parameters(),
535 text_encoder.text_model.final_layer_norm.parameters(),
536 text_encoder.text_model.embeddings.position_embedding.parameters(),
537 )
538 freeze_params(params_to_freeze)
539
540 if args.scale_lr:
541 args.learning_rate = (
542 args.learning_rate * args.gradient_accumulation_steps *
543 args.train_batch_size * accelerator.num_processes
544 )
545
546 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
547 if args.use_8bit_adam:
548 try:
549 import bitsandbytes as bnb
550 except ImportError:
551 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
552
553 optimizer_class = bnb.optim.AdamW8bit
554 else:
555 optimizer_class = torch.optim.AdamW
556
557 # Initialize the optimizer
558 optimizer = optimizer_class(
559 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
560 lr=args.learning_rate,
561 betas=(args.adam_beta1, args.adam_beta2),
562 weight_decay=args.adam_weight_decay,
563 eps=args.adam_epsilon,
564 )
565
566 noise_scheduler = DDPMScheduler(
567 beta_start=0.00085,
568 beta_end=0.012,
569 beta_schedule="scaled_linear",
570 num_train_timesteps=1000
571 )
572
573 def collate_fn(examples):
574 prompts = [example["prompts"] for example in examples]
575 nprompts = [example["nprompts"] for example in examples]
576 input_ids = [example["instance_prompt_ids"] for example in examples]
577 pixel_values = [example["instance_images"] for example in examples]
578
579 # concat class and instance examples for prior preservation
580 if args.use_class_images and "class_prompt_ids" in examples[0]:
581 input_ids += [example["class_prompt_ids"] for example in examples]
582 pixel_values += [example["class_images"] for example in examples]
583
584 pixel_values = torch.stack(pixel_values)
585 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
586
587 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
588
589 batch = {
590 "prompts": prompts,
591 "nprompts": nprompts,
592 "input_ids": input_ids,
593 "pixel_values": pixel_values,
594 }
595 return batch
596
597 datamodule = CSVDataModule(
598 data_file=args.train_data_file,
599 batch_size=args.train_batch_size,
600 tokenizer=tokenizer,
601 instance_identifier=args.placeholder_token,
602 class_identifier=args.initializer_token if args.use_class_images else None,
603 class_subdir="ti_cls",
604 size=args.resolution,
605 repeats=args.repeats,
606 center_crop=args.center_crop,
607 valid_set_size=args.sample_batch_size*args.sample_batches,
608 collate_fn=collate_fn
609 )
610
611 datamodule.prepare_data()
612 datamodule.setup()
613
614 if args.use_class_images:
615 missing_data = [item for item in datamodule.data if not item[1].exists()]
616
617 if len(missing_data) != 0:
618 batched_data = [missing_data[i:i+args.sample_batch_size]
619 for i in range(0, len(missing_data), args.sample_batch_size)]
620
621 scheduler = EulerAScheduler(
622 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
623 )
624
625 pipeline = VlpnStableDiffusion(
626 text_encoder=text_encoder,
627 vae=vae,
628 unet=unet,
629 tokenizer=tokenizer,
630 scheduler=scheduler,
631 ).to(accelerator.device)
632 pipeline.enable_attention_slicing()
633
634 for batch in batched_data:
635 image_name = [p[1] for p in batch]
636 prompt = [p[2].format(args.initializer_token) for p in batch]
637 nprompt = [p[3] for p in batch]
638
639 with accelerator.autocast():
640 images = pipeline(
641 prompt=prompt,
642 negative_prompt=nprompt,
643 num_inference_steps=args.sample_steps
644 ).images
645
646 for i, image in enumerate(images):
647 image.save(image_name[i])
648
649 del pipeline
650
651 if torch.cuda.is_available():
652 torch.cuda.empty_cache()
653
654 train_dataloader = datamodule.train_dataloader()
655 val_dataloader = datamodule.val_dataloader()
656
657 checkpointer = Checkpointer(
658 datamodule=datamodule,
659 accelerator=accelerator,
660 vae=vae,
661 unet=unet,
662 tokenizer=tokenizer,
663 placeholder_token=args.placeholder_token,
664 placeholder_token_id=placeholder_token_id,
665 output_dir=basepath,
666 sample_image_size=args.sample_image_size,
667 sample_batch_size=args.sample_batch_size,
668 sample_batches=args.sample_batches,
669 seed=args.seed
670 )
671
672 # Scheduler and math around the number of training steps.
673 overrode_max_train_steps = False
674 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
675 if args.max_train_steps is None:
676 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
677 overrode_max_train_steps = True
678
679 lr_scheduler = get_scheduler(
680 args.lr_scheduler,
681 optimizer=optimizer,
682 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
683 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
684 )
685
686 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
687 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
688 )
689
690 # Move vae and unet to device
691 vae.to(accelerator.device)
692 unet.to(accelerator.device)
693
694 # Keep vae and unet in eval mode as we don't train these
695 vae.eval()
696 unet.eval()
697
698 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
699 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
700 if overrode_max_train_steps:
701 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
702
703 num_val_steps_per_epoch = len(val_dataloader)
704 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
705 val_steps = num_val_steps_per_epoch * num_epochs
706
707 # We need to initialize the trackers we use, and also store our configuration.
708 # The trackers initializes automatically on the main process.
709 if accelerator.is_main_process:
710 accelerator.init_trackers("textual_inversion", config=vars(args))
711
712 # Train!
713 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
714
715 logger.info("***** Running training *****")
716 logger.info(f" Num Epochs = {num_epochs}")
717 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
718 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
719 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
720 logger.info(f" Total optimization steps = {args.max_train_steps}")
721 # Only show the progress bar once on each machine.
722
723 global_step = 0
724 min_val_loss = np.inf
725
726 if accelerator.is_main_process:
727 checkpointer.save_samples(
728 0,
729 text_encoder,
730 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
731
732 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch),
733 disable=not accelerator.is_local_main_process)
734 local_progress_bar.set_description("Batch X out of Y")
735
736 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process)
737 global_progress_bar.set_description("Total progress")
738
739 try:
740 for epoch in range(num_epochs):
741 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}")
742 local_progress_bar.reset()
743
744 text_encoder.train()
745 train_loss = 0.0
746
747 for step, batch in enumerate(train_dataloader):
748 with accelerator.accumulate(text_encoder):
749 # Convert images to latent space
750 with torch.no_grad():
751 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
752 latents = latents * 0.18215
753
754 # Sample noise that we'll add to the latents
755 noise = torch.randn(latents.shape).to(latents.device)
756 bsz = latents.shape[0]
757 # Sample a random timestep for each image
758 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
759 (bsz,), device=latents.device)
760 timesteps = timesteps.long()
761
762 # Add noise to the latents according to the noise magnitude at each timestep
763 # (this is the forward diffusion process)
764 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
765
766 # Get the text embedding for conditioning
767 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
768
769 # Predict the noise residual
770 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
771
772 if args.use_class_images:
773 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
774 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
775 noise, noise_prior = torch.chunk(noise, 2, dim=0)
776
777 # Compute instance loss
778 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
779
780 # Compute prior loss
781 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
782
783 # Add the prior loss to the instance loss.
784 loss = loss + args.prior_loss_weight * prior_loss
785 else:
786 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
787
788 accelerator.backward(loss)
789
790 # Zero out the gradients for all token embeddings except the newly added
791 # embeddings for the concept, as we only want to optimize the concept embeddings
792 if accelerator.num_processes > 1:
793 grads = text_encoder.module.get_input_embeddings().weight.grad
794 else:
795 grads = text_encoder.get_input_embeddings().weight.grad
796 # Get the index for tokens that we want to zero the grads for
797 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
798 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
799
800 optimizer.step()
801 if not accelerator.optimizer_step_was_skipped:
802 lr_scheduler.step()
803 optimizer.zero_grad(set_to_none=True)
804
805 loss = loss.detach().item()
806 train_loss += loss
807
808 # Checks if the accelerator has performed an optimization step behind the scenes
809 if accelerator.sync_gradients:
810 local_progress_bar.update(1)
811 global_progress_bar.update(1)
812
813 global_step += 1
814
815 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
816 local_progress_bar.clear()
817 global_progress_bar.clear()
818
819 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
820 save_resume_file(basepath, args, {
821 "global_step": global_step + global_step_offset,
822 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
823 })
824
825 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
826 local_progress_bar.set_postfix(**logs)
827
828 if global_step >= args.max_train_steps:
829 break
830
831 train_loss /= len(train_dataloader)
832
833 accelerator.wait_for_everyone()
834
835 text_encoder.eval()
836 val_loss = 0.0
837
838 for step, batch in enumerate(val_dataloader):
839 with torch.no_grad():
840 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
841 latents = latents * 0.18215
842
843 noise = torch.randn(latents.shape).to(latents.device)
844 bsz = latents.shape[0]
845 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
846 (bsz,), device=latents.device)
847 timesteps = timesteps.long()
848
849 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
850
851 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
852
853 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
854
855 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
856
857 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
858
859 loss = loss.detach().item()
860 val_loss += loss
861
862 if accelerator.sync_gradients:
863 local_progress_bar.update(1)
864 global_progress_bar.update(1)
865
866 logs = {"mode": "validation", "loss": loss}
867 local_progress_bar.set_postfix(**logs)
868
869 val_loss /= len(val_dataloader)
870
871 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
872
873 local_progress_bar.clear()
874 global_progress_bar.clear()
875
876 if min_val_loss > val_loss:
877 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
878 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
879 min_val_loss = val_loss
880
881 if accelerator.is_main_process:
882 checkpointer.save_samples(
883 global_step + global_step_offset,
884 text_encoder,
885 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
886
887 # Create the pipeline using using the trained modules and save it.
888 if accelerator.is_main_process:
889 print("Finished! Saving final checkpoint and resume state.")
890 checkpointer.checkpoint(
891 global_step + global_step_offset,
892 "end",
893 text_encoder,
894 path=f"{basepath}/learned_embeds.bin"
895 )
896
897 save_resume_file(basepath, args, {
898 "global_step": global_step + global_step_offset,
899 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
900 })
901
902 accelerator.end_training()
903
904 except KeyboardInterrupt:
905 if accelerator.is_main_process:
906 print("Interrupted, saving checkpoint and resume state...")
907 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder)
908 save_resume_file(basepath, args, {
909 "global_step": global_step + global_step_offset,
910 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
911 })
912 accelerator.end_training()
913 quit()
914
915
916if __name__ == "__main__":
917 main()
diff --git a/textual_inversion.py b/textual_inversion.py
index 7919ebd..11c324d 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json 25import json
26import os 26import os
27 27
28from data.textual_inversion.csv import CSVDataModule 28from data.csv import CSVDataModule
29 29
30logger = get_logger(__name__) 30logger = get_logger(__name__)
31 31
@@ -68,10 +68,10 @@ def parse_args():
68 help="A token to use as initializer word." 68 help="A token to use as initializer word."
69 ) 69 )
70 parser.add_argument( 70 parser.add_argument(
71 "--vectors_per_token", 71 "--use_class_images",
72 type=int, 72 action="store_true",
73 default=1, 73 default=True,
74 help="Vectors per token." 74 help="Include class images in the loss calculation a la Dreambooth.",
75 ) 75 )
76 parser.add_argument( 76 parser.add_argument(
77 "--repeats", 77 "--repeats",
@@ -234,6 +234,12 @@ def parse_args():
234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
235 ) 235 )
236 parser.add_argument( 236 parser.add_argument(
237 "--prior_loss_weight",
238 type=float,
239 default=1.0,
240 help="The weight of prior preservation loss."
241 )
242 parser.add_argument(
237 "--resume_from", 243 "--resume_from",
238 type=str, 244 type=str,
239 default=None, 245 default=None,
@@ -395,7 +401,8 @@ class Checkpointer:
395 401
396 for i in range(self.sample_batches): 402 for i in range(self.sample_batches):
397 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 403 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
398 prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 404 prompt = [prompt.format(self.placeholder_token)
405 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
399 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 406 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
400 407
401 with self.accelerator.autocast(): 408 with self.accelerator.autocast():
@@ -556,25 +563,94 @@ def main():
556 eps=args.adam_epsilon, 563 eps=args.adam_epsilon,
557 ) 564 )
558 565
559 # TODO (patil-suraj): laod scheduler using args
560 noise_scheduler = DDPMScheduler( 566 noise_scheduler = DDPMScheduler(
561 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 567 beta_start=0.00085,
568 beta_end=0.012,
569 beta_schedule="scaled_linear",
570 num_train_timesteps=1000
562 ) 571 )
563 572
573 def collate_fn(examples):
574 prompts = [example["prompts"] for example in examples]
575 nprompts = [example["nprompts"] for example in examples]
576 input_ids = [example["instance_prompt_ids"] for example in examples]
577 pixel_values = [example["instance_images"] for example in examples]
578
579 # concat class and instance examples for prior preservation
580 if args.use_class_images and "class_prompt_ids" in examples[0]:
581 input_ids += [example["class_prompt_ids"] for example in examples]
582 pixel_values += [example["class_images"] for example in examples]
583
584 pixel_values = torch.stack(pixel_values)
585 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
586
587 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
588
589 batch = {
590 "prompts": prompts,
591 "nprompts": nprompts,
592 "input_ids": input_ids,
593 "pixel_values": pixel_values,
594 }
595 return batch
596
564 datamodule = CSVDataModule( 597 datamodule = CSVDataModule(
565 data_file=args.train_data_file, 598 data_file=args.train_data_file,
566 batch_size=args.train_batch_size, 599 batch_size=args.train_batch_size,
567 tokenizer=tokenizer, 600 tokenizer=tokenizer,
601 instance_identifier=args.placeholder_token,
602 class_identifier=args.initializer_token if args.use_class_images else None,
603 class_subdir="ti_cls",
568 size=args.resolution, 604 size=args.resolution,
569 placeholder_token=args.placeholder_token,
570 repeats=args.repeats, 605 repeats=args.repeats,
571 center_crop=args.center_crop, 606 center_crop=args.center_crop,
572 valid_set_size=args.sample_batch_size*args.sample_batches 607 valid_set_size=args.sample_batch_size*args.sample_batches,
608 collate_fn=collate_fn
573 ) 609 )
574 610
575 datamodule.prepare_data() 611 datamodule.prepare_data()
576 datamodule.setup() 612 datamodule.setup()
577 613
614 if args.use_class_images:
615 missing_data = [item for item in datamodule.data if not item[1].exists()]
616
617 if len(missing_data) != 0:
618 batched_data = [missing_data[i:i+args.sample_batch_size]
619 for i in range(0, len(missing_data), args.sample_batch_size)]
620
621 scheduler = EulerAScheduler(
622 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
623 )
624
625 pipeline = VlpnStableDiffusion(
626 text_encoder=text_encoder,
627 vae=vae,
628 unet=unet,
629 tokenizer=tokenizer,
630 scheduler=scheduler,
631 ).to(accelerator.device)
632 pipeline.enable_attention_slicing()
633
634 for batch in batched_data:
635 image_name = [p[1] for p in batch]
636 prompt = [p[2].format(args.initializer_token) for p in batch]
637 nprompt = [p[3] for p in batch]
638
639 with accelerator.autocast():
640 images = pipeline(
641 prompt=prompt,
642 negative_prompt=nprompt,
643 num_inference_steps=args.sample_steps
644 ).images
645
646 for i, image in enumerate(images):
647 image.save(image_name[i])
648
649 del pipeline
650
651 if torch.cuda.is_available():
652 torch.cuda.empty_cache()
653
578 train_dataloader = datamodule.train_dataloader() 654 train_dataloader = datamodule.train_dataloader()
579 val_dataloader = datamodule.val_dataloader() 655 val_dataloader = datamodule.val_dataloader()
580 656
@@ -693,7 +769,21 @@ def main():
693 # Predict the noise residual 769 # Predict the noise residual
694 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 770 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
695 771
696 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 772 if args.use_class_images:
773 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
774 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
775 noise, noise_prior = torch.chunk(noise, 2, dim=0)
776
777 # Compute instance loss
778 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
779
780 # Compute prior loss
781 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
782
783 # Add the prior loss to the instance loss.
784 loss = loss + args.prior_loss_weight * prior_loss
785 else:
786 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
697 787
698 accelerator.backward(loss) 788 accelerator.backward(loss)
699 789