From 59bf501198d7ff6c0c03c45e92adef14069d5ac6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 12:33:52 +0100 Subject: Update --- data/csv.py | 11 ++---- train_ti.py | 74 ++++++++++++++++++----------------- training/functional.py | 100 +++++++++++------------------------------------- training/lr.py | 29 +++++++------- training/strategy/ti.py | 54 +++++++++++++------------- 5 files changed, 106 insertions(+), 162 deletions(-) diff --git a/data/csv.py b/data/csv.py index b058a3e..5de3ac7 100644 --- a/data/csv.py +++ b/data/csv.py @@ -100,28 +100,25 @@ def generate_buckets( return buckets, bucket_items, bucket_assignments -def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): - with_prior = all("class_prompt_ids" in example for example in examples) - +def collate_fn(dtype: torch.dtype, tokenizer: CLIPTokenizer, with_prior_preservation: bool, examples): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - if with_prior: + if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) + pixel_values = pixel_values.to(dtype=dtype, memory_format=torch.contiguous_format) prompts = unify_input_ids(tokenizer, prompt_ids) nprompts = unify_input_ids(tokenizer, nprompt_ids) inputs = unify_input_ids(tokenizer, input_ids) batch = { - "with_prior": torch.tensor([with_prior] * len(examples)), "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, @@ -285,7 +282,7 @@ class VlpnDataModule(): size=self.size, interpolation=self.interpolation, ) - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) + collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) self.train_dataloader = DataLoader( train_dataset, diff --git a/train_ti.py b/train_ti.py index 3c9810f..4bac736 100644 --- a/train_ti.py +++ b/train_ti.py @@ -15,11 +15,11 @@ from slugify import slugify from util import load_config, load_embeddings_from_dir from data.csv import VlpnDataModule, VlpnDataItem -from training.functional import train, generate_class_images, add_placeholder_tokens, get_models +from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models from training.strategy.ti import textual_inversion_strategy from training.optimization import get_scheduler from training.lr import LRFinder -from training.util import EMAModel, save_args +from training.util import save_args logger = get_logger(__name__) @@ -82,7 +82,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=1, + default=0, help="How many class images to generate." ) parser.add_argument( @@ -398,7 +398,7 @@ def parse_args(): ) parser.add_argument( "--emb_decay_factor", - default=0, + default=1, type=float, help="Embedding decay factor." ) @@ -540,16 +540,6 @@ def main(): placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") - if args.use_ema: - ema_embeddings = EMAModel( - text_encoder.text_model.embeddings.temp_token_embedding.parameters(), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) - else: - ema_embeddings = None - if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * @@ -654,23 +644,13 @@ def main(): warmup_epochs=args.lr_warmup_epochs, ) - if args.use_ema: - ema_embeddings.to(accelerator.device) - - trainer = partial( - train, - accelerator=accelerator, - vae=vae, - unet=unet, - text_encoder=text_encoder, - noise_scheduler=noise_scheduler, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - dtype=weight_dtype, - seed=args.seed, + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - strategy = textual_inversion_strategy( + vae.to(accelerator.device, dtype=weight_dtype) + + callbacks = textual_inversion_strategy( accelerator=accelerator, unet=unet, text_encoder=text_encoder, @@ -679,7 +659,6 @@ def main(): sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=weight_dtype, output_dir=output_dir, seed=args.seed, placeholder_tokens=args.placeholder_tokens, @@ -700,31 +679,54 @@ def main(): sample_image_size=args.sample_image_size, ) + for model in (unet, text_encoder, vae): + model.requires_grad_(False) + model.eval() + + callbacks.on_prepare() + + loss_step_ = partial( + loss_step, + vae, + noise_scheduler, + unet, + text_encoder, + args.num_class_images != 0, + args.prior_loss_weight, + args.seed, + ) + if args.find_lr: lr_finder = LRFinder( accelerator=accelerator, optimizer=optimizer, - model=text_encoder, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - **strategy, + callbacks=callbacks, ) lr_finder.run(num_epochs=100, end_lr=1e3) plt.savefig(output_dir.joinpath("lr.png"), dpi=300) plt.close() else: - trainer( + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion") + + train_loop( + accelerator=accelerator, optimizer=optimizer, lr_scheduler=lr_scheduler, - num_train_epochs=args.num_train_epochs, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + loss_step=loss_step_, sample_frequency=args.sample_frequency, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=global_step_offset, - prior_loss_weight=args.prior_loss_weight, - callbacks=strategy, + callbacks=callbacks, ) + accelerator.end_training() + if __name__ == "__main__": main() diff --git a/training/functional.py b/training/functional.py index 4ca7470..c01595a 100644 --- a/training/functional.py +++ b/training/functional.py @@ -33,6 +33,7 @@ def const(result=None): @dataclass class TrainingCallbacks(): on_prepare: Callable[[float], None] = const() + on_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], None] = const() @@ -267,6 +268,7 @@ def loss_step( noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, + with_prior_preservation: bool, prior_loss_weight: float, seed: int, step: int, @@ -322,7 +324,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if batch["with_prior"].all(): + if with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -347,7 +349,6 @@ def train_loop( accelerator: Accelerator, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - model: torch.nn.Module, train_dataloader: DataLoader, val_dataloader: DataLoader, loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], @@ -387,28 +388,37 @@ def train_loop( ) global_progress_bar.set_description("Total progress") + model = callbacks.on_model() + on_log = callbacks.on_log + on_train = callbacks.on_train + on_before_optimize = callbacks.on_before_optimize + on_after_optimize = callbacks.on_after_optimize + on_eval = callbacks.on_eval + on_sample = callbacks.on_sample + on_checkpoint = callbacks.on_checkpoint + try: for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: - callbacks.on_sample(global_step + global_step_offset) + on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: - callbacks.on_checkpoint(global_step + global_step_offset, "training") + on_checkpoint(global_step + global_step_offset, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() model.train() - with callbacks.on_train(epoch): + with on_train(epoch): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(model): loss, acc, bsz = loss_step(step, batch) accelerator.backward(loss) - callbacks.on_before_optimize(epoch) + on_before_optimize(epoch) optimizer.step() lr_scheduler.step() @@ -419,7 +429,7 @@ def train_loop( # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0]) + on_after_optimize(lr_scheduler.get_last_lr()[0]) local_progress_bar.update(1) global_progress_bar.update(1) @@ -433,7 +443,7 @@ def train_loop( "train/cur_acc": acc.item(), "lr": lr_scheduler.get_last_lr()[0], } - logs.update(callbacks.on_log()) + logs.update(on_log()) accelerator.log(logs, step=global_step) @@ -449,7 +459,7 @@ def train_loop( cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() - with torch.inference_mode(), callbacks.on_eval(): + with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, True) @@ -485,80 +495,16 @@ def train_loop( if avg_acc_val.avg.item() > max_acc_val: accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - callbacks.on_checkpoint(global_step + global_step_offset, "milestone") + on_checkpoint(global_step + global_step_offset, "milestone") max_acc_val = avg_acc_val.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") - callbacks.on_checkpoint(global_step + global_step_offset, "end") - callbacks.on_sample(global_step + global_step_offset) - accelerator.end_training() + on_checkpoint(global_step + global_step_offset, "end") + on_sample(global_step + global_step_offset) except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") - callbacks.on_checkpoint(global_step + global_step_offset, "end") - accelerator.end_training() - - -def train( - accelerator: Accelerator, - unet: UNet2DConditionModel, - text_encoder: CLIPTextModel, - vae: AutoencoderKL, - noise_scheduler: DDPMScheduler, - train_dataloader: DataLoader, - val_dataloader: DataLoader, - dtype: torch.dtype, - seed: int, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - num_train_epochs: int = 100, - sample_frequency: int = 20, - checkpoint_frequency: int = 50, - global_step_offset: int = 0, - prior_loss_weight: float = 0, - callbacks: TrainingCallbacks = TrainingCallbacks(), -): - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler - ) - - vae.to(accelerator.device, dtype=dtype) - - for model in (unet, text_encoder, vae): - model.requires_grad_(False) - model.eval() - - callbacks.on_prepare() - - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - prior_loss_weight, - seed, - ) - - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") - - train_loop( - accelerator=accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=text_encoder, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - sample_frequency=sample_frequency, - checkpoint_frequency=checkpoint_frequency, - global_step_offset=global_step_offset, - num_epochs=num_train_epochs, - callbacks=callbacks, - ) - - accelerator.free_memory() + on_checkpoint(global_step + global_step_offset, "end") diff --git a/training/lr.py b/training/lr.py index 7584ba2..902c4eb 100644 --- a/training/lr.py +++ b/training/lr.py @@ -9,6 +9,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR from tqdm.auto import tqdm +from training.functional import TrainingCallbacks from training.util import AverageMeter @@ -24,26 +25,19 @@ class LRFinder(): def __init__( self, accelerator, - model, optimizer, train_dataloader, val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], - on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, - on_before_optimize: Callable[[int], None] = noop, - on_after_optimize: Callable[[float], None] = noop, - on_eval: Callable[[], _GeneratorContextManager] = noop_ctx + callbacks: TrainingCallbacks = TrainingCallbacks() ): self.accelerator = accelerator - self.model = model + self.model = callbacks.on_model() self.optimizer = optimizer self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.loss_fn = loss_fn - self.on_train = on_train - self.on_before_optimize = on_before_optimize - self.on_after_optimize = on_after_optimize - self.on_eval = on_eval + self.callbacks = callbacks # self.model_state = copy.deepcopy(model.state_dict()) # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) @@ -82,6 +76,13 @@ class LRFinder(): ) progress_bar.set_description("Epoch X / Y") + self.callbacks.on_prepare() + + on_train = self.callbacks.on_train + on_before_optimize = self.callbacks.on_before_optimize + on_after_optimize = self.callbacks.on_after_optimize + on_eval = self.callbacks.on_eval + for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -90,7 +91,7 @@ class LRFinder(): self.model.train() - with self.on_train(epoch): + with on_train(epoch): for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: break @@ -100,21 +101,21 @@ class LRFinder(): self.accelerator.backward(loss) - self.on_before_optimize(epoch) + on_before_optimize(epoch) self.optimizer.step() lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) if self.accelerator.sync_gradients: - self.on_after_optimize(lr_scheduler.get_last_lr()[0]) + on_after_optimize(lr_scheduler.get_last_lr()[0]) progress_bar.update(1) self.model.eval() with torch.inference_mode(): - with self.on_eval(): + with on_eval(): for step, batch in enumerate(self.val_dataloader): if step >= num_val_batches: break diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6f8384f..753dce0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -27,7 +27,6 @@ def textual_inversion_strategy( sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: DataLoader, - dtype: torch.dtype, output_dir: Path, seed: int, placeholder_tokens: list[str], @@ -48,6 +47,12 @@ def textual_inversion_strategy( sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + save_samples_ = partial( save_samples, accelerator=accelerator, @@ -58,7 +63,7 @@ def textual_inversion_strategy( sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=dtype, + dtype=weight_dtype, output_dir=output_dir, seed=seed, batch_size=sample_batch_size, @@ -78,6 +83,17 @@ def textual_inversion_strategy( else: ema_embeddings = None + def ema_context(): + if use_ema: + return ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) + else: + return nullcontext() + + def on_model(): + return text_encoder + def on_prepare(): text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) @@ -89,24 +105,15 @@ def textual_inversion_strategy( @contextmanager def on_train(epoch: int): - try: - tokenizer.train() - yield - finally: - pass + tokenizer.train() + yield @contextmanager def on_eval(): - try: - tokenizer.eval() + tokenizer.eval() - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext() - - with ema_context: - yield - finally: - pass + with ema_context(): + yield @torch.no_grad() def on_after_optimize(lr: float): @@ -131,13 +138,7 @@ def textual_inversion_strategy( checkpoints_path = output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) - text_encoder = accelerator.unwrap_model(text_encoder) - - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if ema_embeddings is not None else nullcontext() - - with ema_context: + with ema_context(): for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, @@ -146,15 +147,12 @@ def textual_inversion_strategy( @torch.no_grad() def on_sample(step): - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if ema_embeddings is not None else nullcontext() - - with ema_context: + with ema_context(): save_samples_(step=step) return TrainingCallbacks( on_prepare=on_prepare, + on_model=on_model, on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, -- cgit v1.2.3-54-g00ecf