diff options
author | Volpeon <git@volpeon.ink> | 2022-10-10 12:46:57 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-10 12:46:57 +0200 |
commit | cd80af823d31148f9c0fa4d8045b773adfe1e6c3 (patch) | |
tree | 308ec793db17d5335c4cba37b760a0bdf3701f67 | |
parent | Updated default params (diff) | |
download | textual-inversion-diff-cd80af823d31148f9c0fa4d8045b773adfe1e6c3.tar.gz textual-inversion-diff-cd80af823d31148f9c0fa4d8045b773adfe1e6c3.tar.bz2 textual-inversion-diff-cd80af823d31148f9c0fa4d8045b773adfe1e6c3.zip |
Dreambooth: Add EMA support
-rw-r--r-- | dreambooth.py | 50 |
1 files changed, 44 insertions, 6 deletions
diff --git a/dreambooth.py b/dreambooth.py index 9f1b7af..f7d31d2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -16,6 +16,7 @@ from accelerate.logging import get_logger | |||
16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel |
18 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
19 | from diffusers.training_utils import EMAModel | ||
19 | from PIL import Image | 20 | from PIL import Image |
20 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
21 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
@@ -111,7 +112,7 @@ def parse_args(): | |||
111 | parser.add_argument( | 112 | parser.add_argument( |
112 | "--max_train_steps", | 113 | "--max_train_steps", |
113 | type=int, | 114 | type=int, |
114 | default=600, | 115 | default=5000, |
115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
116 | ) | 117 | ) |
117 | parser.add_argument( | 118 | parser.add_argument( |
@@ -128,7 +129,7 @@ def parse_args(): | |||
128 | parser.add_argument( | 129 | parser.add_argument( |
129 | "--learning_rate", | 130 | "--learning_rate", |
130 | type=float, | 131 | type=float, |
131 | default=1e-7, | 132 | default=1e-4, |
132 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
133 | ) | 134 | ) |
134 | parser.add_argument( | 135 | parser.add_argument( |
@@ -140,7 +141,7 @@ def parse_args(): | |||
140 | parser.add_argument( | 141 | parser.add_argument( |
141 | "--lr_scheduler", | 142 | "--lr_scheduler", |
142 | type=str, | 143 | type=str, |
143 | default="linear", | 144 | default="cosine", |
144 | help=( | 145 | help=( |
145 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | 146 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
146 | ' "constant", "constant_with_warmup"]' | 147 | ' "constant", "constant_with_warmup"]' |
@@ -153,8 +154,30 @@ def parse_args(): | |||
153 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
154 | ) | 155 | ) |
155 | parser.add_argument( | 156 | parser.add_argument( |
157 | "--use_ema", | ||
158 | action="store_true", | ||
159 | default=True, | ||
160 | help="Whether to use EMA model." | ||
161 | ) | ||
162 | parser.add_argument( | ||
163 | "--ema_inv_gamma", | ||
164 | type=float, | ||
165 | default=0.1 | ||
166 | ) | ||
167 | parser.add_argument( | ||
168 | "--ema_power", | ||
169 | type=float, | ||
170 | default=1 | ||
171 | ) | ||
172 | parser.add_argument( | ||
173 | "--ema_max_decay", | ||
174 | type=float, | ||
175 | default=0.9999 | ||
176 | ) | ||
177 | parser.add_argument( | ||
156 | "--use_8bit_adam", | 178 | "--use_8bit_adam", |
157 | action="store_true", | 179 | action="store_true", |
180 | default=True, | ||
158 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 181 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
159 | ) | 182 | ) |
160 | parser.add_argument( | 183 | parser.add_argument( |
@@ -172,7 +195,7 @@ def parse_args(): | |||
172 | parser.add_argument( | 195 | parser.add_argument( |
173 | "--adam_weight_decay", | 196 | "--adam_weight_decay", |
174 | type=float, | 197 | type=float, |
175 | default=0, | 198 | default=1e-2, |
176 | help="Weight decay to use." | 199 | help="Weight decay to use." |
177 | ) | 200 | ) |
178 | parser.add_argument( | 201 | parser.add_argument( |
@@ -298,6 +321,7 @@ class Checkpointer: | |||
298 | accelerator, | 321 | accelerator, |
299 | vae, | 322 | vae, |
300 | unet, | 323 | unet, |
324 | ema_unet, | ||
301 | tokenizer, | 325 | tokenizer, |
302 | text_encoder, | 326 | text_encoder, |
303 | output_dir: Path, | 327 | output_dir: Path, |
@@ -311,6 +335,7 @@ class Checkpointer: | |||
311 | self.accelerator = accelerator | 335 | self.accelerator = accelerator |
312 | self.vae = vae | 336 | self.vae = vae |
313 | self.unet = unet | 337 | self.unet = unet |
338 | self.ema_unet = ema_unet | ||
314 | self.tokenizer = tokenizer | 339 | self.tokenizer = tokenizer |
315 | self.text_encoder = text_encoder | 340 | self.text_encoder = text_encoder |
316 | self.output_dir = output_dir | 341 | self.output_dir = output_dir |
@@ -324,7 +349,8 @@ class Checkpointer: | |||
324 | def checkpoint(self): | 349 | def checkpoint(self): |
325 | print("Saving model...") | 350 | print("Saving model...") |
326 | 351 | ||
327 | unwrapped = self.accelerator.unwrap_model(self.unet) | 352 | unwrapped = self.accelerator.unwrap_model( |
353 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | ||
328 | pipeline = VlpnStableDiffusion( | 354 | pipeline = VlpnStableDiffusion( |
329 | text_encoder=self.text_encoder, | 355 | text_encoder=self.text_encoder, |
330 | vae=self.vae, | 356 | vae=self.vae, |
@@ -346,7 +372,8 @@ class Checkpointer: | |||
346 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 372 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
347 | samples_path = Path(self.output_dir).joinpath("samples") | 373 | samples_path = Path(self.output_dir).joinpath("samples") |
348 | 374 | ||
349 | unwrapped = self.accelerator.unwrap_model(self.unet) | 375 | unwrapped = self.accelerator.unwrap_model( |
376 | self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) | ||
350 | scheduler = EulerAScheduler( | 377 | scheduler = EulerAScheduler( |
351 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 378 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
352 | ) | 379 | ) |
@@ -602,6 +629,13 @@ def main(): | |||
602 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 629 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
603 | ) | 630 | ) |
604 | 631 | ||
632 | ema_unet = EMAModel( | ||
633 | unet, | ||
634 | inv_gamma=args.ema_inv_gamma, | ||
635 | power=args.ema_power, | ||
636 | max_value=args.ema_max_decay | ||
637 | ) if args.use_ema else None | ||
638 | |||
605 | # Move text_encoder and vae to device | 639 | # Move text_encoder and vae to device |
606 | text_encoder.to(accelerator.device) | 640 | text_encoder.to(accelerator.device) |
607 | vae.to(accelerator.device) | 641 | vae.to(accelerator.device) |
@@ -643,6 +677,7 @@ def main(): | |||
643 | accelerator=accelerator, | 677 | accelerator=accelerator, |
644 | vae=vae, | 678 | vae=vae, |
645 | unet=unet, | 679 | unet=unet, |
680 | ema_unet=ema_unet, | ||
646 | tokenizer=tokenizer, | 681 | tokenizer=tokenizer, |
647 | text_encoder=text_encoder, | 682 | text_encoder=text_encoder, |
648 | output_dir=basepath, | 683 | output_dir=basepath, |
@@ -737,6 +772,9 @@ def main(): | |||
737 | 772 | ||
738 | # Checks if the accelerator has performed an optimization step behind the scenes | 773 | # Checks if the accelerator has performed an optimization step behind the scenes |
739 | if accelerator.sync_gradients: | 774 | if accelerator.sync_gradients: |
775 | if args.use_ema: | ||
776 | ema_unet.step(unet) | ||
777 | |||
740 | local_progress_bar.update(1) | 778 | local_progress_bar.update(1) |
741 | global_progress_bar.update(1) | 779 | global_progress_bar.update(1) |
742 | 780 | ||