tuitbot_core/llm/
embedding.rs1use std::fmt;
7
8pub type EmbeddingInput = Vec<String>;
10
11pub type EmbeddingVector = Vec<f32>;
13
14#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
16pub struct EmbeddingUsage {
17 pub total_tokens: u32,
19}
20
21#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
23pub struct EmbeddingResponse {
24 pub embeddings: Vec<EmbeddingVector>,
26 pub model: String,
28 pub dimension: usize,
30 pub usage: EmbeddingUsage,
32}
33
34#[derive(Debug, thiserror::Error)]
36pub enum EmbeddingError {
37 #[error("embedding provider not configured: {0}")]
39 NotConfigured(String),
40
41 #[error("embedding API error (status {status}): {message}")]
43 Api {
44 status: u16,
46 message: String,
48 },
49
50 #[error("embedding network error: {0}")]
52 Network(String),
53
54 #[error("embedding dimension mismatch: expected {expected}, got {actual}")]
56 DimensionMismatch {
57 expected: usize,
59 actual: usize,
61 },
62
63 #[error("embedding batch too large: {size} exceeds max {max}")]
65 BatchTooLarge {
66 size: usize,
68 max: usize,
70 },
71
72 #[error("embedding rate limited, retry after {retry_after_secs}s")]
74 RateLimited {
75 retry_after_secs: u64,
77 },
78
79 #[error("embedding internal error: {0}")]
81 Internal(String),
82}
83
84impl fmt::Display for EmbeddingUsage {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 write!(f, "EmbeddingUsage(tokens={})", self.total_tokens)
87 }
88}
89
90#[async_trait::async_trait]
95pub trait EmbeddingProvider: Send + Sync {
96 fn name(&self) -> &str;
98
99 fn dimension(&self) -> usize;
101
102 fn model_id(&self) -> &str;
104
105 async fn embed(&self, inputs: EmbeddingInput) -> Result<EmbeddingResponse, EmbeddingError>;
107
108 async fn health_check(&self) -> Result<(), EmbeddingError>;
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn embedding_usage_default_is_zero() {
118 let usage = EmbeddingUsage::default();
119 assert_eq!(usage.total_tokens, 0);
120 }
121
122 #[test]
123 fn embedding_usage_display() {
124 let usage = EmbeddingUsage { total_tokens: 42 };
125 assert_eq!(usage.to_string(), "EmbeddingUsage(tokens=42)");
126 }
127
128 #[test]
129 fn embedding_response_fields() {
130 let response = EmbeddingResponse {
131 embeddings: vec![vec![0.1, 0.2, 0.3]],
132 model: "test-model".to_string(),
133 dimension: 3,
134 usage: EmbeddingUsage { total_tokens: 10 },
135 };
136 assert_eq!(response.embeddings.len(), 1);
137 assert_eq!(response.dimension, 3);
138 assert_eq!(response.model, "test-model");
139 assert_eq!(response.usage.total_tokens, 10);
140 }
141
142 #[test]
143 fn embedding_response_serde_roundtrip() {
144 let response = EmbeddingResponse {
145 embeddings: vec![vec![1.0, 2.0], vec![3.0, 4.0]],
146 model: "test".to_string(),
147 dimension: 2,
148 usage: EmbeddingUsage { total_tokens: 5 },
149 };
150 let json = serde_json::to_string(&response).expect("serialize");
151 let deserialized: EmbeddingResponse = serde_json::from_str(&json).expect("deserialize");
152 assert_eq!(deserialized.embeddings.len(), 2);
153 assert_eq!(deserialized.dimension, 2);
154 assert_eq!(deserialized.usage.total_tokens, 5);
155 }
156
157 #[test]
158 fn embedding_error_display_not_configured() {
159 let err = EmbeddingError::NotConfigured("missing api_key".to_string());
160 assert!(err.to_string().contains("not configured"));
161 }
162
163 #[test]
164 fn embedding_error_display_api() {
165 let err = EmbeddingError::Api {
166 status: 500,
167 message: "server error".to_string(),
168 };
169 let msg = err.to_string();
170 assert!(msg.contains("500"));
171 assert!(msg.contains("server error"));
172 }
173
174 #[test]
175 fn embedding_error_display_dimension_mismatch() {
176 let err = EmbeddingError::DimensionMismatch {
177 expected: 768,
178 actual: 1536,
179 };
180 let msg = err.to_string();
181 assert!(msg.contains("768"));
182 assert!(msg.contains("1536"));
183 }
184
185 #[test]
186 fn embedding_error_display_batch_too_large() {
187 let err = EmbeddingError::BatchTooLarge {
188 size: 3000,
189 max: 2048,
190 };
191 let msg = err.to_string();
192 assert!(msg.contains("3000"));
193 assert!(msg.contains("2048"));
194 }
195
196 #[test]
197 fn embedding_error_display_rate_limited() {
198 let err = EmbeddingError::RateLimited {
199 retry_after_secs: 30,
200 };
201 assert!(err.to_string().contains("30"));
202 }
203
204 #[test]
205 fn embedding_error_display_network() {
206 let err = EmbeddingError::Network("connection refused".to_string());
207 let msg = err.to_string();
208 assert!(msg.contains("connection refused"));
209 assert!(msg.contains("network"));
210 }
211
212 #[test]
213 fn embedding_error_display_internal() {
214 let err = EmbeddingError::Internal("something broke".to_string());
215 let msg = err.to_string();
216 assert!(msg.contains("something broke"));
217 assert!(msg.contains("internal"));
218 }
219
220 #[test]
221 fn embedding_usage_display_zero() {
222 let usage = EmbeddingUsage::default();
223 assert_eq!(usage.to_string(), "EmbeddingUsage(tokens=0)");
224 }
225
226 #[test]
227 fn embedding_response_empty_vectors() {
228 let response = EmbeddingResponse {
229 embeddings: vec![],
230 model: "empty".to_string(),
231 dimension: 0,
232 usage: EmbeddingUsage::default(),
233 };
234 assert!(response.embeddings.is_empty());
235 assert_eq!(response.dimension, 0);
236 }
237
238 #[test]
239 fn embedding_usage_serde_roundtrip() {
240 let usage = EmbeddingUsage { total_tokens: 100 };
241 let json = serde_json::to_string(&usage).expect("serialize");
242 let deserialized: EmbeddingUsage = serde_json::from_str(&json).expect("deserialize");
243 assert_eq!(deserialized.total_tokens, 100);
244 }
245}