Skip to main content

sentinel_proxy/upstream/
inference_health.rs

1//! Inference-specific health check for LLM backends.
2//!
3//! This module provides a Pingora-compatible health check that verifies:
4//! - The inference server is responding (HTTP 200)
5//! - Expected models are available in the `/v1/models` response
6//!
7//! # Example Configuration
8//!
9//! ```kdl
10//! upstream "llm-pool" {
11//!     targets {
12//!         target { address "gpu-1:8080" }
13//!     }
14//!     health-check {
15//!         type "inference" {
16//!             endpoint "/v1/models"
17//!             expected-models "gpt-4" "llama-3"
18//!         }
19//!         interval-secs 30
20//!     }
21//! }
22//! ```
23
24use async_trait::async_trait;
25use pingora_core::{Error, ErrorType::CustomCode, Result};
26use pingora_load_balancing::health_check::HealthCheck as PingoraHealthCheck;
27use pingora_load_balancing::Backend;
28use serde::Deserialize;
29use std::time::Duration;
30use tokio::io::{AsyncReadExt, AsyncWriteExt};
31use tokio::net::TcpStream;
32use tracing::{debug, trace, warn};
33
34/// Inference health check for LLM/AI backends.
35///
36/// Implements Pingora's HealthCheck trait to integrate with the load balancing
37/// infrastructure. Verifies both server availability and model availability.
38pub struct InferenceHealthCheck {
39    /// Endpoint to probe (default: /v1/models)
40    endpoint: String,
41    /// Models that must be present for the backend to be healthy
42    expected_models: Vec<String>,
43    /// Connection/response timeout
44    timeout: Duration,
45    /// Consecutive successes needed to mark healthy
46    pub consecutive_success: usize,
47    /// Consecutive failures needed to mark unhealthy
48    pub consecutive_failure: usize,
49}
50
51/// OpenAI-compatible models list response.
52#[derive(Debug, Deserialize)]
53struct ModelsResponse {
54    data: Vec<ModelInfo>,
55}
56
57/// Individual model info in the response.
58#[derive(Debug, Deserialize)]
59struct ModelInfo {
60    id: String,
61    #[serde(default)]
62    object: String,
63}
64
65impl InferenceHealthCheck {
66    /// Create a new inference health check.
67    ///
68    /// # Arguments
69    ///
70    /// * `endpoint` - The models endpoint path (e.g., "/v1/models")
71    /// * `expected_models` - List of model IDs that must be available
72    /// * `timeout` - Connection and response timeout
73    pub fn new(endpoint: String, expected_models: Vec<String>, timeout: Duration) -> Self {
74        Self {
75            endpoint,
76            expected_models,
77            timeout,
78            consecutive_success: 1,
79            consecutive_failure: 1,
80        }
81    }
82
83    /// Perform the actual health check against a target.
84    async fn check_backend(&self, addr: &str) -> Result<(), String> {
85        // Parse address
86        let socket_addr: std::net::SocketAddr = addr
87            .parse()
88            .map_err(|e| format!("Invalid address '{}': {}", addr, e))?;
89
90        // Connect with timeout
91        let stream = tokio::time::timeout(self.timeout, TcpStream::connect(socket_addr))
92            .await
93            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
94            .map_err(|e| format!("Connection failed: {}", e))?;
95
96        // Build HTTP request
97        let request = format!(
98            "GET {} HTTP/1.1\r\n\
99             Host: {}\r\n\
100             User-Agent: Sentinel-HealthCheck/1.0\r\n\
101             Accept: application/json\r\n\
102             Connection: close\r\n\r\n",
103            self.endpoint, addr
104        );
105
106        // Send request
107        let mut stream = stream;
108        stream
109            .write_all(request.as_bytes())
110            .await
111            .map_err(|e| format!("Failed to send request: {}", e))?;
112
113        // Read response with timeout
114        let mut response = vec![0u8; 65536]; // 64KB buffer for models list
115        let n = tokio::time::timeout(self.timeout, stream.read(&mut response))
116            .await
117            .map_err(|_| "Response timeout".to_string())?
118            .map_err(|e| format!("Failed to read response: {}", e))?;
119
120        if n == 0 {
121            return Err("Empty response".to_string());
122        }
123
124        let response_str = String::from_utf8_lossy(&response[..n]);
125
126        // Parse HTTP status
127        let status_code = self.parse_status_code(&response_str)?;
128        if status_code != 200 {
129            return Err(format!("HTTP {} (expected 200)", status_code));
130        }
131
132        // If no expected models specified, just check HTTP 200
133        if self.expected_models.is_empty() {
134            trace!(
135                addr = %addr,
136                endpoint = %self.endpoint,
137                "Inference health check passed (no model verification)"
138            );
139            return Ok(());
140        }
141
142        // Extract JSON body
143        let body = self.extract_body(&response_str)?;
144
145        // Parse models response
146        let models = self.parse_models_response(body)?;
147
148        // Verify expected models are present
149        self.verify_models(&models)?;
150
151        trace!(
152            addr = %addr,
153            endpoint = %self.endpoint,
154            model_count = models.len(),
155            expected_models = ?self.expected_models,
156            "Inference health check passed"
157        );
158
159        Ok(())
160    }
161
162    /// Parse HTTP status code from response.
163    fn parse_status_code(&self, response: &str) -> Result<u16, String> {
164        response
165            .lines()
166            .next()
167            .and_then(|line| line.split_whitespace().nth(1))
168            .and_then(|code| code.parse().ok())
169            .ok_or_else(|| "Failed to parse HTTP status".to_string())
170    }
171
172    /// Extract body from HTTP response.
173    fn extract_body<'a>(&self, response: &'a str) -> Result<&'a str, String> {
174        response
175            .find("\r\n\r\n")
176            .map(|pos| &response[pos + 4..])
177            .ok_or_else(|| "Could not find response body".to_string())
178    }
179
180    /// Parse the models list from JSON response.
181    fn parse_models_response(&self, body: &str) -> Result<Vec<String>, String> {
182        // Handle chunked encoding - find the actual JSON
183        let json_body = if body.starts_with(|c: char| c.is_ascii_hexdigit()) {
184            // Chunked encoding: skip the chunk size line
185            body.lines()
186                .skip(1)
187                .take_while(|line| !line.is_empty() && *line != "0")
188                .collect::<Vec<_>>()
189                .join("\n")
190        } else {
191            body.to_string()
192        };
193
194        // Try parsing as OpenAI-compatible format
195        if let Ok(response) = serde_json::from_str::<ModelsResponse>(&json_body) {
196            return Ok(response.data.into_iter().map(|m| m.id).collect());
197        }
198
199        // Fallback: try parsing as simple array of model objects
200        if let Ok(models) = serde_json::from_str::<Vec<ModelInfo>>(&json_body) {
201            return Ok(models.into_iter().map(|m| m.id).collect());
202        }
203
204        // Last fallback: extract model IDs from any JSON with "id" fields
205        if let Ok(json) = serde_json::from_str::<serde_json::Value>(&json_body) {
206            if let Some(data) = json.get("data").and_then(|d| d.as_array()) {
207                let models: Vec<String> = data
208                    .iter()
209                    .filter_map(|m| m.get("id").and_then(|id| id.as_str()))
210                    .map(String::from)
211                    .collect();
212                if !models.is_empty() {
213                    return Ok(models);
214                }
215            }
216
217            // Check for models array directly
218            if let Some(models_arr) = json.get("models").and_then(|m| m.as_array()) {
219                let models: Vec<String> = models_arr
220                    .iter()
221                    .filter_map(|m| {
222                        m.get("id")
223                            .or_else(|| m.get("name"))
224                            .and_then(|id| id.as_str())
225                    })
226                    .map(String::from)
227                    .collect();
228                if !models.is_empty() {
229                    return Ok(models);
230                }
231            }
232        }
233
234        Err(format!(
235            "Failed to parse models response. Body preview: {}",
236            &json_body[..json_body.len().min(200)]
237        ))
238    }
239
240    /// Verify that all expected models are present.
241    fn verify_models(&self, available_models: &[String]) -> Result<(), String> {
242        let mut missing = Vec::new();
243
244        for expected in &self.expected_models {
245            // Check for exact match or prefix match (for versioned models)
246            let found = available_models.iter().any(|m| {
247                m == expected || m.starts_with(expected) || expected.starts_with(m)
248            });
249
250            if !found {
251                missing.push(expected.as_str());
252            }
253        }
254
255        if missing.is_empty() {
256            Ok(())
257        } else {
258            Err(format!(
259                "Missing models: {}. Available: {:?}",
260                missing.join(", "),
261                available_models
262            ))
263        }
264    }
265}
266
267#[async_trait]
268impl PingoraHealthCheck for InferenceHealthCheck {
269    /// Check if the backend is healthy.
270    ///
271    /// Returns Ok(()) if healthy, Err with message if not.
272    async fn check(&self, backend: &Backend) -> Result<()> {
273        let addr = backend.addr.to_string();
274
275        match self.check_backend(&addr).await {
276            Ok(()) => {
277                trace!(
278                    addr = %addr,
279                    endpoint = %self.endpoint,
280                    expected_models = ?self.expected_models,
281                    "Inference health check passed"
282                );
283                Ok(())
284            }
285            Err(error) => {
286                debug!(
287                    addr = %addr,
288                    endpoint = %self.endpoint,
289                    error = %error,
290                    "Inference health check failed"
291                );
292                Err(Error::explain(
293                    CustomCode("inference health check", 1),
294                    error,
295                ))
296            }
297        }
298    }
299
300    /// Return the health threshold for flipping health status.
301    ///
302    /// * `success: true` - returns consecutive_success (unhealthy -> healthy)
303    /// * `success: false` - returns consecutive_failure (healthy -> unhealthy)
304    fn health_threshold(&self, success: bool) -> usize {
305        if success {
306            self.consecutive_success
307        } else {
308            self.consecutive_failure
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_parse_openai_models_response() {
319        let check = InferenceHealthCheck::new(
320            "/v1/models".to_string(),
321            vec!["gpt-4".to_string()],
322            Duration::from_secs(5),
323        );
324
325        let body = r#"{"object":"list","data":[{"id":"gpt-4","object":"model"},{"id":"gpt-3.5-turbo","object":"model"}]}"#;
326        let models = check.parse_models_response(body).unwrap();
327
328        assert_eq!(models.len(), 2);
329        assert!(models.contains(&"gpt-4".to_string()));
330        assert!(models.contains(&"gpt-3.5-turbo".to_string()));
331    }
332
333    #[test]
334    fn test_parse_ollama_models_response() {
335        let check = InferenceHealthCheck::new(
336            "/api/tags".to_string(),
337            vec!["llama3".to_string()],
338            Duration::from_secs(5),
339        );
340
341        // Ollama uses "models" array with "name" field
342        let body = r#"{"models":[{"name":"llama3:latest"},{"name":"codellama:7b"}]}"#;
343        let models = check.parse_models_response(body).unwrap();
344
345        assert_eq!(models.len(), 2);
346        assert!(models.contains(&"llama3:latest".to_string()));
347    }
348
349    #[test]
350    fn test_verify_models_exact_match() {
351        let check = InferenceHealthCheck::new(
352            "/v1/models".to_string(),
353            vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()],
354            Duration::from_secs(5),
355        );
356
357        let available = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
358        assert!(check.verify_models(&available).is_ok());
359    }
360
361    #[test]
362    fn test_verify_models_prefix_match() {
363        let check = InferenceHealthCheck::new(
364            "/v1/models".to_string(),
365            vec!["gpt-4".to_string()],
366            Duration::from_secs(5),
367        );
368
369        // Should match "gpt-4-turbo" when looking for "gpt-4"
370        let available = vec!["gpt-4-turbo".to_string(), "gpt-3.5-turbo".to_string()];
371        assert!(check.verify_models(&available).is_ok());
372    }
373
374    #[test]
375    fn test_verify_models_missing() {
376        let check = InferenceHealthCheck::new(
377            "/v1/models".to_string(),
378            vec!["gpt-4".to_string(), "claude-3".to_string()],
379            Duration::from_secs(5),
380        );
381
382        let available = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
383        let result = check.verify_models(&available);
384
385        assert!(result.is_err());
386        assert!(result.unwrap_err().contains("claude-3"));
387    }
388
389    #[test]
390    fn test_parse_status_code() {
391        let check = InferenceHealthCheck::new(
392            "/v1/models".to_string(),
393            vec![],
394            Duration::from_secs(5),
395        );
396
397        assert_eq!(
398            check.parse_status_code("HTTP/1.1 200 OK\r\n"),
399            Ok(200)
400        );
401        assert_eq!(
402            check.parse_status_code("HTTP/1.1 404 Not Found\r\n"),
403            Ok(404)
404        );
405    }
406
407    #[test]
408    fn test_extract_body() {
409        let check = InferenceHealthCheck::new(
410            "/v1/models".to_string(),
411            vec![],
412            Duration::from_secs(5),
413        );
414
415        let response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{\"data\":[]}";
416        let body = check.extract_body(response).unwrap();
417        assert_eq!(body, "{\"data\":[]}");
418    }
419}