summaryrefslogtreecommitdiffstats
path: root/training/strategy/dreambooth.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/strategy/dreambooth.py')
-rw-r--r--training/strategy/dreambooth.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index aeaa828..93c81cb 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -27,7 +27,8 @@ def dreambooth_strategy(
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: Optional[DataLoader], 29 val_dataloader: Optional[DataLoader],
30 output_dir: Path, 30 sample_output_dir: Path,
31 checkpoint_output_dir: Path,
31 seed: int, 32 seed: int,
32 train_text_encoder_epochs: int, 33 train_text_encoder_epochs: int,
33 max_grad_norm: float = 1.0, 34 max_grad_norm: float = 1.0,
@@ -47,6 +48,9 @@ def dreambooth_strategy(
47 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 48 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
48 ) 49 )
49 50
51 sample_output_dir.mkdir(parents=True, exist_ok=True)
52 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
53
50 weight_dtype = torch.float32 54 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16": 55 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16 56 weight_dtype = torch.float16
@@ -64,7 +68,7 @@ def dreambooth_strategy(
64 train_dataloader=train_dataloader, 68 train_dataloader=train_dataloader,
65 val_dataloader=val_dataloader, 69 val_dataloader=val_dataloader,
66 dtype=weight_dtype, 70 dtype=weight_dtype,
67 output_dir=output_dir, 71 output_dir=sample_output_dir,
68 seed=seed, 72 seed=seed,
69 batch_size=sample_batch_size, 73 batch_size=sample_batch_size,
70 num_batches=sample_num_batches, 74 num_batches=sample_num_batches,
@@ -156,7 +160,7 @@ def dreambooth_strategy(
156 tokenizer=tokenizer, 160 tokenizer=tokenizer,
157 scheduler=sample_scheduler, 161 scheduler=sample_scheduler,
158 ) 162 )
159 pipeline.save_pretrained(output_dir.joinpath("model")) 163 pipeline.save_pretrained(checkpoint_output_dir)
160 164
161 del unet_ 165 del unet_
162 del text_encoder_ 166 del text_encoder_