From e2d3a62bce63fcde940395a1c5618c4eb43385a9 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 09:25:13 +0100 Subject: Cleanup --- data/csv.py | 21 ++++------- infer.py | 19 ++-------- models/clip/tokenizer.py | 5 ++- train_dreambooth.py | 33 ++++++---------- train_ti.py | 33 ++++++---------- training/common.py | 97 ++++++++++++++++++++++++++---------------------- training/util.py | 26 ++++++------- 7 files changed, 103 insertions(+), 131 deletions(-) diff --git a/data/csv.py b/data/csv.py index a3fef30..df3ee77 100644 --- a/data/csv.py +++ b/data/csv.py @@ -100,20 +100,16 @@ def generate_buckets( return buckets, bucket_items, bucket_assignments -def collate_fn( - num_class_images: int, - weight_dtype: torch.dtype, - tokenizer: CLIPTokenizer, - examples -): +def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): + with_prior = all("class_prompt_ids" in example for example in examples) + prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - # concat class and instance examples for prior preservation - if num_class_images != 0 and "class_prompt_ids" in examples[0]: + if with_prior: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] @@ -125,6 +121,7 @@ def collate_fn( inputs = unify_input_ids(tokenizer, input_ids) batch = { + "with_prior": torch.tensor(with_prior), "prompt_ids": prompts.input_ids, "nprompt_ids": nprompts.input_ids, "input_ids": inputs.input_ids, @@ -166,7 +163,6 @@ class VlpnDataModule(): seed: Optional[int] = None, filter: Optional[Callable[[VlpnDataItem], bool]] = None, dtype: torch.dtype = torch.float32, - num_workers: int = 0 ): super().__init__() @@ -194,7 +190,6 @@ class VlpnDataModule(): self.valid_set_repeat = valid_set_repeat self.seed = seed self.filter = filter - self.num_workers = num_workers self.batch_size = batch_size self.dtype = dtype @@ -290,16 +285,16 @@ class VlpnDataModule(): size=self.size, interpolation=self.interpolation, ) - collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) + collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) self.train_dataloader = DataLoader( train_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers + batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) self.val_dataloader = DataLoader( val_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers + batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) diff --git a/infer.py b/infer.py index 2b07b21..36b5a2c 100644 --- a/infer.py +++ b/infer.py @@ -214,21 +214,10 @@ def load_embeddings(pipeline, embeddings_dir): def create_pipeline(model, dtype): print("Loading Stable Diffusion pipeline...") - tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) - text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) - vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) - unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) - scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) - - patch_managed_embeddings(text_encoder) - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - ) + pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype) + + patch_managed_embeddings(pipeline.text_encoder) + pipeline.enable_xformers_memory_efficient_attention() pipeline.enable_vae_slicing() pipeline.to("cuda") diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 39c41ed..789b525 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -55,6 +55,9 @@ def shuffle_auto(tokens: list[int]): return shuffle_all(tokens) +ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]] + + class MultiCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -73,7 +76,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): def set_dropout(self, dropout: float): self.dropout = dropout - def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): + def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm): if algorithm == "leading": self.vector_shuffle = shuffle_leading elif algorithm == "trailing": diff --git a/train_dreambooth.py b/train_dreambooth.py index a1802a0..c180170 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -194,15 +194,6 @@ def parse_args(): " resolution" ), ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" - " process." - ), - ) parser.add_argument( "--num_train_epochs", type=int, @@ -577,24 +568,24 @@ def main(): ) now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.project), now) - basepath.mkdir(parents=True, exist_ok=True) + output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, - logging_dir=f"{basepath}", + logging_dir=f"{output_dir}", gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision ) - logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 set_seed(args.seed) - save_args(basepath, args) + save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) @@ -618,7 +609,7 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - placeholder_token_ids = add_placeholder_tokens( + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=args.placeholder_tokens, @@ -627,7 +618,9 @@ def main(): ) if len(placeholder_token_ids) != 0: - print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") + initializer_token_id_lens = [len(id) for id in initializer_token_ids] + placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) + print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") if args.use_ema: ema_unet = EMAModel( @@ -726,7 +719,6 @@ def main(): template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, - num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, dtype=weight_dtype @@ -830,7 +822,6 @@ def main(): noise_scheduler, unet, text_encoder, - args.num_class_images, args.prior_loss_weight, args.seed, ) @@ -848,7 +839,8 @@ def main(): scheduler=sample_scheduler, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, - output_dir=basepath, + output_dir=output_dir, + sample_steps=args.sample_steps, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, @@ -873,7 +865,7 @@ def main(): ) lr_finder.run(num_epochs=100, end_lr=1e2) - plt.savefig(basepath.joinpath("lr.png"), dpi=300) + plt.savefig(output_dir.joinpath("lr.png"), dpi=300) plt.close() else: train_loop( @@ -886,7 +878,6 @@ def main(): val_dataloader=val_dataloader, loss_step=loss_step_, sample_frequency=args.sample_frequency, - sample_steps=args.sample_steps, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=0, num_epochs=args.num_train_epochs, diff --git a/train_ti.py b/train_ti.py index d2ca7eb..d752927 100644 --- a/train_ti.py +++ b/train_ti.py @@ -180,15 +180,6 @@ def parse_args(): default="auto", help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help=( - "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" - " process." - ), - ) parser.add_argument( "--num_train_epochs", type=int, @@ -575,24 +566,24 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.project), now) - basepath.mkdir(parents=True, exist_ok=True) + output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) + output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, - logging_dir=f"{basepath}", + logging_dir=f"{output_dir}", gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision ) - logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) if args.seed is None: args.seed = torch.random.seed() >> 32 set_seed(args.seed) - save_args(basepath, args) + save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( args.pretrained_model_name_or_path) @@ -616,7 +607,7 @@ def main(): added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") - placeholder_token_ids = add_placeholder_tokens( + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=args.placeholder_tokens, @@ -625,7 +616,9 @@ def main(): ) if len(placeholder_token_ids) != 0: - print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") + initializer_token_id_lens = [len(id) for id in initializer_token_ids] + placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) + print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") if args.use_ema: ema_embeddings = EMAModel( @@ -708,7 +701,6 @@ def main(): template_key=args.train_data_template, valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, - num_workers=args.dataloader_num_workers, seed=args.seed, filter=keyword_filter, dtype=weight_dtype @@ -807,7 +799,6 @@ def main(): noise_scheduler, unet, text_encoder, - args.num_class_images != 0, args.prior_loss_weight, args.seed, ) @@ -825,7 +816,8 @@ def main(): scheduler=sample_scheduler, placeholder_tokens=args.placeholder_tokens, placeholder_token_ids=placeholder_token_ids, - output_dir=basepath, + output_dir=output_dir, + sample_steps=args.sample_steps, sample_image_size=args.sample_image_size, sample_batch_size=args.sample_batch_size, sample_batches=args.sample_batches, @@ -849,7 +841,7 @@ def main(): ) lr_finder.run(num_epochs=100, end_lr=1e3) - plt.savefig(basepath.joinpath("lr.png"), dpi=300) + plt.savefig(output_dir.joinpath("lr.png"), dpi=300) plt.close() else: train_loop( @@ -862,7 +854,6 @@ def main(): val_dataloader=val_dataloader, loss_step=loss_step_, sample_frequency=args.sample_frequency, - sample_steps=args.sample_steps, checkpoint_frequency=args.checkpoint_frequency, global_step_offset=global_step_offset, num_epochs=args.num_train_epochs, diff --git a/training/common.py b/training/common.py index b6964a3..f5ab326 100644 --- a/training/common.py +++ b/training/common.py @@ -45,42 +45,44 @@ def generate_class_images( ): missing_data = [item for item in data_train if not item.class_image_path.exists()] - if len(missing_data) != 0: - batched_data = [ - missing_data[i:i+sample_batch_size] - for i in range(0, len(missing_data), sample_batch_size) - ] - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - ).to(accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - with torch.inference_mode(): - for batch in batched_data: - image_name = [item.class_image_path for item in batch] - prompt = [item.cprompt for item in batch] - nprompt = [item.nprompt for item in batch] - - images = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=sample_image_size, - width=sample_image_size, - num_inference_steps=sample_steps - ).images - - for i, image in enumerate(images): - image.save(image_name[i]) - - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() + if len(missing_data) == 0: + return + + batched_data = [ + missing_data[i:i+sample_batch_size] + for i in range(0, len(missing_data), sample_batch_size) + ] + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.set_progress_bar_config(dynamic_ncols=True) + + with torch.inference_mode(): + for batch in batched_data: + image_name = [item.class_image_path for item in batch] + prompt = [item.cprompt for item in batch] + nprompt = [item.nprompt for item in batch] + + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + height=sample_image_size, + width=sample_image_size, + num_inference_steps=sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get_models(pretrained_model_name_or_path: str): @@ -119,7 +121,7 @@ def add_placeholder_tokens( for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): embeddings.add_embed(placeholder_token_id, initializer_token_id) - return placeholder_token_ids + return placeholder_token_ids, initializer_token_ids def loss_step( @@ -127,7 +129,6 @@ def loss_step( noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, - with_prior: bool, prior_loss_weight: float, seed: int, step: int, @@ -138,16 +139,23 @@ def loss_step( latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 + generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None + # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) + noise = torch.randn( + latents.shape, + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + generator=generator + ) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (bsz,), - generator=timesteps_gen, + generator=generator, device=latents.device, ) timesteps = timesteps.long() @@ -176,7 +184,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if with_prior: + if batch["with_prior"]: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -207,7 +215,6 @@ def train_loop( val_dataloader: DataLoader, loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], sample_frequency: int = 10, - sample_steps: int = 20, checkpoint_frequency: int = 50, global_step_offset: int = 0, num_epochs: int = 100, @@ -251,7 +258,7 @@ def train_loop( for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: - checkpointer.save_samples(global_step + global_step_offset, sample_steps) + checkpointer.save_samples(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: checkpointer.checkpoint(global_step + global_step_offset, "training") @@ -353,7 +360,7 @@ def train_loop( if accelerator.is_main_process: print("Finished!") checkpointer.checkpoint(global_step + global_step_offset, "end") - checkpointer.save_samples(global_step + global_step_offset, sample_steps) + checkpointer.save_samples(global_step + global_step_offset) accelerator.end_training() except KeyboardInterrupt: diff --git a/training/util.py b/training/util.py index cc4cdee..1008021 100644 --- a/training/util.py +++ b/training/util.py @@ -44,32 +44,29 @@ class CheckpointerBase: train_dataloader, val_dataloader, output_dir: Path, - sample_image_size: int, - sample_batches: int, - sample_batch_size: int, + sample_steps: int = 20, + sample_guidance_scale: float = 7.5, + sample_image_size: int = 768, + sample_batches: int = 1, + sample_batch_size: int = 1, seed: Optional[int] = None ): self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.output_dir = output_dir self.sample_image_size = sample_image_size - self.seed = seed if seed is not None else torch.random.seed() + self.sample_steps = sample_steps + self.sample_guidance_scale = sample_guidance_scale self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size + self.seed = seed if seed is not None else torch.random.seed() @torch.no_grad() def checkpoint(self, step: int, postfix: str): pass @torch.inference_mode() - def save_samples( - self, - pipeline, - step: int, - num_inference_steps: int, - guidance_scale: float = 7.5, - eta: float = 0.0 - ): + def save_samples(self, pipeline, step: int): samples_path = Path(self.output_dir).joinpath("samples") generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) @@ -110,9 +107,8 @@ class CheckpointerBase: height=self.sample_image_size, width=self.sample_image_size, generator=gen, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, + guidance_scale=self.sample_guidance_scale, + num_inference_steps=self.sample_steps, output_type='pil' ).images -- cgit v1.2.3-70-g09d2