Skip to main content

rust_genai/
model_capabilities.rs

1//! Model capability checks and feature gating.
2
3use rust_genai_types::content::{Content, PartKind};
4use rust_genai_types::tool::Tool;
5
6use crate::error::{Error, Result};
7
8#[derive(Debug, Clone, Copy, Default)]
9pub struct ModelCapabilities {
10    flags: u8,
11}
12
13impl ModelCapabilities {
14    const FUNCTION_RESPONSE_MEDIA: u8 = 1 << 0;
15    const CODE_EXECUTION_IMAGES: u8 = 1 << 1;
16    const NATIVE_AUDIO: u8 = 1 << 2;
17    const THINKING: u8 = 1 << 3;
18
19    const fn new(flags: u8) -> Self {
20        Self { flags }
21    }
22
23    #[must_use]
24    pub const fn supports_function_response_media(self) -> bool {
25        self.flags & Self::FUNCTION_RESPONSE_MEDIA != 0
26    }
27
28    #[must_use]
29    pub const fn supports_code_execution_images(self) -> bool {
30        self.flags & Self::CODE_EXECUTION_IMAGES != 0
31    }
32
33    #[must_use]
34    pub const fn supports_native_audio(self) -> bool {
35        self.flags & Self::NATIVE_AUDIO != 0
36    }
37
38    #[must_use]
39    pub const fn supports_thinking(self) -> bool {
40        self.flags & Self::THINKING != 0
41    }
42}
43
44#[must_use]
45pub fn capabilities_for(model: &str) -> ModelCapabilities {
46    let name = normalize_model_name(model);
47    let is_gemini_3 = name.starts_with("gemini-3");
48    let supports_native_audio = name.contains("native-audio");
49    let mut flags = 0;
50    if is_gemini_3 {
51        flags |= ModelCapabilities::FUNCTION_RESPONSE_MEDIA;
52        flags |= ModelCapabilities::CODE_EXECUTION_IMAGES;
53    }
54    if supports_native_audio {
55        flags |= ModelCapabilities::NATIVE_AUDIO;
56    }
57    if is_gemini_3 || name.contains("thinking") {
58        flags |= ModelCapabilities::THINKING;
59    }
60    ModelCapabilities::new(flags)
61}
62
63/// # Errors
64/// 当模型不支持功能响应多媒体时返回错误。
65pub fn validate_function_response_media(model: &str, contents: &[Content]) -> Result<()> {
66    if !has_function_response_media(contents) {
67        return Ok(());
68    }
69    let caps = capabilities_for(model);
70    if !caps.supports_function_response_media() {
71        return Err(Error::InvalidConfig {
72            message: format!("Model {model} does not support media in FunctionResponse parts"),
73        });
74    }
75    Ok(())
76}
77
78/// # Errors
79/// 当模型不支持带图像的代码执行时返回错误。
80pub fn validate_code_execution_image_inputs(
81    model: &str,
82    contents: &[Content],
83    tools: Option<&[Tool]>,
84) -> Result<()> {
85    if !has_code_execution_tool(tools) || !has_image_inputs(contents) {
86        return Ok(());
87    }
88    let caps = capabilities_for(model);
89    if !caps.supports_code_execution_images() {
90        return Err(Error::InvalidConfig {
91            message: format!("Model {model} does not support code execution with image inputs"),
92        });
93    }
94    Ok(())
95}
96
97fn normalize_model_name(model: &str) -> String {
98    model.rsplit('/').next().unwrap_or(model).to_string()
99}
100
101fn has_function_response_media(contents: &[Content]) -> bool {
102    for content in contents {
103        for part in &content.parts {
104            if let PartKind::FunctionResponse { function_response } = &part.kind {
105                if let Some(parts) = &function_response.parts {
106                    if parts
107                        .iter()
108                        .any(|p| p.inline_data.is_some() || p.file_data.is_some())
109                    {
110                        return true;
111                    }
112                }
113            }
114        }
115    }
116    false
117}
118
119fn has_code_execution_tool(tools: Option<&[Tool]>) -> bool {
120    tools.is_some_and(|items| items.iter().any(|tool| tool.code_execution.is_some()))
121}
122
123fn has_image_inputs(contents: &[Content]) -> bool {
124    for content in contents {
125        for part in &content.parts {
126            match &part.kind {
127                PartKind::InlineData { inline_data }
128                    if inline_data.mime_type.starts_with("image/") =>
129                {
130                    return true;
131                }
132                PartKind::FileData { file_data } if file_data.mime_type.starts_with("image/") => {
133                    return true;
134                }
135                _ => {}
136            }
137        }
138    }
139    false
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use rust_genai_types::content::{
146        Content, FunctionResponse, FunctionResponseBlob, FunctionResponseFileData,
147        FunctionResponsePart, Part, Role,
148    };
149    use rust_genai_types::tool::Tool;
150    use serde_json::json;
151
152    #[test]
153    fn capabilities_detect_features() {
154        let caps = capabilities_for("models/gemini-3-thinking");
155        assert!(caps.supports_function_response_media());
156        assert!(caps.supports_code_execution_images());
157        assert!(caps.supports_thinking());
158
159        let caps = capabilities_for("gemini-2.0-flash-native-audio");
160        assert!(caps.supports_native_audio());
161        assert!(!caps.supports_function_response_media());
162    }
163
164    #[test]
165    fn validate_function_response_media_blocks_unsupported_models() {
166        let response = FunctionResponse {
167            will_continue: None,
168            scheduling: None,
169            parts: Some(vec![rust_genai_types::content::FunctionResponsePart {
170                inline_data: Some(rust_genai_types::content::FunctionResponseBlob {
171                    mime_type: "image/png".into(),
172                    data: vec![1, 2, 3],
173                    display_name: None,
174                }),
175                file_data: None,
176            }]),
177            id: Some("id".into()),
178            name: Some("fn".into()),
179            response: Some(json!({"ok": true})),
180        };
181        let content = Content::from_parts(vec![Part::function_response(response)], Role::Model);
182        let result = validate_function_response_media("gemini-2.0-flash", &[content]);
183        assert!(result.is_err());
184
185        let ok = validate_function_response_media("gemini-3", &[]);
186        assert!(ok.is_ok());
187    }
188
189    #[test]
190    fn validate_code_execution_image_inputs_blocks_unsupported_models() {
191        let tool = Tool {
192            code_execution: Some(rust_genai_types::tool::CodeExecution {}),
193            ..Default::default()
194        };
195        let content = Content::from_parts(
196            vec![Part::inline_data(vec![1, 2, 3], "image/png")],
197            Role::User,
198        );
199        let result = validate_code_execution_image_inputs(
200            "gemini-2.0-flash",
201            std::slice::from_ref(&content),
202            Some(&[tool]),
203        );
204        assert!(result.is_err());
205
206        let ok = validate_code_execution_image_inputs(
207            "gemini-3",
208            std::slice::from_ref(&content),
209            Some(&[Tool {
210                code_execution: Some(rust_genai_types::tool::CodeExecution {}),
211                ..Default::default()
212            }]),
213        );
214        assert!(ok.is_ok());
215    }
216
217    #[test]
218    fn has_function_response_media_detects_parts() {
219        let response = FunctionResponse {
220            will_continue: None,
221            scheduling: None,
222            parts: Some(vec![
223                FunctionResponsePart {
224                    inline_data: Some(FunctionResponseBlob {
225                        mime_type: "image/png".into(),
226                        data: vec![1],
227                        display_name: None,
228                    }),
229                    file_data: None,
230                },
231                FunctionResponsePart {
232                    inline_data: None,
233                    file_data: Some(FunctionResponseFileData {
234                        file_uri: "files/abc".into(),
235                        mime_type: "image/png".into(),
236                        display_name: None,
237                    }),
238                },
239            ]),
240            id: Some("id".into()),
241            name: Some("fn".into()),
242            response: Some(json!({"ok": true})),
243        };
244        let content = Content::from_parts(vec![Part::function_response(response)], Role::Model);
245        assert!(has_function_response_media(&[content]));
246    }
247
248    #[test]
249    fn has_image_inputs_detects_inline_and_file() {
250        let inline = Content::from_parts(
251            vec![Part::inline_data(vec![1, 2, 3], "image/png")],
252            Role::User,
253        );
254        let file = Content::from_parts(vec![Part::file_data("files/abc", "image/png")], Role::User);
255        assert!(has_image_inputs(&[inline]));
256        assert!(has_image_inputs(&[file]));
257    }
258}