tldr_core/search/embedding_client.rs
1//! HTTP client for Python embedding service
2//!
3//! Provides semantic search embeddings by calling an external Python service.
4//! The service should expose a REST API for generating embeddings.
5//!
6//! # Mitigation M8
7//! This client implements graceful degradation - if the embedding service
8//! is unavailable, hybrid search falls back to BM25-only mode.
9
10use std::time::Duration;
11
12use serde::{Deserialize, Serialize};
13
14use crate::error::TldrError;
15use crate::TldrResult;
16
17/// Default embedding service URL
18pub const DEFAULT_EMBEDDING_URL: &str = "http://localhost:8765";
19
20/// Default timeout for embedding requests
21pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
22
23/// Embedding vector dimension (for BGE-large-en-v1.5)
24pub const EMBEDDING_DIM: usize = 1024;
25
26/// Request to the embedding service
27#[derive(Debug, Clone, Serialize)]
28struct EmbeddingRequest {
29 /// Text to embed
30 text: String,
31 /// Optional batch of texts
32 #[serde(skip_serializing_if = "Option::is_none")]
33 texts: Option<Vec<String>>,
34}
35
36/// Response from the embedding service
37#[derive(Debug, Clone, Deserialize)]
38struct _EmbeddingResponse {
39 /// Single embedding vector
40 #[serde(default)]
41 _embedding: Vec<f32>,
42 /// Batch of embedding vectors
43 #[serde(default)]
44 _embeddings: Vec<Vec<f32>>,
45}
46
47/// Semantic search result from embedding service
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SemanticResult {
50 /// Document ID / file path
51 pub doc_id: String,
52 /// Cosine similarity score (0-1)
53 pub score: f64,
54 /// Start line of matching region
55 pub line_start: u32,
56 /// End line of matching region
57 pub line_end: u32,
58 /// Snippet of content
59 pub snippet: String,
60}
61
62/// Search request to the embedding service
63#[derive(Debug, Clone, Serialize)]
64struct SearchRequest {
65 /// Query text
66 query: String,
67 /// Number of results to return
68 top_k: usize,
69 /// Project path to search in
70 project: String,
71}
72
73/// Search response from embedding service
74#[derive(Debug, Clone, Deserialize)]
75struct _SearchResponse {
76 /// Search results
77 _results: Vec<SemanticResult>,
78}
79
80/// Client for the Python embedding service
81///
82/// # Example
83/// ```ignore
84/// use tldr_core::search::embedding_client::EmbeddingClient;
85///
86/// let client = EmbeddingClient::new("http://localhost:8765");
87/// if client.is_available().await {
88/// let results = client.search("process data", "src/", 10).await?;
89/// }
90/// ```
91#[derive(Debug, Clone)]
92pub struct EmbeddingClient {
93 /// Base URL of the embedding service
94 base_url: String,
95 /// Request timeout (reserved for future HTTP client integration)
96 _timeout: Duration,
97}
98
99impl Default for EmbeddingClient {
100 fn default() -> Self {
101 Self::new(DEFAULT_EMBEDDING_URL)
102 }
103}
104
105impl EmbeddingClient {
106 /// Create a new embedding client
107 ///
108 /// # Arguments
109 /// * `base_url` - Base URL of the embedding service (e.g., "http://localhost:8765")
110 pub fn new(base_url: &str) -> Self {
111 Self {
112 base_url: base_url.trim_end_matches('/').to_string(),
113 _timeout: DEFAULT_TIMEOUT,
114 }
115 }
116
117 /// Create a client with custom timeout
118 pub fn with_timeout(base_url: &str, timeout: Duration) -> Self {
119 Self {
120 base_url: base_url.trim_end_matches('/').to_string(),
121 _timeout: timeout,
122 }
123 }
124
125 /// Get the base URL
126 pub fn base_url(&self) -> &str {
127 &self.base_url
128 }
129
130 /// Check if the embedding service is available
131 ///
132 /// # Returns
133 /// `true` if the service responds to health check, `false` otherwise
134 pub fn is_available(&self) -> bool {
135 // Synchronous check - try to connect
136 // This is a simplified check; in production, use async
137 std::net::TcpStream::connect_timeout(&self.parse_address(), Duration::from_secs(1)).is_ok()
138 }
139
140 /// Parse the base URL into a socket address
141 fn parse_address(&self) -> std::net::SocketAddr {
142 let url = self
143 .base_url
144 .strip_prefix("http://")
145 .unwrap_or(&self.base_url);
146 let url = url.strip_prefix("https://").unwrap_or(url);
147
148 // Default port
149 let (host, port) = if let Some((h, p)) = url.split_once(':') {
150 (h, p.parse().unwrap_or(8765))
151 } else {
152 (url, 8765)
153 };
154
155 // Resolve to socket address
156 use std::net::ToSocketAddrs;
157 format!("{}:{}", host, port)
158 .to_socket_addrs()
159 .ok()
160 .and_then(|mut addrs| addrs.next())
161 .unwrap_or_else(|| std::net::SocketAddr::from(([127, 0, 0, 1], port)))
162 }
163
164 /// Perform semantic search
165 ///
166 /// # Arguments
167 /// * `query` - Search query text
168 /// * `project` - Project path to search in
169 /// * `top_k` - Number of results to return
170 ///
171 /// # Returns
172 /// Vector of semantic search results, or error if service unavailable
173 ///
174 /// # Mitigation M8
175 /// Returns ConnectionFailed error if service is unavailable,
176 /// allowing caller to fall back to BM25-only search.
177 pub fn search(
178 &self,
179 query: &str,
180 project: &str,
181 top_k: usize,
182 ) -> TldrResult<Vec<SemanticResult>> {
183 // For sync implementation, check availability first
184 if !self.is_available() {
185 return Err(TldrError::ConnectionFailed(format!(
186 "Embedding service at {} is not available",
187 self.base_url
188 )));
189 }
190
191 // In a real implementation, this would make an HTTP request
192 // For now, return empty results to allow compilation and testing
193 // The actual HTTP request would use reqwest or similar
194
195 // Placeholder: return empty results
196 // Real implementation would POST to {base_url}/search
197 let _request = SearchRequest {
198 query: query.to_string(),
199 top_k,
200 project: project.to_string(),
201 };
202
203 // TODO: Implement actual HTTP request
204 // For now, signal that service call would happen
205 Ok(Vec::new())
206 }
207
208 /// Get embedding for a single text
209 ///
210 /// # Arguments
211 /// * `text` - Text to embed
212 ///
213 /// # Returns
214 /// Embedding vector of dimension EMBEDDING_DIM
215 pub fn embed(&self, text: &str) -> TldrResult<Vec<f32>> {
216 if !self.is_available() {
217 return Err(TldrError::ConnectionFailed(format!(
218 "Embedding service at {} is not available",
219 self.base_url
220 )));
221 }
222
223 let _request = EmbeddingRequest {
224 text: text.to_string(),
225 texts: None,
226 };
227
228 // TODO: Implement actual HTTP request
229 // Return placeholder zeros for now
230 Ok(vec![0.0; EMBEDDING_DIM])
231 }
232
233 /// Get embeddings for multiple texts
234 ///
235 /// # Arguments
236 /// * `texts` - Texts to embed
237 ///
238 /// # Returns
239 /// Vector of embedding vectors
240 pub fn embed_batch(&self, texts: &[String]) -> TldrResult<Vec<Vec<f32>>> {
241 if !self.is_available() {
242 return Err(TldrError::ConnectionFailed(format!(
243 "Embedding service at {} is not available",
244 self.base_url
245 )));
246 }
247
248 let _request = EmbeddingRequest {
249 text: String::new(),
250 texts: Some(texts.to_vec()),
251 };
252
253 // TODO: Implement actual HTTP request
254 // Return placeholder zeros for now
255 Ok(texts.iter().map(|_| vec![0.0; EMBEDDING_DIM]).collect())
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_client_creation() {
265 let client = EmbeddingClient::new("http://localhost:8765");
266 assert_eq!(client.base_url(), "http://localhost:8765");
267 }
268
269 #[test]
270 fn test_client_with_trailing_slash() {
271 let client = EmbeddingClient::new("http://localhost:8765/");
272 assert_eq!(client.base_url(), "http://localhost:8765");
273 }
274
275 #[test]
276 fn test_client_unavailable() {
277 // Use a port that's unlikely to be in use
278 let client = EmbeddingClient::new("http://localhost:59999");
279 assert!(!client.is_available());
280 }
281
282 #[test]
283 fn test_search_unavailable_service() {
284 let client = EmbeddingClient::new("http://localhost:59999");
285 let result = client.search("query", "project", 10);
286 assert!(result.is_err());
287
288 if let Err(TldrError::ConnectionFailed(msg)) = result {
289 assert!(msg.contains("not available"));
290 } else {
291 panic!("Expected ConnectionFailed error");
292 }
293 }
294}