Skip to main content

rust_genai/
live.rs

1//! Live API (WebSocket) support.
2
3use std::sync::{Arc, Mutex};
4
5use futures_util::{SinkExt, StreamExt};
6use reqwest::Url;
7use tokio::sync::{mpsc, oneshot};
8use tokio::time::{timeout, Duration};
9use tokio_tungstenite::connect_async;
10use tokio_tungstenite::tungstenite::client::IntoClientRequest;
11use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue};
12use tokio_tungstenite::tungstenite::Message;
13
14use rust_genai_types::config::GenerationConfig;
15use rust_genai_types::content::{Blob, Content};
16use rust_genai_types::live_types::{
17    AudioTranscriptionConfig, ContextWindowCompressionConfig, LiveClientContent, LiveClientMessage,
18    LiveClientRealtimeInput, LiveClientSetup, LiveConnectConfig, LiveSendClientContentParameters,
19    LiveSendRealtimeInputParameters, LiveSendToolResponseParameters, LiveServerMessage,
20    SessionResumptionConfig,
21};
22use rust_genai_types::tool::Tool;
23
24use crate::client::{Backend, ClientInner};
25use crate::error::{Error, Result};
26use crate::live_music::LiveMusic;
27
28#[derive(Clone)]
29pub struct Live {
30    pub(crate) inner: Arc<ClientInner>,
31}
32
33impl Live {
34    pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
35        Self { inner }
36    }
37
38    /// 连接到 Live API。
39    ///
40    /// # Errors
41    /// 当连接失败或配置无效时返回错误。
42    pub async fn connect(
43        &self,
44        model: impl Into<String>,
45        config: LiveConnectConfig,
46    ) -> Result<LiveSession> {
47        Box::pin(
48            LiveSessionBuilder::new(self.inner.clone(), model.into())
49                .with_config(config)
50                .connect(),
51        )
52        .await
53    }
54
55    /// 创建 `LiveSessionBuilder`。
56    #[must_use]
57    pub fn builder(&self, model: impl Into<String>) -> LiveSessionBuilder {
58        LiveSessionBuilder::new(self.inner.clone(), model.into())
59    }
60
61    /// 访问 Live Music API。
62    #[must_use]
63    pub fn music(&self) -> LiveMusic {
64        LiveMusic::new(self.inner.clone())
65    }
66}
67
68pub struct LiveSessionBuilder {
69    inner: Arc<ClientInner>,
70    model: String,
71    config: LiveConnectConfig,
72}
73
74impl LiveSessionBuilder {
75    pub(crate) fn new(inner: Arc<ClientInner>, model: String) -> Self {
76        Self {
77            inner,
78            model,
79            config: LiveConnectConfig::default(),
80        }
81    }
82
83    /// 设置连接配置。
84    #[must_use]
85    pub fn with_config(mut self, config: LiveConnectConfig) -> Self {
86        self.config = config;
87        self
88    }
89
90    /// 设置系统指令。
91    #[must_use]
92    pub fn with_system_instruction(mut self, instruction: impl Into<String>) -> Self {
93        self.config.system_instruction = Some(Content::text(instruction));
94        self
95    }
96
97    /// 设置工具列表。
98    #[must_use]
99    pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
100        self.config.tools = Some(tools);
101        self
102    }
103
104    /// 设置生成配置。
105    #[must_use]
106    pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
107        self.config.generation_config = Some(config);
108        self
109    }
110
111    /// 启用会话恢复(自动获取 resumption handle)。
112    #[must_use]
113    pub fn with_session_resumption(mut self) -> Self {
114        self.config.session_resumption = Some(SessionResumptionConfig {
115            handle: None,
116            transparent: None,
117        });
118        self
119    }
120
121    /// 使用指定的 resumption handle 恢复会话。
122    #[must_use]
123    pub fn with_session_resumption_handle(mut self, handle: impl Into<String>) -> Self {
124        self.config.session_resumption = Some(SessionResumptionConfig {
125            handle: Some(handle.into()),
126            transparent: None,
127        });
128        self
129    }
130
131    /// 配置上下文窗口压缩。
132    #[must_use]
133    pub fn with_context_window_compression(
134        mut self,
135        config: ContextWindowCompressionConfig,
136    ) -> Self {
137        self.config.context_window_compression = Some(config);
138        self
139    }
140
141    /// 配置输入音频转录。
142    #[must_use]
143    pub const fn with_input_audio_transcription(
144        mut self,
145        config: AudioTranscriptionConfig,
146    ) -> Self {
147        self.config.input_audio_transcription = Some(config);
148        self
149    }
150
151    /// 配置输出音频转录。
152    #[must_use]
153    pub const fn with_output_audio_transcription(
154        mut self,
155        config: AudioTranscriptionConfig,
156    ) -> Self {
157        self.config.output_audio_transcription = Some(config);
158        self
159    }
160
161    /// 连接并创建会话。
162    ///
163    /// # Errors
164    /// 当连接失败或配置无效时返回错误。
165    pub async fn connect(self) -> Result<LiveSession> {
166        connect_live_session(self.inner, self.model, self.config).await
167    }
168}
169
170/// Live 会话。
171pub struct LiveSession {
172    outgoing_tx: mpsc::UnboundedSender<LiveClientMessage>,
173    incoming_rx: mpsc::UnboundedReceiver<Result<LiveServerMessage>>,
174    shutdown_tx: Option<oneshot::Sender<()>>,
175    pub session_id: Option<String>,
176    resumption_state: Arc<Mutex<LiveSessionResumptionState>>,
177    go_away_time_left: Arc<Mutex<Option<String>>>,
178}
179
180#[derive(Debug, Clone, Default)]
181pub struct LiveSessionResumptionState {
182    pub handle: Option<String>,
183    pub resumable: Option<bool>,
184    pub last_consumed_client_message_index: Option<String>,
185}
186
187impl LiveSession {
188    /// 发送文本(turn-based)。
189    ///
190    /// # Errors
191    /// 当发送失败或连接中断时返回错误。
192    pub async fn send_text(&self, text: impl Into<String>) -> Result<()> {
193        let message = LiveClientMessage {
194            setup: None,
195            client_content: Some(LiveClientContent {
196                turns: Some(vec![Content::text(text)]),
197                turn_complete: Some(true),
198            }),
199            realtime_input: None,
200            tool_response: None,
201        };
202        self.send_async(message).await
203    }
204
205    /// 发送音频(realtime)。
206    ///
207    /// # Errors
208    /// 当发送失败或连接中断时返回错误。
209    pub async fn send_audio(&self, data: Vec<u8>, mime_type: impl Into<String>) -> Result<()> {
210        let message = LiveClientMessage {
211            setup: None,
212            client_content: None,
213            realtime_input: Some(LiveClientRealtimeInput {
214                media_chunks: None,
215                audio: Some(Blob {
216                    mime_type: mime_type.into(),
217                    data,
218                    display_name: None,
219                }),
220                audio_stream_end: None,
221                video: None,
222                text: None,
223                activity_start: None,
224                activity_end: None,
225            }),
226            tool_response: None,
227        };
228        self.send_async(message).await
229    }
230
231    /// 发送 client content。
232    ///
233    /// # Errors
234    /// 当发送失败或连接中断时返回错误。
235    pub async fn send_client_content(&self, params: LiveSendClientContentParameters) -> Result<()> {
236        let message = LiveClientMessage {
237            setup: None,
238            client_content: Some(LiveClientContent {
239                turns: params.turns,
240                turn_complete: params.turn_complete,
241            }),
242            realtime_input: None,
243            tool_response: None,
244        };
245        self.send_async(message).await
246    }
247
248    /// 发送 realtime input。
249    ///
250    /// # Errors
251    /// 当发送失败或连接中断时返回错误。
252    pub async fn send_realtime_input(&self, params: LiveSendRealtimeInputParameters) -> Result<()> {
253        let message = LiveClientMessage {
254            setup: None,
255            client_content: None,
256            realtime_input: Some(LiveClientRealtimeInput {
257                media_chunks: params.media.map(|media| vec![media]),
258                audio: params.audio,
259                audio_stream_end: params.audio_stream_end,
260                video: params.video,
261                text: params.text,
262                activity_start: params.activity_start,
263                activity_end: params.activity_end,
264            }),
265            tool_response: None,
266        };
267        self.send_async(message).await
268    }
269
270    /// 发送工具响应。
271    ///
272    /// # Errors
273    /// 当发送失败或连接中断时返回错误。
274    pub async fn send_tool_response(&self, params: LiveSendToolResponseParameters) -> Result<()> {
275        let message = LiveClientMessage {
276            setup: None,
277            client_content: None,
278            realtime_input: None,
279            tool_response: Some(rust_genai_types::live_types::LiveClientToolResponse {
280                function_responses: params.function_responses,
281            }),
282        };
283        self.send_async(message).await
284    }
285
286    /// 接收服务器消息。
287    pub async fn receive(&mut self) -> Option<Result<LiveServerMessage>> {
288        self.incoming_rx.recv().await
289    }
290
291    /// 关闭会话。
292    ///
293    /// # Errors
294    /// 当发送关闭信号失败时返回错误。
295    pub async fn close(mut self) -> Result<()> {
296        if let Some(tx) = self.shutdown_tx.take() {
297            let _ = tx.send(());
298        }
299        tokio::task::yield_now().await;
300        Ok(())
301    }
302
303    /// 获取最新的会话恢复状态。
304    pub fn resumption_state(&self) -> LiveSessionResumptionState {
305        self.resumption_state
306            .lock()
307            .unwrap_or_else(std::sync::PoisonError::into_inner)
308            .clone()
309    }
310
311    /// 获取最新的 resumption handle。
312    pub fn resumption_handle(&self) -> Option<String> {
313        self.resumption_state
314            .lock()
315            .unwrap_or_else(std::sync::PoisonError::into_inner)
316            .handle
317            .clone()
318    }
319
320    /// 获取最近一次 `GoAway` 的 `time_left`。
321    pub fn last_go_away_time_left(&self) -> Option<String> {
322        self.go_away_time_left
323            .lock()
324            .unwrap_or_else(std::sync::PoisonError::into_inner)
325            .clone()
326    }
327
328    fn send(&self, message: LiveClientMessage) -> Result<()> {
329        self.outgoing_tx
330            .send(message)
331            .map_err(|_| Error::ChannelClosed)?;
332        Ok(())
333    }
334
335    async fn send_async(&self, message: LiveClientMessage) -> Result<()> {
336        self.send(message)?;
337        tokio::task::yield_now().await;
338        Ok(())
339    }
340}
341
342async fn connect_live_session(
343    inner: Arc<ClientInner>,
344    model: String,
345    config: LiveConnectConfig,
346) -> Result<LiveSession> {
347    if config.http_options.is_some() {
348        return Err(Error::InvalidConfig {
349            message: "LiveConnectConfig.http_options is not supported yet".into(),
350        });
351    }
352
353    if inner.config.backend == Backend::VertexAi {
354        return Err(Error::InvalidConfig {
355            message: "Live API for Vertex AI is not supported yet".into(),
356        });
357    }
358
359    let api_key = inner
360        .config
361        .api_key
362        .as_ref()
363        .ok_or_else(|| Error::InvalidConfig {
364            message: "API key required for Live API".into(),
365        })?;
366
367    let (url, headers) = build_live_ws_url(
368        &inner.api_client.base_url,
369        &inner.api_client.api_version,
370        api_key,
371    )?;
372
373    let setup_timeout_ms = inner.config.http_options.timeout.unwrap_or(30_000);
374    let request = build_ws_request(&url, &headers)?;
375    let (ws_stream, _) = timeout(
376        Duration::from_millis(setup_timeout_ms),
377        connect_async(request),
378    )
379    .await
380    .map_err(|_| Error::Timeout {
381        message: format!("Timed out connecting to Live API after {setup_timeout_ms}ms"),
382    })??;
383    let (mut write, mut read) = ws_stream.split();
384
385    let setup = build_live_setup(&model, &config);
386    let setup_message = LiveClientMessage {
387        setup: Some(setup),
388        client_content: None,
389        realtime_input: None,
390        tool_response: None,
391    };
392    let payload = serde_json::to_string(&setup_message)?;
393    write.send(Message::Text(payload.into())).await?;
394
395    let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
396    let (outgoing_tx, outgoing_rx) = mpsc::unbounded_channel();
397    let (shutdown_tx, shutdown_rx) = oneshot::channel();
398    let resumption_state = Arc::new(Mutex::new(LiveSessionResumptionState::default()));
399    let go_away_time_left = Arc::new(Mutex::new(None));
400
401    let session_id = timeout(Duration::from_millis(setup_timeout_ms), async {
402        loop {
403            match read.next().await {
404                Some(Ok(message)) => match message {
405                    Message::Close(frame) => {
406                        return Err(Error::Parse {
407                            message: format!("WebSocket closed before setup_complete: {frame:?}"),
408                        })
409                    }
410                    _ => {
411                        if let Some(msg) = parse_server_message(message)? {
412                            if let Some(setup) = msg.setup_complete.as_ref() {
413                                return Ok(setup.session_id.clone());
414                            }
415                        }
416                    }
417                },
418                Some(Err(err)) => return Err(Error::WebSocket { source: err }),
419                None => {
420                    return Err(Error::Parse {
421                        message: "WebSocket closed before setup_complete".into(),
422                    })
423                }
424            }
425        }
426    })
427    .await
428    .map_err(|_| Error::Timeout {
429        message: format!(
430            "Timed out waiting for Live API setup_complete after {setup_timeout_ms}ms"
431        ),
432    })??;
433
434    tokio::spawn(message_loop(
435        write,
436        read,
437        outgoing_rx,
438        incoming_tx,
439        shutdown_rx,
440        resumption_state.clone(),
441        go_away_time_left.clone(),
442    ));
443
444    Ok(LiveSession {
445        outgoing_tx,
446        incoming_rx,
447        shutdown_tx: Some(shutdown_tx),
448        session_id,
449        resumption_state,
450        go_away_time_left,
451    })
452}
453
454fn build_live_setup(model: &str, config: &LiveConnectConfig) -> LiveClientSetup {
455    let model = normalize_model_name(model);
456    let generation_config = merge_generation_config(config);
457
458    LiveClientSetup {
459        model: Some(model),
460        generation_config,
461        system_instruction: config.system_instruction.clone(),
462        tools: config.tools.clone(),
463        realtime_input_config: config.realtime_input_config.clone(),
464        session_resumption: config.session_resumption.clone(),
465        context_window_compression: config.context_window_compression.clone(),
466        input_audio_transcription: config.input_audio_transcription.clone(),
467        output_audio_transcription: config.output_audio_transcription.clone(),
468        proactivity: config.proactivity.clone(),
469        explicit_vad_signal: config.explicit_vad_signal,
470    }
471}
472
473fn merge_generation_config(config: &LiveConnectConfig) -> Option<GenerationConfig> {
474    let mut generation_config = config.generation_config.clone().unwrap_or_default();
475    let updated = config.generation_config.is_some()
476        || config.response_modalities.is_some()
477        || config.temperature.is_some()
478        || config.top_p.is_some()
479        || config.top_k.is_some()
480        || config.max_output_tokens.is_some()
481        || config.media_resolution.is_some()
482        || config.seed.is_some()
483        || config.speech_config.is_some()
484        || config.thinking_config.is_some()
485        || config.enable_affective_dialog.is_some();
486
487    if let Some(value) = config.response_modalities.clone() {
488        generation_config.response_modalities = Some(value);
489    }
490    if let Some(value) = config.temperature {
491        generation_config.temperature = Some(value);
492    }
493    if let Some(value) = config.top_p {
494        generation_config.top_p = Some(value);
495    }
496    if let Some(value) = config.top_k {
497        let top_k_value = i16::try_from(value).unwrap_or_else(|_| {
498            if value > i32::from(i16::MAX) {
499                i16::MAX
500            } else {
501                i16::MIN
502            }
503        });
504        generation_config.top_k = Some(f32::from(top_k_value));
505    }
506    if let Some(value) = config.max_output_tokens {
507        generation_config.max_output_tokens = Some(value);
508    }
509    if let Some(value) = config.media_resolution {
510        generation_config.media_resolution = Some(value);
511    }
512    if let Some(value) = config.seed {
513        generation_config.seed = Some(value);
514    }
515    if let Some(value) = config.speech_config.clone() {
516        generation_config.speech_config = Some(value);
517    }
518    if let Some(value) = config.thinking_config.clone() {
519        generation_config.thinking_config = Some(value);
520    }
521    if let Some(value) = config.enable_affective_dialog {
522        generation_config.enable_affective_dialog = Some(value);
523    }
524
525    updated.then_some(generation_config)
526}
527
528fn build_ws_request(
529    url: &Url,
530    headers: &HeaderMap,
531) -> Result<tokio_tungstenite::tungstenite::http::Request<()>> {
532    let mut request = url
533        .as_str()
534        .into_client_request()
535        .map_err(|err| Error::Parse {
536            message: err.to_string(),
537        })?;
538    {
539        let request_headers = request.headers_mut();
540        for (key, value) in headers {
541            request_headers.insert(key, value.clone());
542        }
543    }
544    Ok(request)
545}
546
547fn build_live_ws_url(base_url: &str, api_version: &str, api_key: &str) -> Result<(Url, HeaderMap)> {
548    if api_key.starts_with("auth_tokens/") && api_version != "v1alpha" {
549        return Err(Error::InvalidConfig {
550            message: "Ephemeral tokens require v1alpha for Live API".into(),
551        });
552    }
553    let mut url = Url::parse(base_url).map_err(|err| Error::InvalidConfig {
554        message: err.to_string(),
555    })?;
556
557    let scheme = match url.scheme() {
558        "http" | "ws" => "ws",
559        _ => "wss",
560    };
561    url.set_scheme(scheme).map_err(|()| Error::InvalidConfig {
562        message: "Invalid base_url scheme".into(),
563    })?;
564
565    let base_path = url.path().trim_end_matches('/');
566    let method = if api_key.starts_with("auth_tokens/") {
567        "BidiGenerateContentConstrained"
568    } else {
569        "BidiGenerateContent"
570    };
571    let path = format!(
572        "{base_path}/ws/google.ai.generativelanguage.{api_version}.GenerativeService.{method}"
573    );
574    url.set_path(&path);
575
576    let mut headers = HeaderMap::new();
577    if api_key.starts_with("auth_tokens/") {
578        headers.insert(
579            "authorization",
580            HeaderValue::from_str(&format!("Token {api_key}")).map_err(|_| {
581                Error::InvalidConfig {
582                    message: "Invalid ephemeral token".into(),
583                }
584            })?,
585        );
586    } else {
587        headers.insert(
588            "x-goog-api-key",
589            HeaderValue::from_str(api_key).map_err(|_| Error::InvalidConfig {
590                message: "Invalid API key".into(),
591            })?,
592        );
593    }
594
595    Ok((url, headers))
596}
597
598fn normalize_model_name(model: &str) -> String {
599    if model.starts_with("models/") {
600        model.to_string()
601    } else {
602        format!("models/{model}")
603    }
604}
605
606fn parse_server_message(message: Message) -> Result<Option<LiveServerMessage>> {
607    match message {
608        Message::Text(text) => {
609            let msg = serde_json::from_str::<LiveServerMessage>(&text)?;
610            Ok(Some(msg))
611        }
612        Message::Binary(data) => {
613            let msg = serde_json::from_slice::<LiveServerMessage>(&data)?;
614            Ok(Some(msg))
615        }
616        Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => Ok(None),
617    }
618}
619
620fn update_resumption_state(
621    state: &Arc<Mutex<LiveSessionResumptionState>>,
622    message: &LiveServerMessage,
623) {
624    if let Some(update) = message.session_resumption_update.as_ref() {
625        let mut guard = state
626            .lock()
627            .unwrap_or_else(std::sync::PoisonError::into_inner);
628        if update.new_handle.is_some() || update.resumable.is_some() {
629            guard.handle.clone_from(&update.new_handle);
630        }
631        if update.resumable.is_some() {
632            guard.resumable = update.resumable;
633        }
634        if update.last_consumed_client_message_index.is_some() {
635            guard
636                .last_consumed_client_message_index
637                .clone_from(&update.last_consumed_client_message_index);
638        }
639    }
640}
641
642fn update_go_away(state: &Arc<Mutex<Option<String>>>, message: &LiveServerMessage) {
643    if let Some(go_away) = message.go_away.as_ref() {
644        let mut guard = state
645            .lock()
646            .unwrap_or_else(std::sync::PoisonError::into_inner);
647        guard.clone_from(&go_away.time_left);
648    }
649}
650
651async fn message_loop(
652    mut write: futures_util::stream::SplitSink<
653        tokio_tungstenite::WebSocketStream<
654            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
655        >,
656        Message,
657    >,
658    mut read: futures_util::stream::SplitStream<
659        tokio_tungstenite::WebSocketStream<
660            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
661        >,
662    >,
663    mut outgoing_rx: mpsc::UnboundedReceiver<LiveClientMessage>,
664    incoming_tx: mpsc::UnboundedSender<Result<LiveServerMessage>>,
665    mut shutdown_rx: oneshot::Receiver<()>,
666    resumption_state: Arc<Mutex<LiveSessionResumptionState>>,
667    go_away_time_left: Arc<Mutex<Option<String>>>,
668) {
669    loop {
670        tokio::select! {
671            Some(message) = outgoing_rx.recv() => {
672                match serde_json::to_string(&message) {
673                    Ok(payload) => {
674                        if write.send(Message::Text(payload.into())).await.is_err() {
675                            let _ = incoming_tx.send(Err(Error::ChannelClosed));
676                            break;
677                        }
678                    }
679                    Err(err) => {
680                        let _ = incoming_tx.send(Err(Error::Serialization { source: err }));
681                    }
682                }
683            }
684            message = read.next() => {
685                match message {
686                    Some(Ok(message)) => {
687                        match message {
688                            Message::Ping(payload) => {
689                                let _ = write.send(Message::Pong(payload)).await;
690                            }
691                            Message::Close(_) => break,
692                            other => match parse_server_message(other) {
693                                Ok(Some(parsed)) => {
694                                    update_resumption_state(&resumption_state, &parsed);
695                                    update_go_away(&go_away_time_left, &parsed);
696                                    let _ = incoming_tx.send(Ok(parsed));
697                                }
698                                Ok(None) => {}
699                                Err(err) => {
700                                    let _ = incoming_tx.send(Err(err));
701                                }
702                            },
703                        }
704                    }
705                    Some(Err(err)) => {
706                        let _ = incoming_tx.send(Err(Error::WebSocket { source: err }));
707                        break;
708                    }
709                    None => break,
710                }
711            }
712            _ = &mut shutdown_rx => {
713                let _ = write.send(Message::Close(None)).await;
714                break;
715            }
716        }
717    }
718}
719
720#[cfg(test)]
721mod tests {
722    use super::*;
723    use crate::test_support::test_client_inner_with_api_key;
724    use rust_genai_types::config::{SpeechConfig, ThinkingConfig};
725    use rust_genai_types::enums::{MediaResolution, Modality};
726    use rust_genai_types::live_types::{
727        LiveServerGoAway, LiveServerMessage, LiveServerSessionResumptionUpdate,
728    };
729    use tokio_tungstenite::tungstenite::Message;
730
731    #[test]
732    fn test_build_live_ws_url() {
733        let (url, headers) = build_live_ws_url(
734            "https://generativelanguage.googleapis.com/",
735            "v1beta",
736            "test-key",
737        )
738        .unwrap();
739        assert!(url.as_str().starts_with("wss://"));
740        assert_eq!(
741            url.as_str(),
742            "wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
743        );
744        assert!(headers.contains_key("x-goog-api-key"));
745    }
746
747    #[test]
748    fn test_build_live_ws_url_with_ephemeral_token() {
749        let (_url, headers) = build_live_ws_url(
750            "https://generativelanguage.googleapis.com/",
751            "v1alpha",
752            "auth_tokens/abc",
753        )
754        .unwrap();
755        assert!(headers.contains_key("authorization"));
756        assert!(!headers.contains_key("x-goog-api-key"));
757    }
758
759    #[test]
760    fn test_build_live_ws_url_invalid_key() {
761        let err = build_live_ws_url(
762            "https://generativelanguage.googleapis.com/",
763            "v1beta",
764            "bad\nkey",
765        )
766        .unwrap_err();
767        assert!(matches!(err, Error::InvalidConfig { .. }));
768    }
769
770    #[test]
771    fn test_merge_generation_config() {
772        let config = LiveConnectConfig {
773            response_modalities: Some(vec![Modality::Text]),
774            temperature: Some(0.7),
775            ..LiveConnectConfig::default()
776        };
777        let generation = merge_generation_config(&config).unwrap();
778        assert_eq!(generation.response_modalities.unwrap().len(), 1);
779        assert_eq!(generation.temperature, Some(0.7));
780    }
781
782    #[test]
783    fn test_build_live_setup_and_ws_request() {
784        let config = LiveConnectConfig {
785            response_modalities: Some(vec![Modality::Text]),
786            temperature: Some(0.5),
787            ..LiveConnectConfig::default()
788        };
789        let setup = build_live_setup("gemini-2.0-flash", &config);
790        assert_eq!(setup.model.as_deref(), Some("models/gemini-2.0-flash"));
791        assert!(setup.generation_config.is_some());
792
793        let (url, headers) =
794            build_live_ws_url("https://example.com/", "v1beta", "test-key").unwrap();
795        let request = build_ws_request(&url, &headers).unwrap();
796        assert!(request.headers().contains_key("x-goog-api-key"));
797    }
798
799    #[test]
800    fn test_live_builder_and_music_accessors() {
801        let inner = Arc::new(test_client_inner_with_api_key(
802            Backend::GeminiApi,
803            Some("key"),
804        ));
805        let live = Live::new(inner);
806        let builder = live.builder("gemini-2.0-flash");
807        assert_eq!(builder.model, "gemini-2.0-flash");
808        let _music = live.music();
809    }
810
811    #[test]
812    fn test_merge_generation_config_all_fields() {
813        let config = LiveConnectConfig {
814            response_modalities: Some(vec![Modality::Text]),
815            temperature: Some(0.7),
816            top_p: Some(0.9),
817            top_k: Some(32),
818            max_output_tokens: Some(256),
819            media_resolution: Some(MediaResolution::MediaResolutionHigh),
820            seed: Some(42),
821            speech_config: Some(SpeechConfig::default()),
822            thinking_config: Some(ThinkingConfig::default()),
823            enable_affective_dialog: Some(true),
824            ..LiveConnectConfig::default()
825        };
826        let generation = merge_generation_config(&config).unwrap();
827        assert_eq!(generation.top_p, Some(0.9));
828        assert_eq!(generation.top_k, Some(32.0));
829        assert_eq!(generation.max_output_tokens, Some(256));
830        assert_eq!(generation.seed, Some(42));
831        assert!(generation.speech_config.is_some());
832        assert!(generation.thinking_config.is_some());
833        assert_eq!(generation.enable_affective_dialog, Some(true));
834    }
835
836    #[test]
837    fn test_build_ws_request_invalid_scheme() {
838        let url = Url::parse("file:///tmp/socket").unwrap();
839        let err = build_ws_request(&url, &HeaderMap::new()).unwrap_err();
840        assert!(matches!(err, Error::Parse { .. }));
841    }
842
843    #[test]
844    fn test_build_live_ws_url_scheme_variants_and_invalid_token() {
845        let (url, _) = build_live_ws_url("ws://example.com/", "v1beta", "test-key").unwrap();
846        assert!(url.as_str().starts_with("ws://"));
847        let (url, _) = build_live_ws_url("wss://example.com/", "v1beta", "test-key").unwrap();
848        assert!(url.as_str().starts_with("wss://"));
849
850        let err =
851            build_live_ws_url("https://example.com/", "v1alpha", "auth_tokens/bad\n").unwrap_err();
852        assert!(matches!(err, Error::InvalidConfig { .. }));
853    }
854
855    #[test]
856    fn test_normalize_model_name_with_prefix() {
857        assert_eq!(
858            normalize_model_name("models/gemini-2.0-flash"),
859            "models/gemini-2.0-flash"
860        );
861    }
862
863    #[test]
864    fn test_poisoned_mutex_accessors() {
865        let state = Arc::new(Mutex::new(LiveSessionResumptionState {
866            handle: Some("handle".into()),
867            resumable: Some(true),
868            last_consumed_client_message_index: Some("idx".into()),
869        }));
870        let go_away = Arc::new(Mutex::new(Some("5s".into())));
871        let state_clone = Arc::clone(&state);
872        let go_away_clone = Arc::clone(&go_away);
873        let _ = std::thread::spawn(move || {
874            let _guard = state_clone.lock().unwrap();
875            let _guard2 = go_away_clone.lock().unwrap();
876            panic!("poison");
877        })
878        .join();
879
880        let (outgoing_tx, _outgoing_rx) = mpsc::unbounded_channel();
881        let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
882        let session = LiveSession {
883            outgoing_tx,
884            incoming_rx,
885            shutdown_tx: None,
886            session_id: None,
887            resumption_state: state,
888            go_away_time_left: go_away,
889        };
890        assert_eq!(session.resumption_handle().as_deref(), Some("handle"));
891        assert_eq!(session.last_go_away_time_left().as_deref(), Some("5s"));
892        let state = session.resumption_state();
893        assert_eq!(
894            state.last_consumed_client_message_index.as_deref(),
895            Some("idx")
896        );
897    }
898
899    #[test]
900    fn test_parse_message_and_state_updates() {
901        let message = Message::Text(
902            serde_json::to_string(&LiveServerMessage {
903                session_resumption_update: Some(LiveServerSessionResumptionUpdate {
904                    new_handle: Some("handle".to_string()),
905                    resumable: Some(true),
906                    last_consumed_client_message_index: Some("1".to_string()),
907                }),
908                go_away: Some(LiveServerGoAway {
909                    time_left: Some("5s".to_string()),
910                }),
911                ..LiveServerMessage {
912                    setup_complete: None,
913                    server_content: None,
914                    tool_call: None,
915                    tool_call_cancellation: None,
916                    usage_metadata: None,
917                    voice_activity_detection_signal: None,
918                    session_resumption_update: None,
919                    go_away: None,
920                }
921            })
922            .unwrap()
923            .into(),
924        );
925
926        let parsed = parse_server_message(message).unwrap().unwrap();
927        let state = Arc::new(Mutex::new(LiveSessionResumptionState::default()));
928        update_resumption_state(&state, &parsed);
929        let guard = state.lock().unwrap();
930        assert_eq!(guard.handle.as_deref(), Some("handle"));
931        assert_eq!(guard.resumable, Some(true));
932        drop(guard);
933
934        let go_away = Arc::new(Mutex::new(None));
935        update_go_away(&go_away, &parsed);
936        assert_eq!(*go_away.lock().unwrap(), Some("5s".to_string()));
937
938        let bin_message = Message::Binary(
939            serde_json::to_vec(&LiveServerMessage {
940                setup_complete: None,
941                server_content: None,
942                tool_call: None,
943                tool_call_cancellation: None,
944                usage_metadata: None,
945                go_away: None,
946                session_resumption_update: None,
947                voice_activity_detection_signal: None,
948            })
949            .unwrap()
950            .into(),
951        );
952        assert!(parse_server_message(bin_message).unwrap().is_some());
953    }
954
955    #[test]
956    fn test_parse_server_message_variants() {
957        assert!(parse_server_message(Message::Ping(vec![1].into()))
958            .unwrap()
959            .is_none());
960        assert!(parse_server_message(Message::Close(None))
961            .unwrap()
962            .is_none());
963        assert!(parse_server_message(Message::Text("not-json".into())).is_err());
964    }
965
966    #[test]
967    fn test_update_state_with_partial_resumption_update() {
968        let message = LiveServerMessage {
969            session_resumption_update: Some(LiveServerSessionResumptionUpdate {
970                new_handle: None,
971                resumable: None,
972                last_consumed_client_message_index: Some("2".to_string()),
973            }),
974            setup_complete: None,
975            server_content: None,
976            tool_call: None,
977            tool_call_cancellation: None,
978            usage_metadata: None,
979            voice_activity_detection_signal: None,
980            go_away: None,
981        };
982        let state = Arc::new(Mutex::new(LiveSessionResumptionState {
983            handle: Some("keep".into()),
984            resumable: Some(false),
985            last_consumed_client_message_index: None,
986        }));
987        update_resumption_state(&state, &message);
988        let guard = state.lock().unwrap();
989        assert_eq!(guard.handle.as_deref(), Some("keep"));
990        assert_eq!(guard.resumable, Some(false));
991        assert_eq!(
992            guard.last_consumed_client_message_index.as_deref(),
993            Some("2")
994        );
995        drop(guard);
996
997        let go_away = Arc::new(Mutex::new(Some("stay".to_string())));
998        update_go_away(&go_away, &message);
999        assert_eq!(*go_away.lock().unwrap(), Some("stay".to_string()));
1000    }
1001
1002    #[test]
1003    fn test_live_builder_config_chain() {
1004        let inner = Arc::new(test_client_inner_with_api_key(
1005            Backend::GeminiApi,
1006            Some("key"),
1007        ));
1008        let builder = LiveSessionBuilder::new(inner, "gemini-2.0-flash".to_string())
1009            .with_system_instruction("sys")
1010            .with_tools(vec![Tool::default()])
1011            .with_generation_config(GenerationConfig::default())
1012            .with_session_resumption()
1013            .with_context_window_compression(ContextWindowCompressionConfig {
1014                trigger_tokens: None,
1015                sliding_window: None,
1016            })
1017            .with_input_audio_transcription(AudioTranscriptionConfig::default())
1018            .with_output_audio_transcription(AudioTranscriptionConfig::default());
1019
1020        assert_eq!(builder.model, "gemini-2.0-flash");
1021        assert!(builder.config.system_instruction.is_some());
1022        assert!(builder.config.tools.is_some());
1023        assert!(builder.config.generation_config.is_some());
1024        assert!(builder.config.session_resumption.is_some());
1025        assert!(builder.config.context_window_compression.is_some());
1026        assert!(builder.config.input_audio_transcription.is_some());
1027        assert!(builder.config.output_audio_transcription.is_some());
1028
1029        let builder = builder.with_session_resumption_handle("handle");
1030        assert_eq!(
1031            builder
1032                .config
1033                .session_resumption
1034                .as_ref()
1035                .and_then(|cfg| cfg.handle.as_deref()),
1036            Some("handle")
1037        );
1038    }
1039
1040    #[tokio::test]
1041    async fn test_live_session_send_and_close() {
1042        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel();
1043        let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
1044        let (shutdown_tx, shutdown_rx) = oneshot::channel();
1045        let session = LiveSession {
1046            outgoing_tx,
1047            incoming_rx,
1048            shutdown_tx: Some(shutdown_tx),
1049            session_id: Some("session".to_string()),
1050            resumption_state: Arc::new(Mutex::new(LiveSessionResumptionState::default())),
1051            go_away_time_left: Arc::new(Mutex::new(None)),
1052        };
1053
1054        session.send_text("hi").await.unwrap();
1055        let msg = outgoing_rx.recv().await.unwrap();
1056        assert!(msg.client_content.is_some());
1057        assert!(msg.realtime_input.is_none());
1058
1059        session
1060            .send_audio(vec![1, 2, 3], "audio/pcm")
1061            .await
1062            .unwrap();
1063        let msg = outgoing_rx.recv().await.unwrap();
1064        assert!(msg.realtime_input.as_ref().unwrap().audio.is_some());
1065
1066        session
1067            .send_client_content(LiveSendClientContentParameters {
1068                turns: Some(vec![Content::text("turn")]),
1069                turn_complete: Some(false),
1070            })
1071            .await
1072            .unwrap();
1073        let msg = outgoing_rx.recv().await.unwrap();
1074        assert!(msg.client_content.is_some());
1075
1076        session
1077            .send_realtime_input(LiveSendRealtimeInputParameters {
1078                media: Some(Blob {
1079                    mime_type: "audio/pcm".to_string(),
1080                    data: vec![9],
1081                    display_name: None,
1082                }),
1083                audio: None,
1084                audio_stream_end: Some(true),
1085                video: None,
1086                text: Some("rt".to_string()),
1087                activity_start: None,
1088                activity_end: None,
1089            })
1090            .await
1091            .unwrap();
1092        let msg = outgoing_rx.recv().await.unwrap();
1093        assert!(msg.realtime_input.is_some());
1094
1095        session
1096            .send_tool_response(LiveSendToolResponseParameters {
1097                function_responses: None,
1098            })
1099            .await
1100            .unwrap();
1101        let msg = outgoing_rx.recv().await.unwrap();
1102        assert!(msg.tool_response.is_some());
1103
1104        session.close().await.unwrap();
1105        assert!(shutdown_rx.await.is_ok());
1106    }
1107
1108    #[tokio::test]
1109    async fn test_live_session_send_channel_closed() {
1110        let (outgoing_tx, outgoing_rx) = mpsc::unbounded_channel();
1111        drop(outgoing_rx);
1112        let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
1113        let session = LiveSession {
1114            outgoing_tx,
1115            incoming_rx,
1116            shutdown_tx: None,
1117            session_id: None,
1118            resumption_state: Arc::new(Mutex::new(LiveSessionResumptionState::default())),
1119            go_away_time_left: Arc::new(Mutex::new(None)),
1120        };
1121        let err = session.send_text("hi").await.unwrap_err();
1122        assert!(matches!(err, Error::ChannelClosed));
1123    }
1124
1125    #[test]
1126    fn test_live_session_state_accessors() {
1127        let (outgoing_tx, _outgoing_rx) = mpsc::unbounded_channel();
1128        let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
1129        let state = Arc::new(Mutex::new(LiveSessionResumptionState {
1130            handle: Some("h".to_string()),
1131            resumable: Some(true),
1132            last_consumed_client_message_index: Some("7".to_string()),
1133        }));
1134        let go_away = Arc::new(Mutex::new(Some("10s".to_string())));
1135        let session = LiveSession {
1136            outgoing_tx,
1137            incoming_rx,
1138            shutdown_tx: None,
1139            session_id: None,
1140            resumption_state: state,
1141            go_away_time_left: go_away,
1142        };
1143        assert_eq!(session.resumption_handle().as_deref(), Some("h"));
1144        assert_eq!(session.last_go_away_time_left().as_deref(), Some("10s"));
1145        let state = session.resumption_state();
1146        assert_eq!(
1147            state.last_consumed_client_message_index.as_deref(),
1148            Some("7")
1149        );
1150    }
1151
1152    #[tokio::test]
1153    async fn test_connect_live_session_errors() {
1154        let inner = Arc::new(test_client_inner_with_api_key(
1155            Backend::GeminiApi,
1156            Some("key"),
1157        ));
1158        let config = LiveConnectConfig {
1159            http_options: Some(rust_genai_types::http::HttpOptions::default()),
1160            ..Default::default()
1161        };
1162        let err = connect_live_session(inner, "model".to_string(), config)
1163            .await
1164            .err()
1165            .unwrap();
1166        assert!(matches!(err, Error::InvalidConfig { .. }));
1167
1168        let inner = Arc::new(test_client_inner_with_api_key(
1169            Backend::VertexAi,
1170            Some("key"),
1171        ));
1172        let err = connect_live_session(inner, "model".to_string(), LiveConnectConfig::default())
1173            .await
1174            .err()
1175            .unwrap();
1176        assert!(matches!(err, Error::InvalidConfig { .. }));
1177
1178        let inner = Arc::new(test_client_inner_with_api_key(Backend::GeminiApi, None));
1179        let err = connect_live_session(inner, "model".to_string(), LiveConnectConfig::default())
1180            .await
1181            .err()
1182            .unwrap();
1183        assert!(matches!(err, Error::InvalidConfig { .. }));
1184    }
1185
1186    #[test]
1187    fn test_build_live_ws_url_ephemeral_requires_v1alpha() {
1188        let err = build_live_ws_url(
1189            "https://generativelanguage.googleapis.com/",
1190            "v1beta",
1191            "auth_tokens/abc",
1192        )
1193        .unwrap_err();
1194        assert!(matches!(err, Error::InvalidConfig { .. }));
1195    }
1196
1197    #[test]
1198    fn test_build_live_ws_url_invalid_base_url() {
1199        let err = build_live_ws_url("://bad-url", "v1beta", "test-key").unwrap_err();
1200        assert!(matches!(err, Error::InvalidConfig { .. }));
1201    }
1202}