diff options
| -rw-r--r-- | infer.py | 14 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 41 | ||||
| -rw-r--r-- | train_dreambooth.py | 30 | ||||
| -rw-r--r-- | train_lora.py | 8 | ||||
| -rw-r--r-- | train_ti.py | 18 | ||||
| -rw-r--r-- | training/functional.py | 7 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 3 | ||||
| -rw-r--r-- | util/files.py | 16 |
8 files changed, 69 insertions, 68 deletions
| @@ -46,7 +46,7 @@ default_args = { | |||
| 46 | "model": "stabilityai/stable-diffusion-2-1", | 46 | "model": "stabilityai/stable-diffusion-2-1", |
| 47 | "precision": "fp32", | 47 | "precision": "fp32", |
| 48 | "ti_embeddings_dir": "embeddings_ti", | 48 | "ti_embeddings_dir": "embeddings_ti", |
| 49 | "lora_embedding": None, | 49 | "lora_embeddings_dir": None, |
| 50 | "output_dir": "output/inference", | 50 | "output_dir": "output/inference", |
| 51 | "config": None, | 51 | "config": None, |
| 52 | } | 52 | } |
| @@ -99,7 +99,7 @@ def create_args_parser(): | |||
| 99 | type=str, | 99 | type=str, |
| 100 | ) | 100 | ) |
| 101 | parser.add_argument( | 101 | parser.add_argument( |
| 102 | "--lora_embedding", | 102 | "--lora_embeddings_dir", |
| 103 | type=str, | 103 | type=str, |
| 104 | ) | 104 | ) |
| 105 | parser.add_argument( | 105 | parser.add_argument( |
| @@ -341,7 +341,7 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) | |||
| 341 | 341 | ||
| 342 | 342 | ||
| 343 | def create_pipeline(model, dtype): | 343 | def create_pipeline(model, dtype): |
| 344 | print("Loading Stable Diffusion pipeline...") | 344 | print(f"Loading Stable Diffusion pipeline: {model}...") |
| 345 | 345 | ||
| 346 | tokenizer = MultiCLIPTokenizer.from_pretrained( | 346 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
| 347 | model, subfolder="tokenizer", torch_dtype=dtype | 347 | model, subfolder="tokenizer", torch_dtype=dtype |
| @@ -435,11 +435,11 @@ def generate(output_dir: Path, pipeline, args): | |||
| 435 | negative_prompt=args.negative_prompt, | 435 | negative_prompt=args.negative_prompt, |
| 436 | height=args.height, | 436 | height=args.height, |
| 437 | width=args.width, | 437 | width=args.width, |
| 438 | generator=generator, | ||
| 439 | guidance_scale=args.guidance_scale, | ||
| 438 | num_images_per_prompt=args.batch_size, | 440 | num_images_per_prompt=args.batch_size, |
| 439 | num_inference_steps=args.steps, | 441 | num_inference_steps=args.steps, |
| 440 | guidance_scale=args.guidance_scale, | ||
| 441 | sag_scale=args.sag_scale, | 442 | sag_scale=args.sag_scale, |
| 442 | generator=generator, | ||
| 443 | image=init_image, | 443 | image=init_image, |
| 444 | strength=args.image_noise, | 444 | strength=args.image_noise, |
| 445 | ).images | 445 | ).images |
| @@ -527,8 +527,8 @@ def main(): | |||
| 527 | 527 | ||
| 528 | pipeline = create_pipeline(args.model, dtype) | 528 | pipeline = create_pipeline(args.model, dtype) |
| 529 | 529 | ||
| 530 | load_embeddings_dir(pipeline, args.ti_embeddings_dir) | 530 | # load_embeddings_dir(pipeline, args.ti_embeddings_dir) |
| 531 | load_lora(pipeline, args.lora_embedding) | 531 | # load_lora(pipeline, args.lora_embeddings_dir) |
| 532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
| 533 | 533 | ||
| 534 | cmd_parser = create_cmd_parser() | 534 | cmd_parser = create_cmd_parser() |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 98703d5..204276e 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -9,6 +9,7 @@ import torch.nn.functional as F | |||
| 9 | import PIL | 9 | import PIL |
| 10 | 10 | ||
| 11 | from diffusers.configuration_utils import FrozenDict | 11 | from diffusers.configuration_utils import FrozenDict |
| 12 | from diffusers.image_processor import VaeImageProcessor | ||
| 12 | from diffusers.utils import is_accelerate_available | 13 | from diffusers.utils import is_accelerate_available |
| 13 | from diffusers import ( | 14 | from diffusers import ( |
| 14 | AutoencoderKL, | 15 | AutoencoderKL, |
| @@ -161,6 +162,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 161 | scheduler=scheduler, | 162 | scheduler=scheduler, |
| 162 | ) | 163 | ) |
| 163 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | 164 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| 165 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||
| 164 | 166 | ||
| 165 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 167 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
| 166 | r""" | 168 | r""" |
| @@ -443,14 +445,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 443 | extra_step_kwargs["generator"] = generator | 445 | extra_step_kwargs["generator"] = generator |
| 444 | return extra_step_kwargs | 446 | return extra_step_kwargs |
| 445 | 447 | ||
| 446 | def decode_latents(self, latents): | ||
| 447 | latents = 1 / self.vae.config.scaling_factor * latents | ||
| 448 | image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] | ||
| 449 | image = (image / 2 + 0.5).clamp(0, 1) | ||
| 450 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | ||
| 451 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||
| 452 | return image | ||
| 453 | |||
| 454 | @torch.no_grad() | 448 | @torch.no_grad() |
| 455 | def __call__( | 449 | def __call__( |
| 456 | self, | 450 | self, |
| @@ -544,6 +538,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 544 | do_classifier_free_guidance = guidance_scale > 1.0 | 538 | do_classifier_free_guidance = guidance_scale > 1.0 |
| 545 | do_self_attention_guidance = sag_scale > 0.0 | 539 | do_self_attention_guidance = sag_scale > 0.0 |
| 546 | prep_from_image = isinstance(image, PIL.Image.Image) | 540 | prep_from_image = isinstance(image, PIL.Image.Image) |
| 541 | if not prep_from_image: | ||
| 542 | strength = 1 | ||
| 547 | 543 | ||
| 548 | # 3. Encode input prompt | 544 | # 3. Encode input prompt |
| 549 | prompt_embeds = self.encode_prompt( | 545 | prompt_embeds = self.encode_prompt( |
| @@ -577,7 +573,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 577 | ) | 573 | ) |
| 578 | else: | 574 | else: |
| 579 | latents = self.prepare_latents( | 575 | latents = self.prepare_latents( |
| 580 | batch_size, | 576 | batch_size * num_images_per_prompt, |
| 581 | num_channels_latents, | 577 | num_channels_latents, |
| 582 | height, | 578 | height, |
| 583 | width, | 579 | width, |
| @@ -623,9 +619,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 623 | noise_pred = noise_pred_uncond + guidance_scale * ( | 619 | noise_pred = noise_pred_uncond + guidance_scale * ( |
| 624 | noise_pred_text - noise_pred_uncond | 620 | noise_pred_text - noise_pred_uncond |
| 625 | ) | 621 | ) |
| 626 | noise_pred = rescale_noise_cfg( | 622 | if guidance_rescale > 0.0: |
| 627 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale | 623 | noise_pred = rescale_noise_cfg( |
| 628 | ) | 624 | noise_pred, |
| 625 | noise_pred_text, | ||
| 626 | guidance_rescale=guidance_rescale, | ||
| 627 | ) | ||
| 629 | 628 | ||
| 630 | if do_self_attention_guidance: | 629 | if do_self_attention_guidance: |
| 631 | # classifier-free guidance produces two chunks of attention map | 630 | # classifier-free guidance produces two chunks of attention map |
| @@ -690,17 +689,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 690 | 689 | ||
| 691 | has_nsfw_concept = None | 690 | has_nsfw_concept = None |
| 692 | 691 | ||
| 693 | if output_type == "latent": | 692 | if not output_type == "latent": |
| 693 | image = self.vae.decode( | ||
| 694 | latents / self.vae.config.scaling_factor, return_dict=False | ||
| 695 | )[0] | ||
| 696 | else: | ||
| 694 | image = latents | 697 | image = latents |
| 695 | elif output_type == "pil": | ||
| 696 | # 9. Post-processing | ||
| 697 | image = self.decode_latents(latents) | ||
| 698 | 698 | ||
| 699 | # 10. Convert to PIL | 699 | do_denormalize = [True] * image.shape[0] |
| 700 | image = self.numpy_to_pil(image) | 700 | image = self.image_processor.postprocess( |
| 701 | else: | 701 | image, output_type=output_type, do_denormalize=do_denormalize |
| 702 | # 9. Post-processing | 702 | ) |
| 703 | image = self.decode_latents(latents) | ||
| 704 | 703 | ||
| 705 | # Offload last model to CPU | 704 | # Offload last model to CPU |
| 706 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | 705 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
diff --git a/train_dreambooth.py b/train_dreambooth.py index c8f03ea..be4da1a 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -17,7 +17,6 @@ from accelerate.logging import get_logger | |||
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | 18 | ||
| 19 | # from diffusers.models.attention_processor import AttnProcessor | 19 | # from diffusers.models.attention_processor import AttnProcessor |
| 20 | from diffusers.utils.import_utils import is_xformers_available | ||
| 21 | import transformers | 20 | import transformers |
| 22 | 21 | ||
| 23 | import numpy as np | 22 | import numpy as np |
| @@ -48,25 +47,6 @@ hidet.torch.dynamo_config.use_tensor_core(True) | |||
| 48 | hidet.torch.dynamo_config.search_space(0) | 47 | hidet.torch.dynamo_config.search_space(0) |
| 49 | 48 | ||
| 50 | 49 | ||
| 51 | def patch_xformers(dtype): | ||
| 52 | if is_xformers_available(): | ||
| 53 | import xformers | ||
| 54 | import xformers.ops | ||
| 55 | |||
| 56 | orig_xformers_memory_efficient_attention = ( | ||
| 57 | xformers.ops.memory_efficient_attention | ||
| 58 | ) | ||
| 59 | |||
| 60 | def xformers_memory_efficient_attention( | ||
| 61 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
| 62 | ): | ||
| 63 | return orig_xformers_memory_efficient_attention( | ||
| 64 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
| 65 | ) | ||
| 66 | |||
| 67 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
| 68 | |||
| 69 | |||
| 70 | def parse_args(): | 50 | def parse_args(): |
| 71 | parser = argparse.ArgumentParser(description="Simple example of a training script.") | 51 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 72 | parser.add_argument( | 52 | parser.add_argument( |
| @@ -224,6 +204,12 @@ def parse_args(): | |||
| 224 | help="A collection to filter the dataset.", | 204 | help="A collection to filter the dataset.", |
| 225 | ) | 205 | ) |
| 226 | parser.add_argument( | 206 | parser.add_argument( |
| 207 | "--validation_prompts", | ||
| 208 | type=str, | ||
| 209 | nargs="*", | ||
| 210 | help="Prompts for additional validation images", | ||
| 211 | ) | ||
| 212 | parser.add_argument( | ||
| 227 | "--seed", type=int, default=None, help="A seed for reproducible training." | 213 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 228 | ) | 214 | ) |
| 229 | parser.add_argument( | 215 | parser.add_argument( |
| @@ -476,7 +462,7 @@ def parse_args(): | |||
| 476 | parser.add_argument( | 462 | parser.add_argument( |
| 477 | "--sample_steps", | 463 | "--sample_steps", |
| 478 | type=int, | 464 | type=int, |
| 479 | default=10, | 465 | default=15, |
| 480 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 481 | ) | 467 | ) |
| 482 | parser.add_argument( | 468 | parser.add_argument( |
| @@ -622,8 +608,6 @@ def main(): | |||
| 622 | elif args.mixed_precision == "bf16": | 608 | elif args.mixed_precision == "bf16": |
| 623 | weight_dtype = torch.bfloat16 | 609 | weight_dtype = torch.bfloat16 |
| 624 | 610 | ||
| 625 | patch_xformers(weight_dtype) | ||
| 626 | |||
| 627 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 611 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
| 628 | 612 | ||
| 629 | if args.seed is None: | 613 | if args.seed is None: |
diff --git a/train_lora.py b/train_lora.py index fbec009..2a43252 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -236,6 +236,12 @@ def parse_args(): | |||
| 236 | help="A collection to filter the dataset.", | 236 | help="A collection to filter the dataset.", |
| 237 | ) | 237 | ) |
| 238 | parser.add_argument( | 238 | parser.add_argument( |
| 239 | "--validation_prompts", | ||
| 240 | type=str, | ||
| 241 | nargs="*", | ||
| 242 | help="Prompts for additional validation images", | ||
| 243 | ) | ||
| 244 | parser.add_argument( | ||
| 239 | "--seed", type=int, default=None, help="A seed for reproducible training." | 245 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 240 | ) | 246 | ) |
| 241 | parser.add_argument( | 247 | parser.add_argument( |
| @@ -545,7 +551,7 @@ def parse_args(): | |||
| 545 | parser.add_argument( | 551 | parser.add_argument( |
| 546 | "--sample_steps", | 552 | "--sample_steps", |
| 547 | type=int, | 553 | type=int, |
| 548 | default=10, | 554 | default=15, |
| 549 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 555 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 550 | ) | 556 | ) |
| 551 | parser.add_argument( | 557 | parser.add_argument( |
diff --git a/train_ti.py b/train_ti.py index 8c63493..89f4113 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -160,6 +160,12 @@ def parse_args(): | |||
| 160 | help="A collection to filter the dataset.", | 160 | help="A collection to filter the dataset.", |
| 161 | ) | 161 | ) |
| 162 | parser.add_argument( | 162 | parser.add_argument( |
| 163 | "--validation_prompts", | ||
| 164 | type=str, | ||
| 165 | nargs="*", | ||
| 166 | help="Prompts for additional validation images", | ||
| 167 | ) | ||
| 168 | parser.add_argument( | ||
| 163 | "--seed", type=int, default=None, help="A seed for reproducible training." | 169 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 164 | ) | 170 | ) |
| 165 | parser.add_argument( | 171 | parser.add_argument( |
| @@ -456,7 +462,7 @@ def parse_args(): | |||
| 456 | parser.add_argument( | 462 | parser.add_argument( |
| 457 | "--sample_steps", | 463 | "--sample_steps", |
| 458 | type=int, | 464 | type=int, |
| 459 | default=10, | 465 | default=15, |
| 460 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 461 | ) | 467 | ) |
| 462 | parser.add_argument( | 468 | parser.add_argument( |
| @@ -852,11 +858,6 @@ def main(): | |||
| 852 | sample_image_size=args.sample_image_size, | 858 | sample_image_size=args.sample_image_size, |
| 853 | ) | 859 | ) |
| 854 | 860 | ||
| 855 | optimizer = create_optimizer( | ||
| 856 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 857 | lr=args.learning_rate, | ||
| 858 | ) | ||
| 859 | |||
| 860 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 861 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
| 861 | data_npgenerator = np.random.default_rng(args.seed) | 862 | data_npgenerator = np.random.default_rng(args.seed) |
| 862 | 863 | ||
| @@ -957,6 +958,11 @@ def main(): | |||
| 957 | avg_loss_val = AverageMeter() | 958 | avg_loss_val = AverageMeter() |
| 958 | avg_acc_val = AverageMeter() | 959 | avg_acc_val = AverageMeter() |
| 959 | 960 | ||
| 961 | optimizer = create_optimizer( | ||
| 962 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 963 | lr=args.learning_rate, | ||
| 964 | ) | ||
| 965 | |||
| 960 | while True: | 966 | while True: |
| 961 | if len(auto_cycles) != 0: | 967 | if len(auto_cycles) != 0: |
| 962 | response = auto_cycles.pop(0) | 968 | response = auto_cycles.pop(0) |
diff --git a/training/functional.py b/training/functional.py index 43b03ac..546aaff 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -2,6 +2,7 @@ from dataclasses import dataclass | |||
| 2 | import math | 2 | import math |
| 3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
| 4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol | 4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol |
| 5 | from types import MethodType | ||
| 5 | from functools import partial | 6 | from functools import partial |
| 6 | from pathlib import Path | 7 | from pathlib import Path |
| 7 | import itertools | 8 | import itertools |
| @@ -108,6 +109,7 @@ def save_samples( | |||
| 108 | output_dir: Path, | 109 | output_dir: Path, |
| 109 | seed: int, | 110 | seed: int, |
| 110 | step: int, | 111 | step: int, |
| 112 | validation_prompts: list[str] = [], | ||
| 111 | cycle: int = 1, | 113 | cycle: int = 1, |
| 112 | batch_size: int = 1, | 114 | batch_size: int = 1, |
| 113 | num_batches: int = 1, | 115 | num_batches: int = 1, |
| @@ -136,7 +138,6 @@ def save_samples( | |||
| 136 | 138 | ||
| 137 | if val_dataloader is not None: | 139 | if val_dataloader is not None: |
| 138 | datasets.append(("stable", val_dataloader, generator)) | 140 | datasets.append(("stable", val_dataloader, generator)) |
| 139 | datasets.append(("val", val_dataloader, None)) | ||
| 140 | 141 | ||
| 141 | for pool, data, gen in datasets: | 142 | for pool, data, gen in datasets: |
| 142 | all_samples = [] | 143 | all_samples = [] |
| @@ -165,7 +166,6 @@ def save_samples( | |||
| 165 | guidance_scale=guidance_scale, | 166 | guidance_scale=guidance_scale, |
| 166 | sag_scale=0, | 167 | sag_scale=0, |
| 167 | num_inference_steps=num_steps, | 168 | num_inference_steps=num_steps, |
| 168 | output_type=None, | ||
| 169 | ).images | 169 | ).images |
| 170 | 170 | ||
| 171 | all_samples.append(torch.from_numpy(samples)) | 171 | all_samples.append(torch.from_numpy(samples)) |
| @@ -803,4 +803,7 @@ def train( | |||
| 803 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 803 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 804 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 804 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 805 | 805 | ||
| 806 | text_encoder.forward = MethodType(text_encoder.forward, text_encoder) | ||
| 807 | unet.forward = MethodType(unet.forward, unet) | ||
| 808 | |||
| 806 | accelerator.free_memory() | 809 | accelerator.free_memory() |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0f64747..bd853e2 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -155,9 +155,6 @@ def dreambooth_strategy_callbacks( | |||
| 155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 157 | 157 | ||
| 158 | unet_.forward = MethodType(unet_.forward, unet_) | ||
| 159 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
| 160 | |||
| 161 | with ema_context(): | 158 | with ema_context(): |
| 162 | pipeline = VlpnStableDiffusion( | 159 | pipeline = VlpnStableDiffusion( |
| 163 | text_encoder=text_encoder_, | 160 | text_encoder=text_encoder_, |
diff --git a/util/files.py b/util/files.py index 2712525..73ff802 100644 --- a/util/files.py +++ b/util/files.py | |||
| @@ -8,7 +8,7 @@ from safetensors import safe_open | |||
| 8 | 8 | ||
| 9 | 9 | ||
| 10 | def load_config(filename): | 10 | def load_config(filename): |
| 11 | with open(filename, 'rt') as f: | 11 | with open(filename, "rt") as f: |
| 12 | config = json.load(f) | 12 | config = json.load(f) |
| 13 | 13 | ||
| 14 | args = config["args"] | 14 | args = config["args"] |
| @@ -19,11 +19,17 @@ def load_config(filename): | |||
| 19 | return args | 19 | return args |
| 20 | 20 | ||
| 21 | 21 | ||
| 22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): | 22 | def load_embeddings_from_dir( |
| 23 | tokenizer: MultiCLIPTokenizer, | ||
| 24 | embeddings: ManagedCLIPTextEmbeddings, | ||
| 25 | embeddings_dir: Path, | ||
| 26 | ): | ||
| 23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 27 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 24 | return [] | 28 | return [], [] |
| 25 | 29 | ||
| 26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | 30 | filenames = [ |
| 31 | filename for filename in embeddings_dir.iterdir() if filename.is_file() | ||
| 32 | ] | ||
| 27 | tokens = [filename.stem for filename in filenames] | 33 | tokens = [filename.stem for filename in filenames] |
| 28 | 34 | ||
| 29 | new_ids: list[list[int]] = [] | 35 | new_ids: list[list[int]] = [] |
| @@ -39,7 +45,7 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
| 39 | 45 | ||
| 40 | embeddings.resize(len(tokenizer)) | 46 | embeddings.resize(len(tokenizer)) |
| 41 | 47 | ||
| 42 | for (new_id, embeds) in zip(new_ids, new_embeds): | 48 | for new_id, embeds in zip(new_ids, new_embeds): |
| 43 | embeddings.add_embed(new_id, embeds) | 49 | embeddings.add_embed(new_id, embeds) |
| 44 | 50 | ||
| 45 | return tokens, new_ids | 51 | return tokens, new_ids |
