diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 50 | ||||
| -rw-r--r-- | training/optimization.py | 2 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 6 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 | 
4 files changed, 32 insertions, 28 deletions
diff --git a/training/functional.py b/training/functional.py index b830261..990c4cd 100644 --- a/training/functional.py +++ b/training/functional.py  | |||
| @@ -22,7 +22,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 
| 23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings | 
| 24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer | 
| 25 | from schedulers.scheduling_deis_multistep import DEISMultistepScheduler | ||
| 26 | from training.util import AverageMeter | 25 | from training.util import AverageMeter | 
| 27 | 26 | ||
| 28 | 27 | ||
| @@ -74,19 +73,12 @@ def make_grid(images, rows, cols): | |||
| 74 | return grid | 73 | return grid | 
| 75 | 74 | ||
| 76 | 75 | ||
| 77 | def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): | 76 | def get_models(pretrained_model_name_or_path: str): | 
| 78 | if noise_scheduler == "deis": | ||
| 79 | noise_scheduler_cls = DEISMultistepScheduler | ||
| 80 | elif noise_scheduler == "ddpm": | ||
| 81 | noise_scheduler_cls = DDPMScheduler | ||
| 82 | else: | ||
| 83 | raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") | ||
| 84 | |||
| 85 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 
| 86 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 
| 87 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 
| 88 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 80 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 
| 89 | noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 81 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 
| 90 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 
| 91 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') | 
| 92 | 84 | ||
| @@ -232,9 +224,6 @@ def generate_class_images( | |||
| 232 | 224 | ||
| 233 | del pipeline | 225 | del pipeline | 
| 234 | 226 | ||
| 235 | if torch.cuda.is_available(): | ||
| 236 | torch.cuda.empty_cache() | ||
| 237 | |||
| 238 | 227 | ||
| 239 | def add_placeholder_tokens( | 228 | def add_placeholder_tokens( | 
| 240 | tokenizer: MultiCLIPTokenizer, | 229 | tokenizer: MultiCLIPTokenizer, | 
| @@ -274,26 +263,41 @@ def loss_step( | |||
| 274 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 263 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 
| 275 | latents = latents * vae.config.scaling_factor | 264 | latents = latents * vae.config.scaling_factor | 
| 276 | 265 | ||
| 266 | bsz = latents.shape[0] | ||
| 267 | |||
| 277 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 
| 278 | 269 | ||
| 279 | # Sample noise that we'll add to the latents | 270 | # Sample noise that we'll add to the latents | 
| 280 | noise = torch.randn( | 271 | if low_freq_noise == 0: | 
| 281 | latents.shape, | 272 | noise = torch.randn( | 
| 282 | dtype=latents.dtype, | 273 | latents.shape, | 
| 283 | layout=latents.layout, | 274 | dtype=latents.dtype, | 
| 284 | device=latents.device, | 275 | layout=latents.layout, | 
| 285 | generator=generator | 276 | device=latents.device, | 
| 286 | ) | 277 | generator=generator | 
| 287 | if low_freq_noise > 0: | 278 | ) | 
| 288 | noise += low_freq_noise * torch.randn( | 279 | else: | 
| 280 | noise = (1 - low_freq_noise) * torch.randn( | ||
| 281 | latents.shape, | ||
| 282 | dtype=latents.dtype, | ||
| 283 | layout=latents.layout, | ||
| 284 | device=latents.device, | ||
| 285 | generator=generator | ||
| 286 | ) + low_freq_noise * torch.randn( | ||
| 289 | latents.shape[0], latents.shape[1], 1, 1, | 287 | latents.shape[0], latents.shape[1], 1, 1, | 
| 290 | dtype=latents.dtype, | 288 | dtype=latents.dtype, | 
| 291 | layout=latents.layout, | 289 | layout=latents.layout, | 
| 292 | device=latents.device, | 290 | device=latents.device, | 
| 293 | generator=generator | 291 | generator=generator | 
| 294 | ) | 292 | ) | 
| 293 | # noise += low_freq_noise * torch.randn( | ||
| 294 | # bsz, 1, 1, 1, | ||
| 295 | # dtype=latents.dtype, | ||
| 296 | # layout=latents.layout, | ||
| 297 | # device=latents.device, | ||
| 298 | # generator=generator | ||
| 299 | # ) | ||
| 295 | 300 | ||
| 296 | bsz = latents.shape[0] | ||
| 297 | # Sample a random timestep for each image | 301 | # Sample a random timestep for each image | 
| 298 | timesteps = torch.randint( | 302 | timesteps = torch.randint( | 
| 299 | 0, | 303 | 0, | 
diff --git a/training/optimization.py b/training/optimization.py index 6c9a35d..7d8d55a 100644 --- a/training/optimization.py +++ b/training/optimization.py  | |||
| @@ -113,7 +113,7 @@ def get_scheduler( | |||
| 113 | ): | 113 | ): | 
| 114 | num_training_steps_per_epoch = math.ceil( | 114 | num_training_steps_per_epoch = math.ceil( | 
| 115 | num_training_steps_per_epoch / gradient_accumulation_steps | 115 | num_training_steps_per_epoch / gradient_accumulation_steps | 
| 116 | ) * gradient_accumulation_steps | 116 | ) # * gradient_accumulation_steps | 
| 117 | num_training_steps = train_epochs * num_training_steps_per_epoch | 117 | num_training_steps = train_epochs * num_training_steps_per_epoch | 
| 118 | num_warmup_steps = warmup_epochs * num_training_steps_per_epoch | 118 | num_warmup_steps = warmup_epochs * num_training_steps_per_epoch | 
| 119 | 119 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0290327..e5e84c8 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py  | |||
| @@ -88,8 +88,8 @@ def dreambooth_strategy_callbacks( | |||
| 88 | 88 | ||
| 89 | def on_prepare(): | 89 | def on_prepare(): | 
| 90 | unet.requires_grad_(True) | 90 | unet.requires_grad_(True) | 
| 91 | text_encoder.requires_grad_(True) | 91 | text_encoder.text_model.encoder.requires_grad_(True) | 
| 92 | text_encoder.text_model.embeddings.requires_grad_(False) | 92 | text_encoder.text_model.final_layer_norm.requires_grad_(True) | 
| 93 | 93 | ||
| 94 | if ema_unet is not None: | 94 | if ema_unet is not None: | 
| 95 | ema_unet.to(accelerator.device) | 95 | ema_unet.to(accelerator.device) | 
| @@ -203,7 +203,7 @@ def dreambooth_prepare( | |||
| 203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 203 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 
| 204 | **kwargs | 204 | **kwargs | 
| 205 | ): | 205 | ): | 
| 206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) | 206 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) | 
| 207 | 207 | ||
| 208 | 208 | ||
| 209 | dreambooth_strategy = TrainingStrategy( | 209 | dreambooth_strategy = TrainingStrategy( | 
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 732cd74..bd0d178 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py  | |||
| @@ -130,7 +130,7 @@ def textual_inversion_strategy_callbacks( | |||
| 130 | if lambda_ != 0: | 130 | if lambda_ != 0: | 
| 131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 131 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight | 
| 132 | 132 | ||
| 133 | mask = torch.zeros(w.size(0), dtype=torch.bool) | 133 | mask = torch.zeros(w.shape[0], dtype=torch.bool) | 
| 134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | 134 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | 
| 135 | mask[zero_ids] = False | 135 | mask[zero_ids] = False | 
| 136 | 136 | ||
