Skip to main content

tt_shared/
capability_check.rs

1//! Capability and context-window guard for the routing / failover path.
2//!
3//! [`RequiredCapabilities`] is derived from a [`ChatCompletionRequest`] and
4//! checked against a candidate model's [`ModelInfo`] before a route rewrite or
5//! failover dispatch is committed.  The check is intentionally permissive:
6//!
7//! - When `ModelInfo` is **unknown** for a candidate (not in the registry
8//!   catalog) we allow it through — we only skip when we *positively know* a
9//!   capability is missing.
10//! - A capability that the request needs but the model info does **not** list
11//!   causes the candidate to be skipped (the caller emits a tracing event and
12//!   tries the next candidate or falls back to the original model).
13//!
14//! # Token counting
15//!
16//! [`estimate_input_tokens`] concatenates all message text and delegates to
17//! [`tt_tokenize::estimate_tokens`], keyed on `provider_id` so tiktoken is
18//! used for OpenAI/Anthropic and the char/4 heuristic is used elsewhere.
19//! Image/audio bytes are not measured — the guard is a best-effort floor, not
20//! an exact window-packing count.
21
22use crate::{
23    messages::{ContentPart, Message, MessageContent},
24    pricing::{Capability, ModelInfo},
25    ChatCompletionRequest,
26};
27
28/// The set of capabilities a [`ChatCompletionRequest`] requires.
29#[derive(Debug, Clone, Default, PartialEq, Eq)]
30pub struct RequiredCapabilities {
31    /// At least one message contains an image_url or input_audio content part.
32    pub vision: bool,
33    /// The request has non-empty `tools`, or any assistant message contains
34    /// `tool_calls`.
35    pub tools: bool,
36    /// `response_format.type` is `"json_object"` or `"json_schema"`.
37    pub json_mode: bool,
38}
39
40impl RequiredCapabilities {
41    /// Derive the required capabilities from a chat completion request.
42    pub fn from_request(req: &ChatCompletionRequest) -> Self {
43        let mut caps = Self::default();
44
45        // tools / function-calling
46        if !req.tools.is_empty() {
47            caps.tools = true;
48        }
49
50        // response_format → json mode
51        if let Some(rf) = &req.response_format {
52            if rf.r#type == "json_object" || rf.r#type == "json_schema" {
53                caps.json_mode = true;
54            }
55        }
56
57        // scan messages for vision content and tool_calls
58        for msg in &req.messages {
59            match msg {
60                Message::User { content, .. } | Message::System { content } => {
61                    if let MessageContent::Parts(parts) = content {
62                        for part in parts {
63                            match part {
64                                ContentPart::ImageUrl { .. } | ContentPart::InputAudio { .. } => {
65                                    caps.vision = true;
66                                }
67                                ContentPart::Text { .. } => {}
68                            }
69                        }
70                    }
71                }
72                Message::Assistant { tool_calls, .. } => {
73                    if !tool_calls.is_empty() {
74                        caps.tools = true;
75                    }
76                }
77                Message::Tool { .. } => {
78                    // A Tool message in context means the conversation already
79                    // used tool-calling; the next turn may need it too.
80                    caps.tools = true;
81                }
82            }
83        }
84
85        caps
86    }
87
88    /// Returns `true` when all required capabilities are listed in
89    /// `info.capabilities` **and** `max_input_tokens >= estimated_tokens`.
90    ///
91    /// Pass `estimated_tokens = 0` to skip the context-window check.
92    #[must_use]
93    pub fn satisfied_by(&self, info: &ModelInfo, estimated_tokens: u64) -> bool {
94        if self.vision && !info.capabilities.contains(&Capability::Vision) {
95            return false;
96        }
97        if self.tools && !info.capabilities.contains(&Capability::Tools) {
98            return false;
99        }
100        if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
101            return false;
102        }
103        if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
104            return false;
105        }
106        true
107    }
108
109    /// Human-readable list of the reasons a candidate was skipped, for use in
110    /// the `route_skipped_capability` tracing event.
111    pub fn skip_reasons(&self, info: &ModelInfo, estimated_tokens: u64) -> Vec<&'static str> {
112        let mut reasons = Vec::new();
113        if self.vision && !info.capabilities.contains(&Capability::Vision) {
114            reasons.push("vision_not_supported");
115        }
116        if self.tools && !info.capabilities.contains(&Capability::Tools) {
117            reasons.push("tools_not_supported");
118        }
119        if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
120            reasons.push("json_mode_not_supported");
121        }
122        if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
123            reasons.push("context_window_too_small");
124        }
125        reasons
126    }
127}
128
129/// Concatenate all message text parts from a request for token estimation.
130///
131/// Image/audio bytes are excluded — the result is passed to the caller's
132/// tokenizer (e.g. `tt_tokenize::estimate_tokens`) so that `tt-shared` does
133/// not need to depend on `tt-tokenize`.
134pub fn message_text_for_estimation(req: &ChatCompletionRequest) -> String {
135    req.messages
136        .iter()
137        .map(|m| match m {
138            Message::User { content, .. } | Message::System { content } => extract_text(content),
139            Message::Assistant { content, .. } => {
140                content.as_ref().map(extract_text).unwrap_or_default()
141            }
142            Message::Tool { content, .. } => extract_text(content),
143        })
144        .collect()
145}
146
147fn extract_text(content: &MessageContent) -> String {
148    match content {
149        MessageContent::Text(s) => s.clone(),
150        MessageContent::Parts(parts) => parts
151            .iter()
152            .filter_map(|p| match p {
153                ContentPart::Text { text } => Some(text.as_str()),
154                _ => None,
155            })
156            .collect::<Vec<_>>()
157            .join(""),
158    }
159}
160
161/// True when any message carries an image (`ContentPart::ImageUrl`) content part.
162///
163/// Distinct from [`RequiredCapabilities`], which collapses image **and** audio
164/// into a single `vision` flag; routing needs to tell the two modalities apart.
165pub fn request_has_images(req: &ChatCompletionRequest) -> bool {
166    req.messages
167        .iter()
168        .any(|m| content_of(m).is_some_and(has_image_part))
169}
170
171/// True when any message carries an audio (`ContentPart::InputAudio`) content part.
172pub fn request_has_audio(req: &ChatCompletionRequest) -> bool {
173    req.messages
174        .iter()
175        .any(|m| content_of(m).is_some_and(has_audio_part))
176}
177
178/// The content of a message, if it has any (Assistant content is optional).
179fn content_of(m: &Message) -> Option<&MessageContent> {
180    match m {
181        Message::User { content, .. }
182        | Message::System { content }
183        | Message::Tool { content, .. } => Some(content),
184        Message::Assistant { content, .. } => content.as_ref(),
185    }
186}
187
188fn has_image_part(c: &MessageContent) -> bool {
189    matches!(c, MessageContent::Parts(parts)
190        if parts.iter().any(|p| matches!(p, ContentPart::ImageUrl { .. })))
191}
192
193fn has_audio_part(c: &MessageContent) -> bool {
194    matches!(c, MessageContent::Parts(parts)
195        if parts.iter().any(|p| matches!(p, ContentPart::InputAudio { .. })))
196}
197
198/// Concatenated text of the **user + system** messages — the caller-controlled
199/// input, used for content/topic routing. Assistant/tool turns are excluded so a
200/// model's own output can't spuriously trigger a topic route.
201pub fn request_input_text(req: &ChatCompletionRequest) -> String {
202    req.messages
203        .iter()
204        .filter_map(|m| match m {
205            Message::User { content, .. } | Message::System { content } => {
206                Some(extract_text(content))
207            }
208            _ => None,
209        })
210        .collect::<Vec<_>>()
211        .join("\n")
212}
213
214#[cfg(test)]
215mod tests {
216    use std::collections::HashMap;
217
218    use super::*;
219    use crate::{
220        messages::{
221            ImageUrl, InputAudio, ResponseFormat, Tool, ToolCall, ToolCallFunction, ToolFunction,
222        },
223        pricing::Capability,
224        ModelInfo,
225    };
226
227    fn text_model() -> ModelInfo {
228        ModelInfo {
229            id: "text-only".into(),
230            provider: "mock".into(),
231            capabilities: vec![Capability::Text],
232            max_input_tokens: 4096,
233            max_output_tokens: 1024,
234        }
235    }
236
237    fn vision_model() -> ModelInfo {
238        ModelInfo {
239            id: "vision-model".into(),
240            provider: "mock".into(),
241            capabilities: vec![Capability::Text, Capability::Vision, Capability::Tools],
242            max_input_tokens: 128_000,
243            max_output_tokens: 4096,
244        }
245    }
246
247    fn small_model() -> ModelInfo {
248        ModelInfo {
249            id: "small-ctx".into(),
250            provider: "mock".into(),
251            capabilities: vec![Capability::Text],
252            max_input_tokens: 100,
253            max_output_tokens: 100,
254        }
255    }
256
257    fn base_req() -> ChatCompletionRequest {
258        ChatCompletionRequest {
259            model: "gpt-4o".into(),
260            messages: vec![],
261            temperature: None,
262            top_p: None,
263            max_tokens: None,
264            stream: false,
265            tools: vec![],
266            tool_choice: None,
267            response_format: None,
268            stop: vec![],
269            presence_penalty: None,
270            frequency_penalty: None,
271            n: None,
272            seed: None,
273            user: None,
274            tt_extras: HashMap::new(),
275        }
276    }
277
278    #[test]
279    fn plain_text_request_has_no_required_caps() {
280        let req = base_req();
281        let caps = RequiredCapabilities::from_request(&req);
282        assert!(!caps.vision);
283        assert!(!caps.tools);
284        assert!(!caps.json_mode);
285    }
286
287    #[test]
288    fn image_url_part_sets_vision() {
289        let mut req = base_req();
290        req.messages = vec![Message::User {
291            content: MessageContent::Parts(vec![
292                ContentPart::Text {
293                    text: "describe this".into(),
294                },
295                ContentPart::ImageUrl {
296                    image_url: ImageUrl {
297                        url: "data:image/png;base64,abc".into(),
298                        detail: None,
299                    },
300                },
301            ]),
302            name: None,
303        }];
304        let caps = RequiredCapabilities::from_request(&req);
305        assert!(caps.vision);
306        assert!(!caps.tools);
307    }
308
309    #[test]
310    fn tools_field_sets_tools_cap() {
311        let mut req = base_req();
312        req.tools = vec![Tool {
313            r#type: "function".into(),
314            function: ToolFunction {
315                name: "get_weather".into(),
316                description: None,
317                parameters: serde_json::json!({}),
318            },
319        }];
320        let caps = RequiredCapabilities::from_request(&req);
321        assert!(caps.tools);
322    }
323
324    #[test]
325    fn assistant_tool_calls_in_history_sets_tools_cap() {
326        let mut req = base_req();
327        req.messages = vec![Message::Assistant {
328            content: None,
329            tool_calls: vec![ToolCall {
330                id: "call_1".into(),
331                r#type: "function".into(),
332                function: ToolCallFunction {
333                    name: "get_weather".into(),
334                    arguments: "{}".into(),
335                },
336            }],
337            name: None,
338        }];
339        let caps = RequiredCapabilities::from_request(&req);
340        assert!(caps.tools);
341    }
342
343    #[test]
344    fn json_object_response_format_sets_json_mode() {
345        let mut req = base_req();
346        req.response_format = Some(ResponseFormat {
347            r#type: "json_object".into(),
348            json_schema: None,
349        });
350        let caps = RequiredCapabilities::from_request(&req);
351        assert!(caps.json_mode);
352    }
353
354    #[test]
355    fn vision_request_not_satisfied_by_text_model() {
356        let mut req = base_req();
357        req.messages = vec![Message::User {
358            content: MessageContent::Parts(vec![ContentPart::ImageUrl {
359                image_url: ImageUrl {
360                    url: "data:image/png;base64,abc".into(),
361                    detail: None,
362                },
363            }]),
364            name: None,
365        }];
366        let caps = RequiredCapabilities::from_request(&req);
367        assert!(!caps.satisfied_by(&text_model(), 0));
368    }
369
370    #[test]
371    fn vision_request_satisfied_by_vision_model() {
372        let mut req = base_req();
373        req.messages = vec![Message::User {
374            content: MessageContent::Parts(vec![ContentPart::ImageUrl {
375                image_url: ImageUrl {
376                    url: "data:image/png;base64,abc".into(),
377                    detail: None,
378                },
379            }]),
380            name: None,
381        }];
382        let caps = RequiredCapabilities::from_request(&req);
383        assert!(caps.satisfied_by(&vision_model(), 0));
384    }
385
386    #[test]
387    fn exceeds_context_window_not_satisfied() {
388        let caps = RequiredCapabilities::default();
389        assert!(!caps.satisfied_by(&small_model(), 200));
390    }
391
392    #[test]
393    fn within_context_window_satisfied() {
394        let caps = RequiredCapabilities::default();
395        assert!(caps.satisfied_by(&small_model(), 50));
396    }
397
398    #[test]
399    fn zero_estimated_tokens_skips_window_check() {
400        let caps = RequiredCapabilities::default();
401        assert!(caps.satisfied_by(&small_model(), 0));
402    }
403
404    #[test]
405    fn skip_reasons_lists_all_failures() {
406        let caps = RequiredCapabilities {
407            vision: true,
408            tools: true,
409            ..Default::default()
410        };
411        let reasons = caps.skip_reasons(&text_model(), 9999);
412        assert!(reasons.contains(&"vision_not_supported"));
413        assert!(reasons.contains(&"tools_not_supported"));
414        assert!(reasons.contains(&"context_window_too_small"));
415    }
416
417    #[test]
418    fn request_has_images_detects_image_part() {
419        let mut req = base_req();
420        req.messages = vec![Message::User {
421            content: MessageContent::Parts(vec![
422                ContentPart::Text {
423                    text: "look".into(),
424                },
425                ContentPart::ImageUrl {
426                    image_url: ImageUrl {
427                        url: "data:image/png;base64,abc".into(),
428                        detail: None,
429                    },
430                },
431            ]),
432            name: None,
433        }];
434        assert!(request_has_images(&req));
435        assert!(!request_has_audio(&req));
436    }
437
438    #[test]
439    fn request_has_audio_detects_audio_part() {
440        let mut req = base_req();
441        req.messages = vec![Message::User {
442            content: MessageContent::Parts(vec![ContentPart::InputAudio {
443                input_audio: InputAudio {
444                    data: "abc".into(),
445                    format: "wav".into(),
446                },
447            }]),
448            name: None,
449        }];
450        assert!(request_has_audio(&req));
451        assert!(!request_has_images(&req));
452    }
453
454    #[test]
455    fn plain_text_request_has_no_modality() {
456        let req = base_req();
457        assert!(!request_has_images(&req));
458        assert!(!request_has_audio(&req));
459    }
460
461    #[test]
462    fn request_input_text_user_and_system_only() {
463        let mut req = base_req();
464        req.messages = vec![
465            Message::System {
466                content: MessageContent::Text("sys ctx".into()),
467            },
468            Message::User {
469                content: MessageContent::Text("Confidential matter".into()),
470                name: None,
471            },
472            Message::Assistant {
473                content: Some(MessageContent::Text("legal advice".into())),
474                tool_calls: vec![],
475                name: None,
476            },
477        ];
478        let t = request_input_text(&req);
479        assert!(t.contains("sys ctx"));
480        assert!(t.contains("Confidential matter"));
481        assert!(
482            !t.contains("legal advice"),
483            "assistant output must be excluded"
484        );
485    }
486}