diff options
-rw-r--r-- | infer.py | 14 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 41 | ||||
-rw-r--r-- | train_dreambooth.py | 30 | ||||
-rw-r--r-- | train_lora.py | 8 | ||||
-rw-r--r-- | train_ti.py | 18 | ||||
-rw-r--r-- | training/functional.py | 7 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 3 | ||||
-rw-r--r-- | util/files.py | 16 |
8 files changed, 69 insertions, 68 deletions
@@ -46,7 +46,7 @@ default_args = { | |||
46 | "model": "stabilityai/stable-diffusion-2-1", | 46 | "model": "stabilityai/stable-diffusion-2-1", |
47 | "precision": "fp32", | 47 | "precision": "fp32", |
48 | "ti_embeddings_dir": "embeddings_ti", | 48 | "ti_embeddings_dir": "embeddings_ti", |
49 | "lora_embedding": None, | 49 | "lora_embeddings_dir": None, |
50 | "output_dir": "output/inference", | 50 | "output_dir": "output/inference", |
51 | "config": None, | 51 | "config": None, |
52 | } | 52 | } |
@@ -99,7 +99,7 @@ def create_args_parser(): | |||
99 | type=str, | 99 | type=str, |
100 | ) | 100 | ) |
101 | parser.add_argument( | 101 | parser.add_argument( |
102 | "--lora_embedding", | 102 | "--lora_embeddings_dir", |
103 | type=str, | 103 | type=str, |
104 | ) | 104 | ) |
105 | parser.add_argument( | 105 | parser.add_argument( |
@@ -341,7 +341,7 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) | |||
341 | 341 | ||
342 | 342 | ||
343 | def create_pipeline(model, dtype): | 343 | def create_pipeline(model, dtype): |
344 | print("Loading Stable Diffusion pipeline...") | 344 | print(f"Loading Stable Diffusion pipeline: {model}...") |
345 | 345 | ||
346 | tokenizer = MultiCLIPTokenizer.from_pretrained( | 346 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
347 | model, subfolder="tokenizer", torch_dtype=dtype | 347 | model, subfolder="tokenizer", torch_dtype=dtype |
@@ -435,11 +435,11 @@ def generate(output_dir: Path, pipeline, args): | |||
435 | negative_prompt=args.negative_prompt, | 435 | negative_prompt=args.negative_prompt, |
436 | height=args.height, | 436 | height=args.height, |
437 | width=args.width, | 437 | width=args.width, |
438 | generator=generator, | ||
439 | guidance_scale=args.guidance_scale, | ||
438 | num_images_per_prompt=args.batch_size, | 440 | num_images_per_prompt=args.batch_size, |
439 | num_inference_steps=args.steps, | 441 | num_inference_steps=args.steps, |
440 | guidance_scale=args.guidance_scale, | ||
441 | sag_scale=args.sag_scale, | 442 | sag_scale=args.sag_scale, |
442 | generator=generator, | ||
443 | image=init_image, | 443 | image=init_image, |
444 | strength=args.image_noise, | 444 | strength=args.image_noise, |
445 | ).images | 445 | ).images |
@@ -527,8 +527,8 @@ def main(): | |||
527 | 527 | ||
528 | pipeline = create_pipeline(args.model, dtype) | 528 | pipeline = create_pipeline(args.model, dtype) |
529 | 529 | ||
530 | load_embeddings_dir(pipeline, args.ti_embeddings_dir) | 530 | # load_embeddings_dir(pipeline, args.ti_embeddings_dir) |
531 | load_lora(pipeline, args.lora_embedding) | 531 | # load_lora(pipeline, args.lora_embeddings_dir) |
532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | 532 | # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) |
533 | 533 | ||
534 | cmd_parser = create_cmd_parser() | 534 | cmd_parser = create_cmd_parser() |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 98703d5..204276e 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -9,6 +9,7 @@ import torch.nn.functional as F | |||
9 | import PIL | 9 | import PIL |
10 | 10 | ||
11 | from diffusers.configuration_utils import FrozenDict | 11 | from diffusers.configuration_utils import FrozenDict |
12 | from diffusers.image_processor import VaeImageProcessor | ||
12 | from diffusers.utils import is_accelerate_available | 13 | from diffusers.utils import is_accelerate_available |
13 | from diffusers import ( | 14 | from diffusers import ( |
14 | AutoencoderKL, | 15 | AutoencoderKL, |
@@ -161,6 +162,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
161 | scheduler=scheduler, | 162 | scheduler=scheduler, |
162 | ) | 163 | ) |
163 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | 164 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
165 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) | ||
164 | 166 | ||
165 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): | 167 | def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
166 | r""" | 168 | r""" |
@@ -443,14 +445,6 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
443 | extra_step_kwargs["generator"] = generator | 445 | extra_step_kwargs["generator"] = generator |
444 | return extra_step_kwargs | 446 | return extra_step_kwargs |
445 | 447 | ||
446 | def decode_latents(self, latents): | ||
447 | latents = 1 / self.vae.config.scaling_factor * latents | ||
448 | image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] | ||
449 | image = (image / 2 + 0.5).clamp(0, 1) | ||
450 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | ||
451 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() | ||
452 | return image | ||
453 | |||
454 | @torch.no_grad() | 448 | @torch.no_grad() |
455 | def __call__( | 449 | def __call__( |
456 | self, | 450 | self, |
@@ -544,6 +538,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
544 | do_classifier_free_guidance = guidance_scale > 1.0 | 538 | do_classifier_free_guidance = guidance_scale > 1.0 |
545 | do_self_attention_guidance = sag_scale > 0.0 | 539 | do_self_attention_guidance = sag_scale > 0.0 |
546 | prep_from_image = isinstance(image, PIL.Image.Image) | 540 | prep_from_image = isinstance(image, PIL.Image.Image) |
541 | if not prep_from_image: | ||
542 | strength = 1 | ||
547 | 543 | ||
548 | # 3. Encode input prompt | 544 | # 3. Encode input prompt |
549 | prompt_embeds = self.encode_prompt( | 545 | prompt_embeds = self.encode_prompt( |
@@ -577,7 +573,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
577 | ) | 573 | ) |
578 | else: | 574 | else: |
579 | latents = self.prepare_latents( | 575 | latents = self.prepare_latents( |
580 | batch_size, | 576 | batch_size * num_images_per_prompt, |
581 | num_channels_latents, | 577 | num_channels_latents, |
582 | height, | 578 | height, |
583 | width, | 579 | width, |
@@ -623,9 +619,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
623 | noise_pred = noise_pred_uncond + guidance_scale * ( | 619 | noise_pred = noise_pred_uncond + guidance_scale * ( |
624 | noise_pred_text - noise_pred_uncond | 620 | noise_pred_text - noise_pred_uncond |
625 | ) | 621 | ) |
626 | noise_pred = rescale_noise_cfg( | 622 | if guidance_rescale > 0.0: |
627 | noise_pred, noise_pred_text, guidance_rescale=guidance_rescale | 623 | noise_pred = rescale_noise_cfg( |
628 | ) | 624 | noise_pred, |
625 | noise_pred_text, | ||
626 | guidance_rescale=guidance_rescale, | ||
627 | ) | ||
629 | 628 | ||
630 | if do_self_attention_guidance: | 629 | if do_self_attention_guidance: |
631 | # classifier-free guidance produces two chunks of attention map | 630 | # classifier-free guidance produces two chunks of attention map |
@@ -690,17 +689,17 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
690 | 689 | ||
691 | has_nsfw_concept = None | 690 | has_nsfw_concept = None |
692 | 691 | ||
693 | if output_type == "latent": | 692 | if not output_type == "latent": |
693 | image = self.vae.decode( | ||
694 | latents / self.vae.config.scaling_factor, return_dict=False | ||
695 | )[0] | ||
696 | else: | ||
694 | image = latents | 697 | image = latents |
695 | elif output_type == "pil": | ||
696 | # 9. Post-processing | ||
697 | image = self.decode_latents(latents) | ||
698 | 698 | ||
699 | # 10. Convert to PIL | 699 | do_denormalize = [True] * image.shape[0] |
700 | image = self.numpy_to_pil(image) | 700 | image = self.image_processor.postprocess( |
701 | else: | 701 | image, output_type=output_type, do_denormalize=do_denormalize |
702 | # 9. Post-processing | 702 | ) |
703 | image = self.decode_latents(latents) | ||
704 | 703 | ||
705 | # Offload last model to CPU | 704 | # Offload last model to CPU |
706 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | 705 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
diff --git a/train_dreambooth.py b/train_dreambooth.py index c8f03ea..be4da1a 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -17,7 +17,6 @@ from accelerate.logging import get_logger | |||
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | 18 | ||
19 | # from diffusers.models.attention_processor import AttnProcessor | 19 | # from diffusers.models.attention_processor import AttnProcessor |
20 | from diffusers.utils.import_utils import is_xformers_available | ||
21 | import transformers | 20 | import transformers |
22 | 21 | ||
23 | import numpy as np | 22 | import numpy as np |
@@ -48,25 +47,6 @@ hidet.torch.dynamo_config.use_tensor_core(True) | |||
48 | hidet.torch.dynamo_config.search_space(0) | 47 | hidet.torch.dynamo_config.search_space(0) |
49 | 48 | ||
50 | 49 | ||
51 | def patch_xformers(dtype): | ||
52 | if is_xformers_available(): | ||
53 | import xformers | ||
54 | import xformers.ops | ||
55 | |||
56 | orig_xformers_memory_efficient_attention = ( | ||
57 | xformers.ops.memory_efficient_attention | ||
58 | ) | ||
59 | |||
60 | def xformers_memory_efficient_attention( | ||
61 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
62 | ): | ||
63 | return orig_xformers_memory_efficient_attention( | ||
64 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
65 | ) | ||
66 | |||
67 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
68 | |||
69 | |||
70 | def parse_args(): | 50 | def parse_args(): |
71 | parser = argparse.ArgumentParser(description="Simple example of a training script.") | 51 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
72 | parser.add_argument( | 52 | parser.add_argument( |
@@ -224,6 +204,12 @@ def parse_args(): | |||
224 | help="A collection to filter the dataset.", | 204 | help="A collection to filter the dataset.", |
225 | ) | 205 | ) |
226 | parser.add_argument( | 206 | parser.add_argument( |
207 | "--validation_prompts", | ||
208 | type=str, | ||
209 | nargs="*", | ||
210 | help="Prompts for additional validation images", | ||
211 | ) | ||
212 | parser.add_argument( | ||
227 | "--seed", type=int, default=None, help="A seed for reproducible training." | 213 | "--seed", type=int, default=None, help="A seed for reproducible training." |
228 | ) | 214 | ) |
229 | parser.add_argument( | 215 | parser.add_argument( |
@@ -476,7 +462,7 @@ def parse_args(): | |||
476 | parser.add_argument( | 462 | parser.add_argument( |
477 | "--sample_steps", | 463 | "--sample_steps", |
478 | type=int, | 464 | type=int, |
479 | default=10, | 465 | default=15, |
480 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
481 | ) | 467 | ) |
482 | parser.add_argument( | 468 | parser.add_argument( |
@@ -622,8 +608,6 @@ def main(): | |||
622 | elif args.mixed_precision == "bf16": | 608 | elif args.mixed_precision == "bf16": |
623 | weight_dtype = torch.bfloat16 | 609 | weight_dtype = torch.bfloat16 |
624 | 610 | ||
625 | patch_xformers(weight_dtype) | ||
626 | |||
627 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 611 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
628 | 612 | ||
629 | if args.seed is None: | 613 | if args.seed is None: |
diff --git a/train_lora.py b/train_lora.py index fbec009..2a43252 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -236,6 +236,12 @@ def parse_args(): | |||
236 | help="A collection to filter the dataset.", | 236 | help="A collection to filter the dataset.", |
237 | ) | 237 | ) |
238 | parser.add_argument( | 238 | parser.add_argument( |
239 | "--validation_prompts", | ||
240 | type=str, | ||
241 | nargs="*", | ||
242 | help="Prompts for additional validation images", | ||
243 | ) | ||
244 | parser.add_argument( | ||
239 | "--seed", type=int, default=None, help="A seed for reproducible training." | 245 | "--seed", type=int, default=None, help="A seed for reproducible training." |
240 | ) | 246 | ) |
241 | parser.add_argument( | 247 | parser.add_argument( |
@@ -545,7 +551,7 @@ def parse_args(): | |||
545 | parser.add_argument( | 551 | parser.add_argument( |
546 | "--sample_steps", | 552 | "--sample_steps", |
547 | type=int, | 553 | type=int, |
548 | default=10, | 554 | default=15, |
549 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 555 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
550 | ) | 556 | ) |
551 | parser.add_argument( | 557 | parser.add_argument( |
diff --git a/train_ti.py b/train_ti.py index 8c63493..89f4113 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -160,6 +160,12 @@ def parse_args(): | |||
160 | help="A collection to filter the dataset.", | 160 | help="A collection to filter the dataset.", |
161 | ) | 161 | ) |
162 | parser.add_argument( | 162 | parser.add_argument( |
163 | "--validation_prompts", | ||
164 | type=str, | ||
165 | nargs="*", | ||
166 | help="Prompts for additional validation images", | ||
167 | ) | ||
168 | parser.add_argument( | ||
163 | "--seed", type=int, default=None, help="A seed for reproducible training." | 169 | "--seed", type=int, default=None, help="A seed for reproducible training." |
164 | ) | 170 | ) |
165 | parser.add_argument( | 171 | parser.add_argument( |
@@ -456,7 +462,7 @@ def parse_args(): | |||
456 | parser.add_argument( | 462 | parser.add_argument( |
457 | "--sample_steps", | 463 | "--sample_steps", |
458 | type=int, | 464 | type=int, |
459 | default=10, | 465 | default=15, |
460 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 466 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
461 | ) | 467 | ) |
462 | parser.add_argument( | 468 | parser.add_argument( |
@@ -852,11 +858,6 @@ def main(): | |||
852 | sample_image_size=args.sample_image_size, | 858 | sample_image_size=args.sample_image_size, |
853 | ) | 859 | ) |
854 | 860 | ||
855 | optimizer = create_optimizer( | ||
856 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
857 | lr=args.learning_rate, | ||
858 | ) | ||
859 | |||
860 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 861 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
861 | data_npgenerator = np.random.default_rng(args.seed) | 862 | data_npgenerator = np.random.default_rng(args.seed) |
862 | 863 | ||
@@ -957,6 +958,11 @@ def main(): | |||
957 | avg_loss_val = AverageMeter() | 958 | avg_loss_val = AverageMeter() |
958 | avg_acc_val = AverageMeter() | 959 | avg_acc_val = AverageMeter() |
959 | 960 | ||
961 | optimizer = create_optimizer( | ||
962 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
963 | lr=args.learning_rate, | ||
964 | ) | ||
965 | |||
960 | while True: | 966 | while True: |
961 | if len(auto_cycles) != 0: | 967 | if len(auto_cycles) != 0: |
962 | response = auto_cycles.pop(0) | 968 | response = auto_cycles.pop(0) |
diff --git a/training/functional.py b/training/functional.py index 43b03ac..546aaff 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -2,6 +2,7 @@ from dataclasses import dataclass | |||
2 | import math | 2 | import math |
3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol | 4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol |
5 | from types import MethodType | ||
5 | from functools import partial | 6 | from functools import partial |
6 | from pathlib import Path | 7 | from pathlib import Path |
7 | import itertools | 8 | import itertools |
@@ -108,6 +109,7 @@ def save_samples( | |||
108 | output_dir: Path, | 109 | output_dir: Path, |
109 | seed: int, | 110 | seed: int, |
110 | step: int, | 111 | step: int, |
112 | validation_prompts: list[str] = [], | ||
111 | cycle: int = 1, | 113 | cycle: int = 1, |
112 | batch_size: int = 1, | 114 | batch_size: int = 1, |
113 | num_batches: int = 1, | 115 | num_batches: int = 1, |
@@ -136,7 +138,6 @@ def save_samples( | |||
136 | 138 | ||
137 | if val_dataloader is not None: | 139 | if val_dataloader is not None: |
138 | datasets.append(("stable", val_dataloader, generator)) | 140 | datasets.append(("stable", val_dataloader, generator)) |
139 | datasets.append(("val", val_dataloader, None)) | ||
140 | 141 | ||
141 | for pool, data, gen in datasets: | 142 | for pool, data, gen in datasets: |
142 | all_samples = [] | 143 | all_samples = [] |
@@ -165,7 +166,6 @@ def save_samples( | |||
165 | guidance_scale=guidance_scale, | 166 | guidance_scale=guidance_scale, |
166 | sag_scale=0, | 167 | sag_scale=0, |
167 | num_inference_steps=num_steps, | 168 | num_inference_steps=num_steps, |
168 | output_type=None, | ||
169 | ).images | 169 | ).images |
170 | 170 | ||
171 | all_samples.append(torch.from_numpy(samples)) | 171 | all_samples.append(torch.from_numpy(samples)) |
@@ -803,4 +803,7 @@ def train( | |||
803 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 803 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
804 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 804 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
805 | 805 | ||
806 | text_encoder.forward = MethodType(text_encoder.forward, text_encoder) | ||
807 | unet.forward = MethodType(unet.forward, unet) | ||
808 | |||
806 | accelerator.free_memory() | 809 | accelerator.free_memory() |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0f64747..bd853e2 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -155,9 +155,6 @@ def dreambooth_strategy_callbacks( | |||
155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
157 | 157 | ||
158 | unet_.forward = MethodType(unet_.forward, unet_) | ||
159 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
160 | |||
161 | with ema_context(): | 158 | with ema_context(): |
162 | pipeline = VlpnStableDiffusion( | 159 | pipeline = VlpnStableDiffusion( |
163 | text_encoder=text_encoder_, | 160 | text_encoder=text_encoder_, |
diff --git a/util/files.py b/util/files.py index 2712525..73ff802 100644 --- a/util/files.py +++ b/util/files.py | |||
@@ -8,7 +8,7 @@ from safetensors import safe_open | |||
8 | 8 | ||
9 | 9 | ||
10 | def load_config(filename): | 10 | def load_config(filename): |
11 | with open(filename, 'rt') as f: | 11 | with open(filename, "rt") as f: |
12 | config = json.load(f) | 12 | config = json.load(f) |
13 | 13 | ||
14 | args = config["args"] | 14 | args = config["args"] |
@@ -19,11 +19,17 @@ def load_config(filename): | |||
19 | return args | 19 | return args |
20 | 20 | ||
21 | 21 | ||
22 | def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedCLIPTextEmbeddings, embeddings_dir: Path): | 22 | def load_embeddings_from_dir( |
23 | tokenizer: MultiCLIPTokenizer, | ||
24 | embeddings: ManagedCLIPTextEmbeddings, | ||
25 | embeddings_dir: Path, | ||
26 | ): | ||
23 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 27 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
24 | return [] | 28 | return [], [] |
25 | 29 | ||
26 | filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] | 30 | filenames = [ |
31 | filename for filename in embeddings_dir.iterdir() if filename.is_file() | ||
32 | ] | ||
27 | tokens = [filename.stem for filename in filenames] | 33 | tokens = [filename.stem for filename in filenames] |
28 | 34 | ||
29 | new_ids: list[list[int]] = [] | 35 | new_ids: list[list[int]] = [] |
@@ -39,7 +45,7 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC | |||
39 | 45 | ||
40 | embeddings.resize(len(tokenizer)) | 46 | embeddings.resize(len(tokenizer)) |
41 | 47 | ||
42 | for (new_id, embeds) in zip(new_ids, new_embeds): | 48 | for new_id, embeds in zip(new_ids, new_embeds): |
43 | embeddings.add_embed(new_id, embeds) | 49 | embeddings.add_embed(new_id, embeds) |
44 | 50 | ||
45 | return tokens, new_ids | 51 | return tokens, new_ids |