From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- training/functional.py | 221 +++++++++++++++++++++++++++------------- training/lr.py | 4 +- training/optimization.py | 38 +++++-- training/sampler.py | 2 +- training/strategy/dreambooth.py | 29 +++--- training/strategy/lora.py | 41 +++++--- training/strategy/ti.py | 27 +++-- 7 files changed, 245 insertions(+), 117 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index fd3f9f4..f68faf9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -14,7 +14,13 @@ import numpy as np from accelerate import Accelerator from transformers import CLIPTextModel -from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + DDPMScheduler, + UniPCMultistepScheduler, + SchedulerMixin, +) from tqdm.auto import tqdm @@ -33,11 +39,12 @@ from util.noise import perlin_noise def const(result=None): def fn(*args, **kwargs): return result + return fn @dataclass -class TrainingCallbacks(): +class TrainingCallbacks: on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], Any] = const() @@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol): train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - **kwargs - ) -> Tuple: ... + **kwargs, + ) -> Tuple: + ... @dataclass -class TrainingStrategy(): +class TrainingStrategy: callbacks: Callable[..., TrainingCallbacks] prepare: TrainingStrategyPrepareCallable def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): - tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) - unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) - noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + tokenizer = MultiCLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, subfolder="tokenizer" + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype + ) + vae = AutoencoderKL.from_pretrained( + pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype + ) + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype + ) + noise_scheduler = DDPMScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder="scheduler" + ) + sample_scheduler = UniPCMultistepScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder="scheduler" + ) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler @@ -113,7 +133,9 @@ def save_samples( generator = torch.Generator(device=accelerator.device).manual_seed(seed) - datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] + datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [ + ("train", train_dataloader, None) + ] if val_dataloader is not None: datasets.append(("stable", val_dataloader, generator)) @@ -124,17 +146,11 @@ def save_samples( file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" file_path.parent.mkdir(parents=True, exist_ok=True) - batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) - prompt_ids = [ - prompt - for batch in batches - for prompt in batch["prompt_ids"] - ] - nprompt_ids = [ - prompt - for batch in batches - for prompt in batch["nprompt_ids"] - ] + batches = list( + itertools.islice(itertools.cycle(data), batch_size * num_batches) + ) + prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]] + nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]] with torch.inference_mode(): for i in range(num_batches): @@ -165,7 +181,9 @@ def save_samples( pass image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) - image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] + image_grid = pipeline.numpy_to_pil( + image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() + )[0] image_grid.save(file_path, quality=85) del generator, pipeline @@ -184,15 +202,17 @@ def generate_class_images( train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, - sample_steps: int + sample_steps: int, ): - missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] + missing_data = [ + item for item in train_dataset.items if not item.class_image_path.exists() + ] if len(missing_data) == 0: return batched_data = [ - missing_data[i:i+sample_batch_size] + missing_data[i : i + sample_batch_size] for i in range(0, len(missing_data), sample_batch_size) ] @@ -216,7 +236,7 @@ def generate_class_images( negative_prompt=nprompt, height=sample_image_size, width=sample_image_size, - num_inference_steps=sample_steps + num_inference_steps=sample_steps, ).images for i, image in enumerate(images): @@ -245,8 +265,12 @@ def add_placeholder_tokens( embeddings.resize(len(tokenizer)) - for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): - embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) + for placeholder_token_id, initializer_token_id in zip( + placeholder_token_ids, initializer_token_ids + ): + embeddings.add_embed( + placeholder_token_id, initializer_token_id, initializer_noise + ) return placeholder_token_ids, initializer_token_ids @@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler): # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) @@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler): def get_original( - noise_scheduler, - model_output, - sample: torch.FloatTensor, - timesteps: torch.IntTensor + noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor ): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() while len(sqrt_alphas_cumprod.shape) < len(sample.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(sample.shape) - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) @@ -329,7 +358,9 @@ def loss_step( eval: bool = False, ): images = batch["pixel_values"] - generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None + generator = ( + torch.Generator(device=images.device).manual_seed(seed + step) if eval else None + ) bsz = images.shape[0] # Convert images to latent space @@ -342,7 +373,7 @@ def loss_step( dtype=latents.dtype, layout=latents.layout, device=latents.device, - generator=generator + generator=generator, ) applied_noise = noise @@ -353,7 +384,7 @@ def loss_step( octaves=4, dtype=latents.dtype, device=latents.device, - generator=generator + generator=generator, ) if input_pertubation != 0: @@ -362,7 +393,7 @@ def loss_step( dtype=latents.dtype, layout=latents.layout, device=latents.device, - generator=generator + generator=generator, ) # Sample a random timestep for each image @@ -375,25 +406,27 @@ def loss_step( # Get the text embedding for conditioning encoder_hidden_states = get_extended_embeddings( - text_encoder, - batch["input_ids"], - batch["attention_mask"] + text_encoder, batch["input_ids"], batch["attention_mask"] ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + model_pred = unet( + noisy_latents, timesteps, encoder_hidden_states, return_dict=False + )[0] if guidance_scale != 0: uncond_encoder_hidden_states = get_extended_embeddings( - text_encoder, - batch["negative_input_ids"], - batch["negative_attention_mask"] + text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] ) uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) - model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] - model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) + model_pred_uncond = unet( + noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False + )[0] + model_pred = model_pred_uncond + guidance_scale * ( + model_pred - model_pred_uncond + ) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -401,7 +434,9 @@ def loss_step( elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) acc = (model_pred == target).float().mean() @@ -414,7 +449,9 @@ def loss_step( loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") + prior_loss = F.mse_loss( + model_pred_prior.float(), target_prior.float(), reduction="none" + ) # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss @@ -433,7 +470,10 @@ def loss_step( if min_snr_gamma != 0: snr = compute_snr(timesteps, noise_scheduler) mse_loss_weights = ( - torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] + / snr ) loss = loss * mse_loss_weights @@ -447,8 +487,14 @@ def loss_step( class LossCallable(Protocol): - def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], - eval: bool = False) -> Tuple[Any, Any, int]: ... + def __call__( + self, + step: int, + batch: dict[Any, Any], + cache: dict[str, Any], + eval: bool = False, + ) -> Tuple[Any, Any, int]: + ... def train_loop( @@ -472,9 +518,14 @@ def train_loop( avg_acc_val: AverageMeter = AverageMeter(), callbacks: TrainingCallbacks = TrainingCallbacks(), ): - num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) - num_val_steps_per_epoch = math.ceil( - len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 + num_training_steps_per_epoch = math.ceil( + len(train_dataloader) / gradient_accumulation_steps + ) + num_val_steps_per_epoch = ( + math.ceil(len(val_dataloader) / gradient_accumulation_steps) + if val_dataloader is not None + else 0 + ) num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs @@ -488,14 +539,14 @@ def train_loop( local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, - dynamic_ncols=True + dynamic_ncols=True, ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(num_training_steps + num_val_steps), disable=not accelerator.is_local_main_process, - dynamic_ncols=True + dynamic_ncols=True, ) global_progress_bar.set_description("Total progress") @@ -513,7 +564,9 @@ def train_loop( try: import dadaptation - isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) + isDadaptation = isinstance( + optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan) + ) except ImportError: pass @@ -565,7 +618,10 @@ def train_loop( label = group_labels[i] if i < len(group_labels) else f"{i}" logs[f"lr/{label}"] = lr if isDadaptation: - lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + lr = ( + optimizer.param_groups[i]["d"] + * optimizer.param_groups[i]["lr"] + ) logs[f"d*lr/{label}"] = lr lrs[label] = lr @@ -573,8 +629,10 @@ def train_loop( local_progress_bar.set_postfix(**logs) - if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): - before_optimize_result = on_before_optimize(epoch) + if ((step + 1) % gradient_accumulation_steps == 0) or ( + (step + 1) == len(train_dataloader) + ): + before_optimize_result = on_before_optimize(cycle) optimizer.step() lr_scheduler.step() @@ -614,7 +672,9 @@ def train_loop( } local_progress_bar.set_postfix(**logs) - if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): + if ((step + 1) % gradient_accumulation_steps == 0) or ( + (step + 1) == len(val_dataloader) + ): local_progress_bar.update(1) global_progress_bar.update(1) @@ -634,7 +694,8 @@ def train_loop( global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") + f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}" + ) on_checkpoint(global_step, "milestone") best_acc_val = avg_acc_val.max else: @@ -644,7 +705,8 @@ def train_loop( global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") + f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}" + ) on_checkpoint(global_step, "milestone") best_acc = avg_acc.max @@ -700,17 +762,32 @@ def train( avg_acc_val: AverageMeter = AverageMeter(), **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( - accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = strategy.prepare( + accelerator, + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + **kwargs, + ) vae.to(accelerator.device, dtype=dtype) vae.requires_grad_(False) vae.eval() - vae = torch.compile(vae, backend='hidet') + vae = torch.compile(vae, backend="hidet") if compile_unet: - unet = torch.compile(unet, backend='hidet') + unet = torch.compile(unet, backend="hidet") # unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( diff --git a/training/lr.py b/training/lr.py index f5b362f..a75078f 100644 --- a/training/lr.py +++ b/training/lr.py @@ -23,12 +23,12 @@ def plot_metrics( fig, ax_loss = plt.subplots() ax_acc = ax_loss.twinx() - ax_loss.plot(lrs, losses, color='red') + ax_loss.plot(lrs, losses, color="red") ax_loss.set_xscale("log") ax_loss.set_xlabel(f"Learning rate") ax_loss.set_ylabel("Loss") - ax_acc.plot(lrs, accs, color='blue') + ax_acc.plot(lrs, accs, color="blue") ax_acc.set_xscale("log") ax_acc.set_ylabel("Accuracy") diff --git a/training/optimization.py b/training/optimization.py index d22a900..55531bf 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -5,7 +5,10 @@ from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR -from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup +from diffusers.optimization import ( + get_scheduler as get_scheduler_, + get_cosine_with_hard_restarts_schedule_with_warmup, +) from transformers.optimization import get_adafactor_schedule @@ -52,7 +55,7 @@ def get_one_cycle_schedule( annealing_exp: int = 1, min_lr: float = 0.04, mid_point: float = 0.3, - last_epoch: int = -1 + last_epoch: int = -1, ): if warmup == "linear": warmup_func = warmup_linear @@ -83,12 +86,16 @@ def get_one_cycle_schedule( def lr_lambda(current_step: int): phase = [p for p in phases if current_step >= p.step_min][-1] - return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) + return phase.min + phase.func( + (current_step - phase.step_min) / (phase.step_max - phase.step_min) + ) * (phase.max - phase.min) return LambdaLR(optimizer, lr_lambda, last_epoch) -def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): +def get_exponential_growing_schedule( + optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 +): def lr_lambda(base_lr: float, current_step: int): return (end_lr / base_lr) ** (current_step / num_training_steps) @@ -132,7 +139,14 @@ def get_scheduler( ) elif id == "exponential_growth": if cycles is None: - cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + cycles = math.ceil( + math.sqrt( + ( + (num_training_steps - num_warmup_steps) + / num_training_steps_per_epoch + ) + ) + ) lr_scheduler = get_exponential_growing_schedule( optimizer=optimizer, @@ -141,7 +155,14 @@ def get_scheduler( ) elif id == "cosine_with_restarts": if cycles is None: - cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + cycles = math.ceil( + math.sqrt( + ( + (num_training_steps - num_warmup_steps) + / num_training_steps_per_epoch + ) + ) + ) lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, @@ -150,10 +171,7 @@ def get_scheduler( num_cycles=cycles, ) elif id == "adafactor": - lr_scheduler = get_adafactor_schedule( - optimizer, - initial_lr=min_lr - ) + lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) else: lr_scheduler = get_scheduler_( id, diff --git a/training/sampler.py b/training/sampler.py index bdb3e90..0487d66 100644 --- a/training/sampler.py +++ b/training/sampler.py @@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler): def weights(self): if not self._warmed_up(): return np.ones([self.num_timesteps], dtype=np.float64) - weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) weights /= np.sum(weights) weights *= 1 - self.uniform_prob weights += self.uniform_prob / len(weights) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, - train_text_encoder_epochs: int, + train_text_encoder_cycles: int, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, @@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( return nullcontext() @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): unet.train() tokenizer.train() - if epoch < train_text_encoder_epochs: + if cycle < train_text_encoder_cycles: text_encoder.train() - elif epoch == train_text_encoder_epochs: - text_encoder.requires_grad_(False) - text_encoder.eval() + tokenizer.train() yield @@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( with ema_context(): yield - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): params_to_clip = [unet.parameters()] - if epoch < train_text_encoder_epochs: + if cycle < train_text_encoder_cycles: params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) @@ -189,8 +187,16 @@ def dreambooth_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) text_encoder.text_model.embeddings.requires_grad_(False) @@ -198,6 +204,5 @@ def dreambooth_prepare( dreambooth_strategy = TrainingStrategy( - callbacks=dreambooth_strategy_callbacks, - prepare=dreambooth_prepare + callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare ) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -81,7 +81,7 @@ def lora_strategy_callbacks( tokenizer.eval() yield - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): if not pti_mode: accelerator.clip_grad_norm_( itertools.chain( @@ -89,7 +89,7 @@ def lora_strategy_callbacks( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), ), - max_grad_norm + max_grad_norm, ) if len(placeholder_tokens) != 0 and use_emb_decay: @@ -108,7 +108,9 @@ def lora_strategy_callbacks( if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) @torch.no_grad() def on_checkpoint(step, postfix): @@ -128,25 +130,32 @@ def lora_strategy_callbacks( if not pti_mode: lora_config = {} - state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) + state_dict = get_peft_model_state_dict( + unet_, state_dict=accelerator.get_state_dict(unet_) + ) lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) ) - text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} + text_encoder_state_dict = { + f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() + } state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) + lora_config[ + "text_encoder_peft_config" + ] = text_encoder_.get_peft_config_as_dict(inference=True) if len(placeholder_tokens) != 0: ti_state_dict = { f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) - for (token, ids) - in zip(placeholder_tokens, placeholder_token_ids) + for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) } state_dict.update(ti_state_dict) - save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") + save_file( + state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" + ) with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) @@ -185,10 +194,18 @@ def lora_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - **kwargs + **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( yield @torch.no_grad() - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): if use_emb_decay: params = [ p @@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) + ema_embeddings.step( + text_encoder.text_model.embeddings.token_embedding.parameters() + ) if use_emb_decay and w is not None: lr = lrs["emb"] if "emb" in lrs else lrs["0"] @@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) def on_log(): if ema_embeddings is not None: @@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( print(f"Saving checkpoint for step {step}...") with ema_context(): - for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + for token, ids in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, - checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" + checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", ) @torch.no_grad() @@ -183,7 +187,7 @@ def textual_inversion_prepare( val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, gradient_checkpointing: bool = False, - **kwargs + **kwargs, ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": @@ -191,8 +195,15 @@ def textual_inversion_prepare( elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) unet.to(accelerator.device, dtype=weight_dtype) unet.requires_grad_(False) -- cgit v1.2.3-70-g09d2