diff options
-rw-r--r-- | train_dreambooth.py | 6 | ||||
-rw-r--r-- | train_ti.py | 8 | ||||
-rw-r--r-- | training/functional.py | 4 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 10 | ||||
-rw-r--r-- | training/strategy/ti.py | 13 |
5 files changed, 25 insertions, 16 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 6511f9b..d722e68 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -552,6 +552,9 @@ def main(): | |||
552 | prior_loss_weight=args.prior_loss_weight, | 552 | prior_loss_weight=args.prior_loss_weight, |
553 | ) | 553 | ) |
554 | 554 | ||
555 | checkpoint_output_dir = output_dir.joinpath("model") | ||
556 | sample_output_dir = output_dir.joinpath(f"samples") | ||
557 | |||
555 | datamodule = VlpnDataModule( | 558 | datamodule = VlpnDataModule( |
556 | data_file=args.train_data_file, | 559 | data_file=args.train_data_file, |
557 | batch_size=args.train_batch_size, | 560 | batch_size=args.train_batch_size, |
@@ -620,7 +623,8 @@ def main(): | |||
620 | # -- | 623 | # -- |
621 | tokenizer=tokenizer, | 624 | tokenizer=tokenizer, |
622 | sample_scheduler=sample_scheduler, | 625 | sample_scheduler=sample_scheduler, |
623 | output_dir=output_dir, | 626 | sample_output_dir=sample_output_dir, |
627 | checkpoint_output_dir=checkpoint_output_dir, | ||
624 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 628 | train_text_encoder_epochs=args.train_text_encoder_epochs, |
625 | max_grad_norm=args.max_grad_norm, | 629 | max_grad_norm=args.max_grad_norm, |
626 | use_ema=args.use_ema, | 630 | use_ema=args.use_ema, |
diff --git a/train_ti.py b/train_ti.py index e696577..e7aeb23 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -572,6 +572,8 @@ def main(): | |||
572 | callbacks_fn=textual_inversion_strategy | 572 | callbacks_fn=textual_inversion_strategy |
573 | ) | 573 | ) |
574 | 574 | ||
575 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | ||
576 | |||
575 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 577 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
576 | range(len(args.placeholder_tokens)), | 578 | range(len(args.placeholder_tokens)), |
577 | args.placeholder_tokens, | 579 | args.placeholder_tokens, |
@@ -579,8 +581,7 @@ def main(): | |||
579 | args.num_vectors, | 581 | args.num_vectors, |
580 | args.train_data_template | 582 | args.train_data_template |
581 | ): | 583 | ): |
582 | cur_subdir = output_dir.joinpath(placeholder_token) | 584 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") |
583 | cur_subdir.mkdir(parents=True, exist_ok=True) | ||
584 | 585 | ||
585 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 586 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
586 | tokenizer=tokenizer, | 587 | tokenizer=tokenizer, |
@@ -655,7 +656,8 @@ def main(): | |||
655 | # -- | 656 | # -- |
656 | tokenizer=tokenizer, | 657 | tokenizer=tokenizer, |
657 | sample_scheduler=sample_scheduler, | 658 | sample_scheduler=sample_scheduler, |
658 | output_dir=cur_subdir, | 659 | sample_output_dir=sample_output_dir, |
660 | checkpoint_output_dir=checkpoint_output_dir, | ||
659 | placeholder_tokens=[placeholder_token], | 661 | placeholder_tokens=[placeholder_token], |
660 | placeholder_token_ids=placeholder_token_ids, | 662 | placeholder_token_ids=placeholder_token_ids, |
661 | learning_rate=args.learning_rate, | 663 | learning_rate=args.learning_rate, |
diff --git a/training/functional.py b/training/functional.py index 1548784..3d27380 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -87,8 +87,6 @@ def save_samples( | |||
87 | ): | 87 | ): |
88 | print(f"Saving samples for step {step}...") | 88 | print(f"Saving samples for step {step}...") |
89 | 89 | ||
90 | samples_path = output_dir.joinpath("samples") | ||
91 | |||
92 | grid_cols = min(batch_size, 4) | 90 | grid_cols = min(batch_size, 4) |
93 | grid_rows = (num_batches * batch_size) // grid_cols | 91 | grid_rows = (num_batches * batch_size) // grid_cols |
94 | 92 | ||
@@ -120,7 +118,7 @@ def save_samples( | |||
120 | 118 | ||
121 | for pool, data, gen in datasets: | 119 | for pool, data, gen in datasets: |
122 | all_samples = [] | 120 | all_samples = [] |
123 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 121 | file_path = output_dir.joinpath(pool, f"step_{step}.jpg") |
124 | file_path.parent.mkdir(parents=True, exist_ok=True) | 122 | file_path.parent.mkdir(parents=True, exist_ok=True) |
125 | 123 | ||
126 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 124 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index aeaa828..93c81cb 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -27,7 +27,8 @@ def dreambooth_strategy( | |||
27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
29 | val_dataloader: Optional[DataLoader], | 29 | val_dataloader: Optional[DataLoader], |
30 | output_dir: Path, | 30 | sample_output_dir: Path, |
31 | checkpoint_output_dir: Path, | ||
31 | seed: int, | 32 | seed: int, |
32 | train_text_encoder_epochs: int, | 33 | train_text_encoder_epochs: int, |
33 | max_grad_norm: float = 1.0, | 34 | max_grad_norm: float = 1.0, |
@@ -47,6 +48,9 @@ def dreambooth_strategy( | |||
47 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | 48 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." |
48 | ) | 49 | ) |
49 | 50 | ||
51 | sample_output_dir.mkdir(parents=True, exist_ok=True) | ||
52 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | ||
53 | |||
50 | weight_dtype = torch.float32 | 54 | weight_dtype = torch.float32 |
51 | if accelerator.state.mixed_precision == "fp16": | 55 | if accelerator.state.mixed_precision == "fp16": |
52 | weight_dtype = torch.float16 | 56 | weight_dtype = torch.float16 |
@@ -64,7 +68,7 @@ def dreambooth_strategy( | |||
64 | train_dataloader=train_dataloader, | 68 | train_dataloader=train_dataloader, |
65 | val_dataloader=val_dataloader, | 69 | val_dataloader=val_dataloader, |
66 | dtype=weight_dtype, | 70 | dtype=weight_dtype, |
67 | output_dir=output_dir, | 71 | output_dir=sample_output_dir, |
68 | seed=seed, | 72 | seed=seed, |
69 | batch_size=sample_batch_size, | 73 | batch_size=sample_batch_size, |
70 | num_batches=sample_num_batches, | 74 | num_batches=sample_num_batches, |
@@ -156,7 +160,7 @@ def dreambooth_strategy( | |||
156 | tokenizer=tokenizer, | 160 | tokenizer=tokenizer, |
157 | scheduler=sample_scheduler, | 161 | scheduler=sample_scheduler, |
158 | ) | 162 | ) |
159 | pipeline.save_pretrained(output_dir.joinpath("model")) | 163 | pipeline.save_pretrained(checkpoint_output_dir) |
160 | 164 | ||
161 | del unet_ | 165 | del unet_ |
162 | del text_encoder_ | 166 | del text_encoder_ |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 9d39e15..00f3529 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -27,7 +27,8 @@ def textual_inversion_strategy( | |||
27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
29 | val_dataloader: Optional[DataLoader], | 29 | val_dataloader: Optional[DataLoader], |
30 | output_dir: Path, | 30 | sample_output_dir: Path, |
31 | checkpoint_output_dir: Path, | ||
31 | seed: int, | 32 | seed: int, |
32 | placeholder_tokens: list[str], | 33 | placeholder_tokens: list[str], |
33 | placeholder_token_ids: list[list[int]], | 34 | placeholder_token_ids: list[list[int]], |
@@ -47,6 +48,9 @@ def textual_inversion_strategy( | |||
47 | sample_guidance_scale: float = 7.5, | 48 | sample_guidance_scale: float = 7.5, |
48 | sample_image_size: Optional[int] = None, | 49 | sample_image_size: Optional[int] = None, |
49 | ): | 50 | ): |
51 | sample_output_dir.mkdir(parents=True, exist_ok=True) | ||
52 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | ||
53 | |||
50 | weight_dtype = torch.float32 | 54 | weight_dtype = torch.float32 |
51 | if accelerator.state.mixed_precision == "fp16": | 55 | if accelerator.state.mixed_precision == "fp16": |
52 | weight_dtype = torch.float16 | 56 | weight_dtype = torch.float16 |
@@ -64,7 +68,7 @@ def textual_inversion_strategy( | |||
64 | train_dataloader=train_dataloader, | 68 | train_dataloader=train_dataloader, |
65 | val_dataloader=val_dataloader, | 69 | val_dataloader=val_dataloader, |
66 | dtype=weight_dtype, | 70 | dtype=weight_dtype, |
67 | output_dir=output_dir, | 71 | output_dir=sample_output_dir, |
68 | seed=seed, | 72 | seed=seed, |
69 | batch_size=sample_batch_size, | 73 | batch_size=sample_batch_size, |
70 | num_batches=sample_num_batches, | 74 | num_batches=sample_num_batches, |
@@ -135,14 +139,11 @@ def textual_inversion_strategy( | |||
135 | def on_checkpoint(step, postfix): | 139 | def on_checkpoint(step, postfix): |
136 | print(f"Saving checkpoint for step {step}...") | 140 | print(f"Saving checkpoint for step {step}...") |
137 | 141 | ||
138 | checkpoints_path = output_dir.joinpath("checkpoints") | ||
139 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
140 | |||
141 | with ema_context(): | 142 | with ema_context(): |
142 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 143 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
143 | text_encoder.text_model.embeddings.save_embed( | 144 | text_encoder.text_model.embeddings.save_embed( |
144 | ids, | 145 | ids, |
145 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 146 | checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
146 | ) | 147 | ) |
147 | 148 | ||
148 | @torch.no_grad() | 149 | @torch.no_grad() |