Skip to main content

zoo_embedding/
model_type.rs

1use std::fmt;
2use std::hash::Hash;
3
4use crate::zoo_embedding_errors::ZooEmbeddingError;
5
6pub type EmbeddingModelTypeString = String;
7
8#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, Hash)]
9pub enum EmbeddingModelType {
10    OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference),
11}
12
13impl EmbeddingModelType {
14    pub fn from_string(s: &str) -> Result<Self, ZooEmbeddingError> {
15        OllamaTextEmbeddingsInference::from_string(s)
16            .map(EmbeddingModelType::OllamaTextEmbeddingsInference)
17            .map_err(|_| ZooEmbeddingError::InvalidModelArchitecture)
18    }
19
20    pub fn max_input_token_count(&self) -> usize {
21        match self {
22            EmbeddingModelType::OllamaTextEmbeddingsInference(model) => model.max_input_token_count(),
23        }
24    }
25
26    pub fn embedding_normalization_factor(&self) -> f32 {
27        match self {
28            EmbeddingModelType::OllamaTextEmbeddingsInference(model) => model.embedding_normalization_factor(),
29        }
30    }
31
32    pub fn vector_dimensions(&self) -> Result<usize, ZooEmbeddingError> {
33        match self {
34            EmbeddingModelType::OllamaTextEmbeddingsInference(model) => model.vector_dimensions(),
35        }
36    }
37}
38
39impl fmt::Display for EmbeddingModelType {
40    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
41        match self {
42            EmbeddingModelType::OllamaTextEmbeddingsInference(model) => write!(f, "{}", model),
43        }
44    }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
48pub enum OllamaTextEmbeddingsInference {
49    AllMiniLML6v2,
50    #[serde(alias = "SnowflakeArcticEmbed_M")]
51    SnowflakeArcticEmbedM,
52    JinaEmbeddingsV2BaseEs,
53    Other(String),
54}
55
56impl OllamaTextEmbeddingsInference {
57    const ALL_MINI_LML6V2: &'static str = "all-minilm:l6-v2";
58    const SNOWFLAKE_ARCTIC_EMBED_M: &'static str = "snowflake-arctic-embed:xs";
59    const JINA_EMBEDDINGS_V2_BASE_ES: &'static str = "jina/jina-embeddings-v2-base-es:latest";
60
61    pub fn from_string(s: &str) -> Result<Self, ZooEmbeddingError> {
62        match s {
63            Self::ALL_MINI_LML6V2 => Ok(Self::AllMiniLML6v2),
64            Self::SNOWFLAKE_ARCTIC_EMBED_M => Ok(Self::SnowflakeArcticEmbedM),
65            Self::JINA_EMBEDDINGS_V2_BASE_ES => Ok(Self::JinaEmbeddingsV2BaseEs),
66            _ => Err(ZooEmbeddingError::InvalidModelArchitecture),
67        }
68    }
69
70    pub fn max_input_token_count(&self) -> usize {
71        match self {
72            Self::JinaEmbeddingsV2BaseEs => 1024,
73            _ => 512,
74        }
75    }
76
77    pub fn embedding_normalization_factor(&self) -> f32 {
78        match self {
79            Self::JinaEmbeddingsV2BaseEs => 1.5,
80            _ => 1.0,
81        }
82    }
83
84    pub fn vector_dimensions(&self) -> Result<usize, ZooEmbeddingError> {
85        match self {
86            Self::SnowflakeArcticEmbedM => Ok(384),
87            Self::JinaEmbeddingsV2BaseEs => Ok(768),
88            _ => Err(ZooEmbeddingError::UnimplementedModelDimensions(format!(
89                "{:?}",
90                self
91            ))),
92        }
93    }
94}
95
96impl fmt::Display for OllamaTextEmbeddingsInference {
97    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98        match self {
99            Self::AllMiniLML6v2 => write!(f, "{}", Self::ALL_MINI_LML6V2),
100            Self::SnowflakeArcticEmbedM => write!(f, "{}", Self::SNOWFLAKE_ARCTIC_EMBED_M),
101            Self::JinaEmbeddingsV2BaseEs => write!(f, "{}", Self::JINA_EMBEDDINGS_V2_BASE_ES),
102            Self::Other(name) => write!(f, "{}", name),
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_parse_snowflake_arctic_embed_xs() {
113        let model_str = "snowflake-arctic-embed:xs";
114        let parsed_model = OllamaTextEmbeddingsInference::from_string(model_str);
115        assert_eq!(parsed_model, Ok(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM));
116    }
117
118    #[test]
119    fn test_parse_jina_embeddings_v2_base_es() {
120        let model_str = "jina/jina-embeddings-v2-base-es:latest";
121        let parsed_model = OllamaTextEmbeddingsInference::from_string(model_str);
122        assert_eq!(parsed_model, Ok(OllamaTextEmbeddingsInference::JinaEmbeddingsV2BaseEs));
123    }
124
125    #[test]
126    fn test_parse_snowflake_arctic_embed_xs_as_embedding_model_type() {
127        let model_str = "snowflake-arctic-embed:xs";
128        let parsed_model = EmbeddingModelType::from_string(model_str);
129        assert_eq!(
130            parsed_model,
131            Ok(EmbeddingModelType::OllamaTextEmbeddingsInference(
132                OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM
133            ))
134        );
135    }
136
137    #[test]
138    fn test_parse_jina_embeddings_v2_base_es_as_embedding_model_type() {
139        let model_str = "jina/jina-embeddings-v2-base-es:latest";
140        let parsed_model = EmbeddingModelType::from_string(model_str);
141        assert_eq!(
142            parsed_model,
143            Ok(EmbeddingModelType::OllamaTextEmbeddingsInference(
144                OllamaTextEmbeddingsInference::JinaEmbeddingsV2BaseEs
145            ))
146        );
147    }
148}