diff options
author | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-25 23:50:24 +0100 |
commit | 7505f7e843dc719622a15f4ee301609813763d77 (patch) | |
tree | fe67640dce9fec4f625d6d1600c696cd7de006ee | |
parent | Update (diff) | |
download | textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.gz textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.tar.bz2 textual-inversion-diff-7505f7e843dc719622a15f4ee301609813763d77.zip |
Code simplifications, avoid autocast
-rw-r--r-- | infer.py | 48 | ||||
-rw-r--r-- | train_dreambooth.py | 12 | ||||
-rw-r--r-- | train_ti.py | 8 | ||||
-rw-r--r-- | training/util.py | 106 |
4 files changed, 92 insertions, 82 deletions
@@ -209,6 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype): | |||
209 | return pipeline | 209 | return pipeline |
210 | 210 | ||
211 | 211 | ||
212 | @torch.inference_mode() | ||
212 | def generate(output_dir, pipeline, args): | 213 | def generate(output_dir, pipeline, args): |
213 | if isinstance(args.prompt, str): | 214 | if isinstance(args.prompt, str): |
214 | args.prompt = [args.prompt] | 215 | args.prompt = [args.prompt] |
@@ -245,30 +246,29 @@ def generate(output_dir, pipeline, args): | |||
245 | elif args.scheduler == "kdpm2_a": | 246 | elif args.scheduler == "kdpm2_a": |
246 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) | 247 | pipeline.scheduler = KDPM2AncestralDiscreteScheduler.from_config(pipeline.scheduler.config) |
247 | 248 | ||
248 | with torch.autocast("cuda"), torch.inference_mode(): | 249 | for i in range(args.batch_num): |
249 | for i in range(args.batch_num): | 250 | pipeline.set_progress_bar_config( |
250 | pipeline.set_progress_bar_config( | 251 | desc=f"Batch {i + 1} of {args.batch_num}", |
251 | desc=f"Batch {i + 1} of {args.batch_num}", | 252 | dynamic_ncols=True |
252 | dynamic_ncols=True | 253 | ) |
253 | ) | 254 | |
254 | 255 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | |
255 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 256 | images = pipeline( |
256 | images = pipeline( | 257 | prompt=args.prompt, |
257 | prompt=args.prompt, | 258 | negative_prompt=args.negative_prompt, |
258 | negative_prompt=args.negative_prompt, | 259 | height=args.height, |
259 | height=args.height, | 260 | width=args.width, |
260 | width=args.width, | 261 | num_images_per_prompt=args.batch_size, |
261 | num_images_per_prompt=args.batch_size, | 262 | num_inference_steps=args.steps, |
262 | num_inference_steps=args.steps, | 263 | guidance_scale=args.guidance_scale, |
263 | guidance_scale=args.guidance_scale, | 264 | generator=generator, |
264 | generator=generator, | 265 | image=init_image, |
265 | image=init_image, | 266 | strength=args.image_noise, |
266 | strength=args.image_noise, | 267 | ).images |
267 | ).images | 268 | |
268 | 269 | for j, image in enumerate(images): | |
269 | for j, image in enumerate(images): | 270 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) |
270 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) | 271 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) |
271 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) | ||
272 | 272 | ||
273 | if torch.cuda.is_available(): | 273 | if torch.cuda.is_available(): |
274 | torch.cuda.empty_cache() | 274 | torch.cuda.empty_cache() |
diff --git a/train_dreambooth.py b/train_dreambooth.py index e239833..2c765ec 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -389,6 +389,7 @@ def parse_args(): | |||
389 | class Checkpointer(CheckpointerBase): | 389 | class Checkpointer(CheckpointerBase): |
390 | def __init__( | 390 | def __init__( |
391 | self, | 391 | self, |
392 | weight_dtype, | ||
392 | datamodule, | 393 | datamodule, |
393 | accelerator, | 394 | accelerator, |
394 | vae, | 395 | vae, |
@@ -416,6 +417,7 @@ class Checkpointer(CheckpointerBase): | |||
416 | sample_batch_size=sample_batch_size | 417 | sample_batch_size=sample_batch_size |
417 | ) | 418 | ) |
418 | 419 | ||
420 | self.weight_dtype = weight_dtype | ||
419 | self.accelerator = accelerator | 421 | self.accelerator = accelerator |
420 | self.vae = vae | 422 | self.vae = vae |
421 | self.unet = unet | 423 | self.unet = unet |
@@ -452,6 +454,12 @@ class Checkpointer(CheckpointerBase): | |||
452 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 454 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
453 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 455 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
454 | 456 | ||
457 | orig_unet_dtype = unet.dtype | ||
458 | orig_text_encoder_dtype = text_encoder.dtype | ||
459 | |||
460 | unet.to(dtype=self.weight_dtype) | ||
461 | text_encoder.to(dtype=self.weight_dtype) | ||
462 | |||
455 | pipeline = VlpnStableDiffusion( | 463 | pipeline = VlpnStableDiffusion( |
456 | text_encoder=text_encoder, | 464 | text_encoder=text_encoder, |
457 | vae=self.vae, | 465 | vae=self.vae, |
@@ -463,6 +471,9 @@ class Checkpointer(CheckpointerBase): | |||
463 | 471 | ||
464 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 472 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
465 | 473 | ||
474 | unet.to(dtype=orig_unet_dtype) | ||
475 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
476 | |||
466 | del unet | 477 | del unet |
467 | del text_encoder | 478 | del text_encoder |
468 | del pipeline | 479 | del pipeline |
@@ -798,6 +809,7 @@ def main(): | |||
798 | max_acc_val = 0.0 | 809 | max_acc_val = 0.0 |
799 | 810 | ||
800 | checkpointer = Checkpointer( | 811 | checkpointer = Checkpointer( |
812 | weight_dtype=weight_dtype, | ||
801 | datamodule=datamodule, | 813 | datamodule=datamodule, |
802 | accelerator=accelerator, | 814 | accelerator=accelerator, |
803 | vae=vae, | 815 | vae=vae, |
diff --git a/train_ti.py b/train_ti.py index 5f37d54..a228795 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -361,6 +361,7 @@ def parse_args(): | |||
361 | class Checkpointer(CheckpointerBase): | 361 | class Checkpointer(CheckpointerBase): |
362 | def __init__( | 362 | def __init__( |
363 | self, | 363 | self, |
364 | weight_dtype, | ||
364 | datamodule, | 365 | datamodule, |
365 | accelerator, | 366 | accelerator, |
366 | vae, | 367 | vae, |
@@ -387,6 +388,7 @@ class Checkpointer(CheckpointerBase): | |||
387 | sample_batch_size=sample_batch_size | 388 | sample_batch_size=sample_batch_size |
388 | ) | 389 | ) |
389 | 390 | ||
391 | self.weight_dtype = weight_dtype | ||
390 | self.accelerator = accelerator | 392 | self.accelerator = accelerator |
391 | self.vae = vae | 393 | self.vae = vae |
392 | self.unet = unet | 394 | self.unet = unet |
@@ -417,8 +419,9 @@ class Checkpointer(CheckpointerBase): | |||
417 | @torch.no_grad() | 419 | @torch.no_grad() |
418 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 420 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
419 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 421 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
422 | orig_dtype = text_encoder.dtype | ||
423 | text_encoder.to(dtype=self.weight_dtype) | ||
420 | 424 | ||
421 | # Save a sample image | ||
422 | pipeline = VlpnStableDiffusion( | 425 | pipeline = VlpnStableDiffusion( |
423 | text_encoder=text_encoder, | 426 | text_encoder=text_encoder, |
424 | vae=self.vae, | 427 | vae=self.vae, |
@@ -430,6 +433,8 @@ class Checkpointer(CheckpointerBase): | |||
430 | 433 | ||
431 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) | 434 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
432 | 435 | ||
436 | text_encoder.to(dtype=orig_dtype) | ||
437 | |||
433 | del text_encoder | 438 | del text_encoder |
434 | del pipeline | 439 | del pipeline |
435 | 440 | ||
@@ -739,6 +744,7 @@ def main(): | |||
739 | max_acc_val = 0.0 | 744 | max_acc_val = 0.0 |
740 | 745 | ||
741 | checkpointer = Checkpointer( | 746 | checkpointer = Checkpointer( |
747 | weight_dtype=weight_dtype, | ||
742 | datamodule=datamodule, | 748 | datamodule=datamodule, |
743 | accelerator=accelerator, | 749 | accelerator=accelerator, |
744 | vae=vae, | 750 | vae=vae, |
diff --git a/training/util.py b/training/util.py index 5c056a6..a0c15cd 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -60,7 +60,7 @@ class CheckpointerBase: | |||
60 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
61 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
62 | 62 | ||
63 | @torch.no_grad() | 63 | @torch.inference_mode() |
64 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 64 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
65 | samples_path = Path(self.output_dir).joinpath("samples") | 65 | samples_path = Path(self.output_dir).joinpath("samples") |
66 | 66 | ||
@@ -68,65 +68,57 @@ class CheckpointerBase: | |||
68 | val_data = self.datamodule.val_dataloader() | 68 | val_data = self.datamodule.val_dataloader() |
69 | 69 | ||
70 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 70 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
71 | stable_latents = torch.randn( | ||
72 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
73 | device=pipeline.device, | ||
74 | generator=generator, | ||
75 | ) | ||
76 | 71 | ||
77 | grid_cols = min(self.sample_batch_size, 4) | 72 | grid_cols = min(self.sample_batch_size, 4) |
78 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols | 73 | grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols |
79 | 74 | ||
80 | with torch.autocast("cuda"), torch.inference_mode(): | 75 | for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: |
81 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | 76 | all_samples = [] |
82 | all_samples = [] | 77 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
83 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 78 | file_path.parent.mkdir(parents=True, exist_ok=True) |
84 | file_path.parent.mkdir(parents=True, exist_ok=True) | 79 | |
85 | 80 | data_enum = enumerate(data) | |
86 | data_enum = enumerate(data) | 81 | |
87 | 82 | batches = [ | |
88 | batches = [ | 83 | batch |
89 | batch | 84 | for j, batch in data_enum |
90 | for j, batch in data_enum | 85 | if j * data.batch_size < self.sample_batch_size * self.sample_batches |
91 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | 86 | ] |
92 | ] | 87 | prompts = [ |
93 | prompts = [ | 88 | prompt |
94 | prompt | 89 | for batch in batches |
95 | for batch in batches | 90 | for prompt in batch["prompts"] |
96 | for prompt in batch["prompts"] | 91 | ] |
97 | ] | 92 | nprompts = [ |
98 | nprompts = [ | 93 | prompt |
99 | prompt | 94 | for batch in batches |
100 | for batch in batches | 95 | for prompt in batch["nprompts"] |
101 | for prompt in batch["nprompts"] | 96 | ] |
102 | ] | 97 | |
103 | 98 | for i in range(self.sample_batches): | |
104 | for i in range(self.sample_batches): | 99 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
105 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 100 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
106 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | 101 | |
107 | 102 | samples = pipeline( | |
108 | samples = pipeline( | 103 | prompt=prompt, |
109 | prompt=prompt, | 104 | negative_prompt=nprompt, |
110 | negative_prompt=nprompt, | 105 | height=self.sample_image_size, |
111 | height=self.sample_image_size, | 106 | width=self.sample_image_size, |
112 | width=self.sample_image_size, | 107 | generator=gen, |
113 | image=latents[:len(prompt)] if latents is not None else None, | 108 | guidance_scale=guidance_scale, |
114 | generator=generator if latents is not None else None, | 109 | eta=eta, |
115 | guidance_scale=guidance_scale, | 110 | num_inference_steps=num_inference_steps, |
116 | eta=eta, | 111 | output_type='pil' |
117 | num_inference_steps=num_inference_steps, | 112 | ).images |
118 | output_type='pil' | 113 | |
119 | ).images | 114 | all_samples += samples |
120 | 115 | ||
121 | all_samples += samples | 116 | del samples |
122 | 117 | ||
123 | del samples | 118 | image_grid = make_grid(all_samples, grid_rows, grid_cols) |
124 | 119 | image_grid.save(file_path, quality=85) | |
125 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 120 | |
126 | image_grid.save(file_path, quality=85) | 121 | del all_samples |
127 | 122 | del image_grid | |
128 | del all_samples | ||
129 | del image_grid | ||
130 | 123 | ||
131 | del generator | 124 | del generator |
132 | del stable_latents | ||