summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py4
-rw-r--r--training/strategy/dreambooth.py10
-rw-r--r--training/strategy/ti.py13
3 files changed, 15 insertions, 12 deletions
diff --git a/training/functional.py b/training/functional.py
index 1548784..3d27380 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -87,8 +87,6 @@ def save_samples(
87): 87):
88 print(f"Saving samples for step {step}...") 88 print(f"Saving samples for step {step}...")
89 89
90 samples_path = output_dir.joinpath("samples")
91
92 grid_cols = min(batch_size, 4) 90 grid_cols = min(batch_size, 4)
93 grid_rows = (num_batches * batch_size) // grid_cols 91 grid_rows = (num_batches * batch_size) // grid_cols
94 92
@@ -120,7 +118,7 @@ def save_samples(
120 118
121 for pool, data, gen in datasets: 119 for pool, data, gen in datasets:
122 all_samples = [] 120 all_samples = []
123 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 121 file_path = output_dir.joinpath(pool, f"step_{step}.jpg")
124 file_path.parent.mkdir(parents=True, exist_ok=True) 122 file_path.parent.mkdir(parents=True, exist_ok=True)
125 123
126 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) 124 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches))
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index aeaa828..93c81cb 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -27,7 +27,8 @@ def dreambooth_strategy(
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: Optional[DataLoader], 29 val_dataloader: Optional[DataLoader],
30 output_dir: Path, 30 sample_output_dir: Path,
31 checkpoint_output_dir: Path,
31 seed: int, 32 seed: int,
32 train_text_encoder_epochs: int, 33 train_text_encoder_epochs: int,
33 max_grad_norm: float = 1.0, 34 max_grad_norm: float = 1.0,
@@ -47,6 +48,9 @@ def dreambooth_strategy(
47 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 48 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
48 ) 49 )
49 50
51 sample_output_dir.mkdir(parents=True, exist_ok=True)
52 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
53
50 weight_dtype = torch.float32 54 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16": 55 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16 56 weight_dtype = torch.float16
@@ -64,7 +68,7 @@ def dreambooth_strategy(
64 train_dataloader=train_dataloader, 68 train_dataloader=train_dataloader,
65 val_dataloader=val_dataloader, 69 val_dataloader=val_dataloader,
66 dtype=weight_dtype, 70 dtype=weight_dtype,
67 output_dir=output_dir, 71 output_dir=sample_output_dir,
68 seed=seed, 72 seed=seed,
69 batch_size=sample_batch_size, 73 batch_size=sample_batch_size,
70 num_batches=sample_num_batches, 74 num_batches=sample_num_batches,
@@ -156,7 +160,7 @@ def dreambooth_strategy(
156 tokenizer=tokenizer, 160 tokenizer=tokenizer,
157 scheduler=sample_scheduler, 161 scheduler=sample_scheduler,
158 ) 162 )
159 pipeline.save_pretrained(output_dir.joinpath("model")) 163 pipeline.save_pretrained(checkpoint_output_dir)
160 164
161 del unet_ 165 del unet_
162 del text_encoder_ 166 del text_encoder_
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 9d39e15..00f3529 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -27,7 +27,8 @@ def textual_inversion_strategy(
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: Optional[DataLoader], 29 val_dataloader: Optional[DataLoader],
30 output_dir: Path, 30 sample_output_dir: Path,
31 checkpoint_output_dir: Path,
31 seed: int, 32 seed: int,
32 placeholder_tokens: list[str], 33 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 34 placeholder_token_ids: list[list[int]],
@@ -47,6 +48,9 @@ def textual_inversion_strategy(
47 sample_guidance_scale: float = 7.5, 48 sample_guidance_scale: float = 7.5,
48 sample_image_size: Optional[int] = None, 49 sample_image_size: Optional[int] = None,
49): 50):
51 sample_output_dir.mkdir(parents=True, exist_ok=True)
52 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
53
50 weight_dtype = torch.float32 54 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16": 55 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16 56 weight_dtype = torch.float16
@@ -64,7 +68,7 @@ def textual_inversion_strategy(
64 train_dataloader=train_dataloader, 68 train_dataloader=train_dataloader,
65 val_dataloader=val_dataloader, 69 val_dataloader=val_dataloader,
66 dtype=weight_dtype, 70 dtype=weight_dtype,
67 output_dir=output_dir, 71 output_dir=sample_output_dir,
68 seed=seed, 72 seed=seed,
69 batch_size=sample_batch_size, 73 batch_size=sample_batch_size,
70 num_batches=sample_num_batches, 74 num_batches=sample_num_batches,
@@ -135,14 +139,11 @@ def textual_inversion_strategy(
135 def on_checkpoint(step, postfix): 139 def on_checkpoint(step, postfix):
136 print(f"Saving checkpoint for step {step}...") 140 print(f"Saving checkpoint for step {step}...")
137 141
138 checkpoints_path = output_dir.joinpath("checkpoints")
139 checkpoints_path.mkdir(parents=True, exist_ok=True)
140
141 with ema_context(): 142 with ema_context():
142 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 143 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
143 text_encoder.text_model.embeddings.save_embed( 144 text_encoder.text_model.embeddings.save_embed(
144 ids, 145 ids,
145 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 146 checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
146 ) 147 )
147 148
148 @torch.no_grad() 149 @torch.no_grad()