diff options
-rw-r--r-- | data/csv.py | 8 | ||||
-rw-r--r-- | infer.py | 14 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 69 | ||||
-rw-r--r-- | train_dreambooth.py | 10 | ||||
-rw-r--r-- | train_lora.py | 10 | ||||
-rw-r--r-- | train_ti.py | 19 | ||||
-rw-r--r-- | training/functional.py | 2 | ||||
-rw-r--r-- | training/strategy/lora.py | 2 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 | ||||
-rw-r--r-- | util.py | 2 |
10 files changed, 73 insertions, 65 deletions
diff --git a/data/csv.py b/data/csv.py index b4c81d7..c5902ed 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -42,7 +42,7 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]): | |||
42 | 42 | ||
43 | 43 | ||
44 | def generate_buckets( | 44 | def generate_buckets( |
45 | items: list[str], | 45 | items: Union[list[str], list[Path]], |
46 | base_size: int, | 46 | base_size: int, |
47 | step_size: int = 64, | 47 | step_size: int = 64, |
48 | max_pixels: Optional[int] = None, | 48 | max_pixels: Optional[int] = None, |
@@ -188,7 +188,7 @@ class VlpnDataModule(): | |||
188 | raise ValueError("data_file must be a file") | 188 | raise ValueError("data_file must be a file") |
189 | 189 | ||
190 | self.data_root = self.data_file.parent | 190 | self.data_root = self.data_file.parent |
191 | self.class_root = self.data_root.joinpath(class_subdir) | 191 | self.class_root = self.data_root / class_subdir |
192 | self.class_root.mkdir(parents=True, exist_ok=True) | 192 | self.class_root.mkdir(parents=True, exist_ok=True) |
193 | self.num_class_images = num_class_images | 193 | self.num_class_images = num_class_images |
194 | 194 | ||
@@ -218,7 +218,7 @@ class VlpnDataModule(): | |||
218 | 218 | ||
219 | return [ | 219 | return [ |
220 | VlpnDataItem( | 220 | VlpnDataItem( |
221 | self.data_root.joinpath(image.format(item["image"])), | 221 | self.data_root / image.format(item["image"]), |
222 | None, | 222 | None, |
223 | prompt_to_keywords( | 223 | prompt_to_keywords( |
224 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), | 224 | prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), |
@@ -249,7 +249,7 @@ class VlpnDataModule(): | |||
249 | return [ | 249 | return [ |
250 | VlpnDataItem( | 250 | VlpnDataItem( |
251 | item.instance_image_path, | 251 | item.instance_image_path, |
252 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 252 | self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", |
253 | item.prompt, | 253 | item.prompt, |
254 | item.cprompt, | 254 | item.cprompt, |
255 | item.nprompt, | 255 | item.nprompt, |
@@ -264,16 +264,16 @@ def generate(output_dir: Path, pipeline, args): | |||
264 | 264 | ||
265 | if len(args.prompt) != 1: | 265 | if len(args.prompt) != 1: |
266 | if len(args.project) != 0: | 266 | if len(args.project) != 0: |
267 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") | 267 | output_dir = output_dir / f"{now}_{slugify(args.project)}" |
268 | else: | 268 | else: |
269 | output_dir = output_dir.joinpath(now) | 269 | output_dir = output_dir / now |
270 | 270 | ||
271 | for prompt in args.prompt: | 271 | for prompt in args.prompt: |
272 | dir = output_dir.joinpath(slugify(prompt)[:100]) | 272 | dir = output_dir / slugify(prompt)[:100] |
273 | dir.mkdir(parents=True, exist_ok=True) | 273 | dir.mkdir(parents=True, exist_ok=True) |
274 | image_dir.append(dir) | 274 | image_dir.append(dir) |
275 | else: | 275 | else: |
276 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 276 | output_dir = output_dir / f"{now}_{slugify(args.prompt[0])[:100]}" |
277 | output_dir.mkdir(parents=True, exist_ok=True) | 277 | output_dir.mkdir(parents=True, exist_ok=True) |
278 | image_dir.append(output_dir) | 278 | image_dir.append(output_dir) |
279 | 279 | ||
@@ -332,9 +332,9 @@ def generate(output_dir: Path, pipeline, args): | |||
332 | basename = f"{seed}_{j // len(args.prompt)}" | 332 | basename = f"{seed}_{j // len(args.prompt)}" |
333 | dir = image_dir[j % len(args.prompt)] | 333 | dir = image_dir[j % len(args.prompt)] |
334 | 334 | ||
335 | image.save(dir.joinpath(f"{basename}.png")) | 335 | image.save(dir / f"{basename}.png") |
336 | image.save(dir.joinpath(f"{basename}.jpg"), quality=85) | 336 | image.save(dir / f"{basename}.jpg", quality=85) |
337 | with open(dir.joinpath(f"{basename}.txt"), 'w') as f: | 337 | with open(dir / f"{basename}.txt", 'w') as f: |
338 | f.write(prompt[j % len(args.prompt)]) | 338 | f.write(prompt[j % len(args.prompt)]) |
339 | 339 | ||
340 | if torch.cuda.is_available(): | 340 | if torch.cuda.is_available(): |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 66566b0..cb09fe1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional, Union, Callable | |||
5 | 5 | ||
6 | import numpy as np | 6 | import numpy as np |
7 | import torch | 7 | import torch |
8 | import torchvision.transforms as T | 8 | import torch.nn.functional as F |
9 | import PIL | 9 | import PIL |
10 | 10 | ||
11 | from diffusers.configuration_utils import FrozenDict | 11 | from diffusers.configuration_utils import FrozenDict |
@@ -39,6 +39,27 @@ def preprocess(image): | |||
39 | return 2.0 * image - 1.0 | 39 | return 2.0 * image - 1.0 |
40 | 40 | ||
41 | 41 | ||
42 | def gaussian_blur_2d(img, kernel_size, sigma): | ||
43 | ksize_half = (kernel_size - 1) * 0.5 | ||
44 | |||
45 | x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) | ||
46 | |||
47 | pdf = torch.exp(-0.5 * (x / sigma).pow(2)) | ||
48 | |||
49 | x_kernel = pdf / pdf.sum() | ||
50 | x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) | ||
51 | |||
52 | kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) | ||
53 | kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) | ||
54 | |||
55 | padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] | ||
56 | |||
57 | img = F.pad(img, padding, mode="reflect") | ||
58 | img = F.conv2d(img, kernel2d, groups=img.shape[-3]) | ||
59 | |||
60 | return img | ||
61 | |||
62 | |||
42 | class CrossAttnStoreProcessor: | 63 | class CrossAttnStoreProcessor: |
43 | def __init__(self): | 64 | def __init__(self): |
44 | self.attention_probs = None | 65 | self.attention_probs = None |
@@ -46,13 +67,17 @@ class CrossAttnStoreProcessor: | |||
46 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): | 67 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): |
47 | batch_size, sequence_length, _ = hidden_states.shape | 68 | batch_size, sequence_length, _ = hidden_states.shape |
48 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | 69 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
49 | |||
50 | query = attn.to_q(hidden_states) | 70 | query = attn.to_q(hidden_states) |
51 | query = attn.head_to_batch_dim(query) | ||
52 | 71 | ||
53 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | 72 | if encoder_hidden_states is None: |
73 | encoder_hidden_states = hidden_states | ||
74 | elif attn.cross_attention_norm: | ||
75 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | ||
76 | |||
54 | key = attn.to_k(encoder_hidden_states) | 77 | key = attn.to_k(encoder_hidden_states) |
55 | value = attn.to_v(encoder_hidden_states) | 78 | value = attn.to_v(encoder_hidden_states) |
79 | |||
80 | query = attn.head_to_batch_dim(query) | ||
56 | key = attn.head_to_batch_dim(key) | 81 | key = attn.head_to_batch_dim(key) |
57 | value = attn.head_to_batch_dim(value) | 82 | value = attn.head_to_batch_dim(value) |
58 | 83 | ||
@@ -510,12 +535,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
510 | # in https://arxiv.org/pdf/2210.00939.pdf | 535 | # in https://arxiv.org/pdf/2210.00939.pdf |
511 | if do_classifier_free_guidance: | 536 | if do_classifier_free_guidance: |
512 | # DDIM-like prediction of x0 | 537 | # DDIM-like prediction of x0 |
513 | pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t) | 538 | pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) |
514 | # get the stored attention maps | 539 | # get the stored attention maps |
515 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) | 540 | uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) |
516 | # self-attention-based degrading of latents | 541 | # self-attention-based degrading of latents |
517 | degraded_latents = self.sag_masking( | 542 | degraded_latents = self.sag_masking( |
518 | pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t) | 543 | pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) |
519 | ) | 544 | ) |
520 | uncond_emb, _ = prompt_embeds.chunk(2) | 545 | uncond_emb, _ = prompt_embeds.chunk(2) |
521 | # forward and give guidance | 546 | # forward and give guidance |
@@ -523,12 +548,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
523 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) | 548 | noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) |
524 | else: | 549 | else: |
525 | # DDIM-like prediction of x0 | 550 | # DDIM-like prediction of x0 |
526 | pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t) | 551 | pred_x0 = self.pred_x0(latents, noise_pred, t) |
527 | # get the stored attention maps | 552 | # get the stored attention maps |
528 | cond_attn = store_processor.attention_probs | 553 | cond_attn = store_processor.attention_probs |
529 | # self-attention-based degrading of latents | 554 | # self-attention-based degrading of latents |
530 | degraded_latents = self.sag_masking( | 555 | degraded_latents = self.sag_masking( |
531 | pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t) | 556 | pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) |
532 | ) | 557 | ) |
533 | # forward and give guidance | 558 | # forward and give guidance |
534 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample | 559 | degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample |
@@ -578,8 +603,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
578 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) | 603 | attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) |
579 | 604 | ||
580 | # Blur according to the self-attention mask | 605 | # Blur according to the self-attention mask |
581 | transform = T.GaussianBlur(kernel_size=9, sigma=1.0) | 606 | degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) |
582 | degraded_latents = transform(original_latents) | ||
583 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) | 607 | degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) |
584 | 608 | ||
585 | # Noise it again to match the noise level | 609 | # Noise it again to match the noise level |
@@ -588,19 +612,11 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
588 | return degraded_latents | 612 | return degraded_latents |
589 | 613 | ||
590 | # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step | 614 | # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step |
591 | def pred_x0_from_eps(self, sample, model_output, timestep): | 615 | # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) |
592 | # 1. get previous step value (=t-1) | 616 | def pred_x0(self, sample, model_output, timestep): |
593 | # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | ||
594 | |||
595 | # 2. compute alphas, betas | ||
596 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | 617 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] |
597 | # alpha_prod_t_prev = ( | ||
598 | # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | ||
599 | # ) | ||
600 | 618 | ||
601 | beta_prod_t = 1 - alpha_prod_t | 619 | beta_prod_t = 1 - alpha_prod_t |
602 | # 3. compute predicted original sample from predicted noise also called | ||
603 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||
604 | if self.scheduler.config.prediction_type == "epsilon": | 620 | if self.scheduler.config.prediction_type == "epsilon": |
605 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | 621 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
606 | elif self.scheduler.config.prediction_type == "sample": | 622 | elif self.scheduler.config.prediction_type == "sample": |
@@ -614,24 +630,13 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
614 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," | 630 | f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," |
615 | " or `v_prediction`" | 631 | " or `v_prediction`" |
616 | ) | 632 | ) |
617 | # # 4. Clip "predicted x_0" | ||
618 | # if self.scheduler.config.clip_sample: | ||
619 | # pred_original_sample = torch.clamp(pred_original_sample, -1, 1) | ||
620 | 633 | ||
621 | return pred_original_sample | 634 | return pred_original_sample |
622 | 635 | ||
623 | def pred_eps_from_noise(self, sample, model_output, timestep): | 636 | def pred_epsilon(self, sample, model_output, timestep): |
624 | # 1. get previous step value (=t-1) | ||
625 | # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps | ||
626 | |||
627 | # 2. compute alphas, betas | ||
628 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | 637 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] |
629 | # alpha_prod_t_prev = ( | ||
630 | # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod | ||
631 | # ) | ||
632 | 638 | ||
633 | beta_prod_t = 1 - alpha_prod_t | 639 | beta_prod_t = 1 - alpha_prod_t |
634 | # 3. compute predicted eps from model output | ||
635 | if self.scheduler.config.prediction_type == "epsilon": | 640 | if self.scheduler.config.prediction_type == "epsilon": |
636 | pred_eps = model_output | 641 | pred_eps = model_output |
637 | elif self.scheduler.config.prediction_type == "sample": | 642 | elif self.scheduler.config.prediction_type == "sample": |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 8ac70e8..4c1ec31 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -432,7 +432,7 @@ def main(): | |||
432 | args = parse_args() | 432 | args = parse_args() |
433 | 433 | ||
434 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 434 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
435 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 435 | output_dir = Path(args.output_dir) / slugify(args.project) / now |
436 | output_dir.mkdir(parents=True, exist_ok=True) | 436 | output_dir.mkdir(parents=True, exist_ok=True) |
437 | 437 | ||
438 | accelerator = Accelerator( | 438 | accelerator = Accelerator( |
@@ -448,7 +448,7 @@ def main(): | |||
448 | elif args.mixed_precision == "bf16": | 448 | elif args.mixed_precision == "bf16": |
449 | weight_dtype = torch.bfloat16 | 449 | weight_dtype = torch.bfloat16 |
450 | 450 | ||
451 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 451 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
452 | 452 | ||
453 | if args.seed is None: | 453 | if args.seed is None: |
454 | args.seed = torch.random.seed() >> 32 | 454 | args.seed = torch.random.seed() >> 32 |
@@ -513,8 +513,8 @@ def main(): | |||
513 | prior_loss_weight=args.prior_loss_weight, | 513 | prior_loss_weight=args.prior_loss_weight, |
514 | ) | 514 | ) |
515 | 515 | ||
516 | checkpoint_output_dir = output_dir.joinpath("model") | 516 | checkpoint_output_dir = output_dir / "model" |
517 | sample_output_dir = output_dir.joinpath(f"samples") | 517 | sample_output_dir = output_dir / "samples" |
518 | 518 | ||
519 | datamodule = VlpnDataModule( | 519 | datamodule = VlpnDataModule( |
520 | data_file=args.train_data_file, | 520 | data_file=args.train_data_file, |
@@ -596,7 +596,7 @@ def main(): | |||
596 | sample_image_size=args.sample_image_size, | 596 | sample_image_size=args.sample_image_size, |
597 | ) | 597 | ) |
598 | 598 | ||
599 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 599 | plot_metrics(metrics, output_dir / "lr.png") |
600 | 600 | ||
601 | 601 | ||
602 | if __name__ == "__main__": | 602 | if __name__ == "__main__": |
diff --git a/train_lora.py b/train_lora.py index 5fd05cc..a8c1cf6 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -392,7 +392,7 @@ def main(): | |||
392 | args = parse_args() | 392 | args = parse_args() |
393 | 393 | ||
394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
395 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 395 | output_dir = Path(args.output_dir) / slugify(args.project) / now |
396 | output_dir.mkdir(parents=True, exist_ok=True) | 396 | output_dir.mkdir(parents=True, exist_ok=True) |
397 | 397 | ||
398 | accelerator = Accelerator( | 398 | accelerator = Accelerator( |
@@ -408,7 +408,7 @@ def main(): | |||
408 | elif args.mixed_precision == "bf16": | 408 | elif args.mixed_precision == "bf16": |
409 | weight_dtype = torch.bfloat16 | 409 | weight_dtype = torch.bfloat16 |
410 | 410 | ||
411 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 411 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
412 | 412 | ||
413 | if args.seed is None: | 413 | if args.seed is None: |
414 | args.seed = torch.random.seed() >> 32 | 414 | args.seed = torch.random.seed() >> 32 |
@@ -489,8 +489,8 @@ def main(): | |||
489 | prior_loss_weight=args.prior_loss_weight, | 489 | prior_loss_weight=args.prior_loss_weight, |
490 | ) | 490 | ) |
491 | 491 | ||
492 | checkpoint_output_dir = output_dir.joinpath("model") | 492 | checkpoint_output_dir = output_dir / "model" |
493 | sample_output_dir = output_dir.joinpath(f"samples") | 493 | sample_output_dir = output_dir/"samples" |
494 | 494 | ||
495 | datamodule = VlpnDataModule( | 495 | datamodule = VlpnDataModule( |
496 | data_file=args.train_data_file, | 496 | data_file=args.train_data_file, |
@@ -562,7 +562,7 @@ def main(): | |||
562 | sample_image_size=args.sample_image_size, | 562 | sample_image_size=args.sample_image_size, |
563 | ) | 563 | ) |
564 | 564 | ||
565 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 565 | plot_metrics(metrics, output_dir/"lr.png") |
566 | 566 | ||
567 | 567 | ||
568 | if __name__ == "__main__": | 568 | if __name__ == "__main__": |
diff --git a/train_ti.py b/train_ti.py index c79dfa2..171d085 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -143,7 +143,7 @@ def parse_args(): | |||
143 | parser.add_argument( | 143 | parser.add_argument( |
144 | "--num_buckets", | 144 | "--num_buckets", |
145 | type=int, | 145 | type=int, |
146 | default=4, | 146 | default=0, |
147 | help="Number of aspect ratio buckets in either direction.", | 147 | help="Number of aspect ratio buckets in either direction.", |
148 | ) | 148 | ) |
149 | parser.add_argument( | 149 | parser.add_argument( |
@@ -485,6 +485,9 @@ def parse_args(): | |||
485 | 485 | ||
486 | if len(args.placeholder_tokens) != len(args.train_data_template): | 486 | if len(args.placeholder_tokens) != len(args.train_data_template): |
487 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 487 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") |
488 | else: | ||
489 | if isinstance(args.train_data_template, list): | ||
490 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | ||
488 | 491 | ||
489 | if isinstance(args.collection, str): | 492 | if isinstance(args.collection, str): |
490 | args.collection = [args.collection] | 493 | args.collection = [args.collection] |
@@ -503,7 +506,7 @@ def main(): | |||
503 | 506 | ||
504 | global_step_offset = args.global_step | 507 | global_step_offset = args.global_step |
505 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 508 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
506 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | 509 | output_dir = Path(args.output_dir)/slugify(args.project)/now |
507 | output_dir.mkdir(parents=True, exist_ok=True) | 510 | output_dir.mkdir(parents=True, exist_ok=True) |
508 | 511 | ||
509 | accelerator = Accelerator( | 512 | accelerator = Accelerator( |
@@ -519,7 +522,7 @@ def main(): | |||
519 | elif args.mixed_precision == "bf16": | 522 | elif args.mixed_precision == "bf16": |
520 | weight_dtype = torch.bfloat16 | 523 | weight_dtype = torch.bfloat16 |
521 | 524 | ||
522 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | 525 | logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG) |
523 | 526 | ||
524 | if args.seed is None: | 527 | if args.seed is None: |
525 | args.seed = torch.random.seed() >> 32 | 528 | args.seed = torch.random.seed() >> 32 |
@@ -570,7 +573,7 @@ def main(): | |||
570 | else: | 573 | else: |
571 | optimizer_class = torch.optim.AdamW | 574 | optimizer_class = torch.optim.AdamW |
572 | 575 | ||
573 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | 576 | checkpoint_output_dir = output_dir/"checkpoints" |
574 | 577 | ||
575 | trainer = partial( | 578 | trainer = partial( |
576 | train, | 579 | train, |
@@ -611,11 +614,11 @@ def main(): | |||
611 | return | 614 | return |
612 | 615 | ||
613 | if len(placeholder_tokens) == 1: | 616 | if len(placeholder_tokens) == 1: |
614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 617 | sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}" |
615 | metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") | 618 | metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png" |
616 | else: | 619 | else: |
617 | sample_output_dir = output_dir.joinpath("samples") | 620 | sample_output_dir = output_dir/"samples" |
618 | metrics_output_file = output_dir.joinpath(f"lr.png") | 621 | metrics_output_file = output_dir/f"lr.png" |
619 | 622 | ||
620 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 623 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
621 | tokenizer=tokenizer, | 624 | tokenizer=tokenizer, |
diff --git a/training/functional.py b/training/functional.py index ccbb4ad..83e70e2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -129,7 +129,7 @@ def save_samples( | |||
129 | 129 | ||
130 | for pool, data, gen in datasets: | 130 | for pool, data, gen in datasets: |
131 | all_samples = [] | 131 | all_samples = [] |
132 | file_path = output_dir.joinpath(pool, f"step_{step}.jpg") | 132 | file_path = output_dir / pool / f"step_{step}.jpg" |
133 | file_path.parent.mkdir(parents=True, exist_ok=True) | 133 | file_path.parent.mkdir(parents=True, exist_ok=True) |
134 | 134 | ||
135 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 135 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index bc10e58..4dd1100 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -91,7 +91,7 @@ def lora_strategy_callbacks( | |||
91 | print(f"Saving checkpoint for step {step}...") | 91 | print(f"Saving checkpoint for step {step}...") |
92 | 92 | ||
93 | unet_ = accelerator.unwrap_model(unet) | 93 | unet_ = accelerator.unwrap_model(unet) |
94 | unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) | 94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") |
95 | del unet_ | 95 | del unet_ |
96 | 96 | ||
97 | @torch.no_grad() | 97 | @torch.no_grad() |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index da2b81c..0de3cb0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -138,7 +138,7 @@ def textual_inversion_strategy_callbacks( | |||
138 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 138 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
139 | text_encoder.text_model.embeddings.save_embed( | 139 | text_encoder.text_model.embeddings.save_embed( |
140 | ids, | 140 | ids, |
141 | checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 141 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
142 | ) | 142 | ) |
143 | 143 | ||
144 | @torch.no_grad() | 144 | @torch.no_grad() |
@@ -14,7 +14,7 @@ def load_config(filename): | |||
14 | args = config["args"] | 14 | args = config["args"] |
15 | 15 | ||
16 | if "base" in config: | 16 | if "base" in config: |
17 | args = load_config(Path(filename).parent.joinpath(config["base"])) | args | 17 | args = load_config(Path(filename).parent / config["base"]) | args |
18 | 18 | ||
19 | return args | 19 | return args |
20 | 20 | ||