summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-13 17:19:18 +0100
committerVolpeon <git@volpeon.ink>2023-02-13 17:19:18 +0100
commit94b676d91382267e7429bd68362019868affd9d1 (patch)
tree513697739ab25217cbfcff630299d02b1f6e98c8
parentIntegrate Self-Attention-Guided (SAG) Stable Diffusion in my custom pipeline (diff)
downloadtextual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.gz
textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.tar.bz2
textual-inversion-diff-94b676d91382267e7429bd68362019868affd9d1.zip
Update
-rw-r--r--data/csv.py8
-rw-r--r--infer.py14
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py69
-rw-r--r--train_dreambooth.py10
-rw-r--r--train_lora.py10
-rw-r--r--train_ti.py19
-rw-r--r--training/functional.py2
-rw-r--r--training/strategy/lora.py2
-rw-r--r--training/strategy/ti.py2
-rw-r--r--util.py2
10 files changed, 73 insertions, 65 deletions
diff --git a/data/csv.py b/data/csv.py
index b4c81d7..c5902ed 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -42,7 +42,7 @@ def prepare_prompt(prompt: Union[str, dict[str, str]]):
42 42
43 43
44def generate_buckets( 44def generate_buckets(
45 items: list[str], 45 items: Union[list[str], list[Path]],
46 base_size: int, 46 base_size: int,
47 step_size: int = 64, 47 step_size: int = 64,
48 max_pixels: Optional[int] = None, 48 max_pixels: Optional[int] = None,
@@ -188,7 +188,7 @@ class VlpnDataModule():
188 raise ValueError("data_file must be a file") 188 raise ValueError("data_file must be a file")
189 189
190 self.data_root = self.data_file.parent 190 self.data_root = self.data_file.parent
191 self.class_root = self.data_root.joinpath(class_subdir) 191 self.class_root = self.data_root / class_subdir
192 self.class_root.mkdir(parents=True, exist_ok=True) 192 self.class_root.mkdir(parents=True, exist_ok=True)
193 self.num_class_images = num_class_images 193 self.num_class_images = num_class_images
194 194
@@ -218,7 +218,7 @@ class VlpnDataModule():
218 218
219 return [ 219 return [
220 VlpnDataItem( 220 VlpnDataItem(
221 self.data_root.joinpath(image.format(item["image"])), 221 self.data_root / image.format(item["image"]),
222 None, 222 None,
223 prompt_to_keywords( 223 prompt_to_keywords(
224 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")), 224 prompt.format(**prepare_prompt(item["prompt"] if "prompt" in item else "")),
@@ -249,7 +249,7 @@ class VlpnDataModule():
249 return [ 249 return [
250 VlpnDataItem( 250 VlpnDataItem(
251 item.instance_image_path, 251 item.instance_image_path,
252 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), 252 self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}",
253 item.prompt, 253 item.prompt,
254 item.cprompt, 254 item.cprompt,
255 item.nprompt, 255 item.nprompt,
diff --git a/infer.py b/infer.py
index 42b4e2d..aa75ee5 100644
--- a/infer.py
+++ b/infer.py
@@ -264,16 +264,16 @@ def generate(output_dir: Path, pipeline, args):
264 264
265 if len(args.prompt) != 1: 265 if len(args.prompt) != 1:
266 if len(args.project) != 0: 266 if len(args.project) != 0:
267 output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") 267 output_dir = output_dir / f"{now}_{slugify(args.project)}"
268 else: 268 else:
269 output_dir = output_dir.joinpath(now) 269 output_dir = output_dir / now
270 270
271 for prompt in args.prompt: 271 for prompt in args.prompt:
272 dir = output_dir.joinpath(slugify(prompt)[:100]) 272 dir = output_dir / slugify(prompt)[:100]
273 dir.mkdir(parents=True, exist_ok=True) 273 dir.mkdir(parents=True, exist_ok=True)
274 image_dir.append(dir) 274 image_dir.append(dir)
275 else: 275 else:
276 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") 276 output_dir = output_dir / f"{now}_{slugify(args.prompt[0])[:100]}"
277 output_dir.mkdir(parents=True, exist_ok=True) 277 output_dir.mkdir(parents=True, exist_ok=True)
278 image_dir.append(output_dir) 278 image_dir.append(output_dir)
279 279
@@ -332,9 +332,9 @@ def generate(output_dir: Path, pipeline, args):
332 basename = f"{seed}_{j // len(args.prompt)}" 332 basename = f"{seed}_{j // len(args.prompt)}"
333 dir = image_dir[j % len(args.prompt)] 333 dir = image_dir[j % len(args.prompt)]
334 334
335 image.save(dir.joinpath(f"{basename}.png")) 335 image.save(dir / f"{basename}.png")
336 image.save(dir.joinpath(f"{basename}.jpg"), quality=85) 336 image.save(dir / f"{basename}.jpg", quality=85)
337 with open(dir.joinpath(f"{basename}.txt"), 'w') as f: 337 with open(dir / f"{basename}.txt", 'w') as f:
338 f.write(prompt[j % len(args.prompt)]) 338 f.write(prompt[j % len(args.prompt)])
339 339
340 if torch.cuda.is_available(): 340 if torch.cuda.is_available():
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 66566b0..cb09fe1 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional, Union, Callable
5 5
6import numpy as np 6import numpy as np
7import torch 7import torch
8import torchvision.transforms as T 8import torch.nn.functional as F
9import PIL 9import PIL
10 10
11from diffusers.configuration_utils import FrozenDict 11from diffusers.configuration_utils import FrozenDict
@@ -39,6 +39,27 @@ def preprocess(image):
39 return 2.0 * image - 1.0 39 return 2.0 * image - 1.0
40 40
41 41
42def gaussian_blur_2d(img, kernel_size, sigma):
43 ksize_half = (kernel_size - 1) * 0.5
44
45 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
46
47 pdf = torch.exp(-0.5 * (x / sigma).pow(2))
48
49 x_kernel = pdf / pdf.sum()
50 x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
51
52 kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
53 kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
54
55 padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
56
57 img = F.pad(img, padding, mode="reflect")
58 img = F.conv2d(img, kernel2d, groups=img.shape[-3])
59
60 return img
61
62
42class CrossAttnStoreProcessor: 63class CrossAttnStoreProcessor:
43 def __init__(self): 64 def __init__(self):
44 self.attention_probs = None 65 self.attention_probs = None
@@ -46,13 +67,17 @@ class CrossAttnStoreProcessor:
46 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): 67 def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
47 batch_size, sequence_length, _ = hidden_states.shape 68 batch_size, sequence_length, _ = hidden_states.shape
48 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 69 attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
49
50 query = attn.to_q(hidden_states) 70 query = attn.to_q(hidden_states)
51 query = attn.head_to_batch_dim(query)
52 71
53 encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 72 if encoder_hidden_states is None:
73 encoder_hidden_states = hidden_states
74 elif attn.cross_attention_norm:
75 encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
76
54 key = attn.to_k(encoder_hidden_states) 77 key = attn.to_k(encoder_hidden_states)
55 value = attn.to_v(encoder_hidden_states) 78 value = attn.to_v(encoder_hidden_states)
79
80 query = attn.head_to_batch_dim(query)
56 key = attn.head_to_batch_dim(key) 81 key = attn.head_to_batch_dim(key)
57 value = attn.head_to_batch_dim(value) 82 value = attn.head_to_batch_dim(value)
58 83
@@ -510,12 +535,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
510 # in https://arxiv.org/pdf/2210.00939.pdf 535 # in https://arxiv.org/pdf/2210.00939.pdf
511 if do_classifier_free_guidance: 536 if do_classifier_free_guidance:
512 # DDIM-like prediction of x0 537 # DDIM-like prediction of x0
513 pred_x0 = self.pred_x0_from_eps(latents, noise_pred_uncond, t) 538 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
514 # get the stored attention maps 539 # get the stored attention maps
515 uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) 540 uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
516 # self-attention-based degrading of latents 541 # self-attention-based degrading of latents
517 degraded_latents = self.sag_masking( 542 degraded_latents = self.sag_masking(
518 pred_x0, uncond_attn, t, self.pred_eps_from_noise(latents, noise_pred_uncond, t) 543 pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t)
519 ) 544 )
520 uncond_emb, _ = prompt_embeds.chunk(2) 545 uncond_emb, _ = prompt_embeds.chunk(2)
521 # forward and give guidance 546 # forward and give guidance
@@ -523,12 +548,12 @@ class VlpnStableDiffusion(DiffusionPipeline):
523 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) 548 noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
524 else: 549 else:
525 # DDIM-like prediction of x0 550 # DDIM-like prediction of x0
526 pred_x0 = self.pred_x0_from_eps(latents, noise_pred, t) 551 pred_x0 = self.pred_x0(latents, noise_pred, t)
527 # get the stored attention maps 552 # get the stored attention maps
528 cond_attn = store_processor.attention_probs 553 cond_attn = store_processor.attention_probs
529 # self-attention-based degrading of latents 554 # self-attention-based degrading of latents
530 degraded_latents = self.sag_masking( 555 degraded_latents = self.sag_masking(
531 pred_x0, cond_attn, t, self.pred_eps_from_noise(latents, noise_pred, t) 556 pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t)
532 ) 557 )
533 # forward and give guidance 558 # forward and give guidance
534 degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample 559 degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
@@ -578,8 +603,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
578 attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) 603 attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w))
579 604
580 # Blur according to the self-attention mask 605 # Blur according to the self-attention mask
581 transform = T.GaussianBlur(kernel_size=9, sigma=1.0) 606 degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0)
582 degraded_latents = transform(original_latents)
583 degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) 607 degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask)
584 608
585 # Noise it again to match the noise level 609 # Noise it again to match the noise level
@@ -588,19 +612,11 @@ class VlpnStableDiffusion(DiffusionPipeline):
588 return degraded_latents 612 return degraded_latents
589 613
590 # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step 614 # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step
591 def pred_x0_from_eps(self, sample, model_output, timestep): 615 # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.)
592 # 1. get previous step value (=t-1) 616 def pred_x0(self, sample, model_output, timestep):
593 # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
594
595 # 2. compute alphas, betas
596 alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 617 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
597 # alpha_prod_t_prev = (
598 # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
599 # )
600 618
601 beta_prod_t = 1 - alpha_prod_t 619 beta_prod_t = 1 - alpha_prod_t
602 # 3. compute predicted original sample from predicted noise also called
603 # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
604 if self.scheduler.config.prediction_type == "epsilon": 620 if self.scheduler.config.prediction_type == "epsilon":
605 pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 621 pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
606 elif self.scheduler.config.prediction_type == "sample": 622 elif self.scheduler.config.prediction_type == "sample":
@@ -614,24 +630,13 @@ class VlpnStableDiffusion(DiffusionPipeline):
614 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," 630 f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`,"
615 " or `v_prediction`" 631 " or `v_prediction`"
616 ) 632 )
617 # # 4. Clip "predicted x_0"
618 # if self.scheduler.config.clip_sample:
619 # pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
620 633
621 return pred_original_sample 634 return pred_original_sample
622 635
623 def pred_eps_from_noise(self, sample, model_output, timestep): 636 def pred_epsilon(self, sample, model_output, timestep):
624 # 1. get previous step value (=t-1)
625 # prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
626
627 # 2. compute alphas, betas
628 alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 637 alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
629 # alpha_prod_t_prev = (
630 # self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
631 # )
632 638
633 beta_prod_t = 1 - alpha_prod_t 639 beta_prod_t = 1 - alpha_prod_t
634 # 3. compute predicted eps from model output
635 if self.scheduler.config.prediction_type == "epsilon": 640 if self.scheduler.config.prediction_type == "epsilon":
636 pred_eps = model_output 641 pred_eps = model_output
637 elif self.scheduler.config.prediction_type == "sample": 642 elif self.scheduler.config.prediction_type == "sample":
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 8ac70e8..4c1ec31 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -432,7 +432,7 @@ def main():
432 args = parse_args() 432 args = parse_args()
433 433
434 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 434 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
435 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) 435 output_dir = Path(args.output_dir) / slugify(args.project) / now
436 output_dir.mkdir(parents=True, exist_ok=True) 436 output_dir.mkdir(parents=True, exist_ok=True)
437 437
438 accelerator = Accelerator( 438 accelerator = Accelerator(
@@ -448,7 +448,7 @@ def main():
448 elif args.mixed_precision == "bf16": 448 elif args.mixed_precision == "bf16":
449 weight_dtype = torch.bfloat16 449 weight_dtype = torch.bfloat16
450 450
451 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 451 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
452 452
453 if args.seed is None: 453 if args.seed is None:
454 args.seed = torch.random.seed() >> 32 454 args.seed = torch.random.seed() >> 32
@@ -513,8 +513,8 @@ def main():
513 prior_loss_weight=args.prior_loss_weight, 513 prior_loss_weight=args.prior_loss_weight,
514 ) 514 )
515 515
516 checkpoint_output_dir = output_dir.joinpath("model") 516 checkpoint_output_dir = output_dir / "model"
517 sample_output_dir = output_dir.joinpath(f"samples") 517 sample_output_dir = output_dir / "samples"
518 518
519 datamodule = VlpnDataModule( 519 datamodule = VlpnDataModule(
520 data_file=args.train_data_file, 520 data_file=args.train_data_file,
@@ -596,7 +596,7 @@ def main():
596 sample_image_size=args.sample_image_size, 596 sample_image_size=args.sample_image_size,
597 ) 597 )
598 598
599 plot_metrics(metrics, output_dir.joinpath("lr.png")) 599 plot_metrics(metrics, output_dir / "lr.png")
600 600
601 601
602if __name__ == "__main__": 602if __name__ == "__main__":
diff --git a/train_lora.py b/train_lora.py
index 5fd05cc..a8c1cf6 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -392,7 +392,7 @@ def main():
392 args = parse_args() 392 args = parse_args()
393 393
394 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 394 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
395 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) 395 output_dir = Path(args.output_dir) / slugify(args.project) / now
396 output_dir.mkdir(parents=True, exist_ok=True) 396 output_dir.mkdir(parents=True, exist_ok=True)
397 397
398 accelerator = Accelerator( 398 accelerator = Accelerator(
@@ -408,7 +408,7 @@ def main():
408 elif args.mixed_precision == "bf16": 408 elif args.mixed_precision == "bf16":
409 weight_dtype = torch.bfloat16 409 weight_dtype = torch.bfloat16
410 410
411 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 411 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
412 412
413 if args.seed is None: 413 if args.seed is None:
414 args.seed = torch.random.seed() >> 32 414 args.seed = torch.random.seed() >> 32
@@ -489,8 +489,8 @@ def main():
489 prior_loss_weight=args.prior_loss_weight, 489 prior_loss_weight=args.prior_loss_weight,
490 ) 490 )
491 491
492 checkpoint_output_dir = output_dir.joinpath("model") 492 checkpoint_output_dir = output_dir / "model"
493 sample_output_dir = output_dir.joinpath(f"samples") 493 sample_output_dir = output_dir/"samples"
494 494
495 datamodule = VlpnDataModule( 495 datamodule = VlpnDataModule(
496 data_file=args.train_data_file, 496 data_file=args.train_data_file,
@@ -562,7 +562,7 @@ def main():
562 sample_image_size=args.sample_image_size, 562 sample_image_size=args.sample_image_size,
563 ) 563 )
564 564
565 plot_metrics(metrics, output_dir.joinpath("lr.png")) 565 plot_metrics(metrics, output_dir/"lr.png")
566 566
567 567
568if __name__ == "__main__": 568if __name__ == "__main__":
diff --git a/train_ti.py b/train_ti.py
index c79dfa2..171d085 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -143,7 +143,7 @@ def parse_args():
143 parser.add_argument( 143 parser.add_argument(
144 "--num_buckets", 144 "--num_buckets",
145 type=int, 145 type=int,
146 default=4, 146 default=0,
147 help="Number of aspect ratio buckets in either direction.", 147 help="Number of aspect ratio buckets in either direction.",
148 ) 148 )
149 parser.add_argument( 149 parser.add_argument(
@@ -485,6 +485,9 @@ def parse_args():
485 485
486 if len(args.placeholder_tokens) != len(args.train_data_template): 486 if len(args.placeholder_tokens) != len(args.train_data_template):
487 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") 487 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items")
488 else:
489 if isinstance(args.train_data_template, list):
490 raise ValueError("--train_data_template can't be a list in simultaneous mode")
488 491
489 if isinstance(args.collection, str): 492 if isinstance(args.collection, str):
490 args.collection = [args.collection] 493 args.collection = [args.collection]
@@ -503,7 +506,7 @@ def main():
503 506
504 global_step_offset = args.global_step 507 global_step_offset = args.global_step
505 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 508 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
506 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) 509 output_dir = Path(args.output_dir)/slugify(args.project)/now
507 output_dir.mkdir(parents=True, exist_ok=True) 510 output_dir.mkdir(parents=True, exist_ok=True)
508 511
509 accelerator = Accelerator( 512 accelerator = Accelerator(
@@ -519,7 +522,7 @@ def main():
519 elif args.mixed_precision == "bf16": 522 elif args.mixed_precision == "bf16":
520 weight_dtype = torch.bfloat16 523 weight_dtype = torch.bfloat16
521 524
522 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) 525 logging.basicConfig(filename=output_dir/"log.txt", level=logging.DEBUG)
523 526
524 if args.seed is None: 527 if args.seed is None:
525 args.seed = torch.random.seed() >> 32 528 args.seed = torch.random.seed() >> 32
@@ -570,7 +573,7 @@ def main():
570 else: 573 else:
571 optimizer_class = torch.optim.AdamW 574 optimizer_class = torch.optim.AdamW
572 575
573 checkpoint_output_dir = output_dir.joinpath("checkpoints") 576 checkpoint_output_dir = output_dir/"checkpoints"
574 577
575 trainer = partial( 578 trainer = partial(
576 train, 579 train,
@@ -611,11 +614,11 @@ def main():
611 return 614 return
612 615
613 if len(placeholder_tokens) == 1: 616 if len(placeholder_tokens) == 1:
614 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") 617 sample_output_dir = output_dir/f"samples_{placeholder_tokens[0]}"
615 metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") 618 metrics_output_file = output_dir/f"{placeholder_tokens[0]}.png"
616 else: 619 else:
617 sample_output_dir = output_dir.joinpath("samples") 620 sample_output_dir = output_dir/"samples"
618 metrics_output_file = output_dir.joinpath(f"lr.png") 621 metrics_output_file = output_dir/f"lr.png"
619 622
620 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 623 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
621 tokenizer=tokenizer, 624 tokenizer=tokenizer,
diff --git a/training/functional.py b/training/functional.py
index ccbb4ad..83e70e2 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -129,7 +129,7 @@ def save_samples(
129 129
130 for pool, data, gen in datasets: 130 for pool, data, gen in datasets:
131 all_samples = [] 131 all_samples = []
132 file_path = output_dir.joinpath(pool, f"step_{step}.jpg") 132 file_path = output_dir / pool / f"step_{step}.jpg"
133 file_path.parent.mkdir(parents=True, exist_ok=True) 133 file_path.parent.mkdir(parents=True, exist_ok=True)
134 134
135 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) 135 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches))
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index bc10e58..4dd1100 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -91,7 +91,7 @@ def lora_strategy_callbacks(
91 print(f"Saving checkpoint for step {step}...") 91 print(f"Saving checkpoint for step {step}...")
92 92
93 unet_ = accelerator.unwrap_model(unet) 93 unet_ = accelerator.unwrap_model(unet)
94 unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) 94 unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}")
95 del unet_ 95 del unet_
96 96
97 @torch.no_grad() 97 @torch.no_grad()
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index da2b81c..0de3cb0 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -138,7 +138,7 @@ def textual_inversion_strategy_callbacks(
138 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 138 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
139 text_encoder.text_model.embeddings.save_embed( 139 text_encoder.text_model.embeddings.save_embed(
140 ids, 140 ids,
141 checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 141 checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin"
142 ) 142 )
143 143
144 @torch.no_grad() 144 @torch.no_grad()
diff --git a/util.py b/util.py
index 545bcb5..2712525 100644
--- a/util.py
+++ b/util.py
@@ -14,7 +14,7 @@ def load_config(filename):
14 args = config["args"] 14 args = config["args"]
15 15
16 if "base" in config: 16 if "base" in config:
17 args = load_config(Path(filename).parent.joinpath(config["base"])) | args 17 args = load_config(Path(filename).parent / config["base"]) | args
18 18
19 return args 19 return args
20 20