Skip to main content

rust_genai/
model_capabilities.rs

1//! Model capability checks and feature gating.
2
3use rust_genai_types::content::{Content, PartKind};
4use rust_genai_types::tool::Tool;
5
6use crate::error::{Error, Result};
7
8#[derive(Debug, Clone, Copy, Default)]
9pub struct ModelCapabilities {
10    flags: u8,
11}
12
13impl ModelCapabilities {
14    const FUNCTION_RESPONSE_MEDIA: u8 = 1 << 0;
15    const CODE_EXECUTION_IMAGES: u8 = 1 << 1;
16    const NATIVE_AUDIO: u8 = 1 << 2;
17    const THINKING: u8 = 1 << 3;
18
19    const fn new(flags: u8) -> Self {
20        Self { flags }
21    }
22
23    #[must_use]
24    pub const fn supports_function_response_media(self) -> bool {
25        self.flags & Self::FUNCTION_RESPONSE_MEDIA != 0
26    }
27
28    #[must_use]
29    pub const fn supports_code_execution_images(self) -> bool {
30        self.flags & Self::CODE_EXECUTION_IMAGES != 0
31    }
32
33    #[must_use]
34    pub const fn supports_native_audio(self) -> bool {
35        self.flags & Self::NATIVE_AUDIO != 0
36    }
37
38    #[must_use]
39    pub const fn supports_thinking(self) -> bool {
40        self.flags & Self::THINKING != 0
41    }
42}
43
44#[must_use]
45pub fn capabilities_for(model: &str) -> ModelCapabilities {
46    let name = normalize_model_name(model);
47    let is_gemini_3 = name.starts_with("gemini-3");
48    let supports_native_audio = name.contains("native-audio");
49    let mut flags = 0;
50    if is_gemini_3 {
51        flags |= ModelCapabilities::FUNCTION_RESPONSE_MEDIA;
52        flags |= ModelCapabilities::CODE_EXECUTION_IMAGES;
53    }
54    if supports_native_audio {
55        flags |= ModelCapabilities::NATIVE_AUDIO;
56    }
57    if is_gemini_3 || name.contains("thinking") {
58        flags |= ModelCapabilities::THINKING;
59    }
60    ModelCapabilities::new(flags)
61}
62
63/// # Errors
64/// 当模型不支持功能响应多媒体时返回错误。
65pub fn validate_function_response_media(model: &str, contents: &[Content]) -> Result<()> {
66    if !has_function_response_media(contents) {
67        return Ok(());
68    }
69    let caps = capabilities_for(model);
70    if !caps.supports_function_response_media() {
71        return Err(Error::InvalidConfig {
72            message: format!("Model {model} does not support media in FunctionResponse parts"),
73        });
74    }
75    Ok(())
76}
77
78/// # Errors
79/// 当模型不支持带图像的代码执行时返回错误。
80pub fn validate_code_execution_image_inputs(
81    model: &str,
82    contents: &[Content],
83    tools: Option<&[Tool]>,
84) -> Result<()> {
85    if !has_code_execution_tool(tools) || !has_image_inputs(contents) {
86        return Ok(());
87    }
88    let caps = capabilities_for(model);
89    if !caps.supports_code_execution_images() {
90        return Err(Error::InvalidConfig {
91            message: format!("Model {model} does not support code execution with image inputs"),
92        });
93    }
94    Ok(())
95}
96
97fn normalize_model_name(model: &str) -> String {
98    model.rsplit('/').next().unwrap_or(model).to_string()
99}
100
101fn has_function_response_media(contents: &[Content]) -> bool {
102    for content in contents {
103        for part in &content.parts {
104            if let PartKind::FunctionResponse { function_response } = &part.kind {
105                if let Some(parts) = &function_response.parts {
106                    if parts
107                        .iter()
108                        .any(|p| p.inline_data.is_some() || p.file_data.is_some())
109                    {
110                        return true;
111                    }
112                }
113            }
114        }
115    }
116    false
117}
118
119fn has_code_execution_tool(tools: Option<&[Tool]>) -> bool {
120    tools.is_some_and(|items| items.iter().any(|tool| tool.code_execution.is_some()))
121}
122
123fn has_image_inputs(contents: &[Content]) -> bool {
124    for content in contents {
125        for part in &content.parts {
126            match &part.kind {
127                PartKind::InlineData { inline_data } => {
128                    if inline_data.mime_type.starts_with("image/") {
129                        return true;
130                    }
131                }
132                PartKind::FileData { file_data } => {
133                    if file_data.mime_type.starts_with("image/") {
134                        return true;
135                    }
136                }
137                _ => {}
138            }
139        }
140    }
141    false
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use rust_genai_types::content::{
148        Content, FunctionResponse, FunctionResponseBlob, FunctionResponseFileData,
149        FunctionResponsePart, Part, Role,
150    };
151    use rust_genai_types::tool::Tool;
152    use serde_json::json;
153
154    #[test]
155    fn capabilities_detect_features() {
156        let caps = capabilities_for("models/gemini-3-thinking");
157        assert!(caps.supports_function_response_media());
158        assert!(caps.supports_code_execution_images());
159        assert!(caps.supports_thinking());
160
161        let caps = capabilities_for("gemini-2.0-flash-native-audio");
162        assert!(caps.supports_native_audio());
163        assert!(!caps.supports_function_response_media());
164    }
165
166    #[test]
167    fn validate_function_response_media_blocks_unsupported_models() {
168        let response = FunctionResponse {
169            will_continue: None,
170            scheduling: None,
171            parts: Some(vec![rust_genai_types::content::FunctionResponsePart {
172                inline_data: Some(rust_genai_types::content::FunctionResponseBlob {
173                    mime_type: "image/png".into(),
174                    data: vec![1, 2, 3],
175                    display_name: None,
176                }),
177                file_data: None,
178            }]),
179            id: Some("id".into()),
180            name: Some("fn".into()),
181            response: Some(json!({"ok": true})),
182        };
183        let content = Content::from_parts(vec![Part::function_response(response)], Role::Model);
184        let result = validate_function_response_media("gemini-2.0-flash", &[content]);
185        assert!(result.is_err());
186
187        let ok = validate_function_response_media("gemini-3", &[]);
188        assert!(ok.is_ok());
189    }
190
191    #[test]
192    fn validate_code_execution_image_inputs_blocks_unsupported_models() {
193        let tool = Tool {
194            code_execution: Some(rust_genai_types::tool::CodeExecution {}),
195            ..Default::default()
196        };
197        let content = Content::from_parts(
198            vec![Part::inline_data(vec![1, 2, 3], "image/png")],
199            Role::User,
200        );
201        let result = validate_code_execution_image_inputs(
202            "gemini-2.0-flash",
203            std::slice::from_ref(&content),
204            Some(&[tool]),
205        );
206        assert!(result.is_err());
207
208        let ok = validate_code_execution_image_inputs(
209            "gemini-3",
210            std::slice::from_ref(&content),
211            Some(&[Tool {
212                code_execution: Some(rust_genai_types::tool::CodeExecution {}),
213                ..Default::default()
214            }]),
215        );
216        assert!(ok.is_ok());
217    }
218
219    #[test]
220    fn has_function_response_media_detects_parts() {
221        let response = FunctionResponse {
222            will_continue: None,
223            scheduling: None,
224            parts: Some(vec![
225                FunctionResponsePart {
226                    inline_data: Some(FunctionResponseBlob {
227                        mime_type: "image/png".into(),
228                        data: vec![1],
229                        display_name: None,
230                    }),
231                    file_data: None,
232                },
233                FunctionResponsePart {
234                    inline_data: None,
235                    file_data: Some(FunctionResponseFileData {
236                        file_uri: "files/abc".into(),
237                        mime_type: "image/png".into(),
238                        display_name: None,
239                    }),
240                },
241            ]),
242            id: Some("id".into()),
243            name: Some("fn".into()),
244            response: Some(json!({"ok": true})),
245        };
246        let content = Content::from_parts(vec![Part::function_response(response)], Role::Model);
247        assert!(has_function_response_media(&[content]));
248    }
249
250    #[test]
251    fn has_image_inputs_detects_inline_and_file() {
252        let inline = Content::from_parts(
253            vec![Part::inline_data(vec![1, 2, 3], "image/png")],
254            Role::User,
255        );
256        let file = Content::from_parts(vec![Part::file_data("files/abc", "image/png")], Role::User);
257        assert!(has_image_inputs(&[inline]));
258        assert!(has_image_inputs(&[file]));
259    }
260}