summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-04 12:18:07 +0100
committerVolpeon <git@volpeon.ink>2023-01-04 12:18:07 +0100
commit01fee7d37a116265edb0f16e0b2f75d2116eb9f6 (patch)
tree6389f385191247fb3639900da0d29a3064259cb7
parentBetter eval generator (diff)
downloadtextual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.gz
textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.tar.bz2
textual-inversion-diff-01fee7d37a116265edb0f16e0b2f75d2116eb9f6.zip
Various updates
-rw-r--r--data/csv.py45
-rw-r--r--infer.py56
-rw-r--r--train_dreambooth.py8
-rw-r--r--train_ti.py8
-rw-r--r--training/optimization.py4
5 files changed, 87 insertions, 34 deletions
diff --git a/data/csv.py b/data/csv.py
index e901ab4..c505230 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -165,19 +165,27 @@ class CSVDataModule():
165 self.data_val = self.pad_items(data_val) 165 self.data_val = self.pad_items(data_val)
166 166
167 def setup(self, stage=None): 167 def setup(self, stage=None):
168 train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, 168 train_dataset = CSVDataset(
169 num_class_images=self.num_class_images, 169 self.data_train, self.prompt_processor, batch_size=self.batch_size,
170 size=self.size, interpolation=self.interpolation, 170 num_class_images=self.num_class_images,
171 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) 171 size=self.size, interpolation=self.interpolation,
172 val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, 172 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout
173 size=self.size, interpolation=self.interpolation, 173 )
174 center_crop=self.center_crop) 174 val_dataset = CSVDataset(
175 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 175 self.data_val, self.prompt_processor, batch_size=self.batch_size,
176 shuffle=True, pin_memory=True, collate_fn=self.collate_fn, 176 size=self.size, interpolation=self.interpolation,
177 num_workers=self.num_workers) 177 center_crop=self.center_crop
178 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, 178 )
179 pin_memory=True, collate_fn=self.collate_fn, 179 self.train_dataloader_ = DataLoader(
180 num_workers=self.num_workers) 180 train_dataset, batch_size=self.batch_size,
181 shuffle=True, pin_memory=True, collate_fn=self.collate_fn,
182 num_workers=self.num_workers
183 )
184 self.val_dataloader_ = DataLoader(
185 val_dataset, batch_size=self.batch_size,
186 pin_memory=True, collate_fn=self.collate_fn,
187 num_workers=self.num_workers
188 )
181 189
182 def train_dataloader(self): 190 def train_dataloader(self):
183 return self.train_dataloader_ 191 return self.train_dataloader_
@@ -210,11 +218,12 @@ class CSVDataset(Dataset):
210 self.num_instance_images = len(self.data) 218 self.num_instance_images = len(self.data)
211 self._length = self.num_instance_images * repeats 219 self._length = self.num_instance_images * repeats
212 220
213 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, 221 self.interpolation = {
214 "bilinear": transforms.InterpolationMode.BILINEAR, 222 "linear": transforms.InterpolationMode.NEAREST,
215 "bicubic": transforms.InterpolationMode.BICUBIC, 223 "bilinear": transforms.InterpolationMode.BILINEAR,
216 "lanczos": transforms.InterpolationMode.LANCZOS, 224 "bicubic": transforms.InterpolationMode.BICUBIC,
217 }[interpolation] 225 "lanczos": transforms.InterpolationMode.LANCZOS,
226 }[interpolation]
218 self.image_transforms = transforms.Compose( 227 self.image_transforms = transforms.Compose(
219 [ 228 [
220 transforms.Resize(size, interpolation=self.interpolation), 229 transforms.Resize(size, interpolation=self.interpolation),
diff --git a/infer.py b/infer.py
index f88245a..c4d1e0d 100644
--- a/infer.py
+++ b/infer.py
@@ -45,6 +45,7 @@ default_args = {
45 45
46 46
47default_cmds = { 47default_cmds = {
48 "project": "",
48 "scheduler": "dpmsm", 49 "scheduler": "dpmsm",
49 "prompt": None, 50 "prompt": None,
50 "negative_prompt": None, 51 "negative_prompt": None,
@@ -104,6 +105,12 @@ def create_cmd_parser():
104 description="Simple example of a training script." 105 description="Simple example of a training script."
105 ) 106 )
106 parser.add_argument( 107 parser.add_argument(
108 "--project",
109 type=str,
110 default=None,
111 help="The name of the current project.",
112 )
113 parser.add_argument(
107 "--scheduler", 114 "--scheduler",
108 type=str, 115 type=str,
109 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], 116 choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"],
@@ -184,7 +191,16 @@ def save_args(basepath, args, extra={}):
184 json.dump(info, f, indent=4) 191 json.dump(info, f, indent=4)
185 192
186 193
187def create_pipeline(model, embeddings_dir, dtype): 194def load_embeddings(pipeline, embeddings_dir):
195 added_tokens = load_embeddings_from_dir(
196 pipeline.tokenizer,
197 pipeline.text_encoder.text_model.embeddings,
198 Path(embeddings_dir)
199 )
200 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
201
202
203def create_pipeline(model, dtype):
188 print("Loading Stable Diffusion pipeline...") 204 print("Loading Stable Diffusion pipeline...")
189 205
190 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 206 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
@@ -193,10 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype):
193 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 209 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
194 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) 210 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
195 211
196 embeddings = patch_managed_embeddings(text_encoder) 212 patch_managed_embeddings(text_encoder)
197 added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir))
198
199 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
200 213
201 pipeline = VlpnStableDiffusion( 214 pipeline = VlpnStableDiffusion(
202 text_encoder=text_encoder, 215 text_encoder=text_encoder,
@@ -220,7 +233,14 @@ def generate(output_dir, pipeline, args):
220 args.prompt = [args.prompt] 233 args.prompt = [args.prompt]
221 234
222 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 235 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
223 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") 236 use_subdirs = len(args.prompt) != 1
237 if use_subdirs:
238 if len(args.project) != 0:
239 output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}")
240 else:
241 output_dir = output_dir.joinpath(now)
242 else:
243 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
224 output_dir.mkdir(parents=True, exist_ok=True) 244 output_dir.mkdir(parents=True, exist_ok=True)
225 245
226 args.seed = args.seed or torch.random.seed() 246 args.seed = args.seed or torch.random.seed()
@@ -257,7 +277,8 @@ def generate(output_dir, pipeline, args):
257 dynamic_ncols=True 277 dynamic_ncols=True
258 ) 278 )
259 279
260 generator = torch.Generator(device="cuda").manual_seed(args.seed + i) 280 seed = args.seed + i
281 generator = torch.Generator(device="cuda").manual_seed(seed)
261 images = pipeline( 282 images = pipeline(
262 prompt=args.prompt, 283 prompt=args.prompt,
263 negative_prompt=args.negative_prompt, 284 negative_prompt=args.negative_prompt,
@@ -272,8 +293,13 @@ def generate(output_dir, pipeline, args):
272 ).images 293 ).images
273 294
274 for j, image in enumerate(images): 295 for j, image in enumerate(images):
275 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) 296 image_dir = output_dir
276 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) 297 if use_subdirs:
298 idx = j % len(args.prompt)
299 image_dir = image_dir.joinpath(slugify(args.prompt[idx])[:100])
300 image_dir.mkdir(parents=True, exist_ok=True)
301 image.save(image_dir.joinpath(f"{seed}_{j}.png"))
302 image.save(image_dir.joinpath(f"{seed}_{j}.jpg"), quality=85)
277 303
278 if torch.cuda.is_available(): 304 if torch.cuda.is_available():
279 torch.cuda.empty_cache() 305 torch.cuda.empty_cache()
@@ -283,10 +309,11 @@ class CmdParse(cmd.Cmd):
283 prompt = 'dream> ' 309 prompt = 'dream> '
284 commands = [] 310 commands = []
285 311
286 def __init__(self, output_dir, pipeline, parser): 312 def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser):
287 super().__init__() 313 super().__init__()
288 314
289 self.output_dir = output_dir 315 self.output_dir = output_dir
316 self.ti_embeddings_dir = ti_embeddings_dir
290 self.pipeline = pipeline 317 self.pipeline = pipeline
291 self.parser = parser 318 self.parser = parser
292 319
@@ -302,6 +329,10 @@ class CmdParse(cmd.Cmd):
302 if elements[0] == 'q': 329 if elements[0] == 'q':
303 return True 330 return True
304 331
332 if elements[0] == 'reload_embeddings':
333 load_embeddings(self.pipeline, self.ti_embeddings_dir)
334 return
335
305 try: 336 try:
306 args = run_parser(self.parser, default_cmds, elements) 337 args = run_parser(self.parser, default_cmds, elements)
307 338
@@ -337,9 +368,10 @@ def main():
337 output_dir = Path(args.output_dir) 368 output_dir = Path(args.output_dir)
338 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 369 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
339 370
340 pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) 371 pipeline = create_pipeline(args.model, dtype)
372 load_embeddings(pipeline, args.ti_embeddings_dir)
341 cmd_parser = create_cmd_parser() 373 cmd_parser = create_cmd_parser()
342 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 374 cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser)
343 cmd_prompt.cmdloop() 375 cmd_prompt.cmdloop()
344 376
345 377
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 5e6e35d..2e0696b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -269,6 +269,12 @@ def parse_args():
269 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 269 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
270 ) 270 )
271 parser.add_argument( 271 parser.add_argument(
272 "--lr_min_lr",
273 type=float,
274 default=None,
275 help="Minimum learning rate in the lr scheduler."
276 )
277 parser.add_argument(
272 "--use_ema", 278 "--use_ema",
273 action="store_true", 279 action="store_true",
274 default=True, 280 default=True,
@@ -799,6 +805,7 @@ def main():
799 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps 805 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
800 806
801 if args.lr_scheduler == "one_cycle": 807 if args.lr_scheduler == "one_cycle":
808 lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate
802 lr_scheduler = get_one_cycle_schedule( 809 lr_scheduler = get_one_cycle_schedule(
803 optimizer=optimizer, 810 optimizer=optimizer,
804 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 811 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
@@ -806,6 +813,7 @@ def main():
806 annealing=args.lr_annealing_func, 813 annealing=args.lr_annealing_func,
807 warmup_exp=args.lr_warmup_exp, 814 warmup_exp=args.lr_warmup_exp,
808 annealing_exp=args.lr_annealing_exp, 815 annealing_exp=args.lr_annealing_exp,
816 min_lr=lr_min_lr,
809 ) 817 )
810 elif args.lr_scheduler == "cosine_with_restarts": 818 elif args.lr_scheduler == "cosine_with_restarts":
811 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 819 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
diff --git a/train_ti.py b/train_ti.py
index 6f116c3..1b60f64 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -260,6 +260,12 @@ def parse_args():
260 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 260 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
261 ) 261 )
262 parser.add_argument( 262 parser.add_argument(
263 "--lr_min_lr",
264 type=float,
265 default=None,
266 help="Minimum learning rate in the lr scheduler."
267 )
268 parser.add_argument(
263 "--use_8bit_adam", 269 "--use_8bit_adam",
264 action="store_true", 270 action="store_true",
265 help="Whether or not to use 8-bit Adam from bitsandbytes." 271 help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -744,6 +750,7 @@ def main():
744 if args.find_lr: 750 if args.find_lr:
745 lr_scheduler = None 751 lr_scheduler = None
746 elif args.lr_scheduler == "one_cycle": 752 elif args.lr_scheduler == "one_cycle":
753 lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate
747 lr_scheduler = get_one_cycle_schedule( 754 lr_scheduler = get_one_cycle_schedule(
748 optimizer=optimizer, 755 optimizer=optimizer,
749 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 756 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
@@ -751,6 +758,7 @@ def main():
751 annealing=args.lr_annealing_func, 758 annealing=args.lr_annealing_func,
752 warmup_exp=args.lr_warmup_exp, 759 warmup_exp=args.lr_warmup_exp,
753 annealing_exp=args.lr_annealing_exp, 760 annealing_exp=args.lr_annealing_exp,
761 min_lr=lr_min_lr,
754 ) 762 )
755 elif args.lr_scheduler == "cosine_with_restarts": 763 elif args.lr_scheduler == "cosine_with_restarts":
756 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 764 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
diff --git a/training/optimization.py b/training/optimization.py
index 14c2bd5..dd84f9c 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -5,10 +5,6 @@ from functools import partial
5import torch 5import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.utils import logging
9
10logger = logging.get_logger(__name__)
11
12 8
13class OneCyclePhase(NamedTuple): 9class OneCyclePhase(NamedTuple):
14 step_min: int 10 step_min: int