diff options
| author | Volpeon <git@volpeon.ink> | 2022-11-30 14:02:35 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-11-30 14:02:35 +0100 |
| commit | 329ad48b307e782b0e23fce80ae9087a4003e73d (patch) | |
| tree | 0c72434a8d45ae933582064849b43bd7419f7ee8 | |
| parent | Adjusted training to upstream (diff) | |
| download | textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.gz textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.bz2 textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.zip | |
Update
| -rw-r--r-- | dreambooth.py | 19 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 27 |
2 files changed, 34 insertions, 12 deletions
diff --git a/dreambooth.py b/dreambooth.py index 49d4447..3dd0920 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -115,7 +115,7 @@ def parse_args(): | |||
| 115 | parser.add_argument( | 115 | parser.add_argument( |
| 116 | "--resolution", | 116 | "--resolution", |
| 117 | type=int, | 117 | type=int, |
| 118 | default=512, | 118 | default=768, |
| 119 | help=( | 119 | help=( |
| 120 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | 120 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" |
| 121 | " resolution" | 121 | " resolution" |
| @@ -267,7 +267,7 @@ def parse_args(): | |||
| 267 | parser.add_argument( | 267 | parser.add_argument( |
| 268 | "--sample_image_size", | 268 | "--sample_image_size", |
| 269 | type=int, | 269 | type=int, |
| 270 | default=512, | 270 | default=768, |
| 271 | help="Size of sample images", | 271 | help="Size of sample images", |
| 272 | ) | 272 | ) |
| 273 | parser.add_argument( | 273 | parser.add_argument( |
| @@ -297,7 +297,7 @@ def parse_args(): | |||
| 297 | parser.add_argument( | 297 | parser.add_argument( |
| 298 | "--sample_steps", | 298 | "--sample_steps", |
| 299 | type=int, | 299 | type=int, |
| 300 | default=25, | 300 | default=15, |
| 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 301 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 302 | ) | 302 | ) |
| 303 | parser.add_argument( | 303 | parser.add_argument( |
| @@ -459,7 +459,7 @@ class Checkpointer: | |||
| 459 | torch.cuda.empty_cache() | 459 | torch.cuda.empty_cache() |
| 460 | 460 | ||
| 461 | @torch.no_grad() | 461 | @torch.no_grad() |
| 462 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 462 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 463 | samples_path = Path(self.output_dir).joinpath("samples") | 463 | samples_path = Path(self.output_dir).joinpath("samples") |
| 464 | 464 | ||
| 465 | unwrapped_unet = self.accelerator.unwrap_model( | 465 | unwrapped_unet = self.accelerator.unwrap_model( |
| @@ -474,13 +474,14 @@ class Checkpointer: | |||
| 474 | scheduler=self.scheduler, | 474 | scheduler=self.scheduler, |
| 475 | ).to(self.accelerator.device) | 475 | ).to(self.accelerator.device) |
| 476 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 476 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 477 | pipeline.enable_vae_slicing() | ||
| 477 | 478 | ||
| 478 | train_data = self.datamodule.train_dataloader() | 479 | train_data = self.datamodule.train_dataloader() |
| 479 | val_data = self.datamodule.val_dataloader() | 480 | val_data = self.datamodule.val_dataloader() |
| 480 | 481 | ||
| 481 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 482 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| 482 | stable_latents = torch.randn( | 483 | stable_latents = torch.randn( |
| 483 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 484 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), |
| 484 | device=pipeline.device, | 485 | device=pipeline.device, |
| 485 | generator=generator, | 486 | generator=generator, |
| 486 | ) | 487 | ) |
| @@ -875,9 +876,7 @@ def main(): | |||
| 875 | ) | 876 | ) |
| 876 | 877 | ||
| 877 | if accelerator.is_main_process: | 878 | if accelerator.is_main_process: |
| 878 | checkpointer.save_samples( | 879 | checkpointer.save_samples(0, args.sample_steps) |
| 879 | 0, | ||
| 880 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 881 | 880 | ||
| 882 | local_progress_bar = tqdm( | 881 | local_progress_bar = tqdm( |
| 883 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), | 882 | range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| @@ -1089,9 +1088,7 @@ def main(): | |||
| 1089 | max_acc_val = avg_acc_val | 1088 | max_acc_val = avg_acc_val |
| 1090 | 1089 | ||
| 1091 | if sample_checkpoint and accelerator.is_main_process: | 1090 | if sample_checkpoint and accelerator.is_main_process: |
| 1092 | checkpointer.save_samples( | 1091 | checkpointer.save_samples(global_step, args.sample_steps) |
| 1093 | global_step, | ||
| 1094 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 1095 | 1092 | ||
| 1096 | # Create the pipeline using using the trained modules and save it. | 1093 | # Create the pipeline using using the trained modules and save it. |
| 1097 | if accelerator.is_main_process: | 1094 | if accelerator.is_main_process: |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 85b0216..c77c4d1 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -20,6 +20,7 @@ from diffusers import ( | |||
| 20 | PNDMScheduler, | 20 | PNDMScheduler, |
| 21 | ) | 21 | ) |
| 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 22 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 23 | from diffusers.models.vae import DecoderOutput | ||
| 23 | from diffusers.utils import logging | 24 | from diffusers.utils import logging |
| 24 | from transformers import CLIPTextModel, CLIPTokenizer | 25 | from transformers import CLIPTextModel, CLIPTokenizer |
| 25 | from models.clip.prompt import PromptProcessor | 26 | from models.clip.prompt import PromptProcessor |
| @@ -69,6 +70,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 69 | scheduler._internal_dict = FrozenDict(new_config) | 70 | scheduler._internal_dict = FrozenDict(new_config) |
| 70 | 71 | ||
| 71 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) | 72 | self.prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 73 | self.use_slicing = False | ||
| 72 | 74 | ||
| 73 | self.register_modules( | 75 | self.register_modules( |
| 74 | vae=vae, | 76 | vae=vae, |
| @@ -136,6 +138,21 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 136 | if cpu_offloaded_model is not None: | 138 | if cpu_offloaded_model is not None: |
| 137 | cpu_offload(cpu_offloaded_model, device) | 139 | cpu_offload(cpu_offloaded_model, device) |
| 138 | 140 | ||
| 141 | def enable_vae_slicing(self): | ||
| 142 | r""" | ||
| 143 | Enable sliced VAE decoding. | ||
| 144 | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several | ||
| 145 | steps. This is useful to save some memory and allow larger batch sizes. | ||
| 146 | """ | ||
| 147 | self.use_slicing = True | ||
| 148 | |||
| 149 | def disable_vae_slicing(self): | ||
| 150 | r""" | ||
| 151 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to | ||
| 152 | computing decoding in one step. | ||
| 153 | """ | ||
| 154 | self.use_slicing = False | ||
| 155 | |||
| 139 | @property | 156 | @property |
| 140 | def execution_device(self): | 157 | def execution_device(self): |
| 141 | r""" | 158 | r""" |
| @@ -280,12 +297,20 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 280 | 297 | ||
| 281 | def decode_latents(self, latents): | 298 | def decode_latents(self, latents): |
| 282 | latents = 1 / 0.18215 * latents | 299 | latents = 1 / 0.18215 * latents |
| 283 | image = self.vae.decode(latents).sample | 300 | image = self.vae_decode(latents).sample |
| 284 | image = (image / 2 + 0.5).clamp(0, 1) | 301 | image = (image / 2 + 0.5).clamp(0, 1) |
| 285 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 302 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
| 286 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 303 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| 287 | return image | 304 | return image |
| 288 | 305 | ||
| 306 | def vae_decode(self, latents): | ||
| 307 | if self.use_slicing: | ||
| 308 | decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)] | ||
| 309 | decoded = torch.cat(decoded_slices) | ||
| 310 | return DecoderOutput(sample=decoded) | ||
| 311 | else: | ||
| 312 | return self.vae.decode(latents) | ||
| 313 | |||
| 289 | @torch.no_grad() | 314 | @torch.no_grad() |
| 290 | def __call__( | 315 | def __call__( |
| 291 | self, | 316 | self, |
