summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-11-30 14:02:35 +0100
committerVolpeon <git@volpeon.ink>2022-11-30 14:02:35 +0100
commit329ad48b307e782b0e23fce80ae9087a4003e73d (patch)
tree0c72434a8d45ae933582064849b43bd7419f7ee8
parentAdjusted training to upstream (diff)
downloadtextual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.gz
textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.tar.bz2
textual-inversion-diff-329ad48b307e782b0e23fce80ae9087a4003e73d.zip
Update
-rw-r--r--dreambooth.py19
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py27
2 files changed, 34 insertions, 12 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 49d4447..3dd0920 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -115,7 +115,7 @@ def parse_args():
115 parser.add_argument( 115 parser.add_argument(
116 "--resolution", 116 "--resolution",
117 type=int, 117 type=int,
118 default=512, 118 default=768,
119 help=( 119 help=(
120 "The resolution for input images, all the images in the train/validation dataset will be resized to this" 120 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
121 " resolution" 121 " resolution"
@@ -267,7 +267,7 @@ def parse_args():
267 parser.add_argument( 267 parser.add_argument(
268 "--sample_image_size", 268 "--sample_image_size",
269 type=int, 269 type=int,
270 default=512, 270 default=768,
271 help="Size of sample images", 271 help="Size of sample images",
272 ) 272 )
273 parser.add_argument( 273 parser.add_argument(
@@ -297,7 +297,7 @@ def parse_args():
297 parser.add_argument( 297 parser.add_argument(
298 "--sample_steps", 298 "--sample_steps",
299 type=int, 299 type=int,
300 default=25, 300 default=15,
301 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 301 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
302 ) 302 )
303 parser.add_argument( 303 parser.add_argument(
@@ -459,7 +459,7 @@ class Checkpointer:
459 torch.cuda.empty_cache() 459 torch.cuda.empty_cache()
460 460
461 @torch.no_grad() 461 @torch.no_grad()
462 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): 462 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
463 samples_path = Path(self.output_dir).joinpath("samples") 463 samples_path = Path(self.output_dir).joinpath("samples")
464 464
465 unwrapped_unet = self.accelerator.unwrap_model( 465 unwrapped_unet = self.accelerator.unwrap_model(
@@ -474,13 +474,14 @@ class Checkpointer:
474 scheduler=self.scheduler, 474 scheduler=self.scheduler,
475 ).to(self.accelerator.device) 475 ).to(self.accelerator.device)
476 pipeline.set_progress_bar_config(dynamic_ncols=True) 476 pipeline.set_progress_bar_config(dynamic_ncols=True)
477 pipeline.enable_vae_slicing()
477 478
478 train_data = self.datamodule.train_dataloader() 479 train_data = self.datamodule.train_dataloader()
479 val_data = self.datamodule.val_dataloader() 480 val_data = self.datamodule.val_dataloader()
480 481
481 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 482 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
482 stable_latents = torch.randn( 483 stable_latents = torch.randn(
483 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 484 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
484 device=pipeline.device, 485 device=pipeline.device,
485 generator=generator, 486 generator=generator,
486 ) 487 )
@@ -875,9 +876,7 @@ def main():
875 ) 876 )
876 877
877 if accelerator.is_main_process: 878 if accelerator.is_main_process:
878 checkpointer.save_samples( 879 checkpointer.save_samples(0, args.sample_steps)
879 0,
880 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
881 880
882 local_progress_bar = tqdm( 881 local_progress_bar = tqdm(
883 range(num_update_steps_per_epoch + num_val_steps_per_epoch), 882 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
@@ -1089,9 +1088,7 @@ def main():
1089 max_acc_val = avg_acc_val 1088 max_acc_val = avg_acc_val
1090 1089
1091 if sample_checkpoint and accelerator.is_main_process: 1090 if sample_checkpoint and accelerator.is_main_process:
1092 checkpointer.save_samples( 1091 checkpointer.save_samples(global_step, args.sample_steps)
1093 global_step,
1094 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
1095 1092
1096 # Create the pipeline using using the trained modules and save it. 1093 # Create the pipeline using using the trained modules and save it.
1097 if accelerator.is_main_process: 1094 if accelerator.is_main_process:
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 85b0216..c77c4d1 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -20,6 +20,7 @@ from diffusers import (
20 PNDMScheduler, 20 PNDMScheduler,
21) 21)
22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 22from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
23from diffusers.models.vae import DecoderOutput
23from diffusers.utils import logging 24from diffusers.utils import logging
24from transformers import CLIPTextModel, CLIPTokenizer 25from transformers import CLIPTextModel, CLIPTokenizer
25from models.clip.prompt import PromptProcessor 26from models.clip.prompt import PromptProcessor
@@ -69,6 +70,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
69 scheduler._internal_dict = FrozenDict(new_config) 70 scheduler._internal_dict = FrozenDict(new_config)
70 71
71 self.prompt_processor = PromptProcessor(tokenizer, text_encoder) 72 self.prompt_processor = PromptProcessor(tokenizer, text_encoder)
73 self.use_slicing = False
72 74
73 self.register_modules( 75 self.register_modules(
74 vae=vae, 76 vae=vae,
@@ -136,6 +138,21 @@ class VlpnStableDiffusion(DiffusionPipeline):
136 if cpu_offloaded_model is not None: 138 if cpu_offloaded_model is not None:
137 cpu_offload(cpu_offloaded_model, device) 139 cpu_offload(cpu_offloaded_model, device)
138 140
141 def enable_vae_slicing(self):
142 r"""
143 Enable sliced VAE decoding.
144 When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
145 steps. This is useful to save some memory and allow larger batch sizes.
146 """
147 self.use_slicing = True
148
149 def disable_vae_slicing(self):
150 r"""
151 Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
152 computing decoding in one step.
153 """
154 self.use_slicing = False
155
139 @property 156 @property
140 def execution_device(self): 157 def execution_device(self):
141 r""" 158 r"""
@@ -280,12 +297,20 @@ class VlpnStableDiffusion(DiffusionPipeline):
280 297
281 def decode_latents(self, latents): 298 def decode_latents(self, latents):
282 latents = 1 / 0.18215 * latents 299 latents = 1 / 0.18215 * latents
283 image = self.vae.decode(latents).sample 300 image = self.vae_decode(latents).sample
284 image = (image / 2 + 0.5).clamp(0, 1) 301 image = (image / 2 + 0.5).clamp(0, 1)
285 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 302 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
286 image = image.cpu().permute(0, 2, 3, 1).float().numpy() 303 image = image.cpu().permute(0, 2, 3, 1).float().numpy()
287 return image 304 return image
288 305
306 def vae_decode(self, latents):
307 if self.use_slicing:
308 decoded_slices = [self.vae.decode(latents_slice).sample for latents_slice in latents.split(1)]
309 decoded = torch.cat(decoded_slices)
310 return DecoderOutput(sample=decoded)
311 else:
312 return self.vae.decode(latents)
313
289 @torch.no_grad() 314 @torch.no_grad()
290 def __call__( 315 def __call__(
291 self, 316 self,