1use async_trait::async_trait;
25use pingora_core::{Error, ErrorType::CustomCode, Result};
26use pingora_load_balancing::health_check::HealthCheck as PingoraHealthCheck;
27use pingora_load_balancing::Backend;
28use serde::Deserialize;
29use std::time::Duration;
30use tokio::io::{AsyncReadExt, AsyncWriteExt};
31use tokio::net::TcpStream;
32use tracing::{debug, trace, warn};
33
34pub struct InferenceHealthCheck {
39 endpoint: String,
41 expected_models: Vec<String>,
43 timeout: Duration,
45 pub consecutive_success: usize,
47 pub consecutive_failure: usize,
49}
50
51#[derive(Debug, Deserialize)]
53struct ModelsResponse {
54 data: Vec<ModelInfo>,
55}
56
57#[derive(Debug, Deserialize)]
59struct ModelInfo {
60 id: String,
61 #[serde(default)]
62 object: String,
63}
64
65impl InferenceHealthCheck {
66 pub fn new(endpoint: String, expected_models: Vec<String>, timeout: Duration) -> Self {
74 Self {
75 endpoint,
76 expected_models,
77 timeout,
78 consecutive_success: 1,
79 consecutive_failure: 1,
80 }
81 }
82
83 async fn check_backend(&self, addr: &str) -> Result<(), String> {
85 let socket_addr: std::net::SocketAddr = addr
87 .parse()
88 .map_err(|e| format!("Invalid address '{}': {}", addr, e))?;
89
90 let stream = tokio::time::timeout(self.timeout, TcpStream::connect(socket_addr))
92 .await
93 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
94 .map_err(|e| format!("Connection failed: {}", e))?;
95
96 let request = format!(
98 "GET {} HTTP/1.1\r\n\
99 Host: {}\r\n\
100 User-Agent: Sentinel-HealthCheck/1.0\r\n\
101 Accept: application/json\r\n\
102 Connection: close\r\n\r\n",
103 self.endpoint, addr
104 );
105
106 let mut stream = stream;
108 stream
109 .write_all(request.as_bytes())
110 .await
111 .map_err(|e| format!("Failed to send request: {}", e))?;
112
113 let mut response = vec![0u8; 65536]; let n = tokio::time::timeout(self.timeout, stream.read(&mut response))
116 .await
117 .map_err(|_| "Response timeout".to_string())?
118 .map_err(|e| format!("Failed to read response: {}", e))?;
119
120 if n == 0 {
121 return Err("Empty response".to_string());
122 }
123
124 let response_str = String::from_utf8_lossy(&response[..n]);
125
126 let status_code = self.parse_status_code(&response_str)?;
128 if status_code != 200 {
129 return Err(format!("HTTP {} (expected 200)", status_code));
130 }
131
132 if self.expected_models.is_empty() {
134 trace!(
135 addr = %addr,
136 endpoint = %self.endpoint,
137 "Inference health check passed (no model verification)"
138 );
139 return Ok(());
140 }
141
142 let body = self.extract_body(&response_str)?;
144
145 let models = self.parse_models_response(body)?;
147
148 self.verify_models(&models)?;
150
151 trace!(
152 addr = %addr,
153 endpoint = %self.endpoint,
154 model_count = models.len(),
155 expected_models = ?self.expected_models,
156 "Inference health check passed"
157 );
158
159 Ok(())
160 }
161
162 fn parse_status_code(&self, response: &str) -> Result<u16, String> {
164 response
165 .lines()
166 .next()
167 .and_then(|line| line.split_whitespace().nth(1))
168 .and_then(|code| code.parse().ok())
169 .ok_or_else(|| "Failed to parse HTTP status".to_string())
170 }
171
172 fn extract_body<'a>(&self, response: &'a str) -> Result<&'a str, String> {
174 response
175 .find("\r\n\r\n")
176 .map(|pos| &response[pos + 4..])
177 .ok_or_else(|| "Could not find response body".to_string())
178 }
179
180 fn parse_models_response(&self, body: &str) -> Result<Vec<String>, String> {
182 let json_body = if body.starts_with(|c: char| c.is_ascii_hexdigit()) {
184 body.lines()
186 .skip(1)
187 .take_while(|line| !line.is_empty() && *line != "0")
188 .collect::<Vec<_>>()
189 .join("\n")
190 } else {
191 body.to_string()
192 };
193
194 if let Ok(response) = serde_json::from_str::<ModelsResponse>(&json_body) {
196 return Ok(response.data.into_iter().map(|m| m.id).collect());
197 }
198
199 if let Ok(models) = serde_json::from_str::<Vec<ModelInfo>>(&json_body) {
201 return Ok(models.into_iter().map(|m| m.id).collect());
202 }
203
204 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&json_body) {
206 if let Some(data) = json.get("data").and_then(|d| d.as_array()) {
207 let models: Vec<String> = data
208 .iter()
209 .filter_map(|m| m.get("id").and_then(|id| id.as_str()))
210 .map(String::from)
211 .collect();
212 if !models.is_empty() {
213 return Ok(models);
214 }
215 }
216
217 if let Some(models_arr) = json.get("models").and_then(|m| m.as_array()) {
219 let models: Vec<String> = models_arr
220 .iter()
221 .filter_map(|m| {
222 m.get("id")
223 .or_else(|| m.get("name"))
224 .and_then(|id| id.as_str())
225 })
226 .map(String::from)
227 .collect();
228 if !models.is_empty() {
229 return Ok(models);
230 }
231 }
232 }
233
234 Err(format!(
235 "Failed to parse models response. Body preview: {}",
236 &json_body[..json_body.len().min(200)]
237 ))
238 }
239
240 fn verify_models(&self, available_models: &[String]) -> Result<(), String> {
242 let mut missing = Vec::new();
243
244 for expected in &self.expected_models {
245 let found = available_models.iter().any(|m| {
247 m == expected || m.starts_with(expected) || expected.starts_with(m)
248 });
249
250 if !found {
251 missing.push(expected.as_str());
252 }
253 }
254
255 if missing.is_empty() {
256 Ok(())
257 } else {
258 Err(format!(
259 "Missing models: {}. Available: {:?}",
260 missing.join(", "),
261 available_models
262 ))
263 }
264 }
265}
266
267#[async_trait]
268impl PingoraHealthCheck for InferenceHealthCheck {
269 async fn check(&self, backend: &Backend) -> Result<()> {
273 let addr = backend.addr.to_string();
274
275 match self.check_backend(&addr).await {
276 Ok(()) => {
277 trace!(
278 addr = %addr,
279 endpoint = %self.endpoint,
280 expected_models = ?self.expected_models,
281 "Inference health check passed"
282 );
283 Ok(())
284 }
285 Err(error) => {
286 debug!(
287 addr = %addr,
288 endpoint = %self.endpoint,
289 error = %error,
290 "Inference health check failed"
291 );
292 Err(Error::explain(
293 CustomCode("inference health check", 1),
294 error,
295 ))
296 }
297 }
298 }
299
300 fn health_threshold(&self, success: bool) -> usize {
305 if success {
306 self.consecutive_success
307 } else {
308 self.consecutive_failure
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_parse_openai_models_response() {
319 let check = InferenceHealthCheck::new(
320 "/v1/models".to_string(),
321 vec!["gpt-4".to_string()],
322 Duration::from_secs(5),
323 );
324
325 let body = r#"{"object":"list","data":[{"id":"gpt-4","object":"model"},{"id":"gpt-3.5-turbo","object":"model"}]}"#;
326 let models = check.parse_models_response(body).unwrap();
327
328 assert_eq!(models.len(), 2);
329 assert!(models.contains(&"gpt-4".to_string()));
330 assert!(models.contains(&"gpt-3.5-turbo".to_string()));
331 }
332
333 #[test]
334 fn test_parse_ollama_models_response() {
335 let check = InferenceHealthCheck::new(
336 "/api/tags".to_string(),
337 vec!["llama3".to_string()],
338 Duration::from_secs(5),
339 );
340
341 let body = r#"{"models":[{"name":"llama3:latest"},{"name":"codellama:7b"}]}"#;
343 let models = check.parse_models_response(body).unwrap();
344
345 assert_eq!(models.len(), 2);
346 assert!(models.contains(&"llama3:latest".to_string()));
347 }
348
349 #[test]
350 fn test_verify_models_exact_match() {
351 let check = InferenceHealthCheck::new(
352 "/v1/models".to_string(),
353 vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()],
354 Duration::from_secs(5),
355 );
356
357 let available = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
358 assert!(check.verify_models(&available).is_ok());
359 }
360
361 #[test]
362 fn test_verify_models_prefix_match() {
363 let check = InferenceHealthCheck::new(
364 "/v1/models".to_string(),
365 vec!["gpt-4".to_string()],
366 Duration::from_secs(5),
367 );
368
369 let available = vec!["gpt-4-turbo".to_string(), "gpt-3.5-turbo".to_string()];
371 assert!(check.verify_models(&available).is_ok());
372 }
373
374 #[test]
375 fn test_verify_models_missing() {
376 let check = InferenceHealthCheck::new(
377 "/v1/models".to_string(),
378 vec!["gpt-4".to_string(), "claude-3".to_string()],
379 Duration::from_secs(5),
380 );
381
382 let available = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
383 let result = check.verify_models(&available);
384
385 assert!(result.is_err());
386 assert!(result.unwrap_err().contains("claude-3"));
387 }
388
389 #[test]
390 fn test_parse_status_code() {
391 let check = InferenceHealthCheck::new(
392 "/v1/models".to_string(),
393 vec![],
394 Duration::from_secs(5),
395 );
396
397 assert_eq!(
398 check.parse_status_code("HTTP/1.1 200 OK\r\n"),
399 Ok(200)
400 );
401 assert_eq!(
402 check.parse_status_code("HTTP/1.1 404 Not Found\r\n"),
403 Ok(404)
404 );
405 }
406
407 #[test]
408 fn test_extract_body() {
409 let check = InferenceHealthCheck::new(
410 "/v1/models".to_string(),
411 vec![],
412 Duration::from_secs(5),
413 );
414
415 let response = "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{\"data\":[]}";
416 let body = check.extract_body(response).unwrap();
417 assert_eq!(body, "{\"data\":[]}");
418 }
419}