diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-14 22:42:44 +0100 |
| commit | f00877a13bce50b02cfc3790f2d18a325e9ff95b (patch) | |
| tree | ebbda04024081e9c3c00400fae98124f3db2cc9c /training | |
| parent | Update (diff) | |
| download | textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.gz textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.tar.bz2 textual-inversion-diff-f00877a13bce50b02cfc3790f2d18a325e9ff95b.zip | |
Update
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 34 | ||||
| -rw-r--r-- | training/util.py | 112 |
2 files changed, 24 insertions, 122 deletions
diff --git a/training/functional.py b/training/functional.py index c100ea2..c5b514a 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -25,17 +25,31 @@ def const(result=None): | |||
| 25 | return fn | 25 | return fn |
| 26 | 26 | ||
| 27 | 27 | ||
| 28 | def get_models(pretrained_model_name_or_path: str): | ||
| 29 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
| 30 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
| 31 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
| 32 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
| 33 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
| 34 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 35 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 36 | |||
| 37 | embeddings = patch_managed_embeddings(text_encoder) | ||
| 38 | |||
| 39 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
| 40 | |||
| 41 | |||
| 28 | def generate_class_images( | 42 | def generate_class_images( |
| 29 | accelerator, | 43 | accelerator: Accelerator, |
| 30 | text_encoder, | 44 | text_encoder: CLIPTextModel, |
| 31 | vae, | 45 | vae: AutoencoderKL, |
| 32 | unet, | 46 | unet: UNet2DConditionModel, |
| 33 | tokenizer, | 47 | tokenizer: MultiCLIPTokenizer, |
| 34 | scheduler, | 48 | sample_scheduler: DPMSolverMultistepScheduler, |
| 35 | data_train, | 49 | data_train, |
| 36 | sample_batch_size, | 50 | sample_batch_size: int, |
| 37 | sample_image_size, | 51 | sample_image_size: int, |
| 38 | sample_steps | 52 | sample_steps: int |
| 39 | ): | 53 | ): |
| 40 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 54 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
| 41 | 55 | ||
| @@ -52,7 +66,7 @@ def generate_class_images( | |||
| 52 | vae=vae, | 66 | vae=vae, |
| 53 | unet=unet, | 67 | unet=unet, |
| 54 | tokenizer=tokenizer, | 68 | tokenizer=tokenizer, |
| 55 | scheduler=scheduler, | 69 | scheduler=sample_scheduler, |
| 56 | ).to(accelerator.device) | 70 | ).to(accelerator.device) |
| 57 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 71 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 58 | 72 | ||
diff --git a/training/util.py b/training/util.py index a292edd..f46cc61 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -14,29 +14,6 @@ from models.clip.tokenizer import MultiCLIPTokenizer | |||
| 14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| 15 | 15 | ||
| 16 | 16 | ||
| 17 | class TrainingStrategy(): | ||
| 18 | @property | ||
| 19 | def main_model(self) -> torch.nn.Module: | ||
| 20 | ... | ||
| 21 | |||
| 22 | @contextmanager | ||
| 23 | def on_train(self, epoch: int): | ||
| 24 | yield | ||
| 25 | |||
| 26 | @contextmanager | ||
| 27 | def on_eval(self): | ||
| 28 | yield | ||
| 29 | |||
| 30 | def on_before_optimize(self, epoch: int): | ||
| 31 | ... | ||
| 32 | |||
| 33 | def on_after_optimize(self, lr: float): | ||
| 34 | ... | ||
| 35 | |||
| 36 | def on_log(): | ||
| 37 | return {} | ||
| 38 | |||
| 39 | |||
| 40 | def save_args(basepath: Path, args, extra={}): | 17 | def save_args(basepath: Path, args, extra={}): |
| 41 | info = {"args": vars(args)} | 18 | info = {"args": vars(args)} |
| 42 | info["args"].update(extra) | 19 | info["args"].update(extra) |
| @@ -44,95 +21,6 @@ def save_args(basepath: Path, args, extra={}): | |||
| 44 | json.dump(info, f, indent=4) | 21 | json.dump(info, f, indent=4) |
| 45 | 22 | ||
| 46 | 23 | ||
| 47 | def generate_class_images( | ||
| 48 | accelerator, | ||
| 49 | text_encoder, | ||
| 50 | vae, | ||
| 51 | unet, | ||
| 52 | tokenizer, | ||
| 53 | scheduler, | ||
| 54 | data_train, | ||
| 55 | sample_batch_size, | ||
| 56 | sample_image_size, | ||
| 57 | sample_steps | ||
| 58 | ): | ||
| 59 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | ||
| 60 | |||
| 61 | if len(missing_data) == 0: | ||
| 62 | return | ||
| 63 | |||
| 64 | batched_data = [ | ||
| 65 | missing_data[i:i+sample_batch_size] | ||
| 66 | for i in range(0, len(missing_data), sample_batch_size) | ||
| 67 | ] | ||
| 68 | |||
| 69 | pipeline = VlpnStableDiffusion( | ||
| 70 | text_encoder=text_encoder, | ||
| 71 | vae=vae, | ||
| 72 | unet=unet, | ||
| 73 | tokenizer=tokenizer, | ||
| 74 | scheduler=scheduler, | ||
| 75 | ).to(accelerator.device) | ||
| 76 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 77 | |||
| 78 | with torch.inference_mode(): | ||
| 79 | for batch in batched_data: | ||
| 80 | image_name = [item.class_image_path for item in batch] | ||
| 81 | prompt = [item.cprompt for item in batch] | ||
| 82 | nprompt = [item.nprompt for item in batch] | ||
| 83 | |||
| 84 | images = pipeline( | ||
| 85 | prompt=prompt, | ||
| 86 | negative_prompt=nprompt, | ||
| 87 | height=sample_image_size, | ||
| 88 | width=sample_image_size, | ||
| 89 | num_inference_steps=sample_steps | ||
| 90 | ).images | ||
| 91 | |||
| 92 | for i, image in enumerate(images): | ||
| 93 | image.save(image_name[i]) | ||
| 94 | |||
| 95 | del pipeline | ||
| 96 | |||
| 97 | if torch.cuda.is_available(): | ||
| 98 | torch.cuda.empty_cache() | ||
| 99 | |||
| 100 | |||
| 101 | def get_models(pretrained_model_name_or_path: str): | ||
| 102 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
| 103 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
| 104 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
| 105 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
| 106 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
| 107 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 108 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 109 | |||
| 110 | embeddings = patch_managed_embeddings(text_encoder) | ||
| 111 | |||
| 112 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
| 113 | |||
| 114 | |||
| 115 | def add_placeholder_tokens( | ||
| 116 | tokenizer: MultiCLIPTokenizer, | ||
| 117 | embeddings: ManagedCLIPTextEmbeddings, | ||
| 118 | placeholder_tokens: list[str], | ||
| 119 | initializer_tokens: list[str], | ||
| 120 | num_vectors: Union[list[int], int] | ||
| 121 | ): | ||
| 122 | initializer_token_ids = [ | ||
| 123 | tokenizer.encode(token, add_special_tokens=False) | ||
| 124 | for token in initializer_tokens | ||
| 125 | ] | ||
| 126 | placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors) | ||
| 127 | |||
| 128 | embeddings.resize(len(tokenizer)) | ||
| 129 | |||
| 130 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | ||
| 131 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | ||
| 132 | |||
| 133 | return placeholder_token_ids, initializer_token_ids | ||
| 134 | |||
| 135 | |||
| 136 | class AverageMeter: | 24 | class AverageMeter: |
| 137 | def __init__(self, name=None): | 25 | def __init__(self, name=None): |
| 138 | self.name = name | 26 | self.name = name |
