1use crate::{
23 messages::{ContentPart, Message, MessageContent},
24 pricing::{Capability, ModelInfo},
25 ChatCompletionRequest,
26};
27
28#[derive(Debug, Clone, Default, PartialEq, Eq)]
30pub struct RequiredCapabilities {
31 pub vision: bool,
33 pub tools: bool,
36 pub json_mode: bool,
38}
39
40impl RequiredCapabilities {
41 pub fn from_request(req: &ChatCompletionRequest) -> Self {
43 let mut caps = Self::default();
44
45 if !req.tools.is_empty() {
47 caps.tools = true;
48 }
49
50 if let Some(rf) = &req.response_format {
52 if rf.r#type == "json_object" || rf.r#type == "json_schema" {
53 caps.json_mode = true;
54 }
55 }
56
57 for msg in &req.messages {
59 match msg {
60 Message::User { content, .. } | Message::System { content } => {
61 if let MessageContent::Parts(parts) = content {
62 for part in parts {
63 match part {
64 ContentPart::ImageUrl { .. } | ContentPart::InputAudio { .. } => {
65 caps.vision = true;
66 }
67 ContentPart::Text { .. } => {}
68 }
69 }
70 }
71 }
72 Message::Assistant { tool_calls, .. } => {
73 if !tool_calls.is_empty() {
74 caps.tools = true;
75 }
76 }
77 Message::Tool { .. } => {
78 caps.tools = true;
81 }
82 }
83 }
84
85 caps
86 }
87
88 #[must_use]
93 pub fn satisfied_by(&self, info: &ModelInfo, estimated_tokens: u64) -> bool {
94 if self.vision && !info.capabilities.contains(&Capability::Vision) {
95 return false;
96 }
97 if self.tools && !info.capabilities.contains(&Capability::Tools) {
98 return false;
99 }
100 if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
101 return false;
102 }
103 if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
104 return false;
105 }
106 true
107 }
108
109 pub fn skip_reasons(&self, info: &ModelInfo, estimated_tokens: u64) -> Vec<&'static str> {
112 let mut reasons = Vec::new();
113 if self.vision && !info.capabilities.contains(&Capability::Vision) {
114 reasons.push("vision_not_supported");
115 }
116 if self.tools && !info.capabilities.contains(&Capability::Tools) {
117 reasons.push("tools_not_supported");
118 }
119 if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
120 reasons.push("json_mode_not_supported");
121 }
122 if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
123 reasons.push("context_window_too_small");
124 }
125 reasons
126 }
127}
128
129pub fn message_text_for_estimation(req: &ChatCompletionRequest) -> String {
135 req.messages
136 .iter()
137 .map(|m| match m {
138 Message::User { content, .. } | Message::System { content } => extract_text(content),
139 Message::Assistant { content, .. } => {
140 content.as_ref().map(extract_text).unwrap_or_default()
141 }
142 Message::Tool { content, .. } => extract_text(content),
143 })
144 .collect()
145}
146
147fn extract_text(content: &MessageContent) -> String {
148 match content {
149 MessageContent::Text(s) => s.clone(),
150 MessageContent::Parts(parts) => parts
151 .iter()
152 .filter_map(|p| match p {
153 ContentPart::Text { text } => Some(text.as_str()),
154 _ => None,
155 })
156 .collect::<Vec<_>>()
157 .join(""),
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use std::collections::HashMap;
164
165 use super::*;
166 use crate::{
167 messages::{ImageUrl, ResponseFormat, Tool, ToolCall, ToolCallFunction, ToolFunction},
168 pricing::Capability,
169 ModelInfo,
170 };
171
172 fn text_model() -> ModelInfo {
173 ModelInfo {
174 id: "text-only".into(),
175 provider: "mock".into(),
176 capabilities: vec![Capability::Text],
177 max_input_tokens: 4096,
178 max_output_tokens: 1024,
179 }
180 }
181
182 fn vision_model() -> ModelInfo {
183 ModelInfo {
184 id: "vision-model".into(),
185 provider: "mock".into(),
186 capabilities: vec![Capability::Text, Capability::Vision, Capability::Tools],
187 max_input_tokens: 128_000,
188 max_output_tokens: 4096,
189 }
190 }
191
192 fn small_model() -> ModelInfo {
193 ModelInfo {
194 id: "small-ctx".into(),
195 provider: "mock".into(),
196 capabilities: vec![Capability::Text],
197 max_input_tokens: 100,
198 max_output_tokens: 100,
199 }
200 }
201
202 fn base_req() -> ChatCompletionRequest {
203 ChatCompletionRequest {
204 model: "gpt-4o".into(),
205 messages: vec![],
206 temperature: None,
207 top_p: None,
208 max_tokens: None,
209 stream: false,
210 tools: vec![],
211 tool_choice: None,
212 response_format: None,
213 stop: vec![],
214 presence_penalty: None,
215 frequency_penalty: None,
216 n: None,
217 seed: None,
218 user: None,
219 tt_extras: HashMap::new(),
220 }
221 }
222
223 #[test]
224 fn plain_text_request_has_no_required_caps() {
225 let req = base_req();
226 let caps = RequiredCapabilities::from_request(&req);
227 assert!(!caps.vision);
228 assert!(!caps.tools);
229 assert!(!caps.json_mode);
230 }
231
232 #[test]
233 fn image_url_part_sets_vision() {
234 let mut req = base_req();
235 req.messages = vec![Message::User {
236 content: MessageContent::Parts(vec![
237 ContentPart::Text {
238 text: "describe this".into(),
239 },
240 ContentPart::ImageUrl {
241 image_url: ImageUrl {
242 url: "data:image/png;base64,abc".into(),
243 detail: None,
244 },
245 },
246 ]),
247 name: None,
248 }];
249 let caps = RequiredCapabilities::from_request(&req);
250 assert!(caps.vision);
251 assert!(!caps.tools);
252 }
253
254 #[test]
255 fn tools_field_sets_tools_cap() {
256 let mut req = base_req();
257 req.tools = vec![Tool {
258 r#type: "function".into(),
259 function: ToolFunction {
260 name: "get_weather".into(),
261 description: None,
262 parameters: serde_json::json!({}),
263 },
264 }];
265 let caps = RequiredCapabilities::from_request(&req);
266 assert!(caps.tools);
267 }
268
269 #[test]
270 fn assistant_tool_calls_in_history_sets_tools_cap() {
271 let mut req = base_req();
272 req.messages = vec![Message::Assistant {
273 content: None,
274 tool_calls: vec![ToolCall {
275 id: "call_1".into(),
276 r#type: "function".into(),
277 function: ToolCallFunction {
278 name: "get_weather".into(),
279 arguments: "{}".into(),
280 },
281 }],
282 name: None,
283 }];
284 let caps = RequiredCapabilities::from_request(&req);
285 assert!(caps.tools);
286 }
287
288 #[test]
289 fn json_object_response_format_sets_json_mode() {
290 let mut req = base_req();
291 req.response_format = Some(ResponseFormat {
292 r#type: "json_object".into(),
293 json_schema: None,
294 });
295 let caps = RequiredCapabilities::from_request(&req);
296 assert!(caps.json_mode);
297 }
298
299 #[test]
300 fn vision_request_not_satisfied_by_text_model() {
301 let mut req = base_req();
302 req.messages = vec![Message::User {
303 content: MessageContent::Parts(vec![ContentPart::ImageUrl {
304 image_url: ImageUrl {
305 url: "data:image/png;base64,abc".into(),
306 detail: None,
307 },
308 }]),
309 name: None,
310 }];
311 let caps = RequiredCapabilities::from_request(&req);
312 assert!(!caps.satisfied_by(&text_model(), 0));
313 }
314
315 #[test]
316 fn vision_request_satisfied_by_vision_model() {
317 let mut req = base_req();
318 req.messages = vec![Message::User {
319 content: MessageContent::Parts(vec![ContentPart::ImageUrl {
320 image_url: ImageUrl {
321 url: "data:image/png;base64,abc".into(),
322 detail: None,
323 },
324 }]),
325 name: None,
326 }];
327 let caps = RequiredCapabilities::from_request(&req);
328 assert!(caps.satisfied_by(&vision_model(), 0));
329 }
330
331 #[test]
332 fn exceeds_context_window_not_satisfied() {
333 let caps = RequiredCapabilities::default();
334 assert!(!caps.satisfied_by(&small_model(), 200));
335 }
336
337 #[test]
338 fn within_context_window_satisfied() {
339 let caps = RequiredCapabilities::default();
340 assert!(caps.satisfied_by(&small_model(), 50));
341 }
342
343 #[test]
344 fn zero_estimated_tokens_skips_window_check() {
345 let caps = RequiredCapabilities::default();
346 assert!(caps.satisfied_by(&small_model(), 0));
347 }
348
349 #[test]
350 fn skip_reasons_lists_all_failures() {
351 let caps = RequiredCapabilities {
352 vision: true,
353 tools: true,
354 ..Default::default()
355 };
356 let reasons = caps.skip_reasons(&text_model(), 9999);
357 assert!(reasons.contains(&"vision_not_supported"));
358 assert!(reasons.contains(&"tools_not_supported"));
359 assert!(reasons.contains(&"context_window_too_small"));
360 }
361}