Skip to main content

tldr_core/search/
embedding_client.rs

1//! HTTP client for Python embedding service
2//!
3//! Provides semantic search embeddings by calling an external Python service.
4//! The service should expose a REST API for generating embeddings.
5//!
6//! # Mitigation M8
7//! This client implements graceful degradation - if the embedding service
8//! is unavailable, hybrid search falls back to BM25-only mode.
9
10use std::time::Duration;
11
12use serde::{Deserialize, Serialize};
13
14use crate::error::TldrError;
15use crate::TldrResult;
16
17/// Default embedding service URL
18pub const DEFAULT_EMBEDDING_URL: &str = "http://localhost:8765";
19
20/// Default timeout for embedding requests
21pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
22
23/// Embedding vector dimension (for BGE-large-en-v1.5)
24pub const EMBEDDING_DIM: usize = 1024;
25
26/// Request to the embedding service
27#[derive(Debug, Clone, Serialize)]
28struct EmbeddingRequest {
29    /// Text to embed
30    text: String,
31    /// Optional batch of texts
32    #[serde(skip_serializing_if = "Option::is_none")]
33    texts: Option<Vec<String>>,
34}
35
36/// Response from the embedding service
37#[derive(Debug, Clone, Deserialize)]
38struct _EmbeddingResponse {
39    /// Single embedding vector
40    #[serde(default)]
41    _embedding: Vec<f32>,
42    /// Batch of embedding vectors
43    #[serde(default)]
44    _embeddings: Vec<Vec<f32>>,
45}
46
47/// Semantic search result from embedding service
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SemanticResult {
50    /// Document ID / file path
51    pub doc_id: String,
52    /// Cosine similarity score (0-1)
53    pub score: f64,
54    /// Start line of matching region
55    pub line_start: u32,
56    /// End line of matching region
57    pub line_end: u32,
58    /// Snippet of content
59    pub snippet: String,
60}
61
62/// Search request to the embedding service
63#[derive(Debug, Clone, Serialize)]
64struct SearchRequest {
65    /// Query text
66    query: String,
67    /// Number of results to return
68    top_k: usize,
69    /// Project path to search in
70    project: String,
71}
72
73/// Search response from embedding service
74#[derive(Debug, Clone, Deserialize)]
75struct _SearchResponse {
76    /// Search results
77    _results: Vec<SemanticResult>,
78}
79
80/// Client for the Python embedding service
81///
82/// # Example
83/// ```ignore
84/// use tldr_core::search::embedding_client::EmbeddingClient;
85///
86/// let client = EmbeddingClient::new("http://localhost:8765");
87/// if client.is_available().await {
88///     let results = client.search("process data", "src/", 10).await?;
89/// }
90/// ```
91#[derive(Debug, Clone)]
92pub struct EmbeddingClient {
93    /// Base URL of the embedding service
94    base_url: String,
95    /// Request timeout (reserved for future HTTP client integration)
96    _timeout: Duration,
97}
98
99impl Default for EmbeddingClient {
100    fn default() -> Self {
101        Self::new(DEFAULT_EMBEDDING_URL)
102    }
103}
104
105impl EmbeddingClient {
106    /// Create a new embedding client
107    ///
108    /// # Arguments
109    /// * `base_url` - Base URL of the embedding service (e.g., "http://localhost:8765")
110    pub fn new(base_url: &str) -> Self {
111        Self {
112            base_url: base_url.trim_end_matches('/').to_string(),
113            _timeout: DEFAULT_TIMEOUT,
114        }
115    }
116
117    /// Create a client with custom timeout
118    pub fn with_timeout(base_url: &str, timeout: Duration) -> Self {
119        Self {
120            base_url: base_url.trim_end_matches('/').to_string(),
121            _timeout: timeout,
122        }
123    }
124
125    /// Get the base URL
126    pub fn base_url(&self) -> &str {
127        &self.base_url
128    }
129
130    /// Check if the embedding service is available
131    ///
132    /// # Returns
133    /// `true` if the service responds to health check, `false` otherwise
134    pub fn is_available(&self) -> bool {
135        // Synchronous check - try to connect
136        // This is a simplified check; in production, use async
137        std::net::TcpStream::connect_timeout(&self.parse_address(), Duration::from_secs(1)).is_ok()
138    }
139
140    /// Parse the base URL into a socket address
141    fn parse_address(&self) -> std::net::SocketAddr {
142        let url = self
143            .base_url
144            .strip_prefix("http://")
145            .unwrap_or(&self.base_url);
146        let url = url.strip_prefix("https://").unwrap_or(url);
147
148        // Default port
149        let (host, port) = if let Some((h, p)) = url.split_once(':') {
150            (h, p.parse().unwrap_or(8765))
151        } else {
152            (url, 8765)
153        };
154
155        // Resolve to socket address
156        use std::net::ToSocketAddrs;
157        format!("{}:{}", host, port)
158            .to_socket_addrs()
159            .ok()
160            .and_then(|mut addrs| addrs.next())
161            .unwrap_or_else(|| std::net::SocketAddr::from(([127, 0, 0, 1], port)))
162    }
163
164    /// Perform semantic search
165    ///
166    /// # Arguments
167    /// * `query` - Search query text
168    /// * `project` - Project path to search in
169    /// * `top_k` - Number of results to return
170    ///
171    /// # Returns
172    /// Vector of semantic search results, or error if service unavailable
173    ///
174    /// # Mitigation M8
175    /// Returns ConnectionFailed error if service is unavailable,
176    /// allowing caller to fall back to BM25-only search.
177    pub fn search(
178        &self,
179        query: &str,
180        project: &str,
181        top_k: usize,
182    ) -> TldrResult<Vec<SemanticResult>> {
183        // For sync implementation, check availability first
184        if !self.is_available() {
185            return Err(TldrError::ConnectionFailed(format!(
186                "Embedding service at {} is not available",
187                self.base_url
188            )));
189        }
190
191        // In a real implementation, this would make an HTTP request
192        // For now, return empty results to allow compilation and testing
193        // The actual HTTP request would use reqwest or similar
194
195        // Placeholder: return empty results
196        // Real implementation would POST to {base_url}/search
197        let _request = SearchRequest {
198            query: query.to_string(),
199            top_k,
200            project: project.to_string(),
201        };
202
203        // TODO: Implement actual HTTP request
204        // For now, signal that service call would happen
205        Ok(Vec::new())
206    }
207
208    /// Get embedding for a single text
209    ///
210    /// # Arguments
211    /// * `text` - Text to embed
212    ///
213    /// # Returns
214    /// Embedding vector of dimension EMBEDDING_DIM
215    pub fn embed(&self, text: &str) -> TldrResult<Vec<f32>> {
216        if !self.is_available() {
217            return Err(TldrError::ConnectionFailed(format!(
218                "Embedding service at {} is not available",
219                self.base_url
220            )));
221        }
222
223        let _request = EmbeddingRequest {
224            text: text.to_string(),
225            texts: None,
226        };
227
228        // TODO: Implement actual HTTP request
229        // Return placeholder zeros for now
230        Ok(vec![0.0; EMBEDDING_DIM])
231    }
232
233    /// Get embeddings for multiple texts
234    ///
235    /// # Arguments
236    /// * `texts` - Texts to embed
237    ///
238    /// # Returns
239    /// Vector of embedding vectors
240    pub fn embed_batch(&self, texts: &[String]) -> TldrResult<Vec<Vec<f32>>> {
241        if !self.is_available() {
242            return Err(TldrError::ConnectionFailed(format!(
243                "Embedding service at {} is not available",
244                self.base_url
245            )));
246        }
247
248        let _request = EmbeddingRequest {
249            text: String::new(),
250            texts: Some(texts.to_vec()),
251        };
252
253        // TODO: Implement actual HTTP request
254        // Return placeholder zeros for now
255        Ok(texts.iter().map(|_| vec![0.0; EMBEDDING_DIM]).collect())
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_client_creation() {
265        let client = EmbeddingClient::new("http://localhost:8765");
266        assert_eq!(client.base_url(), "http://localhost:8765");
267    }
268
269    #[test]
270    fn test_client_with_trailing_slash() {
271        let client = EmbeddingClient::new("http://localhost:8765/");
272        assert_eq!(client.base_url(), "http://localhost:8765");
273    }
274
275    #[test]
276    fn test_client_unavailable() {
277        // Use a port that's unlikely to be in use
278        let client = EmbeddingClient::new("http://localhost:59999");
279        assert!(!client.is_available());
280    }
281
282    #[test]
283    fn test_search_unavailable_service() {
284        let client = EmbeddingClient::new("http://localhost:59999");
285        let result = client.search("query", "project", 10);
286        assert!(result.is_err());
287
288        if let Err(TldrError::ConnectionFailed(msg)) = result {
289            assert!(msg.contains("not available"));
290        } else {
291            panic!("Expected ConnectionFailed error");
292        }
293    }
294}