1use crate::{
23 messages::{ContentPart, Message, MessageContent},
24 pricing::{Capability, ModelInfo},
25 ChatCompletionRequest,
26};
27
28#[derive(Debug, Clone, Default, PartialEq, Eq)]
30pub struct RequiredCapabilities {
31 pub vision: bool,
33 pub tools: bool,
36 pub json_mode: bool,
38}
39
40impl RequiredCapabilities {
41 pub fn from_request(req: &ChatCompletionRequest) -> Self {
43 let mut caps = Self::default();
44
45 if !req.tools.is_empty() {
47 caps.tools = true;
48 }
49
50 if let Some(rf) = &req.response_format {
52 if rf.r#type == "json_object" || rf.r#type == "json_schema" {
53 caps.json_mode = true;
54 }
55 }
56
57 for msg in &req.messages {
59 match msg {
60 Message::User { content, .. } | Message::System { content } => {
61 if let MessageContent::Parts(parts) = content {
62 for part in parts {
63 match part {
64 ContentPart::ImageUrl { .. } | ContentPart::InputAudio { .. } => {
65 caps.vision = true;
66 }
67 ContentPart::Text { .. } => {}
68 }
69 }
70 }
71 }
72 Message::Assistant { tool_calls, .. } => {
73 if !tool_calls.is_empty() {
74 caps.tools = true;
75 }
76 }
77 Message::Tool { .. } => {
78 caps.tools = true;
81 }
82 }
83 }
84
85 caps
86 }
87
88 #[must_use]
93 pub fn satisfied_by(&self, info: &ModelInfo, estimated_tokens: u64) -> bool {
94 if self.vision && !info.capabilities.contains(&Capability::Vision) {
95 return false;
96 }
97 if self.tools && !info.capabilities.contains(&Capability::Tools) {
98 return false;
99 }
100 if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
101 return false;
102 }
103 if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
104 return false;
105 }
106 true
107 }
108
109 pub fn skip_reasons(&self, info: &ModelInfo, estimated_tokens: u64) -> Vec<&'static str> {
112 let mut reasons = Vec::new();
113 if self.vision && !info.capabilities.contains(&Capability::Vision) {
114 reasons.push("vision_not_supported");
115 }
116 if self.tools && !info.capabilities.contains(&Capability::Tools) {
117 reasons.push("tools_not_supported");
118 }
119 if self.json_mode && !info.capabilities.contains(&Capability::JsonMode) {
120 reasons.push("json_mode_not_supported");
121 }
122 if estimated_tokens > 0 && info.max_input_tokens < estimated_tokens {
123 reasons.push("context_window_too_small");
124 }
125 reasons
126 }
127}
128
129pub fn message_text_for_estimation(req: &ChatCompletionRequest) -> String {
135 req.messages
136 .iter()
137 .map(|m| match m {
138 Message::User { content, .. } | Message::System { content } => extract_text(content),
139 Message::Assistant { content, .. } => {
140 content.as_ref().map(extract_text).unwrap_or_default()
141 }
142 Message::Tool { content, .. } => extract_text(content),
143 })
144 .collect()
145}
146
147fn extract_text(content: &MessageContent) -> String {
148 match content {
149 MessageContent::Text(s) => s.clone(),
150 MessageContent::Parts(parts) => parts
151 .iter()
152 .filter_map(|p| match p {
153 ContentPart::Text { text } => Some(text.as_str()),
154 _ => None,
155 })
156 .collect::<Vec<_>>()
157 .join(""),
158 }
159}
160
161pub fn request_has_images(req: &ChatCompletionRequest) -> bool {
166 req.messages
167 .iter()
168 .any(|m| content_of(m).is_some_and(has_image_part))
169}
170
171pub fn request_has_audio(req: &ChatCompletionRequest) -> bool {
173 req.messages
174 .iter()
175 .any(|m| content_of(m).is_some_and(has_audio_part))
176}
177
178fn content_of(m: &Message) -> Option<&MessageContent> {
180 match m {
181 Message::User { content, .. }
182 | Message::System { content }
183 | Message::Tool { content, .. } => Some(content),
184 Message::Assistant { content, .. } => content.as_ref(),
185 }
186}
187
188fn has_image_part(c: &MessageContent) -> bool {
189 matches!(c, MessageContent::Parts(parts)
190 if parts.iter().any(|p| matches!(p, ContentPart::ImageUrl { .. })))
191}
192
193fn has_audio_part(c: &MessageContent) -> bool {
194 matches!(c, MessageContent::Parts(parts)
195 if parts.iter().any(|p| matches!(p, ContentPart::InputAudio { .. })))
196}
197
198pub fn request_input_text(req: &ChatCompletionRequest) -> String {
202 req.messages
203 .iter()
204 .filter_map(|m| match m {
205 Message::User { content, .. } | Message::System { content } => {
206 Some(extract_text(content))
207 }
208 _ => None,
209 })
210 .collect::<Vec<_>>()
211 .join("\n")
212}
213
214#[cfg(test)]
215mod tests {
216 use std::collections::HashMap;
217
218 use super::*;
219 use crate::{
220 messages::{
221 ImageUrl, InputAudio, ResponseFormat, Tool, ToolCall, ToolCallFunction, ToolFunction,
222 },
223 pricing::Capability,
224 ModelInfo,
225 };
226
227 fn text_model() -> ModelInfo {
228 ModelInfo {
229 id: "text-only".into(),
230 provider: "mock".into(),
231 capabilities: vec![Capability::Text],
232 max_input_tokens: 4096,
233 max_output_tokens: 1024,
234 }
235 }
236
237 fn vision_model() -> ModelInfo {
238 ModelInfo {
239 id: "vision-model".into(),
240 provider: "mock".into(),
241 capabilities: vec![Capability::Text, Capability::Vision, Capability::Tools],
242 max_input_tokens: 128_000,
243 max_output_tokens: 4096,
244 }
245 }
246
247 fn small_model() -> ModelInfo {
248 ModelInfo {
249 id: "small-ctx".into(),
250 provider: "mock".into(),
251 capabilities: vec![Capability::Text],
252 max_input_tokens: 100,
253 max_output_tokens: 100,
254 }
255 }
256
257 fn base_req() -> ChatCompletionRequest {
258 ChatCompletionRequest {
259 model: "gpt-4o".into(),
260 messages: vec![],
261 temperature: None,
262 top_p: None,
263 max_tokens: None,
264 stream: false,
265 tools: vec![],
266 tool_choice: None,
267 response_format: None,
268 stop: vec![],
269 presence_penalty: None,
270 frequency_penalty: None,
271 n: None,
272 seed: None,
273 user: None,
274 tt_extras: HashMap::new(),
275 }
276 }
277
278 #[test]
279 fn plain_text_request_has_no_required_caps() {
280 let req = base_req();
281 let caps = RequiredCapabilities::from_request(&req);
282 assert!(!caps.vision);
283 assert!(!caps.tools);
284 assert!(!caps.json_mode);
285 }
286
287 #[test]
288 fn image_url_part_sets_vision() {
289 let mut req = base_req();
290 req.messages = vec![Message::User {
291 content: MessageContent::Parts(vec![
292 ContentPart::Text {
293 text: "describe this".into(),
294 },
295 ContentPart::ImageUrl {
296 image_url: ImageUrl {
297 url: "data:image/png;base64,abc".into(),
298 detail: None,
299 },
300 },
301 ]),
302 name: None,
303 }];
304 let caps = RequiredCapabilities::from_request(&req);
305 assert!(caps.vision);
306 assert!(!caps.tools);
307 }
308
309 #[test]
310 fn tools_field_sets_tools_cap() {
311 let mut req = base_req();
312 req.tools = vec![Tool {
313 r#type: "function".into(),
314 function: ToolFunction {
315 name: "get_weather".into(),
316 description: None,
317 parameters: serde_json::json!({}),
318 },
319 }];
320 let caps = RequiredCapabilities::from_request(&req);
321 assert!(caps.tools);
322 }
323
324 #[test]
325 fn assistant_tool_calls_in_history_sets_tools_cap() {
326 let mut req = base_req();
327 req.messages = vec![Message::Assistant {
328 content: None,
329 tool_calls: vec![ToolCall {
330 id: "call_1".into(),
331 r#type: "function".into(),
332 function: ToolCallFunction {
333 name: "get_weather".into(),
334 arguments: "{}".into(),
335 },
336 }],
337 name: None,
338 }];
339 let caps = RequiredCapabilities::from_request(&req);
340 assert!(caps.tools);
341 }
342
343 #[test]
344 fn json_object_response_format_sets_json_mode() {
345 let mut req = base_req();
346 req.response_format = Some(ResponseFormat {
347 r#type: "json_object".into(),
348 json_schema: None,
349 });
350 let caps = RequiredCapabilities::from_request(&req);
351 assert!(caps.json_mode);
352 }
353
354 #[test]
355 fn vision_request_not_satisfied_by_text_model() {
356 let mut req = base_req();
357 req.messages = vec![Message::User {
358 content: MessageContent::Parts(vec![ContentPart::ImageUrl {
359 image_url: ImageUrl {
360 url: "data:image/png;base64,abc".into(),
361 detail: None,
362 },
363 }]),
364 name: None,
365 }];
366 let caps = RequiredCapabilities::from_request(&req);
367 assert!(!caps.satisfied_by(&text_model(), 0));
368 }
369
370 #[test]
371 fn vision_request_satisfied_by_vision_model() {
372 let mut req = base_req();
373 req.messages = vec![Message::User {
374 content: MessageContent::Parts(vec![ContentPart::ImageUrl {
375 image_url: ImageUrl {
376 url: "data:image/png;base64,abc".into(),
377 detail: None,
378 },
379 }]),
380 name: None,
381 }];
382 let caps = RequiredCapabilities::from_request(&req);
383 assert!(caps.satisfied_by(&vision_model(), 0));
384 }
385
386 #[test]
387 fn exceeds_context_window_not_satisfied() {
388 let caps = RequiredCapabilities::default();
389 assert!(!caps.satisfied_by(&small_model(), 200));
390 }
391
392 #[test]
393 fn within_context_window_satisfied() {
394 let caps = RequiredCapabilities::default();
395 assert!(caps.satisfied_by(&small_model(), 50));
396 }
397
398 #[test]
399 fn zero_estimated_tokens_skips_window_check() {
400 let caps = RequiredCapabilities::default();
401 assert!(caps.satisfied_by(&small_model(), 0));
402 }
403
404 #[test]
405 fn skip_reasons_lists_all_failures() {
406 let caps = RequiredCapabilities {
407 vision: true,
408 tools: true,
409 ..Default::default()
410 };
411 let reasons = caps.skip_reasons(&text_model(), 9999);
412 assert!(reasons.contains(&"vision_not_supported"));
413 assert!(reasons.contains(&"tools_not_supported"));
414 assert!(reasons.contains(&"context_window_too_small"));
415 }
416
417 #[test]
418 fn request_has_images_detects_image_part() {
419 let mut req = base_req();
420 req.messages = vec![Message::User {
421 content: MessageContent::Parts(vec![
422 ContentPart::Text {
423 text: "look".into(),
424 },
425 ContentPart::ImageUrl {
426 image_url: ImageUrl {
427 url: "data:image/png;base64,abc".into(),
428 detail: None,
429 },
430 },
431 ]),
432 name: None,
433 }];
434 assert!(request_has_images(&req));
435 assert!(!request_has_audio(&req));
436 }
437
438 #[test]
439 fn request_has_audio_detects_audio_part() {
440 let mut req = base_req();
441 req.messages = vec![Message::User {
442 content: MessageContent::Parts(vec![ContentPart::InputAudio {
443 input_audio: InputAudio {
444 data: "abc".into(),
445 format: "wav".into(),
446 },
447 }]),
448 name: None,
449 }];
450 assert!(request_has_audio(&req));
451 assert!(!request_has_images(&req));
452 }
453
454 #[test]
455 fn plain_text_request_has_no_modality() {
456 let req = base_req();
457 assert!(!request_has_images(&req));
458 assert!(!request_has_audio(&req));
459 }
460
461 #[test]
462 fn request_input_text_user_and_system_only() {
463 let mut req = base_req();
464 req.messages = vec![
465 Message::System {
466 content: MessageContent::Text("sys ctx".into()),
467 },
468 Message::User {
469 content: MessageContent::Text("Confidential matter".into()),
470 name: None,
471 },
472 Message::Assistant {
473 content: Some(MessageContent::Text("legal advice".into())),
474 tool_calls: vec![],
475 name: None,
476 },
477 ];
478 let t = request_input_text(&req);
479 assert!(t.contains("sys ctx"));
480 assert!(t.contains("Confidential matter"));
481 assert!(
482 !t.contains("legal advice"),
483 "assistant output must be excluded"
484 );
485 }
486}