zoo_embedding/
model_type.rs1use std::fmt;
2use std::hash::Hash;
3
4use crate::zoo_embedding_errors::ZooEmbeddingError;
5
6pub type EmbeddingModelTypeString = String;
7
8#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, Hash)]
9pub enum EmbeddingModelType {
10 OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference),
11}
12
13impl EmbeddingModelType {
14 pub fn from_string(s: &str) -> Result<Self, ZooEmbeddingError> {
15 OllamaTextEmbeddingsInference::from_string(s)
16 .map(EmbeddingModelType::OllamaTextEmbeddingsInference)
17 .map_err(|_| ZooEmbeddingError::InvalidModelArchitecture)
18 }
19
20 pub fn max_input_token_count(&self) -> usize {
21 match self {
22 EmbeddingModelType::OllamaTextEmbeddingsInference(model) => model.max_input_token_count(),
23 }
24 }
25
26 pub fn embedding_normalization_factor(&self) -> f32 {
27 match self {
28 EmbeddingModelType::OllamaTextEmbeddingsInference(model) => model.embedding_normalization_factor(),
29 }
30 }
31
32 pub fn vector_dimensions(&self) -> Result<usize, ZooEmbeddingError> {
33 match self {
34 EmbeddingModelType::OllamaTextEmbeddingsInference(model) => model.vector_dimensions(),
35 }
36 }
37}
38
39impl fmt::Display for EmbeddingModelType {
40 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
41 match self {
42 EmbeddingModelType::OllamaTextEmbeddingsInference(model) => write!(f, "{}", model),
43 }
44 }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
48pub enum OllamaTextEmbeddingsInference {
49 AllMiniLML6v2,
50 #[serde(alias = "SnowflakeArcticEmbed_M")]
51 SnowflakeArcticEmbedM,
52 JinaEmbeddingsV2BaseEs,
53 Other(String),
54}
55
56impl OllamaTextEmbeddingsInference {
57 const ALL_MINI_LML6V2: &'static str = "all-minilm:l6-v2";
58 const SNOWFLAKE_ARCTIC_EMBED_M: &'static str = "snowflake-arctic-embed:xs";
59 const JINA_EMBEDDINGS_V2_BASE_ES: &'static str = "jina/jina-embeddings-v2-base-es:latest";
60
61 pub fn from_string(s: &str) -> Result<Self, ZooEmbeddingError> {
62 match s {
63 Self::ALL_MINI_LML6V2 => Ok(Self::AllMiniLML6v2),
64 Self::SNOWFLAKE_ARCTIC_EMBED_M => Ok(Self::SnowflakeArcticEmbedM),
65 Self::JINA_EMBEDDINGS_V2_BASE_ES => Ok(Self::JinaEmbeddingsV2BaseEs),
66 _ => Err(ZooEmbeddingError::InvalidModelArchitecture),
67 }
68 }
69
70 pub fn max_input_token_count(&self) -> usize {
71 match self {
72 Self::JinaEmbeddingsV2BaseEs => 1024,
73 _ => 512,
74 }
75 }
76
77 pub fn embedding_normalization_factor(&self) -> f32 {
78 match self {
79 Self::JinaEmbeddingsV2BaseEs => 1.5,
80 _ => 1.0,
81 }
82 }
83
84 pub fn vector_dimensions(&self) -> Result<usize, ZooEmbeddingError> {
85 match self {
86 Self::SnowflakeArcticEmbedM => Ok(384),
87 Self::JinaEmbeddingsV2BaseEs => Ok(768),
88 _ => Err(ZooEmbeddingError::UnimplementedModelDimensions(format!(
89 "{:?}",
90 self
91 ))),
92 }
93 }
94}
95
96impl fmt::Display for OllamaTextEmbeddingsInference {
97 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98 match self {
99 Self::AllMiniLML6v2 => write!(f, "{}", Self::ALL_MINI_LML6V2),
100 Self::SnowflakeArcticEmbedM => write!(f, "{}", Self::SNOWFLAKE_ARCTIC_EMBED_M),
101 Self::JinaEmbeddingsV2BaseEs => write!(f, "{}", Self::JINA_EMBEDDINGS_V2_BASE_ES),
102 Self::Other(name) => write!(f, "{}", name),
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn test_parse_snowflake_arctic_embed_xs() {
113 let model_str = "snowflake-arctic-embed:xs";
114 let parsed_model = OllamaTextEmbeddingsInference::from_string(model_str);
115 assert_eq!(parsed_model, Ok(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM));
116 }
117
118 #[test]
119 fn test_parse_jina_embeddings_v2_base_es() {
120 let model_str = "jina/jina-embeddings-v2-base-es:latest";
121 let parsed_model = OllamaTextEmbeddingsInference::from_string(model_str);
122 assert_eq!(parsed_model, Ok(OllamaTextEmbeddingsInference::JinaEmbeddingsV2BaseEs));
123 }
124
125 #[test]
126 fn test_parse_snowflake_arctic_embed_xs_as_embedding_model_type() {
127 let model_str = "snowflake-arctic-embed:xs";
128 let parsed_model = EmbeddingModelType::from_string(model_str);
129 assert_eq!(
130 parsed_model,
131 Ok(EmbeddingModelType::OllamaTextEmbeddingsInference(
132 OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM
133 ))
134 );
135 }
136
137 #[test]
138 fn test_parse_jina_embeddings_v2_base_es_as_embedding_model_type() {
139 let model_str = "jina/jina-embeddings-v2-base-es:latest";
140 let parsed_model = EmbeddingModelType::from_string(model_str);
141 assert_eq!(
142 parsed_model,
143 Ok(EmbeddingModelType::OllamaTextEmbeddingsInference(
144 OllamaTextEmbeddingsInference::JinaEmbeddingsV2BaseEs
145 ))
146 );
147 }
148}