1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::{error::Error, num::NonZeroUsize};
6
7pub mod prelude {
8 pub use crate::{
9 EmbeddingDimension, EmbeddingDistanceMetric, EmbeddingError, EmbeddingIndexKind,
10 EmbeddingModality, EmbeddingModelName, EmbeddingNormalizationKind, EmbeddingSearchKind,
11 EmbeddingVectorFormat, EmbeddingVectorId, EmbeddingVectorShape,
12 };
13}
14
15macro_rules! embedding_text_newtype {
16 ($name:ident) => {
17 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18 pub struct $name(String);
19
20 impl $name {
21 pub fn new(value: impl AsRef<str>) -> Result<Self, EmbeddingError> {
22 non_empty_text(value).map(Self)
23 }
24
25 pub fn as_str(&self) -> &str {
26 &self.0
27 }
28 }
29
30 impl AsRef<str> for $name {
31 fn as_ref(&self) -> &str {
32 self.as_str()
33 }
34 }
35
36 impl fmt::Display for $name {
37 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38 formatter.write_str(self.as_str())
39 }
40 }
41
42 impl FromStr for $name {
43 type Err = EmbeddingError;
44
45 fn from_str(value: &str) -> Result<Self, Self::Err> {
46 Self::new(value)
47 }
48 }
49
50 impl TryFrom<&str> for $name {
51 type Error = EmbeddingError;
52
53 fn try_from(value: &str) -> Result<Self, Self::Error> {
54 Self::new(value)
55 }
56 }
57 };
58}
59
60macro_rules! embedding_enum {
61 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63 pub enum $name {
64 $($variant),+
65 }
66
67 impl $name {
68 pub const fn as_str(self) -> &'static str {
69 match self {
70 $(Self::$variant => $label),+
71 }
72 }
73 }
74
75 impl fmt::Display for $name {
76 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77 formatter.write_str(self.as_str())
78 }
79 }
80
81 impl FromStr for $name {
82 type Err = EmbeddingError;
83
84 fn from_str(value: &str) -> Result<Self, Self::Err> {
85 match normalized_label(value)?.as_str() {
86 $($label => Ok(Self::$variant),)+
87 _ => Err(EmbeddingError::UnknownLabel),
88 }
89 }
90 }
91 };
92}
93
94embedding_text_newtype!(EmbeddingModelName);
95embedding_text_newtype!(EmbeddingVectorId);
96
97#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
98pub struct EmbeddingDimension(NonZeroUsize);
99
100impl EmbeddingDimension {
101 pub fn new(value: usize) -> Result<Self, EmbeddingError> {
102 NonZeroUsize::new(value)
103 .map(Self)
104 .ok_or(EmbeddingError::Zero)
105 }
106
107 pub const fn get(self) -> usize {
108 self.0.get()
109 }
110}
111
112#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
113pub struct EmbeddingVectorShape {
114 dimension: EmbeddingDimension,
115}
116
117impl EmbeddingVectorShape {
118 pub const fn new(dimension: EmbeddingDimension) -> Self {
119 Self { dimension }
120 }
121
122 pub const fn dimension(self) -> EmbeddingDimension {
123 self.dimension
124 }
125}
126
127embedding_enum!(EmbeddingModality {
128 Text => "text",
129 Image => "image",
130 Audio => "audio",
131 Video => "video",
132 Code => "code",
133 Tabular => "tabular",
134 Graph => "graph",
135 Multimodal => "multimodal",
136 Other => "other",
137});
138
139embedding_enum!(EmbeddingDistanceMetric {
140 Cosine => "cosine",
141 DotProduct => "dot-product",
142 Euclidean => "euclidean",
143 Manhattan => "manhattan",
144 Hamming => "hamming",
145 Jaccard => "jaccard",
146 Custom => "custom",
147});
148
149embedding_enum!(EmbeddingNormalizationKind {
150 None => "none",
151 Unit => "unit",
152 MeanCentered => "mean-centered",
153 Standardized => "standardized",
154 Custom => "custom",
155});
156
157embedding_enum!(EmbeddingIndexKind {
158 Flat => "flat",
159 Hnsw => "hnsw",
160 Ivf => "ivf",
161 Pq => "pq",
162 IvfPq => "ivf-pq",
163 Annoy => "annoy",
164 Scann => "scann",
165 Other => "other",
166});
167
168embedding_enum!(EmbeddingSearchKind {
169 Exact => "exact",
170 Approximate => "approximate",
171 Hybrid => "hybrid",
172 Filtered => "filtered",
173 Reranked => "reranked",
174});
175
176embedding_enum!(EmbeddingVectorFormat {
177 Dense => "dense",
178 Sparse => "sparse",
179 Binary => "binary",
180 Quantized => "quantized",
181 Mixed => "mixed",
182});
183
184#[derive(Clone, Copy, Debug, Eq, PartialEq)]
185pub enum EmbeddingError {
186 Empty,
187 Zero,
188 UnknownLabel,
189}
190
191impl fmt::Display for EmbeddingError {
192 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
193 match self {
194 Self::Empty => formatter.write_str("embedding metadata text cannot be empty"),
195 Self::Zero => formatter.write_str("embedding dimension must be positive"),
196 Self::UnknownLabel => formatter.write_str("unknown embedding metadata label"),
197 }
198 }
199}
200
201impl Error for EmbeddingError {}
202
203fn non_empty_text(value: impl AsRef<str>) -> Result<String, EmbeddingError> {
204 let trimmed = value.as_ref().trim();
205 if trimmed.is_empty() {
206 Err(EmbeddingError::Empty)
207 } else {
208 Ok(trimmed.to_string())
209 }
210}
211
212fn normalized_label(value: &str) -> Result<String, EmbeddingError> {
213 let trimmed = value.trim();
214 if trimmed.is_empty() {
215 Err(EmbeddingError::Empty)
216 } else {
217 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::{
224 EmbeddingDimension, EmbeddingDistanceMetric, EmbeddingError, EmbeddingIndexKind,
225 EmbeddingModelName, EmbeddingNormalizationKind, EmbeddingVectorShape,
226 };
227
228 #[test]
229 fn validates_embedding_names_and_dimensions() -> Result<(), EmbeddingError> {
230 let model = EmbeddingModelName::new(" text-embedding ")?;
231 let dimension = EmbeddingDimension::new(384)?;
232 let shape = EmbeddingVectorShape::new(dimension);
233
234 assert_eq!(model.as_str(), "text-embedding");
235 assert_eq!(shape.dimension().get(), 384);
236 assert_eq!(EmbeddingDimension::new(0), Err(EmbeddingError::Zero));
237 Ok(())
238 }
239
240 #[test]
241 fn displays_and_parses_embedding_enums() -> Result<(), EmbeddingError> {
242 assert_eq!(
243 "dot product".parse::<EmbeddingDistanceMetric>()?,
244 EmbeddingDistanceMetric::DotProduct
245 );
246 assert_eq!(
247 "mean centered".parse::<EmbeddingNormalizationKind>()?,
248 EmbeddingNormalizationKind::MeanCentered
249 );
250 assert_eq!(
251 "ivf pq".parse::<EmbeddingIndexKind>()?,
252 EmbeddingIndexKind::IvfPq
253 );
254 Ok(())
255 }
256}