Skip to main content

tuitbot_core/llm/
embedding.rs

1//! Embedding provider abstraction for semantic search.
2//!
3//! Provides a trait-based abstraction for embedding providers (OpenAI, Ollama)
4//! with typed responses, usage tracking, and health checking.
5
6use std::fmt;
7
8/// Input texts to embed.
9pub type EmbeddingInput = Vec<String>;
10
11/// A single embedding vector.
12pub type EmbeddingVector = Vec<f32>;
13
14/// Token usage from an embedding request.
15#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
16pub struct EmbeddingUsage {
17    /// Total tokens consumed across all inputs.
18    pub total_tokens: u32,
19}
20
21/// Response from an embedding request.
22#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
23pub struct EmbeddingResponse {
24    /// One embedding vector per input text, in order.
25    pub embeddings: Vec<EmbeddingVector>,
26    /// The model that produced these embeddings.
27    pub model: String,
28    /// Dimensionality of each vector.
29    pub dimension: usize,
30    /// Token usage for this request.
31    pub usage: EmbeddingUsage,
32}
33
34/// Errors from embedding operations.
35#[derive(Debug, thiserror::Error)]
36pub enum EmbeddingError {
37    /// No embedding provider is configured.
38    #[error("embedding provider not configured: {0}")]
39    NotConfigured(String),
40
41    /// The embedding API returned an error.
42    #[error("embedding API error (status {status}): {message}")]
43    Api {
44        /// HTTP status code.
45        status: u16,
46        /// Error message from the provider.
47        message: String,
48    },
49
50    /// Network-level failure communicating with the provider.
51    #[error("embedding network error: {0}")]
52    Network(String),
53
54    /// Returned vectors have unexpected dimensions.
55    #[error("embedding dimension mismatch: expected {expected}, got {actual}")]
56    DimensionMismatch {
57        /// The expected dimension.
58        expected: usize,
59        /// The actual dimension received.
60        actual: usize,
61    },
62
63    /// Batch exceeds the provider's maximum.
64    #[error("embedding batch too large: {size} exceeds max {max}")]
65    BatchTooLarge {
66        /// The batch size attempted.
67        size: usize,
68        /// The provider's maximum batch size.
69        max: usize,
70    },
71
72    /// Provider rate limit hit.
73    #[error("embedding rate limited, retry after {retry_after_secs}s")]
74    RateLimited {
75        /// Seconds to wait before retrying.
76        retry_after_secs: u64,
77    },
78
79    /// Internal storage or processing error.
80    #[error("embedding internal error: {0}")]
81    Internal(String),
82}
83
84impl fmt::Display for EmbeddingUsage {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        write!(f, "EmbeddingUsage(tokens={})", self.total_tokens)
87    }
88}
89
90/// Trait abstracting embedding provider operations.
91///
92/// Implementations include `OpenAiEmbeddingProvider` and `OllamaEmbeddingProvider`.
93/// Object-safe for use as `Box<dyn EmbeddingProvider>`.
94#[async_trait::async_trait]
95pub trait EmbeddingProvider: Send + Sync {
96    /// Display name of this provider (e.g., "openai", "ollama").
97    fn name(&self) -> &str;
98
99    /// Vector dimension produced by this provider's model.
100    fn dimension(&self) -> usize;
101
102    /// Model identifier string.
103    fn model_id(&self) -> &str;
104
105    /// Embed a batch of texts into vectors.
106    async fn embed(&self, inputs: EmbeddingInput) -> Result<EmbeddingResponse, EmbeddingError>;
107
108    /// Check if the provider is reachable and configured correctly.
109    async fn health_check(&self) -> Result<(), EmbeddingError>;
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn embedding_usage_default_is_zero() {
118        let usage = EmbeddingUsage::default();
119        assert_eq!(usage.total_tokens, 0);
120    }
121
122    #[test]
123    fn embedding_usage_display() {
124        let usage = EmbeddingUsage { total_tokens: 42 };
125        assert_eq!(usage.to_string(), "EmbeddingUsage(tokens=42)");
126    }
127
128    #[test]
129    fn embedding_response_fields() {
130        let response = EmbeddingResponse {
131            embeddings: vec![vec![0.1, 0.2, 0.3]],
132            model: "test-model".to_string(),
133            dimension: 3,
134            usage: EmbeddingUsage { total_tokens: 10 },
135        };
136        assert_eq!(response.embeddings.len(), 1);
137        assert_eq!(response.dimension, 3);
138        assert_eq!(response.model, "test-model");
139        assert_eq!(response.usage.total_tokens, 10);
140    }
141
142    #[test]
143    fn embedding_response_serde_roundtrip() {
144        let response = EmbeddingResponse {
145            embeddings: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
146            model: "test".to_string(),
147            dimension: 2,
148            usage: EmbeddingUsage { total_tokens: 5 },
149        };
150        let json = serde_json::to_string(&response).expect("serialize");
151        let deserialized: EmbeddingResponse = serde_json::from_str(&json).expect("deserialize");
152        assert_eq!(deserialized.embeddings.len(), 2);
153        assert_eq!(deserialized.dimension, 2);
154        assert_eq!(deserialized.usage.total_tokens, 5);
155    }
156
157    #[test]
158    fn embedding_error_display_not_configured() {
159        let err = EmbeddingError::NotConfigured("missing api_key".to_string());
160        assert!(err.to_string().contains("not configured"));
161    }
162
163    #[test]
164    fn embedding_error_display_api() {
165        let err = EmbeddingError::Api {
166            status: 500,
167            message: "server error".to_string(),
168        };
169        let msg = err.to_string();
170        assert!(msg.contains("500"));
171        assert!(msg.contains("server error"));
172    }
173
174    #[test]
175    fn embedding_error_display_dimension_mismatch() {
176        let err = EmbeddingError::DimensionMismatch {
177            expected: 768,
178            actual: 1536,
179        };
180        let msg = err.to_string();
181        assert!(msg.contains("768"));
182        assert!(msg.contains("1536"));
183    }
184
185    #[test]
186    fn embedding_error_display_batch_too_large() {
187        let err = EmbeddingError::BatchTooLarge {
188            size: 3000,
189            max: 2048,
190        };
191        let msg = err.to_string();
192        assert!(msg.contains("3000"));
193        assert!(msg.contains("2048"));
194    }
195
196    #[test]
197    fn embedding_error_display_rate_limited() {
198        let err = EmbeddingError::RateLimited {
199            retry_after_secs: 30,
200        };
201        assert!(err.to_string().contains("30"));
202    }
203
204    #[test]
205    fn embedding_error_display_network() {
206        let err = EmbeddingError::Network("connection refused".to_string());
207        let msg = err.to_string();
208        assert!(msg.contains("connection refused"));
209        assert!(msg.contains("network"));
210    }
211
212    #[test]
213    fn embedding_error_display_internal() {
214        let err = EmbeddingError::Internal("something broke".to_string());
215        let msg = err.to_string();
216        assert!(msg.contains("something broke"));
217        assert!(msg.contains("internal"));
218    }
219
220    #[test]
221    fn embedding_usage_display_zero() {
222        let usage = EmbeddingUsage::default();
223        assert_eq!(usage.to_string(), "EmbeddingUsage(tokens=0)");
224    }
225
226    #[test]
227    fn embedding_response_empty_vectors() {
228        let response = EmbeddingResponse {
229            embeddings: vec![],
230            model: "empty".to_string(),
231            dimension: 0,
232            usage: EmbeddingUsage::default(),
233        };
234        assert!(response.embeddings.is_empty());
235        assert_eq!(response.dimension, 0);
236    }
237
238    #[test]
239    fn embedding_usage_serde_roundtrip() {
240        let usage = EmbeddingUsage { total_tokens: 100 };
241        let json = serde_json::to_string(&usage).expect("serialize");
242        let deserialized: EmbeddingUsage = serde_json::from_str(&json).expect("deserialize");
243        assert_eq!(deserialized.total_tokens, 100);
244    }
245}