Skip to main content

tt_shared/
capability_check.rs

1//! Capability and context-window guard for the routing / failover path.
2//!
3//! [`RequiredCapabilities`] is derived from a [`ChatCompletionRequest`] and
4//! checked against a candidate model's [`ModelInfo`] before a route rewrite or
5//! failover dispatch is committed.  The check is intentionally permissive:
6//!
7//! - When `ModelInfo` is **unknown** for a candidate (not in the registry
8//!   catalog) we allow it through — we only skip when we *positively know* a
9//!   capability is missing.
10//! - A capability that the request needs but the model info does **not** list
11//!   causes the candidate to be skipped (the caller emits a tracing event and
12//!   tries the next candidate or falls back to the original model).
13//!
14//! # Token counting
15//!
16//! [`estimate_input_tokens`] concatenates all message text and delegates to
17//! [`tt_tokenize::estimate_tokens`], keyed on `provider_id` so tiktoken is
18//! used for OpenAI/Anthropic and the char/4 heuristic is used elsewhere.
19//! Image/audio bytes are not measured — the guard is a best-effort floor, not
20//! an exact window-packing count.
21
22use crate::{
23    messages::{ContentPart, Message, MessageContent},
24    pricing::{Capability, ModelInfo},
25    ChatCompletionRequest,
26};
27
28/// The set of capabilities a [`ChatCompletionRequest`] requires.
29#[derive(Debug, Clone, Default, PartialEq, Eq)]
30pub struct RequiredCapabilities {
31    /// At least one message contains an image_url or input_audio content part.
32    pub vision: bool,
33    /// The request has non-empty `tools`, or any assistant message contains
34    /// `tool_calls`.
35    pub tools: bool,
36    /// `response_format.type` is `"json_object"` or `"json_schema"`.
37    pub json_mode: bool,
38}
39
40impl RequiredCapabilities {
41    /// Derive the required capabilities from a chat completion request.
42    pub fn from_request(req: &ChatCompletionRequest) -> Self {
43        let mut caps = Self::default();
44
45        // tools / function-calling
46        if !req.tools.is_empty() {
47            caps.tools = true;
48        }
49
50        // response_format → json mode
51        if let Some(rf) = &req.response_format {
52            if rf.r#type == "json_object" || rf.r#type == "json_schema" {
53                caps.json_mode = true;
54            }
55        }
56
57        // scan messages for vision content and tool_calls
58        for msg in &req.messages {
59            match msg {
60                Message::User { content, .. } | Message::System { content } => {
61                    if let MessageContent::Parts(parts) = content {
62                        for part in parts {
63                            match part {
64                                ContentPart::ImageUrl { .. } | ContentPart::InputAudio { .. } => {
65                                    caps.vision = true;
66                                }
67                                ContentPart::Text { .. } => {}
68                            }
69                        }
70                    }
71                }
72                Message::Assistant { tool_calls, .. } => {
73                    if !tool_calls.is_empty() {
74                        caps.tools = true;
75                    }
76                }
77                Message::Tool { .. } => {
78                    // A Tool message in context means the conversation already
79                    // used tool-calling; the next turn may need it too.
80                    caps.tools = true;
81                }
82            }
83        }
84
85        caps
86    }
87
88    /// Returns `true` when all required capabilities are listed in
89    /// `info.capabilities` **and** `max_input_tokens >= estimated_tokens`.
90    ///
91    /// Pass `estimated_tokens = 0` to skip the context-window check.
92    #[must_use]
93    pub fn satisfied_by(&self, info: &ModelInfo, estimated_tokens: u64) -> bool {
94        if self.vision && !info.capabilities.contains(&Capability::Vision) {
95            return false;
96        }
97        if self.tools && !info.capabilities.contains(&Capability::Tools) {
98            return false;
99        }
100        if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
101            return false;
102        }
103        if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
104            return false;
105        }
106        true
107    }
108
109    /// Human-readable list of the reasons a candidate was skipped, for use in
110    /// the `route_skipped_capability` tracing event.
111    pub fn skip_reasons(&self, info: &ModelInfo, estimated_tokens: u64) -> Vec<&'static str> {
112        let mut reasons = Vec::new();
113        if self.vision && !info.capabilities.contains(&Capability::Vision) {
114            reasons.push("vision_not_supported");
115        }
116        if self.tools && !info.capabilities.contains(&Capability::Tools) {
117            reasons.push("tools_not_supported");
118        }
119        if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
120            reasons.push("json_mode_not_supported");
121        }
122        if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
123            reasons.push("context_window_too_small");
124        }
125        reasons
126    }
127}
128
129/// Concatenate all message text parts from a request for token estimation.
130///
131/// Image/audio bytes are excluded — the result is passed to the caller's
132/// tokenizer (e.g. `tt_tokenize::estimate_tokens`) so that `tt-shared` does
133/// not need to depend on `tt-tokenize`.
134pub fn message_text_for_estimation(req: &ChatCompletionRequest) -> String {
135    req.messages
136        .iter()
137        .map(|m| match m {
138            Message::User { content, .. } | Message::System { content } => extract_text(content),
139            Message::Assistant { content, .. } => {
140                content.as_ref().map(extract_text).unwrap_or_default()
141            }
142            Message::Tool { content, .. } => extract_text(content),
143        })
144        .collect()
145}
146
147fn extract_text(content: &MessageContent) -> String {
148    match content {
149        MessageContent::Text(s) => s.clone(),
150        MessageContent::Parts(parts) => parts
151            .iter()
152            .filter_map(|p| match p {
153                ContentPart::Text { text } => Some(text.as_str()),
154                _ => None,
155            })
156            .collect::<Vec<_>>()
157            .join(""),
158    }
159}
160
161/// True when any message carries an image (`ContentPart::ImageUrl`) content part.
162///
163/// Distinct from [`RequiredCapabilities`], which collapses image **and** audio
164/// into a single `vision` flag; routing needs to tell the two modalities apart.
165pub fn request_has_images(req: &ChatCompletionRequest) -> bool {
166    req.messages
167        .iter()
168        .any(|m| content_of(m).is_some_and(has_image_part))
169}
170
171/// True when any message carries an audio (`ContentPart::InputAudio`) content part.
172pub fn request_has_audio(req: &ChatCompletionRequest) -> bool {
173    req.messages
174        .iter()
175        .any(|m| content_of(m).is_some_and(has_audio_part))
176}
177
178/// The content of a message, if it has any (Assistant content is optional).
179fn content_of(m: &Message) -> Option<&MessageContent> {
180    match m {
181        Message::User { content, .. }
182        | Message::System { content }
183        | Message::Tool { content, .. } => Some(content),
184        Message::Assistant { content, .. } => content.as_ref(),
185    }
186}
187
188fn has_image_part(c: &MessageContent) -> bool {
189    matches!(c, MessageContent::Parts(parts)
190        if parts.iter().any(|p| matches!(p, ContentPart::ImageUrl { .. })))
191}
192
193fn has_audio_part(c: &MessageContent) -> bool {
194    matches!(c, MessageContent::Parts(parts)
195        if parts.iter().any(|p| matches!(p, ContentPart::InputAudio { .. })))
196}
197
198/// Concatenated text of the **user + system** messages — the caller-controlled
199/// input, used for content/topic routing. Assistant/tool turns are excluded so a
200/// model's own output can't spuriously trigger a topic route.
201pub fn request_input_text(req: &ChatCompletionRequest) -> String {
202    req.messages
203        .iter()
204        .filter_map(|m| match m {
205            Message::User { content, .. } | Message::System { content } => {
206                Some(extract_text(content))
207            }
208            _ => None,
209        })
210        .collect::<Vec<_>>()
211        .join("\n")
212}
213
214#[cfg(test)]
215mod tests {
216    use std::collections::HashMap;
217
218    use super::*;
219    use crate::{
220        messages::{
221            ImageUrl, InputAudio, ResponseFormat, Tool, ToolCall, ToolCallFunction, ToolFunction,
222        },
223        pricing::Capability,
224        ModelInfo,
225    };
226
227    fn text_model() -> ModelInfo {
228        ModelInfo {
229            id: "text-only".into(),
230            provider: "mock".into(),
231            capabilities: vec![Capability::Text],
232            max_input_tokens: 4096,
233            max_output_tokens: 1024,
234        }
235    }
236
237    fn vision_model() -> ModelInfo {
238        ModelInfo {
239            id: "vision-model".into(),
240            provider: "mock".into(),
241            capabilities: vec![Capability::Text, Capability::Vision, Capability::Tools],
242            max_input_tokens: 128_000,
243            max_output_tokens: 4096,
244        }
245    }
246
247    fn small_model() -> ModelInfo {
248        ModelInfo {
249            id: "small-ctx".into(),
250            provider: "mock".into(),
251            capabilities: vec![Capability::Text],
252            max_input_tokens: 100,
253            max_output_tokens: 100,
254        }
255    }
256
257    fn base_req() -> ChatCompletionRequest {
258        ChatCompletionRequest {
259            model: "gpt-4o".into(),
260            messages: vec![],
261            temperature: None,
262            top_p: None,
263            max_tokens: None,
264            stream: false,
265            tools: vec![],
266            tool_choice: None,
267            response_format: None,
268            stop: vec![],
269            presence_penalty: None,
270            frequency_penalty: None,
271            n: None,
272            seed: None,
273            user: None,
274            tt_extras: HashMap::new(),
275            ..Default::default()
276        }
277    }
278
279    #[test]
280    fn plain_text_request_has_no_required_caps() {
281        let req = base_req();
282        let caps = RequiredCapabilities::from_request(&req);
283        assert!(!caps.vision);
284        assert!(!caps.tools);
285        assert!(!caps.json_mode);
286    }
287
288    #[test]
289    fn image_url_part_sets_vision() {
290        let mut req = base_req();
291        req.messages = vec![Message::User {
292            content: MessageContent::Parts(vec![
293                ContentPart::Text {
294                    text: "describe this".into(),
295                },
296                ContentPart::ImageUrl {
297                    image_url: ImageUrl {
298                        url: "data:image/png;base64,abc".into(),
299                        detail: None,
300                    },
301                },
302            ]),
303            name: None,
304        }];
305        let caps = RequiredCapabilities::from_request(&req);
306        assert!(caps.vision);
307        assert!(!caps.tools);
308    }
309
310    #[test]
311    fn tools_field_sets_tools_cap() {
312        let mut req = base_req();
313        req.tools = vec![Tool {
314            r#type: "function".into(),
315            function: ToolFunction {
316                name: "get_weather".into(),
317                description: None,
318                parameters: serde_json::json!({}),
319            },
320        }];
321        let caps = RequiredCapabilities::from_request(&req);
322        assert!(caps.tools);
323    }
324
325    #[test]
326    fn assistant_tool_calls_in_history_sets_tools_cap() {
327        let mut req = base_req();
328        req.messages = vec![Message::Assistant {
329            content: None,
330            tool_calls: vec![ToolCall {
331                id: "call_1".into(),
332                r#type: "function".into(),
333                function: ToolCallFunction {
334                    name: "get_weather".into(),
335                    arguments: "{}".into(),
336                },
337            }],
338            name: None,
339        }];
340        let caps = RequiredCapabilities::from_request(&req);
341        assert!(caps.tools);
342    }
343
344    #[test]
345    fn json_object_response_format_sets_json_mode() {
346        let mut req = base_req();
347        req.response_format = Some(ResponseFormat {
348            r#type: "json_object".into(),
349            json_schema: None,
350        });
351        let caps = RequiredCapabilities::from_request(&req);
352        assert!(caps.json_mode);
353    }
354
355    #[test]
356    fn vision_request_not_satisfied_by_text_model() {
357        let mut req = base_req();
358        req.messages = vec![Message::User {
359            content: MessageContent::Parts(vec![ContentPart::ImageUrl {
360                image_url: ImageUrl {
361                    url: "data:image/png;base64,abc".into(),
362                    detail: None,
363                },
364            }]),
365            name: None,
366        }];
367        let caps = RequiredCapabilities::from_request(&req);
368        assert!(!caps.satisfied_by(&text_model(), 0));
369    }
370
371    #[test]
372    fn vision_request_satisfied_by_vision_model() {
373        let mut req = base_req();
374        req.messages = vec![Message::User {
375            content: MessageContent::Parts(vec![ContentPart::ImageUrl {
376                image_url: ImageUrl {
377                    url: "data:image/png;base64,abc".into(),
378                    detail: None,
379                },
380            }]),
381            name: None,
382        }];
383        let caps = RequiredCapabilities::from_request(&req);
384        assert!(caps.satisfied_by(&vision_model(), 0));
385    }
386
387    #[test]
388    fn exceeds_context_window_not_satisfied() {
389        let caps = RequiredCapabilities::default();
390        assert!(!caps.satisfied_by(&small_model(), 200));
391    }
392
393    #[test]
394    fn within_context_window_satisfied() {
395        let caps = RequiredCapabilities::default();
396        assert!(caps.satisfied_by(&small_model(), 50));
397    }
398
399    #[test]
400    fn zero_estimated_tokens_skips_window_check() {
401        let caps = RequiredCapabilities::default();
402        assert!(caps.satisfied_by(&small_model(), 0));
403    }
404
405    #[test]
406    fn skip_reasons_lists_all_failures() {
407        let caps = RequiredCapabilities {
408            vision: true,
409            tools: true,
410            ..Default::default()
411        };
412        let reasons = caps.skip_reasons(&text_model(), 9999);
413        assert!(reasons.contains(&"vision_not_supported"));
414        assert!(reasons.contains(&"tools_not_supported"));
415        assert!(reasons.contains(&"context_window_too_small"));
416    }
417
418    #[test]
419    fn request_has_images_detects_image_part() {
420        let mut req = base_req();
421        req.messages = vec![Message::User {
422            content: MessageContent::Parts(vec![
423                ContentPart::Text {
424                    text: "look".into(),
425                },
426                ContentPart::ImageUrl {
427                    image_url: ImageUrl {
428                        url: "data:image/png;base64,abc".into(),
429                        detail: None,
430                    },
431                },
432            ]),
433            name: None,
434        }];
435        assert!(request_has_images(&req));
436        assert!(!request_has_audio(&req));
437    }
438
439    #[test]
440    fn request_has_audio_detects_audio_part() {
441        let mut req = base_req();
442        req.messages = vec![Message::User {
443            content: MessageContent::Parts(vec![ContentPart::InputAudio {
444                input_audio: InputAudio {
445                    data: "abc".into(),
446                    format: "wav".into(),
447                },
448            }]),
449            name: None,
450        }];
451        assert!(request_has_audio(&req));
452        assert!(!request_has_images(&req));
453    }
454
455    #[test]
456    fn plain_text_request_has_no_modality() {
457        let req = base_req();
458        assert!(!request_has_images(&req));
459        assert!(!request_has_audio(&req));
460    }
461
462    #[test]
463    fn request_input_text_user_and_system_only() {
464        let mut req = base_req();
465        req.messages = vec![
466            Message::System {
467                content: MessageContent::Text("sys ctx".into()),
468            },
469            Message::User {
470                content: MessageContent::Text("Confidential matter".into()),
471                name: None,
472            },
473            Message::Assistant {
474                content: Some(MessageContent::Text("legal advice".into())),
475                tool_calls: vec![],
476                name: None,
477            },
478        ];
479        let t = request_input_text(&req);
480        assert!(t.contains("sys ctx"));
481        assert!(t.contains("Confidential matter"));
482        assert!(
483            !t.contains("legal advice"),
484            "assistant output must be excluded"
485        );
486    }
487}