1use crate::{
23 messages::{ContentPart, Message, MessageContent},
24 pricing::{Capability, ModelInfo},
25 ChatCompletionRequest,
26};
27
28#[derive(Debug, Clone, Default, PartialEq, Eq)]
30pub struct RequiredCapabilities {
31 pub vision: bool,
33 pub tools: bool,
36 pub json_mode: bool,
38}
39
40impl RequiredCapabilities {
41 pub fn from_request(req: &ChatCompletionRequest) -> Self {
43 let mut caps = Self::default();
44
45 if !req.tools.is_empty() {
47 caps.tools = true;
48 }
49
50 if let Some(rf) = &req.response_format {
52 if rf.r#type == "json_object" || rf.r#type == "json_schema" {
53 caps.json_mode = true;
54 }
55 }
56
57 for msg in &req.messages {
59 match msg {
60 Message::User { content, .. } | Message::System { content } => {
61 if let MessageContent::Parts(parts) = content {
62 for part in parts {
63 match part {
64 ContentPart::ImageUrl { .. } | ContentPart::InputAudio { .. } => {
65 caps.vision = true;
66 }
67 ContentPart::Text { .. } => {}
68 }
69 }
70 }
71 }
72 Message::Assistant { tool_calls, .. } => {
73 if !tool_calls.is_empty() {
74 caps.tools = true;
75 }
76 }
77 Message::Tool { .. } => {
78 caps.tools = true;
81 }
82 }
83 }
84
85 caps
86 }
87
88 #[must_use]
93 pub fn satisfied_by(&self, info: &ModelInfo, estimated_tokens: u64) -> bool {
94 if self.vision && !info.capabilities.contains(&Capability::Vision) {
95 return false;
96 }
97 if self.tools && !info.capabilities.contains(&Capability::Tools) {
98 return false;
99 }
100 if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
101 return false;
102 }
103 if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
104 return false;
105 }
106 true
107 }
108
109 pub fn skip_reasons(&self, info: &ModelInfo, estimated_tokens: u64) -> Vec<&'static str> {
112 let mut reasons = Vec::new();
113 if self.vision && !info.capabilities.contains(&Capability::Vision) {
114 reasons.push("vision_not_supported");
115 }
116 if self.tools && !info.capabilities.contains(&Capability::Tools) {
117 reasons.push("tools_not_supported");
118 }
119 if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
120 reasons.push("json_mode_not_supported");
121 }
122 if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
123 reasons.push("context_window_too_small");
124 }
125 reasons
126 }
127}
128
129pub fn message_text_for_estimation(req: &ChatCompletionRequest) -> String {
135 req.messages
136 .iter()
137 .map(|m| match m {
138 Message::User { content, .. } | Message::System { content } => extract_text(content),
139 Message::Assistant { content, .. } => {
140 content.as_ref().map(extract_text).unwrap_or_default()
141 }
142 Message::Tool { content, .. } => extract_text(content),
143 })
144 .collect()
145}
146
147fn extract_text(content: &MessageContent) -> String {
148 match content {
149 MessageContent::Text(s) => s.clone(),
150 MessageContent::Parts(parts) => parts
151 .iter()
152 .filter_map(|p| match p {
153 ContentPart::Text { text } => Some(text.as_str()),
154 _ => None,
155 })
156 .collect::<Vec<_>>()
157 .join(""),
158 }
159}
160
161pub fn request_has_images(req: &ChatCompletionRequest) -> bool {
166 req.messages
167 .iter()
168 .any(|m| content_of(m).is_some_and(has_image_part))
169}
170
171pub fn request_has_audio(req: &ChatCompletionRequest) -> bool {
173 req.messages
174 .iter()
175 .any(|m| content_of(m).is_some_and(has_audio_part))
176}
177
178fn content_of(m: &Message) -> Option<&MessageContent> {
180 match m {
181 Message::User { content, .. }
182 | Message::System { content }
183 | Message::Tool { content, .. } => Some(content),
184 Message::Assistant { content, .. } => content.as_ref(),
185 }
186}
187
188fn has_image_part(c: &MessageContent) -> bool {
189 matches!(c, MessageContent::Parts(parts)
190 if parts.iter().any(|p| matches!(p, ContentPart::ImageUrl { .. })))
191}
192
193fn has_audio_part(c: &MessageContent) -> bool {
194 matches!(c, MessageContent::Parts(parts)
195 if parts.iter().any(|p| matches!(p, ContentPart::InputAudio { .. })))
196}
197
198pub fn request_input_text(req: &ChatCompletionRequest) -> String {
202 req.messages
203 .iter()
204 .filter_map(|m| match m {
205 Message::User { content, .. } | Message::System { content } => {
206 Some(extract_text(content))
207 }
208 _ => None,
209 })
210 .collect::<Vec<_>>()
211 .join("\n")
212}
213
214#[cfg(test)]
215mod tests {
216 use std::collections::HashMap;
217
218 use super::*;
219 use crate::{
220 messages::{
221 ImageUrl, InputAudio, ResponseFormat, Tool, ToolCall, ToolCallFunction, ToolFunction,
222 },
223 pricing::Capability,
224 ModelInfo,
225 };
226
227 fn text_model() -> ModelInfo {
228 ModelInfo {
229 id: "text-only".into(),
230 provider: "mock".into(),
231 capabilities: vec![Capability::Text],
232 max_input_tokens: 4096,
233 max_output_tokens: 1024,
234 }
235 }
236
237 fn vision_model() -> ModelInfo {
238 ModelInfo {
239 id: "vision-model".into(),
240 provider: "mock".into(),
241 capabilities: vec![Capability::Text, Capability::Vision, Capability::Tools],
242 max_input_tokens: 128_000,
243 max_output_tokens: 4096,
244 }
245 }
246
247 fn small_model() -> ModelInfo {
248 ModelInfo {
249 id: "small-ctx".into(),
250 provider: "mock".into(),
251 capabilities: vec![Capability::Text],
252 max_input_tokens: 100,
253 max_output_tokens: 100,
254 }
255 }
256
257 fn base_req() -> ChatCompletionRequest {
258 ChatCompletionRequest {
259 model: "gpt-4o".into(),
260 messages: vec![],
261 temperature: None,
262 top_p: None,
263 max_tokens: None,
264 stream: false,
265 tools: vec![],
266 tool_choice: None,
267 response_format: None,
268 stop: vec![],
269 presence_penalty: None,
270 frequency_penalty: None,
271 n: None,
272 seed: None,
273 user: None,
274 tt_extras: HashMap::new(),
275 ..Default::default()
276 }
277 }
278
279 #[test]
280 fn plain_text_request_has_no_required_caps() {
281 let req = base_req();
282 let caps = RequiredCapabilities::from_request(&req);
283 assert!(!caps.vision);
284 assert!(!caps.tools);
285 assert!(!caps.json_mode);
286 }
287
288 #[test]
289 fn image_url_part_sets_vision() {
290 let mut req = base_req();
291 req.messages = vec![Message::User {
292 content: MessageContent::Parts(vec![
293 ContentPart::Text {
294 text: "describe this".into(),
295 },
296 ContentPart::ImageUrl {
297 image_url: ImageUrl {
298 url: "data:image/png;base64,abc".into(),
299 detail: None,
300 },
301 },
302 ]),
303 name: None,
304 }];
305 let caps = RequiredCapabilities::from_request(&req);
306 assert!(caps.vision);
307 assert!(!caps.tools);
308 }
309
310 #[test]
311 fn tools_field_sets_tools_cap() {
312 let mut req = base_req();
313 req.tools = vec![Tool {
314 r#type: "function".into(),
315 function: ToolFunction {
316 name: "get_weather".into(),
317 description: None,
318 parameters: serde_json::json!({}),
319 },
320 }];
321 let caps = RequiredCapabilities::from_request(&req);
322 assert!(caps.tools);
323 }
324
325 #[test]
326 fn assistant_tool_calls_in_history_sets_tools_cap() {
327 let mut req = base_req();
328 req.messages = vec![Message::Assistant {
329 content: None,
330 tool_calls: vec![ToolCall {
331 id: "call_1".into(),
332 r#type: "function".into(),
333 function: ToolCallFunction {
334 name: "get_weather".into(),
335 arguments: "{}".into(),
336 },
337 }],
338 name: None,
339 }];
340 let caps = RequiredCapabilities::from_request(&req);
341 assert!(caps.tools);
342 }
343
344 #[test]
345 fn json_object_response_format_sets_json_mode() {
346 let mut req = base_req();
347 req.response_format = Some(ResponseFormat {
348 r#type: "json_object".into(),
349 json_schema: None,
350 });
351 let caps = RequiredCapabilities::from_request(&req);
352 assert!(caps.json_mode);
353 }
354
355 #[test]
356 fn vision_request_not_satisfied_by_text_model() {
357 let mut req = base_req();
358 req.messages = vec![Message::User {
359 content: MessageContent::Parts(vec![ContentPart::ImageUrl {
360 image_url: ImageUrl {
361 url: "data:image/png;base64,abc".into(),
362 detail: None,
363 },
364 }]),
365 name: None,
366 }];
367 let caps = RequiredCapabilities::from_request(&req);
368 assert!(!caps.satisfied_by(&text_model(), 0));
369 }
370
371 #[test]
372 fn vision_request_satisfied_by_vision_model() {
373 let mut req = base_req();
374 req.messages = vec![Message::User {
375 content: MessageContent::Parts(vec![ContentPart::ImageUrl {
376 image_url: ImageUrl {
377 url: "data:image/png;base64,abc".into(),
378 detail: None,
379 },
380 }]),
381 name: None,
382 }];
383 let caps = RequiredCapabilities::from_request(&req);
384 assert!(caps.satisfied_by(&vision_model(), 0));
385 }
386
387 #[test]
388 fn exceeds_context_window_not_satisfied() {
389 let caps = RequiredCapabilities::default();
390 assert!(!caps.satisfied_by(&small_model(), 200));
391 }
392
393 #[test]
394 fn within_context_window_satisfied() {
395 let caps = RequiredCapabilities::default();
396 assert!(caps.satisfied_by(&small_model(), 50));
397 }
398
399 #[test]
400 fn zero_estimated_tokens_skips_window_check() {
401 let caps = RequiredCapabilities::default();
402 assert!(caps.satisfied_by(&small_model(), 0));
403 }
404
405 #[test]
406 fn skip_reasons_lists_all_failures() {
407 let caps = RequiredCapabilities {
408 vision: true,
409 tools: true,
410 ..Default::default()
411 };
412 let reasons = caps.skip_reasons(&text_model(), 9999);
413 assert!(reasons.contains(&"vision_not_supported"));
414 assert!(reasons.contains(&"tools_not_supported"));
415 assert!(reasons.contains(&"context_window_too_small"));
416 }
417
418 #[test]
419 fn request_has_images_detects_image_part() {
420 let mut req = base_req();
421 req.messages = vec![Message::User {
422 content: MessageContent::Parts(vec![
423 ContentPart::Text {
424 text: "look".into(),
425 },
426 ContentPart::ImageUrl {
427 image_url: ImageUrl {
428 url: "data:image/png;base64,abc".into(),
429 detail: None,
430 },
431 },
432 ]),
433 name: None,
434 }];
435 assert!(request_has_images(&req));
436 assert!(!request_has_audio(&req));
437 }
438
439 #[test]
440 fn request_has_audio_detects_audio_part() {
441 let mut req = base_req();
442 req.messages = vec![Message::User {
443 content: MessageContent::Parts(vec![ContentPart::InputAudio {
444 input_audio: InputAudio {
445 data: "abc".into(),
446 format: "wav".into(),
447 },
448 }]),
449 name: None,
450 }];
451 assert!(request_has_audio(&req));
452 assert!(!request_has_images(&req));
453 }
454
455 #[test]
456 fn plain_text_request_has_no_modality() {
457 let req = base_req();
458 assert!(!request_has_images(&req));
459 assert!(!request_has_audio(&req));
460 }
461
462 #[test]
463 fn request_input_text_user_and_system_only() {
464 let mut req = base_req();
465 req.messages = vec![
466 Message::System {
467 content: MessageContent::Text("sys ctx".into()),
468 },
469 Message::User {
470 content: MessageContent::Text("Confidential matter".into()),
471 name: None,
472 },
473 Message::Assistant {
474 content: Some(MessageContent::Text("legal advice".into())),
475 tool_calls: vec![],
476 name: None,
477 },
478 ];
479 let t = request_input_text(&req);
480 assert!(t.contains("sys ctx"));
481 assert!(t.contains("Confidential matter"));
482 assert!(
483 !t.contains("legal advice"),
484 "assistant output must be excluded"
485 );
486 }
487}