Skip to main content

vona_gemini_live/
lib.rs

1use base64::Engine as _;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use thiserror::Error;
5use vona_core::{
6    AudioInputFrame, RealtimeVoiceCapabilities, RealtimeVoiceInput, RealtimeVoiceModelFamily,
7    RealtimeVoiceOutput, RealtimeVoiceSessionConfig,
8};
9
10pub const DEFAULT_API_BASE: &str = "https://generativelanguage.googleapis.com";
11pub const DEFAULT_API_VERSION: &str = "v1alpha";
12pub const DEFAULT_MODEL: &str = "gemini-2.5-flash-native-audio-preview-12-2025";
13pub const DEFAULT_VOICE: &str = "Kore";
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct GeminiLiveConfig {
17    pub api_base: String,
18    pub api_key: Option<String>,
19    pub api_version: String,
20    pub model: String,
21    pub voice: String,
22    pub input_sample_rate_hz: u32,
23}
24
25impl Default for GeminiLiveConfig {
26    fn default() -> Self {
27        Self {
28            api_base: DEFAULT_API_BASE.to_string(),
29            api_key: None,
30            api_version: DEFAULT_API_VERSION.to_string(),
31            model: DEFAULT_MODEL.to_string(),
32            voice: DEFAULT_VOICE.to_string(),
33            input_sample_rate_hz: 16_000,
34        }
35    }
36}
37
38impl GeminiLiveConfig {
39    pub fn from_env() -> Self {
40        Self {
41            api_base: std::env::var("GEMINI_API_BASE")
42                .unwrap_or_else(|_| DEFAULT_API_BASE.to_string()),
43            api_key: std::env::var("GEMINI_API_KEY")
44                .ok()
45                .filter(|value| !value.is_empty()),
46            api_version: std::env::var("GEMINI_LIVE_API_VERSION")
47                .unwrap_or_else(|_| DEFAULT_API_VERSION.to_string()),
48            model: std::env::var("GEMINI_LIVE_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.to_string()),
49            voice: std::env::var("GEMINI_LIVE_VOICE").unwrap_or_else(|_| DEFAULT_VOICE.to_string()),
50            input_sample_rate_hz: std::env::var("GEMINI_LIVE_INPUT_SAMPLE_RATE")
51                .ok()
52                .and_then(|value| value.parse().ok())
53                .unwrap_or(16_000),
54        }
55    }
56
57    pub fn websocket_url(&self) -> String {
58        let base = self
59            .api_base
60            .trim_end_matches('/')
61            .replace("https://", "wss://")
62            .replace("http://", "ws://");
63        format!(
64            "{base}/ws/google.ai.generativelanguage.{}/GenerativeService.BidiGenerateContent",
65            self.api_version
66        )
67    }
68
69    pub fn session_config(&self, session_id: impl Into<String>) -> RealtimeVoiceSessionConfig {
70        RealtimeVoiceSessionConfig {
71            session_id: session_id.into(),
72            input_sample_rate_hz: self.input_sample_rate_hz,
73            output_sample_rate_hz: 24_000,
74            channels: 1,
75            model_family: RealtimeVoiceModelFamily::HostedRealtimeApi {
76                provider: "gemini".to_string(),
77                model: self.model.clone(),
78            },
79            metadata: json!({
80                "voice": self.voice,
81                "api_version": self.api_version,
82            }),
83        }
84    }
85
86    pub fn capabilities(&self) -> RealtimeVoiceCapabilities {
87        RealtimeVoiceCapabilities {
88            supports_full_duplex: true,
89            supports_streaming_audio_input: true,
90            supports_streaming_audio_output: true,
91            supports_tool_calls: true,
92            supports_interruption: true,
93            supports_context_injection: true,
94            is_hosted_service: true,
95            max_input_chunk_ms: None,
96        }
97    }
98}
99
100#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
101pub struct GeminiLiveClientMessage {
102    #[serde(flatten)]
103    pub payload: Value,
104}
105
106#[derive(Debug, Clone, Error, PartialEq, Eq)]
107pub enum GeminiLiveMappingError {
108    #[error("Gemini Live does not accept Vona event: {0}")]
109    UnsupportedInput(String),
110    #[error("Gemini Live server message is missing required field: {0}")]
111    MissingField(&'static str),
112}
113
114pub fn setup_message(config: &GeminiLiveConfig) -> GeminiLiveClientMessage {
115    GeminiLiveClientMessage {
116        payload: json!({
117            "setup": {
118                "model": format!("models/{}", config.model),
119                "generationConfig": {
120                    "responseModalities": ["AUDIO"],
121                    "speechConfig": {
122                        "voiceConfig": {
123                            "prebuiltVoiceConfig": { "voiceName": config.voice }
124                        }
125                    }
126                }
127            }
128        }),
129    }
130}
131
132pub fn input_to_client_message(
133    input: RealtimeVoiceInput,
134    sample_rate_hz: u32,
135) -> Result<GeminiLiveClientMessage, GeminiLiveMappingError> {
136    match input {
137        RealtimeVoiceInput::Audio(frame) => Ok(audio_message(&frame, sample_rate_hz)),
138        RealtimeVoiceInput::Text { text } => Ok(GeminiLiveClientMessage {
139            payload: json!({
140                "clientContent": {
141                    "turns": [{ "role": "user", "parts": [{ "text": text }] }],
142                    "turnComplete": true
143                }
144            }),
145        }),
146        RealtimeVoiceInput::Control(vona_core::RealtimeVoiceControl::CommitInput) => {
147            Ok(GeminiLiveClientMessage {
148                payload: json!({ "clientContent": { "turnComplete": true } }),
149            })
150        }
151        RealtimeVoiceInput::Control(control) => Err(GeminiLiveMappingError::UnsupportedInput(
152            format!("{control:?}"),
153        )),
154        RealtimeVoiceInput::ToolResult(event) => Ok(GeminiLiveClientMessage {
155            payload: json!({
156                "toolResponse": {
157                    "functionResponses": [{
158                        "name": event.source,
159                        "response": event.payload,
160                    }]
161                }
162            }),
163        }),
164    }
165}
166
167pub fn audio_message(frame: &AudioInputFrame, sample_rate_hz: u32) -> GeminiLiveClientMessage {
168    GeminiLiveClientMessage {
169        payload: json!({
170            "realtimeInput": {
171                "mediaChunks": [{
172                    "mimeType": format!("audio/pcm;rate={sample_rate_hz}"),
173                    "data": base64::engine::general_purpose::STANDARD.encode(samples_to_pcm16_le(&frame.samples)),
174                }]
175            }
176        }),
177    }
178}
179
180pub fn server_message_to_output(
181    message: &Value,
182) -> Result<Option<RealtimeVoiceOutput>, GeminiLiveMappingError> {
183    let Some(parts) = message
184        .pointer("/serverContent/modelTurn/parts")
185        .and_then(Value::as_array)
186    else {
187        return Ok(None);
188    };
189
190    for part in parts {
191        if let Some(data) = part.pointer("/inlineData/data").and_then(Value::as_str) {
192            let pcm = base64::engine::general_purpose::STANDARD
193                .decode(data)
194                .map_err(|_| GeminiLiveMappingError::MissingField("inlineData.data"))?;
195            return Ok(Some(RealtimeVoiceOutput::Audio(
196                vona_core::AudioOutputFrame {
197                    sequence: 0,
198                    sample_rate_hz: 24_000,
199                    channels: 1,
200                    samples: pcm16_le_to_samples(&pcm),
201                    is_filler: false,
202                },
203            )));
204        }
205        if let Some(text) = part.get("text").and_then(Value::as_str) {
206            return Ok(Some(RealtimeVoiceOutput::TranscriptFragment {
207                text: text.to_string(),
208                final_fragment: false,
209            }));
210        }
211    }
212    Ok(None)
213}
214
215fn samples_to_pcm16_le(samples: &[f32]) -> Vec<u8> {
216    samples
217        .iter()
218        .flat_map(|sample| {
219            let sample = sample.clamp(-1.0, 1.0);
220            let pcm = if sample < 0.0 {
221                (sample * 32768.0).round() as i16
222            } else {
223                (sample * 32767.0).round() as i16
224            };
225            pcm.to_le_bytes()
226        })
227        .collect()
228}
229
230fn pcm16_le_to_samples(bytes: &[u8]) -> Vec<f32> {
231    bytes
232        .chunks_exact(2)
233        .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / 32768.0)
234        .collect()
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn setup_message_uses_native_audio_model_and_voice() {
243        let message = setup_message(&GeminiLiveConfig::default());
244        assert_eq!(
245            message.payload["setup"]["model"],
246            "models/gemini-2.5-flash-native-audio-preview-12-2025"
247        );
248        assert_eq!(
249            message.payload["setup"]["generationConfig"]["speechConfig"]["voiceConfig"]["prebuiltVoiceConfig"]
250                ["voiceName"],
251            "Kore"
252        );
253    }
254
255    #[test]
256    fn audio_message_includes_pcm_rate_mime_type() {
257        let message = audio_message(
258            &AudioInputFrame {
259                sequence: 1,
260                sample_rate_hz: 16_000,
261                channels: 1,
262                samples: vec![0.0, 1.0, -1.0],
263            },
264            16_000,
265        );
266        assert_eq!(
267            message.payload["realtimeInput"]["mediaChunks"][0]["mimeType"],
268            "audio/pcm;rate=16000"
269        );
270        assert_eq!(
271            message.payload["realtimeInput"]["mediaChunks"][0]["data"],
272            "AAD/fwCA"
273        );
274    }
275
276    #[test]
277    fn server_inline_audio_decodes_to_vona_output() {
278        let message = json!({
279            "serverContent": {
280                "modelTurn": {
281                    "parts": [{ "inlineData": { "mimeType": "audio/pcm;rate=24000", "data": "AAD/fwCA" } }]
282                }
283            }
284        });
285        let output = server_message_to_output(&message).unwrap().unwrap();
286        assert_eq!(
287            output,
288            RealtimeVoiceOutput::Audio(vona_core::AudioOutputFrame {
289                sequence: 0,
290                sample_rate_hz: 24_000,
291                channels: 1,
292                samples: vec![0.0, 32767.0 / 32768.0, -1.0],
293                is_filler: false,
294            })
295        );
296    }
297}