diff options
-rw-r--r-- | dreambooth.py | 19 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 |
2 files changed, 34 insertions, 12 deletions
diff --git a/dreambooth.py b/dreambooth.py index 49d4447..3dd0920 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -115,7 +115,7 @@ def parse_args(): | |||
115 | parser.add_argument( | 115 | parser.add_argument( |
116 | "--resolution", | 116 | "--resolution", |
117 | type=int, | 117 | type=int, |
118 | default=512, | 118 | default=768, |
119 | help=( | 119 | help=( |
120 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | 120 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" |
121 | " resolution" | 121 | " resolution" |
@@ -267,7 +267,7 @@ def parse_args(): | |||
267 | parser.add_argument( | 267 | parser.add_argument( |
268 | "--sample_image_size", | 268 | "--sample_image_size", |
269 | type=int, | 269 | type=int, |
270 | default=512, | 270 | default=768, |
271 | help="Size of sample images", | 271 | help="Size of sample images", |
272 | ) | 272 | ) |
273 | parser.add_argument( | 273 | parser.add_argument( |
@@ -297,7 +297,7 @@ def parse_args(): | |||
297 | parser.add_argument( | 297 | parser.add_argument( |
298 | "--sample_steps", | 298 | "--sample_steps", |
299 | type=int, | 299 | type=int, |
300 | default=25, | 300 | default=15, |
301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
302 | ) | 302 | ) |
303 | parser.add_argument( | 303 | parser.add_argument( |
@@ -459,7 +459,7 @@ class Checkpointer: | |||
459 | torch.cuda.empty_cache() | 459 | torch.cuda.empty_cache() |
460 | 460 | ||
461 | @torch.no_grad() | 461 | @torch.no_grad() |
462 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 462 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
463 | samples_path = Path(self.output_dir).joinpath("samples") | 463 | samples_path = Path(self.output_dir).joinpath("samples") |
464 | 464 | ||
465 | unwrapped_unet = self.accelerator.unwrap_model( | 465 | unwrapped_unet = self.accelerator.unwrap_model( |
@@ -474,13 +474,14 @@ class Checkpointer: | |||
474 | scheduler=self.scheduler, | 474 | scheduler=self.scheduler, |
475 | ).to(self.accelerator.device) | 475 | ).to(self.accelerator.device) |
476 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 476 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
477 | pipeline.enable_vae_slicing() | ||
477 | 478 | ||
478 | train_data = self.datamodule.train_dataloader() | 479 | train_data = self.datamodule.train_dataloader() |
479 | val_data = self.datamodule.val_dataloader() | 480 | val_data = self.datamodule.val_dataloader() |
480 | 481 | ||
481 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 482 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
482 | stable_latents = torch.randn( | 483 | stable_latents = torch.randn( |
483 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 484 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), |
484 | device=pipeline.device, | 485 | device=pipeline.device, |
485 | generator=generator, | 486 | generator=generator, |
486 | ) | 487 | ) |
@@ -875,9 +876,7 @@ def main(): | |||
875 | ) | 876 | ) |
876 | 877 | ||
877 | if accelerator.is_main_process: | 878 | if accelerator.is_main_process: |
878 | checkpointer.save_samples( | 879 | checkpointer.save_samples(0, args.sample_steps) |
879 | 0, | ||
880 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
881 | 880 | ||
882 | local_progress_bar = tqdm( | 881 | local_progress_bar = tqdm( |
883 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 882 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
@@ -1089,9 +1088,7 @@ def main(): | |||
1089 | max_acc_val = avg_acc_val | 1088 | max_acc_val = avg_acc_val |
1090 | 1089 | ||
1091 | if sample_checkpoint and accelerator.is_main_process: | 1090 | if sample_checkpoint and accelerator.is_main_process: |
1092 | checkpointer.save_samples( | 1091 | checkpointer.save_samples(global_step, args.sample_steps) |
1093 | global_step, | ||
1094 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
1095 | 1092 | ||
1096 | # Create the pipeline using using the trained modules and save it. | 1093 | # Create the pipeline using using the trained modules and save it. |
1097 | if accelerator.is_main_process: | 1094 | if accelerator.is_main_process: |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 85b0216..c77c4d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -20,6 +20,7 @@ from diffusers import ( | |||
20 | PNDMScheduler, | 20 | PNDMScheduler, |
21 | ) | 21 | ) |
22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
23 | from diffusers.models.vae import DecoderOutput | ||
23 | from diffusers.utils import logging | 24 | from diffusers.utils import logging |
24 | from transformers import CLIPTextModel, CLIPTokenizer | 25 | from transformers import CLIPTextModel, CLIPTokenizer |
25 | from models.clip.prompt import PromptProcessor | 26 | from models.clip.prompt import PromptProcessor |
@@ -69,6 +70,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
69 | scheduler._internal_dict = FrozenDict(new_config) | 70 | scheduler._internal_dict = FrozenDict(new_config) |
70 | 71 | ||
71 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | 72 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) |
73 | self.use_slicing = False | ||
72 | 74 | ||
73 | self.register_modules( | 75 | self.register_modules( |
74 | vae=vae, | 76 | vae=vae, |
@@ -136,6 +138,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
136 | if cpu_offloaded_model is not None: | 138 | if cpu_offloaded_model is not None: |
137 | cpu_offload(cpu_offloaded_model, device) | 139 | cpu_offload(cpu_offloaded_model, device) |
138 | 140 | ||
141 | def enable_vae_slicing(self): | ||
142 | r""" | ||
143 | Enable sliced VAE decoding. | ||
144 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several | ||
145 | steps. This is useful to save some memory and allow larger batch sizes. | ||
146 | """ | ||
147 | self.use_slicing = True | ||
148 | |||
149 | def disable_vae_slicing(self): | ||
150 | r""" | ||
151 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to | ||
152 | computing decoding in one step. | ||
153 | """ | ||
154 | self.use_slicing = False | ||
155 | |||
139 | @property | 156 | @property |
140 | def execution_device(self): | 157 | def execution_device(self): |
141 | r""" | 158 | r""" |
@@ -280,12 +297,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
280 | 297 | ||
281 | def decode_latents(self, latents): | 298 | def decode_latents(self, latents): |
282 | latents = 1 / 0.18215 * latents | 299 | latents = 1 / 0.18215 * latents |
283 | image = self.vae.decode(latents).sample | 300 | image = self.vae_decode(latents).sample |
284 | image = (image / 2 + 0.5).clamp(0, 1) | 301 | image = (image / 2 + 0.5).clamp(0, 1) |
285 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 302 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
286 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 303 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
287 | return image | 304 | return image |
288 | 305 | ||
306 | def vae_decode(self, latents): | ||
307 | if self.use_slicing: | ||
308 | decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] | ||
309 | decoded = torch.cat(decoded_slices) | ||
310 | return DecoderOutput(sample=decoded) | ||
311 | else: | ||
312 | return self.vae.decode(latents) | ||
313 | |||
289 | @torch.no_grad() | 314 | @torch.no_grad() |
290 | def __call__( | 315 | def __call__( |
291 | self, | 316 | self, |