1use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
21pub struct InferenceReadinessConfig {
22 #[serde(default, skip_serializing_if = "Option::is_none")]
24 pub inference_probe: Option<InferenceProbeConfig>,
25
26 #[serde(default, skip_serializing_if = "Option::is_none")]
28 pub model_status: Option<ModelStatusConfig>,
29
30 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub queue_depth: Option<QueueDepthConfig>,
33
34 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub warmth_detection: Option<WarmthDetectionConfig>,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
44pub struct InferenceProbeConfig {
45 #[serde(default = "default_probe_endpoint")]
47 pub endpoint: String,
48
49 pub model: String,
51
52 #[serde(default = "default_probe_prompt")]
54 pub prompt: String,
55
56 #[serde(default = "default_probe_max_tokens")]
58 pub max_tokens: u32,
59
60 #[serde(default = "default_probe_timeout")]
62 pub timeout_secs: u64,
63
64 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub max_latency_ms: Option<u64>,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
74pub struct ModelStatusConfig {
75 #[serde(default = "default_status_endpoint")]
77 pub endpoint_pattern: String,
78
79 pub models: Vec<String>,
81
82 #[serde(default = "default_expected_status")]
84 pub expected_status: String,
85
86 #[serde(default = "default_status_field")]
88 pub status_field: String,
89
90 #[serde(default = "default_status_timeout")]
92 pub timeout_secs: u64,
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
100pub struct QueueDepthConfig {
101 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub header: Option<String>,
104
105 #[serde(default, skip_serializing_if = "Option::is_none")]
107 pub body_field: Option<String>,
108
109 #[serde(default, skip_serializing_if = "Option::is_none")]
111 pub endpoint: Option<String>,
112
113 pub degraded_threshold: u64,
115
116 pub unhealthy_threshold: u64,
118
119 #[serde(default = "default_queue_timeout")]
121 pub timeout_secs: u64,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct WarmthDetectionConfig {
131 #[serde(default = "default_warmth_sample_size")]
133 pub sample_size: u32,
134
135 #[serde(default = "default_cold_threshold_multiplier")]
137 pub cold_threshold_multiplier: f64,
138
139 #[serde(default = "default_idle_cold_timeout")]
141 pub idle_cold_timeout_secs: u64,
142
143 #[serde(default)]
145 pub cold_action: ColdModelAction,
146}
147
148impl PartialEq for WarmthDetectionConfig {
149 fn eq(&self, other: &Self) -> bool {
150 self.sample_size == other.sample_size
151 && self.cold_threshold_multiplier.to_bits() == other.cold_threshold_multiplier.to_bits()
152 && self.idle_cold_timeout_secs == other.idle_cold_timeout_secs
153 && self.cold_action == other.cold_action
154 }
155}
156
157impl Eq for WarmthDetectionConfig {}
158
159#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
161#[serde(rename_all = "snake_case")]
162pub enum ColdModelAction {
163 #[default]
165 LogOnly,
166 MarkDegraded,
168 MarkUnhealthy,
170}
171
172fn default_probe_endpoint() -> String {
175 "/v1/completions".to_string()
176}
177
178fn default_probe_prompt() -> String {
179 ".".to_string()
180}
181
182fn default_probe_max_tokens() -> u32 {
183 1
184}
185
186fn default_probe_timeout() -> u64 {
187 30
188}
189
190fn default_status_endpoint() -> String {
191 "/v1/models/{model}/status".to_string()
192}
193
194fn default_expected_status() -> String {
195 "ready".to_string()
196}
197
198fn default_status_field() -> String {
199 "status".to_string()
200}
201
202fn default_status_timeout() -> u64 {
203 5
204}
205
206fn default_queue_timeout() -> u64 {
207 5
208}
209
210fn default_warmth_sample_size() -> u32 {
211 10
212}
213
214fn default_cold_threshold_multiplier() -> f64 {
215 3.0
216}
217
218fn default_idle_cold_timeout() -> u64 {
219 300 }
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_inference_readiness_config_defaults() {
228 let config: InferenceReadinessConfig = serde_json::from_str("{}").unwrap();
229 assert!(config.inference_probe.is_none());
230 assert!(config.model_status.is_none());
231 assert!(config.queue_depth.is_none());
232 assert!(config.warmth_detection.is_none());
233 }
234
235 #[test]
236 fn test_inference_probe_config_defaults() {
237 let json = r#"{"model": "gpt-4"}"#;
238 let config: InferenceProbeConfig = serde_json::from_str(json).unwrap();
239 assert_eq!(config.endpoint, "/v1/completions");
240 assert_eq!(config.model, "gpt-4");
241 assert_eq!(config.prompt, ".");
242 assert_eq!(config.max_tokens, 1);
243 assert_eq!(config.timeout_secs, 30);
244 assert!(config.max_latency_ms.is_none());
245 }
246
247 #[test]
248 fn test_model_status_config_defaults() {
249 let json = r#"{"models": ["gpt-4", "gpt-3.5-turbo"]}"#;
250 let config: ModelStatusConfig = serde_json::from_str(json).unwrap();
251 assert_eq!(config.endpoint_pattern, "/v1/models/{model}/status");
252 assert_eq!(config.models, vec!["gpt-4", "gpt-3.5-turbo"]);
253 assert_eq!(config.expected_status, "ready");
254 assert_eq!(config.status_field, "status");
255 assert_eq!(config.timeout_secs, 5);
256 }
257
258 #[test]
259 fn test_queue_depth_config() {
260 let json = r#"{
261 "header": "x-queue-depth",
262 "degraded_threshold": 50,
263 "unhealthy_threshold": 200
264 }"#;
265 let config: QueueDepthConfig = serde_json::from_str(json).unwrap();
266 assert_eq!(config.header, Some("x-queue-depth".to_string()));
267 assert!(config.body_field.is_none());
268 assert_eq!(config.degraded_threshold, 50);
269 assert_eq!(config.unhealthy_threshold, 200);
270 assert_eq!(config.timeout_secs, 5);
271 }
272
273 #[test]
274 fn test_warmth_detection_defaults() {
275 let json = "{}";
276 let config: WarmthDetectionConfig = serde_json::from_str(json).unwrap();
277 assert_eq!(config.sample_size, 10);
278 assert!((config.cold_threshold_multiplier - 3.0).abs() < f64::EPSILON);
279 assert_eq!(config.idle_cold_timeout_secs, 300);
280 assert_eq!(config.cold_action, ColdModelAction::LogOnly);
281 }
282
283 #[test]
284 fn test_cold_model_action_serialization() {
285 assert_eq!(
286 serde_json::to_string(&ColdModelAction::LogOnly).unwrap(),
287 r#""log_only""#
288 );
289 assert_eq!(
290 serde_json::to_string(&ColdModelAction::MarkDegraded).unwrap(),
291 r#""mark_degraded""#
292 );
293 assert_eq!(
294 serde_json::to_string(&ColdModelAction::MarkUnhealthy).unwrap(),
295 r#""mark_unhealthy""#
296 );
297 }
298
299 #[test]
300 fn test_full_config_roundtrip() {
301 let config = InferenceReadinessConfig {
302 inference_probe: Some(InferenceProbeConfig {
303 endpoint: "/v1/completions".to_string(),
304 model: "gpt-4".to_string(),
305 prompt: ".".to_string(),
306 max_tokens: 1,
307 timeout_secs: 30,
308 max_latency_ms: Some(5000),
309 }),
310 model_status: None,
311 queue_depth: Some(QueueDepthConfig {
312 header: Some("x-queue-depth".to_string()),
313 body_field: None,
314 endpoint: None,
315 degraded_threshold: 50,
316 unhealthy_threshold: 200,
317 timeout_secs: 5,
318 }),
319 warmth_detection: Some(WarmthDetectionConfig {
320 sample_size: 10,
321 cold_threshold_multiplier: 3.0,
322 idle_cold_timeout_secs: 300,
323 cold_action: ColdModelAction::MarkDegraded,
324 }),
325 };
326
327 let json = serde_json::to_string(&config).unwrap();
328 let parsed: InferenceReadinessConfig = serde_json::from_str(&json).unwrap();
329 assert_eq!(config, parsed);
330 }
331}