summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-15 10:12:04 +0100
committerVolpeon <git@volpeon.ink>2023-01-15 10:12:04 +0100
commit34648b763fa60e3161a5b5f1243ed1b5c3b0188e (patch)
tree4c2b8104a8d1af26955561959591249d9281a02f /training/functional.py
parentAdded functional trainer (diff)
downloadtextual-inversion-diff-34648b763fa60e3161a5b5f1243ed1b5c3b0188e.tar.gz
textual-inversion-diff-34648b763fa60e3161a5b5f1243ed1b5c3b0188e.tar.bz2
textual-inversion-diff-34648b763fa60e3161a5b5f1243ed1b5c3b0188e.zip
Added functional TI strategy
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py118
1 files changed, 118 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py
index 1f2ca6d..e54c9c8 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -2,6 +2,8 @@ import math
2from contextlib import _GeneratorContextManager, nullcontext 2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union, Optional 3from typing import Callable, Any, Tuple, Union, Optional
4from functools import partial 4from functools import partial
5from pathlib import Path
6import itertools
5 7
6import torch 8import torch
7import torch.nn.functional as F 9import torch.nn.functional as F
@@ -26,6 +28,14 @@ def const(result=None):
26 return fn 28 return fn
27 29
28 30
31def make_grid(images, rows, cols):
32 w, h = images[0].size
33 grid = Image.new('RGB', size=(cols*w, rows*h))
34 for i, image in enumerate(images):
35 grid.paste(image, box=(i % cols*w, i//cols*h))
36 return grid
37
38
29def get_models(pretrained_model_name_or_path: str): 39def get_models(pretrained_model_name_or_path: str):
30 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 40 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
31 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 41 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
@@ -40,6 +50,107 @@ def get_models(pretrained_model_name_or_path: str):
40 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 50 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
41 51
42 52
53def save_samples(
54 accelerator: Accelerator,
55 unet: UNet2DConditionModel,
56 text_encoder: CLIPTextModel,
57 tokenizer: MultiCLIPTokenizer,
58 vae: AutoencoderKL,
59 sample_scheduler: DPMSolverMultistepScheduler,
60 train_dataloader: DataLoader,
61 val_dataloader: DataLoader,
62 dtype: torch.dtype,
63 output_dir: Path,
64 seed: int,
65 step: int,
66 batch_size: int = 1,
67 num_batches: int = 1,
68 num_steps: int = 20,
69 guidance_scale: float = 7.5,
70 image_size: Optional[int] = None,
71):
72 print(f"Saving samples for step {step}...")
73
74 samples_path = output_dir.joinpath("samples")
75
76 grid_cols = min(batch_size, 4)
77 grid_rows = (num_batches * batch_size) // grid_cols
78
79 unet = accelerator.unwrap_model(unet)
80 text_encoder = accelerator.unwrap_model(text_encoder)
81
82 orig_unet_dtype = unet.dtype
83 orig_text_encoder_dtype = text_encoder.dtype
84
85 unet.to(dtype=dtype)
86 text_encoder.to(dtype=dtype)
87
88 pipeline = VlpnStableDiffusion(
89 text_encoder=text_encoder,
90 vae=vae,
91 unet=unet,
92 tokenizer=tokenizer,
93 scheduler=sample_scheduler,
94 ).to(accelerator.device)
95 pipeline.set_progress_bar_config(dynamic_ncols=True)
96
97 generator = torch.Generator(device=accelerator.device).manual_seed(seed)
98
99 for pool, data, gen in [
100 ("stable", val_dataloader, generator),
101 ("val", val_dataloader, None),
102 ("train", train_dataloader, None)
103 ]:
104 all_samples = []
105 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
106 file_path.parent.mkdir(parents=True, exist_ok=True)
107
108 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches))
109 prompt_ids = [
110 prompt
111 for batch in batches
112 for prompt in batch["prompt_ids"]
113 ]
114 nprompt_ids = [
115 prompt
116 for batch in batches
117 for prompt in batch["nprompt_ids"]
118 ]
119
120 for i in range(num_batches):
121 start = i * batch_size
122 end = (i + 1) * batch_size
123 prompt = prompt_ids[start:end]
124 nprompt = nprompt_ids[start:end]
125
126 samples = pipeline(
127 prompt=prompt,
128 negative_prompt=nprompt,
129 height=image_size,
130 width=image_size,
131 generator=gen,
132 guidance_scale=guidance_scale,
133 num_inference_steps=num_steps,
134 output_type='pil'
135 ).images
136
137 all_samples += samples
138
139 image_grid = make_grid(all_samples, grid_rows, grid_cols)
140 image_grid.save(file_path, quality=85)
141
142 unet.to(dtype=orig_unet_dtype)
143 text_encoder.to(dtype=orig_text_encoder_dtype)
144
145 del unet
146 del text_encoder
147 del generator
148 del pipeline
149
150 if torch.cuda.is_available():
151 torch.cuda.empty_cache()
152
153
43def generate_class_images( 154def generate_class_images(
44 accelerator: Accelerator, 155 accelerator: Accelerator,
45 text_encoder: CLIPTextModel, 156 text_encoder: CLIPTextModel,
@@ -109,6 +220,10 @@ def get_models(pretrained_model_name_or_path: str):
109 220
110 embeddings = patch_managed_embeddings(text_encoder) 221 embeddings = patch_managed_embeddings(text_encoder)
111 222
223 vae.requires_grad_(False)
224 unet.requires_grad_(False)
225 text_encoder.requires_grad_(False)
226
112 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 227 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
113 228
114 229
@@ -427,6 +542,9 @@ def train(
427 seed, 542 seed,
428 ) 543 )
429 544
545 if accelerator.is_main_process:
546 accelerator.init_trackers("textual_inversion")
547
430 train_loop( 548 train_loop(
431 accelerator=accelerator, 549 accelerator=accelerator,
432 optimizer=optimizer, 550 optimizer=optimizer,