From f00877a13bce50b02cfc3790f2d18a325e9ff95b Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 14 Jan 2023 22:42:44 +0100 Subject: Update --- training/util.py | 112 ------------------------------------------------------- 1 file changed, 112 deletions(-) (limited to 'training/util.py') diff --git a/training/util.py b/training/util.py index a292edd..f46cc61 100644 --- a/training/util.py +++ b/training/util.py @@ -14,29 +14,6 @@ from models.clip.tokenizer import MultiCLIPTokenizer from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings -class TrainingStrategy(): - @property - def main_model(self) -> torch.nn.Module: - ... - - @contextmanager - def on_train(self, epoch: int): - yield - - @contextmanager - def on_eval(self): - yield - - def on_before_optimize(self, epoch: int): - ... - - def on_after_optimize(self, lr: float): - ... - - def on_log(): - return {} - - def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) @@ -44,95 +21,6 @@ def save_args(basepath: Path, args, extra={}): json.dump(info, f, indent=4) -def generate_class_images( - accelerator, - text_encoder, - vae, - unet, - tokenizer, - scheduler, - data_train, - sample_batch_size, - sample_image_size, - sample_steps -): - missing_data = [item for item in data_train if not item.class_image_path.exists()] - - if len(missing_data) == 0: - return - - batched_data = [ - missing_data[i:i+sample_batch_size] - for i in range(0, len(missing_data), sample_batch_size) - ] - - pipeline = VlpnStableDiffusion( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - ).to(accelerator.device) - pipeline.set_progress_bar_config(dynamic_ncols=True) - - with torch.inference_mode(): - for batch in batched_data: - image_name = [item.class_image_path for item in batch] - prompt = [item.cprompt for item in batch] - nprompt = [item.nprompt for item in batch] - - images = pipeline( - prompt=prompt, - negative_prompt=nprompt, - height=sample_image_size, - width=sample_image_size, - num_inference_steps=sample_steps - ).images - - for i, image in enumerate(images): - image.save(image_name[i]) - - del pipeline - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def get_models(pretrained_model_name_or_path: str): - tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') - unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( - pretrained_model_name_or_path, subfolder='scheduler') - - embeddings = patch_managed_embeddings(text_encoder) - - return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings - - -def add_placeholder_tokens( - tokenizer: MultiCLIPTokenizer, - embeddings: ManagedCLIPTextEmbeddings, - placeholder_tokens: list[str], - initializer_tokens: list[str], - num_vectors: Union[list[int], int] -): - initializer_token_ids = [ - tokenizer.encode(token, add_special_tokens=False) - for token in initializer_tokens - ] - placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) - - embeddings.resize(len(tokenizer)) - - for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): - embeddings.add_embed(placeholder_token_id, initializer_token_id) - - return placeholder_token_ids, initializer_token_ids - - class AverageMeter: def __init__(self, name=None): self.name = name -- cgit v1.2.3-54-g00ecf