Skip to main content

rust_genai_types/
config.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::collections::HashMap;
4
5use crate::base64_serde;
6use crate::enums::{
7    FeatureSelectionPreference, HarmBlockMethod, HarmBlockThreshold, HarmCategory, MediaResolution,
8    Modality, ThinkingLevel,
9};
10use crate::tool::Schema;
11
12/// 生成配置。
13#[derive(Debug, Clone, Serialize, Deserialize, Default)]
14#[serde(rename_all = "camelCase")]
15pub struct GenerationConfig {
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub temperature: Option<f32>,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub top_p: Option<f32>,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub top_k: Option<f32>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub max_output_tokens: Option<i32>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub candidate_count: Option<i32>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub seed: Option<i32>,
28    /// 是否返回 logprobs 结果。
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub response_logprobs: Option<bool>,
31    /// 每个 token 返回的候选数(0-20)。
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub logprobs: Option<i32>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub thinking_config: Option<ThinkingConfig>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub speech_config: Option<SpeechConfig>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub image_config: Option<ImageConfig>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub media_resolution: Option<MediaResolution>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub response_mime_type: Option<String>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub response_schema: Option<Schema>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub response_json_schema: Option<Value>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub response_modalities: Option<Vec<Modality>>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub stop_sequences: Option<Vec<String>>,
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub audio_timestamp: Option<bool>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub presence_penalty: Option<f32>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub frequency_penalty: Option<f32>,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub enable_enhanced_civic_answers: Option<bool>,
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub enable_affective_dialog: Option<bool>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub model_selection_config: Option<ModelSelectionConfig>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub routing_config: Option<GenerationConfigRoutingConfig>,
66}
67
68/// 安全设置。
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub struct SafetySetting {
72    pub category: HarmCategory,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub threshold: Option<HarmBlockThreshold>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub method: Option<HarmBlockMethod>,
77}
78
79/// Configuration for Model Armor integrations of prompt and responses.
80///
81/// This data type is not supported in Gemini API.
82#[derive(Debug, Clone, Serialize, Deserialize, Default)]
83#[serde(rename_all = "camelCase")]
84pub struct ModelArmorConfig {
85    /// Optional. The name of the Model Armor template to use for prompt sanitization.
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub prompt_template_name: Option<String>,
88    /// Optional. The name of the Model Armor template to use for response sanitization.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub response_template_name: Option<String>,
91}
92
93/// Thinking 配置。
94#[derive(Debug, Clone, Serialize, Deserialize, Default)]
95#[serde(rename_all = "camelCase")]
96pub struct ThinkingConfig {
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub thinking_budget: Option<i32>,
99    #[serde(skip_serializing_if = "Option::is_none")]
100    pub thinking_level: Option<ThinkingLevel>,
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub include_thoughts: Option<bool>,
103}
104
105/// 语音合成配置。
106#[derive(Debug, Clone, Serialize, Deserialize, Default)]
107#[serde(rename_all = "camelCase")]
108pub struct SpeechConfig {
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub voice_config: Option<VoiceConfig>,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub language_code: Option<String>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    pub multi_speaker_voice_config: Option<MultiSpeakerVoiceConfig>,
115    /// Forward-compatible extension fields.
116    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
117    #[serde(flatten)]
118    pub extra: HashMap<String, Value>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize, Default)]
122#[serde(rename_all = "camelCase")]
123pub struct VoiceConfig {
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub replicated_voice_config: Option<ReplicatedVoiceConfig>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub prebuilt_voice_config: Option<PrebuiltVoiceConfig>,
128    /// Forward-compatible extension fields.
129    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
130    #[serde(flatten)]
131    pub extra: HashMap<String, Value>,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, Default)]
135#[serde(rename_all = "camelCase")]
136pub struct ReplicatedVoiceConfig {
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub mime_type: Option<String>,
139    #[serde(
140        default,
141        skip_serializing_if = "Option::is_none",
142        with = "base64_serde::option"
143    )]
144    pub voice_sample_audio: Option<Vec<u8>>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize, Default)]
148#[serde(rename_all = "camelCase")]
149pub struct PrebuiltVoiceConfig {
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub voice_name: Option<String>,
152    /// Forward-compatible extension fields.
153    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
154    #[serde(flatten)]
155    pub extra: HashMap<String, Value>,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, Default)]
159#[serde(rename_all = "camelCase")]
160pub struct SpeakerVoiceConfig {
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub speaker: Option<String>,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub voice_config: Option<VoiceConfig>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize, Default)]
168#[serde(rename_all = "camelCase")]
169pub struct MultiSpeakerVoiceConfig {
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub speaker_voice_configs: Option<Vec<SpeakerVoiceConfig>>,
172}
173
174/// 图像生成配置。
175#[derive(Debug, Clone, Serialize, Deserialize, Default)]
176#[serde(rename_all = "camelCase")]
177pub struct ImageConfig {
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub aspect_ratio: Option<String>,
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub image_size: Option<String>,
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub person_generation: Option<String>,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    pub output_mime_type: Option<String>,
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub output_compression_quality: Option<i32>,
188}
189
190/// 模型选择配置。
191#[derive(Debug, Clone, Serialize, Deserialize, Default)]
192#[serde(rename_all = "camelCase")]
193pub struct ModelSelectionConfig {
194    #[serde(skip_serializing_if = "Option::is_none")]
195    pub feature_selection_preference: Option<FeatureSelectionPreference>,
196}
197
198/// 路由配置。
199#[derive(Debug, Clone, Serialize, Deserialize, Default)]
200#[serde(rename_all = "camelCase")]
201pub struct GenerationConfigRoutingConfig {
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub auto_routing_mode: Option<GenerationConfigRoutingConfigAutoRoutingMode>,
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub manual_routing_mode: Option<GenerationConfigRoutingConfigManualRoutingMode>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize, Default)]
209#[serde(rename_all = "camelCase")]
210pub struct GenerationConfigRoutingConfigAutoRoutingMode {
211    #[serde(skip_serializing_if = "Option::is_none")]
212    pub model_routing_preference: Option<String>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize, Default)]
216#[serde(rename_all = "camelCase")]
217pub struct GenerationConfigRoutingConfigManualRoutingMode {
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub model_name: Option<String>,
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use crate::enums::{MediaResolution, Modality, ThinkingLevel};
226    use crate::tool::Schema;
227
228    #[test]
229    fn generation_config_serializes_camel_case() {
230        let config = GenerationConfig {
231            temperature: Some(0.7),
232            max_output_tokens: Some(128),
233            response_mime_type: Some("application/json".into()),
234            response_modalities: Some(vec![Modality::Text, Modality::Audio]),
235            media_resolution: Some(MediaResolution::MediaResolutionHigh),
236            response_schema: Some(Schema::string()),
237            thinking_config: Some(ThinkingConfig {
238                thinking_level: Some(ThinkingLevel::High),
239                include_thoughts: Some(true),
240                thinking_budget: None,
241            }),
242            ..Default::default()
243        };
244
245        let value = serde_json::to_value(&config).unwrap();
246        assert!(value.get("maxOutputTokens").is_some());
247        assert!(value.get("responseMimeType").is_some());
248        assert!(value.get("responseModalities").is_some());
249        assert!(value.get("mediaResolution").is_some());
250        assert!(value.get("thinkingConfig").is_some());
251    }
252
253    #[test]
254    fn safety_setting_roundtrip() {
255        let setting = SafetySetting {
256            category: HarmCategory::HarmCategoryHarassment,
257            threshold: Some(HarmBlockThreshold::BlockOnlyHigh),
258            method: None,
259        };
260        let json = serde_json::to_string(&setting).unwrap();
261        let decoded: SafetySetting = serde_json::from_str(&json).unwrap();
262        assert_eq!(decoded.category, HarmCategory::HarmCategoryHarassment);
263    }
264}