diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 50 | ||||
-rw-r--r-- | training/optimization.py | 2 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 6 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
4 files changed, 32 insertions, 28 deletions
diff --git a/training/functional.py b/training/functional.py index b830261..990c4cd 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -22,7 +22,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
25 | from schedulers.scheduling_deis_multistep import DEISMultistepScheduler | ||
26 | from training.util import AverageMeter | 25 | from training.util import AverageMeter |
27 | 26 | ||
28 | 27 | ||
@@ -74,19 +73,12 @@ def make_grid(images, rows, cols): | |||
74 | return grid | 73 | return grid |
75 | 74 | ||
76 | 75 | ||
77 | def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): | 76 | def get_models(pretrained_model_name_or_path: str): |
78 | if noise_scheduler == "deis": | ||
79 | noise_scheduler_cls = DEISMultistepScheduler | ||
80 | elif noise_scheduler == "ddpm": | ||
81 | noise_scheduler_cls = DDPMScheduler | ||
82 | else: | ||
83 | raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") | ||
84 | |||
85 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
86 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
87 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
88 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 80 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
89 | noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 81 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
90 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
91 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') |
92 | 84 | ||
@@ -232,9 +224,6 @@ def generate_class_images( | |||
232 | 224 | ||
233 | del pipeline | 225 | del pipeline |
234 | 226 | ||
235 | if torch.cuda.is_available(): | ||
236 | torch.cuda.empty_cache() | ||
237 | |||
238 | 227 | ||
239 | def add_placeholder_tokens( | 228 | def add_placeholder_tokens( |
240 | tokenizer: MultiCLIPTokenizer, | 229 | tokenizer: MultiCLIPTokenizer, |
@@ -274,26 +263,41 @@ def loss_step( | |||
274 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 263 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
275 | latents = latents * vae.config.scaling_factor | 264 | latents = latents * vae.config.scaling_factor |
276 | 265 | ||
266 | bsz = latents.shape[0] | ||
267 | |||
277 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None |
278 | 269 | ||
279 | # Sample noise that we'll add to the latents | 270 | # Sample noise that we'll add to the latents |
280 | noise = torch.randn( | 271 | if low_freq_noise == 0: |
281 | latents.shape, | 272 | noise = torch.randn( |
282 | dtype=latents.dtype, | 273 | latents.shape, |
283 | layout=latents.layout, | 274 | dtype=latents.dtype, |
284 | device=latents.device, | 275 | layout=latents.layout, |
285 | generator=generator | 276 | device=latents.device, |
286 | ) | 277 | generator=generator |
287 | if low_freq_noise > 0: | 278 | ) |
288 | noise += low_freq_noise * torch.randn( | 279 | else: |
280 | noise = (1 - low_freq_noise) * torch.randn( | ||
281 | latents.shape, | ||
282 | dtype=latents.dtype, | ||
283 | layout=latents.layout, | ||
284 | device=latents.device, | ||
285 | generator=generator | ||
286 | ) + low_freq_noise * torch.randn( | ||
289 | latents.shape[0], latents.shape[1], 1, 1, | 287 | latents.shape[0], latents.shape[1], 1, 1, |
290 | dtype=latents.dtype, | 288 | dtype=latents.dtype, |
291 | layout=latents.layout, | 289 | layout=latents.layout, |
292 | device=latents.device, | 290 | device=latents.device, |
293 | generator=generator | 291 | generator=generator |
294 | ) | 292 | ) |
293 | # noise += low_freq_noise * torch.randn( | ||
294 | # bsz, 1, 1, 1, | ||
295 | # dtype=latents.dtype, | ||
296 | # layout=latents.layout, | ||
297 | # device=latents.device, | ||
298 | # generator=generator | ||
299 | # ) | ||
295 | 300 | ||
296 | bsz = latents.shape[0] | ||
297 | # Sample a random timestep for each image | 301 | # Sample a random timestep for each image |
298 | timesteps = torch.randint( | 302 | timesteps = torch.randint( |
299 | 0, | 303 | 0, |
diff --git a/training/optimization.py b/training/optimization.py index 6c9a35d..7d8d55a 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -113,7 +113,7 @@ def get_scheduler( | |||
113 | ): | 113 | ): |
114 | num_training_steps_per_epoch = math.ceil( | 114 | num_training_steps_per_epoch = math.ceil( |
115 | num_training_steps_per_epoch / gradient_accumulation_steps | 115 | num_training_steps_per_epoch / gradient_accumulation_steps |
116 | ) * gradient_accumulation_steps | 116 | ) # * gradient_accumulation_steps |
117 | num_training_steps = train_epochs * num_training_steps_per_epoch | 117 | num_training_steps = train_epochs * num_training_steps_per_epoch |
118 | num_warmup_steps = warmup_epochs * num_training_steps_per_epoch | 118 | num_warmup_steps = warmup_epochs * num_training_steps_per_epoch |
119 | 119 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0290327..e5e84c8 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -88,8 +88,8 @@ def dreambooth_strategy_callbacks( | |||
88 | 88 | ||
89 | def on_prepare(): | 89 | def on_prepare(): |
90 | unet.requires_grad_(True) | 90 | unet.requires_grad_(True) |
91 | text_encoder.requires_grad_(True) | 91 | text_encoder.text_model.encoder.requires_grad_(True) |
92 | text_encoder.text_model.embeddings.requires_grad_(False) | 92 | text_encoder.text_model.final_layer_norm.requires_grad_(True) |
93 | 93 | ||
94 | if ema_unet is not None: | 94 | if ema_unet is not None: |
95 | ema_unet.to(accelerator.device) | 95 | ema_unet.to(accelerator.device) |
@@ -203,7 +203,7 @@ def dreambooth_prepare( | |||
203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
204 | **kwargs | 204 | **kwargs |
205 | ): | 205 | ): |
206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) | 206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) |
207 | 207 | ||
208 | 208 | ||
209 | dreambooth_strategy = TrainingStrategy( | 209 | dreambooth_strategy = TrainingStrategy( |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 732cd74..bd0d178 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -130,7 +130,7 @@ def textual_inversion_strategy_callbacks( | |||
130 | if lambda_ != 0: | 130 | if lambda_ != 0: |
131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
132 | 132 | ||
133 | mask = torch.zeros(w.size(0), dtype=torch.bool) | 133 | mask = torch.zeros(w.shape[0], dtype=torch.bool) |
134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | 134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True |
135 | mask[zero_ids] = False | 135 | mask[zero_ids] = False |
136 | 136 | ||