|
|
|
|
@ -16,13 +16,6 @@ from ..config import log
|
|
|
|
|
from ..schemas import ModelType
|
|
|
|
|
from .base import InferenceModel
|
|
|
|
|
|
|
|
|
|
_ST_TO_JINA_MODEL_NAME = {
|
|
|
|
|
"clip-ViT-B-16": "ViT-B-16::openai",
|
|
|
|
|
"clip-ViT-B-32": "ViT-B-32::openai",
|
|
|
|
|
"clip-ViT-B-32-multilingual-v1": "M-CLIP/XLM-Roberta-Large-Vit-B-32",
|
|
|
|
|
"clip-ViT-L-14": "ViT-L-14::openai",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPEncoder(InferenceModel):
|
|
|
|
|
_model_type = ModelType.CLIP
|
|
|
|
|
@ -36,11 +29,10 @@ class CLIPEncoder(InferenceModel):
|
|
|
|
|
) -> None:
|
|
|
|
|
if mode is not None and mode not in ("text", "vision"):
|
|
|
|
|
raise ValueError(f"Mode must be 'text', 'vision', or omitted; got '{mode}'")
|
|
|
|
|
if "vit-b" not in model_name.lower():
|
|
|
|
|
raise ValueError(f"Only ViT-B models are currently supported; got '{model_name}'")
|
|
|
|
|
if model_name not in _MODELS:
|
|
|
|
|
raise ValueError(f"Unknown model name {model_name}.")
|
|
|
|
|
self.mode = mode
|
|
|
|
|
jina_model_name = self._get_jina_model_name(model_name)
|
|
|
|
|
super().__init__(jina_model_name, cache_dir, **model_kwargs)
|
|
|
|
|
super().__init__(model_name, cache_dir, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
def _download(self) -> None:
|
|
|
|
|
models: tuple[tuple[str, str], tuple[str, str]] = _MODELS[self.model_name]
|
|
|
|
|
@ -104,20 +96,6 @@ class CLIPEncoder(InferenceModel):
|
|
|
|
|
|
|
|
|
|
return outputs[0][0].tolist()
|
|
|
|
|
|
|
|
|
|
def _get_jina_model_name(self, model_name: str) -> str:
|
|
|
|
|
if model_name in _MODELS:
|
|
|
|
|
return model_name
|
|
|
|
|
elif model_name in _ST_TO_JINA_MODEL_NAME:
|
|
|
|
|
log.warn(
|
|
|
|
|
(
|
|
|
|
|
f"Sentence-Transformer models like '{model_name}' are not supported."
|
|
|
|
|
f"Using '{_ST_TO_JINA_MODEL_NAME[model_name]}' instead as it is the best match for '{model_name}'."
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
return _ST_TO_JINA_MODEL_NAME[model_name]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unknown model name {model_name}.")
|
|
|
|
|
|
|
|
|
|
def _download_model(self, model_name: str, model_md5: str) -> bool:
|
|
|
|
|
# downloading logic is adapted from clip-server's CLIPOnnxModel class
|
|
|
|
|
download_model(
|
|
|
|
|
|