summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 21:47:06 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 21:47:06 +0100
commit2dfd1790078753f19ca8c585ac77079f3114f3a9 (patch)
treed1d1d643f247767c13535105dbe4afafcc5ab8c0
parentIf valid set size is 0, re-use one image from train set (diff)
downloadtextual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.gz
textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.bz2
textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.zip
Training update
-rw-r--r--train_dreambooth.py6
-rw-r--r--train_ti.py8
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/dreambooth.py10
-rw-r--r--training/strategy/ti.py13
5 files changed, 25 insertions, 16 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 6511f9b..d722e68 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -552,6 +552,9 @@ def main():
552 prior_loss_weight=args.prior_loss_weight, 552 prior_loss_weight=args.prior_loss_weight,
553 ) 553 )
554 554
555 checkpoint_output_dir = output_dir.joinpath("model")
556 sample_output_dir = output_dir.joinpath(f"samples")
557
555 datamodule = VlpnDataModule( 558 datamodule = VlpnDataModule(
556 data_file=args.train_data_file, 559 data_file=args.train_data_file,
557 batch_size=args.train_batch_size, 560 batch_size=args.train_batch_size,
@@ -620,7 +623,8 @@ def main():
620 # -- 623 # --
621 tokenizer=tokenizer, 624 tokenizer=tokenizer,
622 sample_scheduler=sample_scheduler, 625 sample_scheduler=sample_scheduler,
623 output_dir=output_dir, 626 sample_output_dir=sample_output_dir,
627 checkpoint_output_dir=checkpoint_output_dir,
624 train_text_encoder_epochs=args.train_text_encoder_epochs, 628 train_text_encoder_epochs=args.train_text_encoder_epochs,
625 max_grad_norm=args.max_grad_norm, 629 max_grad_norm=args.max_grad_norm,
626 use_ema=args.use_ema, 630 use_ema=args.use_ema,
diff --git a/train_ti.py b/train_ti.py
index e696577..e7aeb23 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -572,6 +572,8 @@ def main():
572 callbacks_fn=textual_inversion_strategy 572 callbacks_fn=textual_inversion_strategy
573 ) 573 )
574 574
575 checkpoint_output_dir = output_dir.joinpath("checkpoints")
576
575 for i, placeholder_token, initializer_token, num_vectors, data_template in zip( 577 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
576 range(len(args.placeholder_tokens)), 578 range(len(args.placeholder_tokens)),
577 args.placeholder_tokens, 579 args.placeholder_tokens,
@@ -579,8 +581,7 @@ def main():
579 args.num_vectors, 581 args.num_vectors,
580 args.train_data_template 582 args.train_data_template
581 ): 583 ):
582 cur_subdir = output_dir.joinpath(placeholder_token) 584 sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}")
583 cur_subdir.mkdir(parents=True, exist_ok=True)
584 585
585 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 586 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
586 tokenizer=tokenizer, 587 tokenizer=tokenizer,
@@ -655,7 +656,8 @@ def main():
655 # -- 656 # --
656 tokenizer=tokenizer, 657 tokenizer=tokenizer,
657 sample_scheduler=sample_scheduler, 658 sample_scheduler=sample_scheduler,
658 output_dir=cur_subdir, 659 sample_output_dir=sample_output_dir,
660 checkpoint_output_dir=checkpoint_output_dir,
659 placeholder_tokens=[placeholder_token], 661 placeholder_tokens=[placeholder_token],
660 placeholder_token_ids=placeholder_token_ids, 662 placeholder_token_ids=placeholder_token_ids,
661 learning_rate=args.learning_rate, 663 learning_rate=args.learning_rate,
diff --git a/training/functional.py b/training/functional.py
index 1548784..3d27380 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -87,8 +87,6 @@ def save_samples(
87): 87):
88 print(f"Saving samples for step {step}...") 88 print(f"Saving samples for step {step}...")
89 89
90 samples_path = output_dir.joinpath("samples")
91
92 grid_cols = min(batch_size, 4) 90 grid_cols = min(batch_size, 4)
93 grid_rows = (num_batches * batch_size) // grid_cols 91 grid_rows = (num_batches * batch_size) // grid_cols
94 92
@@ -120,7 +118,7 @@ def save_samples(
120 118
121 for pool, data, gen in datasets: 119 for pool, data, gen in datasets:
122 all_samples = [] 120 all_samples = []
123 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 121 file_path = output_dir.joinpath(pool, f"step_{step}.jpg")
124 file_path.parent.mkdir(parents=True, exist_ok=True) 122 file_path.parent.mkdir(parents=True, exist_ok=True)
125 123
126 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) 124 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches))
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index aeaa828..93c81cb 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -27,7 +27,8 @@ def dreambooth_strategy(
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: Optional[DataLoader], 29 val_dataloader: Optional[DataLoader],
30 output_dir: Path, 30 sample_output_dir: Path,
31 checkpoint_output_dir: Path,
31 seed: int, 32 seed: int,
32 train_text_encoder_epochs: int, 33 train_text_encoder_epochs: int,
33 max_grad_norm: float = 1.0, 34 max_grad_norm: float = 1.0,
@@ -47,6 +48,9 @@ def dreambooth_strategy(
47 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 48 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
48 ) 49 )
49 50
51 sample_output_dir.mkdir(parents=True, exist_ok=True)
52 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
53
50 weight_dtype = torch.float32 54 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16": 55 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16 56 weight_dtype = torch.float16
@@ -64,7 +68,7 @@ def dreambooth_strategy(
64 train_dataloader=train_dataloader, 68 train_dataloader=train_dataloader,
65 val_dataloader=val_dataloader, 69 val_dataloader=val_dataloader,
66 dtype=weight_dtype, 70 dtype=weight_dtype,
67 output_dir=output_dir, 71 output_dir=sample_output_dir,
68 seed=seed, 72 seed=seed,
69 batch_size=sample_batch_size, 73 batch_size=sample_batch_size,
70 num_batches=sample_num_batches, 74 num_batches=sample_num_batches,
@@ -156,7 +160,7 @@ def dreambooth_strategy(
156 tokenizer=tokenizer, 160 tokenizer=tokenizer,
157 scheduler=sample_scheduler, 161 scheduler=sample_scheduler,
158 ) 162 )
159 pipeline.save_pretrained(output_dir.joinpath("model")) 163 pipeline.save_pretrained(checkpoint_output_dir)
160 164
161 del unet_ 165 del unet_
162 del text_encoder_ 166 del text_encoder_
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 9d39e15..00f3529 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -27,7 +27,8 @@ def textual_inversion_strategy(
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: Optional[DataLoader], 29 val_dataloader: Optional[DataLoader],
30 output_dir: Path, 30 sample_output_dir: Path,
31 checkpoint_output_dir: Path,
31 seed: int, 32 seed: int,
32 placeholder_tokens: list[str], 33 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 34 placeholder_token_ids: list[list[int]],
@@ -47,6 +48,9 @@ def textual_inversion_strategy(
47 sample_guidance_scale: float = 7.5, 48 sample_guidance_scale: float = 7.5,
48 sample_image_size: Optional[int] = None, 49 sample_image_size: Optional[int] = None,
49): 50):
51 sample_output_dir.mkdir(parents=True, exist_ok=True)
52 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
53
50 weight_dtype = torch.float32 54 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16": 55 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16 56 weight_dtype = torch.float16
@@ -64,7 +68,7 @@ def textual_inversion_strategy(
64 train_dataloader=train_dataloader, 68 train_dataloader=train_dataloader,
65 val_dataloader=val_dataloader, 69 val_dataloader=val_dataloader,
66 dtype=weight_dtype, 70 dtype=weight_dtype,
67 output_dir=output_dir, 71 output_dir=sample_output_dir,
68 seed=seed, 72 seed=seed,
69 batch_size=sample_batch_size, 73 batch_size=sample_batch_size,
70 num_batches=sample_num_batches, 74 num_batches=sample_num_batches,
@@ -135,14 +139,11 @@ def textual_inversion_strategy(
135 def on_checkpoint(step, postfix): 139 def on_checkpoint(step, postfix):
136 print(f"Saving checkpoint for step {step}...") 140 print(f"Saving checkpoint for step {step}...")
137 141
138 checkpoints_path = output_dir.joinpath("checkpoints")
139 checkpoints_path.mkdir(parents=True, exist_ok=True)
140
141 with ema_context(): 142 with ema_context():
142 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 143 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
143 text_encoder.text_model.embeddings.save_embed( 144 text_encoder.text_model.embeddings.save_embed(
144 ids, 145 ids,
145 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 146 checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
146 ) 147 )
147 148
148 @torch.no_grad() 149 @torch.no_grad()