Skip to main content

rust_genai/
live.rs

1//! Live API (WebSocket) support.
2
3use std::sync::{Arc, Mutex};
4
5use futures_util::{SinkExt, StreamExt};
6use reqwest::Url;
7use tokio::sync::{mpsc, oneshot};
8use tokio::time::{timeout, Duration};
9use tokio_tungstenite::connect_async;
10use tokio_tungstenite::tungstenite::client::IntoClientRequest;
11use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue};
12use tokio_tungstenite::tungstenite::Message;
13
14use rust_genai_types::config::GenerationConfig;
15use rust_genai_types::content::{Blob, Content};
16use rust_genai_types::live_types::{
17    AudioTranscriptionConfig, ContextWindowCompressionConfig, LiveClientContent, LiveClientMessage,
18    LiveClientRealtimeInput, LiveClientSetup, LiveConnectConfig, LiveSendClientContentParameters,
19    LiveSendRealtimeInputParameters, LiveSendToolResponseParameters, LiveServerMessage,
20    SessionResumptionConfig,
21};
22use rust_genai_types::tool::Tool;
23
24use crate::client::{Backend, ClientInner};
25use crate::error::{Error, Result};
26use crate::live_music::LiveMusic;
27
28#[derive(Clone)]
29pub struct Live {
30    pub(crate) inner: Arc<ClientInner>,
31}
32
33impl Live {
34    pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
35        Self { inner }
36    }
37
38    /// 连接到 Live API。
39    ///
40    /// # Errors
41    /// 当连接失败或配置无效时返回错误。
42    pub async fn connect(
43        &self,
44        model: impl Into<String>,
45        config: LiveConnectConfig,
46    ) -> Result<LiveSession> {
47        Box::pin(
48            LiveSessionBuilder::new(self.inner.clone(), model.into())
49                .with_config(config)
50                .connect(),
51        )
52        .await
53    }
54
55    /// 创建 `LiveSessionBuilder`。
56    #[must_use]
57    pub fn builder(&self, model: impl Into<String>) -> LiveSessionBuilder {
58        LiveSessionBuilder::new(self.inner.clone(), model.into())
59    }
60
61    /// 访问 Live Music API。
62    #[must_use]
63    pub fn music(&self) -> LiveMusic {
64        LiveMusic::new(self.inner.clone())
65    }
66}
67
68pub struct LiveSessionBuilder {
69    inner: Arc<ClientInner>,
70    model: String,
71    config: LiveConnectConfig,
72}
73
74impl LiveSessionBuilder {
75    pub(crate) fn new(inner: Arc<ClientInner>, model: String) -> Self {
76        Self {
77            inner,
78            model,
79            config: LiveConnectConfig::default(),
80        }
81    }
82
83    /// 设置连接配置。
84    #[must_use]
85    pub fn with_config(mut self, config: LiveConnectConfig) -> Self {
86        self.config = config;
87        self
88    }
89
90    /// 设置系统指令。
91    #[must_use]
92    pub fn with_system_instruction(mut self, instruction: impl Into<String>) -> Self {
93        self.config.system_instruction = Some(Content::text(instruction));
94        self
95    }
96
97    /// 设置工具列表。
98    #[must_use]
99    pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
100        self.config.tools = Some(tools);
101        self
102    }
103
104    /// 设置生成配置。
105    #[must_use]
106    pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
107        self.config.generation_config = Some(config);
108        self
109    }
110
111    /// 启用会话恢复(自动获取 resumption handle)。
112    #[must_use]
113    pub fn with_session_resumption(mut self) -> Self {
114        self.config.session_resumption = Some(SessionResumptionConfig {
115            handle: None,
116            transparent: None,
117        });
118        self
119    }
120
121    /// 使用指定的 resumption handle 恢复会话。
122    #[must_use]
123    pub fn with_session_resumption_handle(mut self, handle: impl Into<String>) -> Self {
124        self.config.session_resumption = Some(SessionResumptionConfig {
125            handle: Some(handle.into()),
126            transparent: None,
127        });
128        self
129    }
130
131    /// 配置上下文窗口压缩。
132    #[must_use]
133    pub fn with_context_window_compression(
134        mut self,
135        config: ContextWindowCompressionConfig,
136    ) -> Self {
137        self.config.context_window_compression = Some(config);
138        self
139    }
140
141    /// 配置输入音频转录。
142    #[must_use]
143    pub const fn with_input_audio_transcription(
144        mut self,
145        config: AudioTranscriptionConfig,
146    ) -> Self {
147        self.config.input_audio_transcription = Some(config);
148        self
149    }
150
151    /// 配置输出音频转录。
152    #[must_use]
153    pub const fn with_output_audio_transcription(
154        mut self,
155        config: AudioTranscriptionConfig,
156    ) -> Self {
157        self.config.output_audio_transcription = Some(config);
158        self
159    }
160
161    /// 连接并创建会话。
162    ///
163    /// # Errors
164    /// 当连接失败或配置无效时返回错误。
165    pub async fn connect(self) -> Result<LiveSession> {
166        connect_live_session(self.inner, self.model, self.config).await
167    }
168}
169
170/// Live 会话。
171pub struct LiveSession {
172    outgoing_tx: mpsc::UnboundedSender<LiveClientMessage>,
173    incoming_rx: mpsc::UnboundedReceiver<Result<LiveServerMessage>>,
174    shutdown_tx: Option<oneshot::Sender<()>>,
175    pub session_id: Option<String>,
176    resumption_state: Arc<Mutex<LiveSessionResumptionState>>,
177    go_away_time_left: Arc<Mutex<Option<String>>>,
178}
179
180#[derive(Debug, Clone, Default)]
181pub struct LiveSessionResumptionState {
182    pub handle: Option<String>,
183    pub resumable: Option<bool>,
184    pub last_consumed_client_message_index: Option<String>,
185}
186
187impl LiveSession {
188    /// 发送文本(turn-based)。
189    ///
190    /// # Errors
191    /// 当发送失败或连接中断时返回错误。
192    pub async fn send_text(&self, text: impl Into<String>) -> Result<()> {
193        let message = LiveClientMessage {
194            setup: None,
195            client_content: Some(LiveClientContent {
196                turns: Some(vec![Content::text(text)]),
197                turn_complete: Some(true),
198            }),
199            realtime_input: None,
200            tool_response: None,
201        };
202        self.send_async(message).await
203    }
204
205    /// 发送音频(realtime)。
206    ///
207    /// # Errors
208    /// 当发送失败或连接中断时返回错误。
209    pub async fn send_audio(&self, data: Vec<u8>, mime_type: impl Into<String>) -> Result<()> {
210        let message = LiveClientMessage {
211            setup: None,
212            client_content: None,
213            realtime_input: Some(LiveClientRealtimeInput {
214                media_chunks: None,
215                audio: Some(Blob {
216                    mime_type: mime_type.into(),
217                    data,
218                    display_name: None,
219                }),
220                audio_stream_end: None,
221                video: None,
222                text: None,
223                activity_start: None,
224                activity_end: None,
225            }),
226            tool_response: None,
227        };
228        self.send_async(message).await
229    }
230
231    /// 发送 client content。
232    ///
233    /// # Errors
234    /// 当发送失败或连接中断时返回错误。
235    pub async fn send_client_content(&self, params: LiveSendClientContentParameters) -> Result<()> {
236        let message = LiveClientMessage {
237            setup: None,
238            client_content: Some(LiveClientContent {
239                turns: params.turns,
240                turn_complete: params.turn_complete,
241            }),
242            realtime_input: None,
243            tool_response: None,
244        };
245        self.send_async(message).await
246    }
247
248    /// 发送 realtime input。
249    ///
250    /// # Errors
251    /// 当发送失败或连接中断时返回错误。
252    pub async fn send_realtime_input(&self, params: LiveSendRealtimeInputParameters) -> Result<()> {
253        let message = LiveClientMessage {
254            setup: None,
255            client_content: None,
256            realtime_input: Some(LiveClientRealtimeInput {
257                media_chunks: params.media.map(|media| vec![media]),
258                audio: params.audio,
259                audio_stream_end: params.audio_stream_end,
260                video: params.video,
261                text: params.text,
262                activity_start: params.activity_start,
263                activity_end: params.activity_end,
264            }),
265            tool_response: None,
266        };
267        self.send_async(message).await
268    }
269
270    /// 发送工具响应。
271    ///
272    /// # Errors
273    /// 当发送失败或连接中断时返回错误。
274    pub async fn send_tool_response(&self, params: LiveSendToolResponseParameters) -> Result<()> {
275        let message = LiveClientMessage {
276            setup: None,
277            client_content: None,
278            realtime_input: None,
279            tool_response: Some(rust_genai_types::live_types::LiveClientToolResponse {
280                function_responses: params.function_responses,
281            }),
282        };
283        self.send_async(message).await
284    }
285
286    /// 接收服务器消息。
287    pub async fn receive(&mut self) -> Option<Result<LiveServerMessage>> {
288        self.incoming_rx.recv().await
289    }
290
291    /// 关闭会话。
292    ///
293    /// # Errors
294    /// 当发送关闭信号失败时返回错误。
295    pub async fn close(mut self) -> Result<()> {
296        if let Some(tx) = self.shutdown_tx.take() {
297            let _ = tx.send(());
298        }
299        tokio::task::yield_now().await;
300        Ok(())
301    }
302
303    /// 获取最新的会话恢复状态。
304    pub fn resumption_state(&self) -> LiveSessionResumptionState {
305        self.resumption_state
306            .lock()
307            .unwrap_or_else(std::sync::PoisonError::into_inner)
308            .clone()
309    }
310
311    /// 获取最新的 resumption handle。
312    pub fn resumption_handle(&self) -> Option<String> {
313        self.resumption_state
314            .lock()
315            .unwrap_or_else(std::sync::PoisonError::into_inner)
316            .handle
317            .clone()
318    }
319
320    /// 获取最近一次 `GoAway` 的 `time_left`。
321    pub fn last_go_away_time_left(&self) -> Option<String> {
322        self.go_away_time_left
323            .lock()
324            .unwrap_or_else(std::sync::PoisonError::into_inner)
325            .clone()
326    }
327
328    fn send(&self, message: LiveClientMessage) -> Result<()> {
329        self.outgoing_tx
330            .send(message)
331            .map_err(|_| Error::ChannelClosed)?;
332        Ok(())
333    }
334
335    async fn send_async(&self, message: LiveClientMessage) -> Result<()> {
336        self.send(message)?;
337        tokio::task::yield_now().await;
338        Ok(())
339    }
340}
341
342async fn connect_live_session(
343    inner: Arc<ClientInner>,
344    model: String,
345    config: LiveConnectConfig,
346) -> Result<LiveSession> {
347    if config.http_options.is_some() {
348        return Err(Error::InvalidConfig {
349            message: "LiveConnectConfig.http_options is not supported yet".into(),
350        });
351    }
352
353    if inner.config.backend == Backend::VertexAi {
354        return Err(Error::InvalidConfig {
355            message: "Live API for Vertex AI is not supported yet".into(),
356        });
357    }
358
359    let api_key = inner
360        .config
361        .api_key
362        .as_ref()
363        .ok_or_else(|| Error::InvalidConfig {
364            message: "API key required for Live API".into(),
365        })?;
366
367    let (url, headers) = build_live_ws_url(
368        &inner.api_client.base_url,
369        &inner.api_client.api_version,
370        api_key,
371    )?;
372
373    let setup_timeout_ms = inner.config.http_options.timeout.unwrap_or(30_000);
374    let request = build_ws_request(&url, &headers)?;
375    let (ws_stream, _) = timeout(
376        Duration::from_millis(setup_timeout_ms),
377        connect_async(request),
378    )
379    .await
380    .map_err(|_| Error::Timeout {
381        message: format!("Timed out connecting to Live API after {setup_timeout_ms}ms"),
382    })??;
383    let (mut write, mut read) = ws_stream.split();
384
385    let setup = build_live_setup(&model, &config);
386    let setup_message = LiveClientMessage {
387        setup: Some(setup),
388        client_content: None,
389        realtime_input: None,
390        tool_response: None,
391    };
392    let payload = serde_json::to_string(&setup_message)?;
393    write.send(Message::Text(payload.into())).await?;
394
395    let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
396    let (outgoing_tx, outgoing_rx) = mpsc::unbounded_channel();
397    let (shutdown_tx, shutdown_rx) = oneshot::channel();
398    let resumption_state = Arc::new(Mutex::new(LiveSessionResumptionState::default()));
399    let go_away_time_left = Arc::new(Mutex::new(None));
400
401    let session_id = timeout(Duration::from_millis(setup_timeout_ms), async {
402        loop {
403            match read.next().await {
404                Some(Ok(message)) => match message {
405                    Message::Close(frame) => {
406                        return Err(Error::Parse {
407                            message: format!("WebSocket closed before setup_complete: {frame:?}"),
408                        })
409                    }
410                    _ => {
411                        if let Some(msg) = parse_server_message(message)? {
412                            if let Some(setup) = msg.setup_complete.as_ref() {
413                                return Ok(setup.session_id.clone());
414                            }
415                        }
416                    }
417                },
418                Some(Err(err)) => return Err(Error::WebSocket { source: err }),
419                None => {
420                    return Err(Error::Parse {
421                        message: "WebSocket closed before setup_complete".into(),
422                    })
423                }
424            }
425        }
426    })
427    .await
428    .map_err(|_| Error::Timeout {
429        message: format!(
430            "Timed out waiting for Live API setup_complete after {setup_timeout_ms}ms"
431        ),
432    })??;
433
434    tokio::spawn(message_loop(
435        write,
436        read,
437        outgoing_rx,
438        incoming_tx,
439        shutdown_rx,
440        resumption_state.clone(),
441        go_away_time_left.clone(),
442    ));
443
444    Ok(LiveSession {
445        outgoing_tx,
446        incoming_rx,
447        shutdown_tx: Some(shutdown_tx),
448        session_id,
449        resumption_state,
450        go_away_time_left,
451    })
452}
453
454fn build_live_setup(model: &str, config: &LiveConnectConfig) -> LiveClientSetup {
455    let model = normalize_model_name(model);
456    let generation_config = merge_generation_config(config);
457
458    LiveClientSetup {
459        model: Some(model),
460        generation_config,
461        system_instruction: config.system_instruction.clone(),
462        tools: config.tools.clone(),
463        realtime_input_config: config.realtime_input_config.clone(),
464        session_resumption: config.session_resumption.clone(),
465        context_window_compression: config.context_window_compression.clone(),
466        input_audio_transcription: config.input_audio_transcription.clone(),
467        output_audio_transcription: config.output_audio_transcription.clone(),
468        proactivity: config.proactivity.clone(),
469        explicit_vad_signal: config.explicit_vad_signal,
470    }
471}
472
473fn merge_generation_config(config: &LiveConnectConfig) -> Option<GenerationConfig> {
474    let mut generation_config = config.generation_config.clone().unwrap_or_default();
475    let updated = config.generation_config.is_some()
476        || config.response_modalities.is_some()
477        || config.temperature.is_some()
478        || config.top_p.is_some()
479        || config.top_k.is_some()
480        || config.max_output_tokens.is_some()
481        || config.media_resolution.is_some()
482        || config.seed.is_some()
483        || config.speech_config.is_some()
484        || config.thinking_config.is_some()
485        || config.enable_affective_dialog.is_some();
486
487    if let Some(value) = config.response_modalities.clone() {
488        generation_config.response_modalities = Some(value);
489    }
490    if let Some(value) = config.temperature {
491        generation_config.temperature = Some(value);
492    }
493    if let Some(value) = config.top_p {
494        generation_config.top_p = Some(value);
495    }
496    if let Some(value) = config.top_k {
497        let top_k_value = i16::try_from(value).unwrap_or_else(|_| {
498            if value > i32::from(i16::MAX) {
499                i16::MAX
500            } else {
501                i16::MIN
502            }
503        });
504        generation_config.top_k = Some(f32::from(top_k_value));
505    }
506    if let Some(value) = config.max_output_tokens {
507        generation_config.max_output_tokens = Some(value);
508    }
509    if let Some(value) = config.media_resolution {
510        generation_config.media_resolution = Some(value);
511    }
512    if let Some(value) = config.seed {
513        generation_config.seed = Some(value);
514    }
515    if let Some(value) = config.speech_config.clone() {
516        generation_config.speech_config = Some(value);
517    }
518    if let Some(value) = config.thinking_config.clone() {
519        generation_config.thinking_config = Some(value);
520    }
521    if let Some(value) = config.enable_affective_dialog {
522        generation_config.enable_affective_dialog = Some(value);
523    }
524
525    updated.then_some(generation_config)
526}
527
528fn build_ws_request(
529    url: &Url,
530    headers: &HeaderMap,
531) -> Result<tokio_tungstenite::tungstenite::http::Request<()>> {
532    let mut request = url
533        .as_str()
534        .into_client_request()
535        .map_err(|err| Error::Parse {
536            message: err.to_string(),
537        })?;
538    {
539        let request_headers = request.headers_mut();
540        for (key, value) in headers {
541            request_headers.insert(key, value.clone());
542        }
543    }
544    Ok(request)
545}
546
547fn build_live_ws_url(base_url: &str, api_version: &str, api_key: &str) -> Result<(Url, HeaderMap)> {
548    if api_key.starts_with("auth_tokens/") && api_version != "v1alpha" {
549        return Err(Error::InvalidConfig {
550            message: "Ephemeral tokens require v1alpha for Live API".into(),
551        });
552    }
553    let mut url = Url::parse(base_url).map_err(|err| Error::InvalidConfig {
554        message: err.to_string(),
555    })?;
556
557    let scheme = match url.scheme() {
558        "http" | "ws" => "ws",
559        _ => "wss",
560    };
561    url.set_scheme(scheme).map_err(|()| Error::InvalidConfig {
562        message: "Invalid base_url scheme".into(),
563    })?;
564
565    let base_path = url.path().trim_end_matches('/');
566    let method = if api_key.starts_with("auth_tokens/") {
567        "BidiGenerateContentConstrained"
568    } else {
569        "BidiGenerateContent"
570    };
571    let path = format!(
572        "{base_path}/ws/google.ai.generativelanguage.{api_version}.GenerativeService.{method}"
573    );
574    url.set_path(&path);
575
576    let mut headers = HeaderMap::new();
577    if api_key.starts_with("auth_tokens/") {
578        headers.insert(
579            "authorization",
580            HeaderValue::from_str(&format!("Token {api_key}")).map_err(|_| {
581                Error::InvalidConfig {
582                    message: "Invalid ephemeral token".into(),
583                }
584            })?,
585        );
586    } else {
587        headers.insert(
588            "x-goog-api-key",
589            HeaderValue::from_str(api_key).map_err(|_| Error::InvalidConfig {
590                message: "Invalid API key".into(),
591            })?,
592        );
593    }
594
595    Ok((url, headers))
596}
597
598fn normalize_model_name(model: &str) -> String {
599    if model.starts_with("models/") {
600        model.to_string()
601    } else {
602        format!("models/{model}")
603    }
604}
605
606fn parse_server_message(message: Message) -> Result<Option<LiveServerMessage>> {
607    match message {
608        Message::Text(text) => {
609            let msg = serde_json::from_str::<LiveServerMessage>(&text)?;
610            Ok(Some(msg))
611        }
612        Message::Binary(data) => {
613            let msg = serde_json::from_slice::<LiveServerMessage>(&data)?;
614            Ok(Some(msg))
615        }
616        Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => Ok(None),
617    }
618}
619
620fn update_resumption_state(
621    state: &Arc<Mutex<LiveSessionResumptionState>>,
622    message: &LiveServerMessage,
623) {
624    if let Some(update) = message.session_resumption_update.as_ref() {
625        let mut guard = state
626            .lock()
627            .unwrap_or_else(std::sync::PoisonError::into_inner);
628        if update.new_handle.is_some() || update.resumable.is_some() {
629            guard.handle.clone_from(&update.new_handle);
630        }
631        if update.resumable.is_some() {
632            guard.resumable = update.resumable;
633        }
634        if update.last_consumed_client_message_index.is_some() {
635            guard
636                .last_consumed_client_message_index
637                .clone_from(&update.last_consumed_client_message_index);
638        }
639    }
640}
641
642fn update_go_away(state: &Arc<Mutex<Option<String>>>, message: &LiveServerMessage) {
643    if let Some(go_away) = message.go_away.as_ref() {
644        let mut guard = state
645            .lock()
646            .unwrap_or_else(std::sync::PoisonError::into_inner);
647        guard.clone_from(&go_away.time_left);
648    }
649}
650
651async fn message_loop(
652    mut write: futures_util::stream::SplitSink<
653        tokio_tungstenite::WebSocketStream<
654            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
655        >,
656        Message,
657    >,
658    mut read: futures_util::stream::SplitStream<
659        tokio_tungstenite::WebSocketStream<
660            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
661        >,
662    >,
663    mut outgoing_rx: mpsc::UnboundedReceiver<LiveClientMessage>,
664    incoming_tx: mpsc::UnboundedSender<Result<LiveServerMessage>>,
665    mut shutdown_rx: oneshot::Receiver<()>,
666    resumption_state: Arc<Mutex<LiveSessionResumptionState>>,
667    go_away_time_left: Arc<Mutex<Option<String>>>,
668) {
669    loop {
670        tokio::select! {
671            Some(message) = outgoing_rx.recv() => {
672                match serde_json::to_string(&message) {
673                    Ok(payload) => {
674                        if write.send(Message::Text(payload.into())).await.is_err() {
675                            let _ = incoming_tx.send(Err(Error::ChannelClosed));
676                            break;
677                        }
678                    }
679                    Err(err) => {
680                        let _ = incoming_tx.send(Err(Error::Serialization { source: err }));
681                    }
682                }
683            }
684            message = read.next() => {
685                match message {
686                    Some(Ok(message)) => {
687                        match message {
688                            Message::Ping(payload) => {
689                                let _ = write.send(Message::Pong(payload)).await;
690                            }
691                            Message::Close(_) => break,
692                            other => match parse_server_message(other) {
693                                Ok(Some(parsed)) => {
694                                    update_resumption_state(&resumption_state, &parsed);
695                                    update_go_away(&go_away_time_left, &parsed);
696                                    let _ = incoming_tx.send(Ok(parsed));
697                                }
698                                Ok(None) => {}
699                                Err(err) => {
700                                    let _ = incoming_tx.send(Err(err));
701                                }
702                            },
703                        }
704                    }
705                    Some(Err(err)) => {
706                        let _ = incoming_tx.send(Err(Error::WebSocket { source: err }));
707                        break;
708                    }
709                    None => break,
710                }
711            }
712            _ = &mut shutdown_rx => {
713                let _ = write.send(Message::Close(None)).await;
714                break;
715            }
716        }
717    }
718}
719
720#[cfg(test)]
721mod tests {
722    use super::*;
723    use crate::test_support::test_client_inner_with_api_key;
724    use rust_genai_types::config::{SpeechConfig, ThinkingConfig};
725    use rust_genai_types::enums::{MediaResolution, Modality};
726    use rust_genai_types::live_types::{
727        LiveServerGoAway, LiveServerMessage, LiveServerSessionResumptionUpdate,
728    };
729    use tokio_tungstenite::tungstenite::Message;
730
731    #[test]
732    fn test_build_live_ws_url() {
733        let (url, headers) = build_live_ws_url(
734            "https://generativelanguage.googleapis.com/",
735            "v1beta",
736            "test-key",
737        )
738        .unwrap();
739        assert!(url.as_str().starts_with("wss://"));
740        assert_eq!(
741            url.as_str(),
742            "wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
743        );
744        assert!(headers.contains_key("x-goog-api-key"));
745    }
746
747    #[test]
748    fn test_build_live_ws_url_with_ephemeral_token() {
749        let (_url, headers) = build_live_ws_url(
750            "https://generativelanguage.googleapis.com/",
751            "v1alpha",
752            "auth_tokens/abc",
753        )
754        .unwrap();
755        assert!(headers.contains_key("authorization"));
756        assert!(!headers.contains_key("x-goog-api-key"));
757    }
758
759    #[test]
760    fn test_build_live_ws_url_invalid_key() {
761        let err = build_live_ws_url(
762            "https://generativelanguage.googleapis.com/",
763            "v1beta",
764            "bad\nkey",
765        )
766        .unwrap_err();
767        assert!(matches!(err, Error::InvalidConfig { .. }));
768    }
769
770    #[test]
771    fn test_merge_generation_config() {
772        let config = LiveConnectConfig {
773            response_modalities: Some(vec![Modality::Text]),
774            temperature: Some(0.7),
775            ..LiveConnectConfig::default()
776        };
777        let generation = merge_generation_config(&config).unwrap();
778        assert_eq!(generation.response_modalities.unwrap().len(), 1);
779        assert_eq!(generation.temperature, Some(0.7));
780    }
781
782    #[test]
783    fn test_build_live_setup_and_ws_request() {
784        let config = LiveConnectConfig {
785            response_modalities: Some(vec![Modality::Text]),
786            temperature: Some(0.5),
787            ..LiveConnectConfig::default()
788        };
789        let setup = build_live_setup("gemini-3.1-flash-live-preview", &config);
790        assert_eq!(
791            setup.model.as_deref(),
792            Some("models/gemini-3.1-flash-live-preview")
793        );
794        assert!(setup.generation_config.is_some());
795
796        let (url, headers) =
797            build_live_ws_url("https://example.com/", "v1beta", "test-key").unwrap();
798        let request = build_ws_request(&url, &headers).unwrap();
799        assert!(request.headers().contains_key("x-goog-api-key"));
800    }
801
802    #[test]
803    fn test_live_builder_and_music_accessors() {
804        let inner = Arc::new(test_client_inner_with_api_key(
805            Backend::GeminiApi,
806            Some("key"),
807        ));
808        let live = Live::new(inner);
809        let builder = live.builder("gemini-3.1-flash-live-preview");
810        assert_eq!(builder.model, "gemini-3.1-flash-live-preview");
811        let _music = live.music();
812    }
813
814    #[test]
815    fn test_merge_generation_config_all_fields() {
816        let config = LiveConnectConfig {
817            response_modalities: Some(vec![Modality::Text]),
818            temperature: Some(0.7),
819            top_p: Some(0.9),
820            top_k: Some(32),
821            max_output_tokens: Some(256),
822            media_resolution: Some(MediaResolution::MediaResolutionHigh),
823            seed: Some(42),
824            speech_config: Some(SpeechConfig::default()),
825            thinking_config: Some(ThinkingConfig::default()),
826            enable_affective_dialog: Some(true),
827            ..LiveConnectConfig::default()
828        };
829        let generation = merge_generation_config(&config).unwrap();
830        assert_eq!(generation.top_p, Some(0.9));
831        assert_eq!(generation.top_k, Some(32.0));
832        assert_eq!(generation.max_output_tokens, Some(256));
833        assert_eq!(generation.seed, Some(42));
834        assert!(generation.speech_config.is_some());
835        assert!(generation.thinking_config.is_some());
836        assert_eq!(generation.enable_affective_dialog, Some(true));
837    }
838
839    #[test]
840    fn test_build_ws_request_invalid_scheme() {
841        let url = Url::parse("file:///tmp/socket").unwrap();
842        let err = build_ws_request(&url, &HeaderMap::new()).unwrap_err();
843        assert!(matches!(err, Error::Parse { .. }));
844    }
845
846    #[test]
847    fn test_build_live_ws_url_scheme_variants_and_invalid_token() {
848        let (url, _) = build_live_ws_url("ws://example.com/", "v1beta", "test-key").unwrap();
849        assert!(url.as_str().starts_with("ws://"));
850        let (url, _) = build_live_ws_url("wss://example.com/", "v1beta", "test-key").unwrap();
851        assert!(url.as_str().starts_with("wss://"));
852
853        let err =
854            build_live_ws_url("https://example.com/", "v1alpha", "auth_tokens/bad\n").unwrap_err();
855        assert!(matches!(err, Error::InvalidConfig { .. }));
856    }
857
858    #[test]
859    fn test_normalize_model_name_with_prefix() {
860        assert_eq!(
861            normalize_model_name("models/gemini-3.1-flash-live-preview"),
862            "models/gemini-3.1-flash-live-preview"
863        );
864    }
865
866    #[test]
867    fn test_poisoned_mutex_accessors() {
868        let state = Arc::new(Mutex::new(LiveSessionResumptionState {
869            handle: Some("handle".into()),
870            resumable: Some(true),
871            last_consumed_client_message_index: Some("idx".into()),
872        }));
873        let go_away = Arc::new(Mutex::new(Some("5s".into())));
874        let state_clone = Arc::clone(&state);
875        let go_away_clone = Arc::clone(&go_away);
876        let _ = std::thread::spawn(move || {
877            let _guard = state_clone.lock().unwrap();
878            let _guard2 = go_away_clone.lock().unwrap();
879            panic!("poison");
880        })
881        .join();
882
883        let (outgoing_tx, _outgoing_rx) = mpsc::unbounded_channel();
884        let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
885        let session = LiveSession {
886            outgoing_tx,
887            incoming_rx,
888            shutdown_tx: None,
889            session_id: None,
890            resumption_state: state,
891            go_away_time_left: go_away,
892        };
893        assert_eq!(session.resumption_handle().as_deref(), Some("handle"));
894        assert_eq!(session.last_go_away_time_left().as_deref(), Some("5s"));
895        let state = session.resumption_state();
896        assert_eq!(
897            state.last_consumed_client_message_index.as_deref(),
898            Some("idx")
899        );
900    }
901
902    #[test]
903    fn test_parse_message_and_state_updates() {
904        let message = Message::Text(
905            serde_json::to_string(&LiveServerMessage {
906                session_resumption_update: Some(LiveServerSessionResumptionUpdate {
907                    new_handle: Some("handle".to_string()),
908                    resumable: Some(true),
909                    last_consumed_client_message_index: Some("1".to_string()),
910                }),
911                go_away: Some(LiveServerGoAway {
912                    time_left: Some("5s".to_string()),
913                }),
914                ..LiveServerMessage {
915                    setup_complete: None,
916                    server_content: None,
917                    tool_call: None,
918                    tool_call_cancellation: None,
919                    usage_metadata: None,
920                    voice_activity_detection_signal: None,
921                    session_resumption_update: None,
922                    go_away: None,
923                }
924            })
925            .unwrap()
926            .into(),
927        );
928
929        let parsed = parse_server_message(message).unwrap().unwrap();
930        let state = Arc::new(Mutex::new(LiveSessionResumptionState::default()));
931        update_resumption_state(&state, &parsed);
932        let guard = state.lock().unwrap();
933        assert_eq!(guard.handle.as_deref(), Some("handle"));
934        assert_eq!(guard.resumable, Some(true));
935        drop(guard);
936
937        let go_away = Arc::new(Mutex::new(None));
938        update_go_away(&go_away, &parsed);
939        assert_eq!(*go_away.lock().unwrap(), Some("5s".to_string()));
940
941        let bin_message = Message::Binary(
942            serde_json::to_vec(&LiveServerMessage {
943                setup_complete: None,
944                server_content: None,
945                tool_call: None,
946                tool_call_cancellation: None,
947                usage_metadata: None,
948                go_away: None,
949                session_resumption_update: None,
950                voice_activity_detection_signal: None,
951            })
952            .unwrap()
953            .into(),
954        );
955        assert!(parse_server_message(bin_message).unwrap().is_some());
956    }
957
958    #[test]
959    fn test_parse_server_message_variants() {
960        assert!(parse_server_message(Message::Ping(vec![1].into()))
961            .unwrap()
962            .is_none());
963        assert!(parse_server_message(Message::Close(None))
964            .unwrap()
965            .is_none());
966        assert!(parse_server_message(Message::Text("not-json".into())).is_err());
967    }
968
969    #[test]
970    fn test_update_state_with_partial_resumption_update() {
971        let message = LiveServerMessage {
972            session_resumption_update: Some(LiveServerSessionResumptionUpdate {
973                new_handle: None,
974                resumable: None,
975                last_consumed_client_message_index: Some("2".to_string()),
976            }),
977            setup_complete: None,
978            server_content: None,
979            tool_call: None,
980            tool_call_cancellation: None,
981            usage_metadata: None,
982            voice_activity_detection_signal: None,
983            go_away: None,
984        };
985        let state = Arc::new(Mutex::new(LiveSessionResumptionState {
986            handle: Some("keep".into()),
987            resumable: Some(false),
988            last_consumed_client_message_index: None,
989        }));
990        update_resumption_state(&state, &message);
991        let guard = state.lock().unwrap();
992        assert_eq!(guard.handle.as_deref(), Some("keep"));
993        assert_eq!(guard.resumable, Some(false));
994        assert_eq!(
995            guard.last_consumed_client_message_index.as_deref(),
996            Some("2")
997        );
998        drop(guard);
999
1000        let go_away = Arc::new(Mutex::new(Some("stay".to_string())));
1001        update_go_away(&go_away, &message);
1002        assert_eq!(*go_away.lock().unwrap(), Some("stay".to_string()));
1003    }
1004
1005    #[test]
1006    fn test_live_builder_config_chain() {
1007        let inner = Arc::new(test_client_inner_with_api_key(
1008            Backend::GeminiApi,
1009            Some("key"),
1010        ));
1011        let builder = LiveSessionBuilder::new(inner, "gemini-3.1-flash-live-preview".to_string())
1012            .with_system_instruction("sys")
1013            .with_tools(vec![Tool::default()])
1014            .with_generation_config(GenerationConfig::default())
1015            .with_session_resumption()
1016            .with_context_window_compression(ContextWindowCompressionConfig {
1017                trigger_tokens: None,
1018                sliding_window: None,
1019            })
1020            .with_input_audio_transcription(AudioTranscriptionConfig::default())
1021            .with_output_audio_transcription(AudioTranscriptionConfig::default());
1022
1023        assert_eq!(builder.model, "gemini-3.1-flash-live-preview");
1024        assert!(builder.config.system_instruction.is_some());
1025        assert!(builder.config.tools.is_some());
1026        assert!(builder.config.generation_config.is_some());
1027        assert!(builder.config.session_resumption.is_some());
1028        assert!(builder.config.context_window_compression.is_some());
1029        assert!(builder.config.input_audio_transcription.is_some());
1030        assert!(builder.config.output_audio_transcription.is_some());
1031
1032        let builder = builder.with_session_resumption_handle("handle");
1033        assert_eq!(
1034            builder
1035                .config
1036                .session_resumption
1037                .as_ref()
1038                .and_then(|cfg| cfg.handle.as_deref()),
1039            Some("handle")
1040        );
1041    }
1042
1043    #[tokio::test]
1044    async fn test_live_session_send_and_close() {
1045        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel();
1046        let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
1047        let (shutdown_tx, shutdown_rx) = oneshot::channel();
1048        let session = LiveSession {
1049            outgoing_tx,
1050            incoming_rx,
1051            shutdown_tx: Some(shutdown_tx),
1052            session_id: Some("session".to_string()),
1053            resumption_state: Arc::new(Mutex::new(LiveSessionResumptionState::default())),
1054            go_away_time_left: Arc::new(Mutex::new(None)),
1055        };
1056
1057        session.send_text("hi").await.unwrap();
1058        let msg = outgoing_rx.recv().await.unwrap();
1059        assert!(msg.client_content.is_some());
1060        assert!(msg.realtime_input.is_none());
1061
1062        session
1063            .send_audio(vec![1, 2, 3], "audio/pcm")
1064            .await
1065            .unwrap();
1066        let msg = outgoing_rx.recv().await.unwrap();
1067        assert!(msg.realtime_input.as_ref().unwrap().audio.is_some());
1068
1069        session
1070            .send_client_content(LiveSendClientContentParameters {
1071                turns: Some(vec![Content::text("turn")]),
1072                turn_complete: Some(false),
1073            })
1074            .await
1075            .unwrap();
1076        let msg = outgoing_rx.recv().await.unwrap();
1077        assert!(msg.client_content.is_some());
1078
1079        session
1080            .send_realtime_input(LiveSendRealtimeInputParameters {
1081                media: Some(Blob {
1082                    mime_type: "audio/pcm".to_string(),
1083                    data: vec![9],
1084                    display_name: None,
1085                }),
1086                audio: None,
1087                audio_stream_end: Some(true),
1088                video: None,
1089                text: Some("rt".to_string()),
1090                activity_start: None,
1091                activity_end: None,
1092            })
1093            .await
1094            .unwrap();
1095        let msg = outgoing_rx.recv().await.unwrap();
1096        assert!(msg.realtime_input.is_some());
1097
1098        session
1099            .send_tool_response(LiveSendToolResponseParameters {
1100                function_responses: None,
1101            })
1102            .await
1103            .unwrap();
1104        let msg = outgoing_rx.recv().await.unwrap();
1105        assert!(msg.tool_response.is_some());
1106
1107        session.close().await.unwrap();
1108        assert!(shutdown_rx.await.is_ok());
1109    }
1110
1111    #[tokio::test]
1112    async fn test_live_session_send_channel_closed() {
1113        let (outgoing_tx, outgoing_rx) = mpsc::unbounded_channel();
1114        drop(outgoing_rx);
1115        let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
1116        let session = LiveSession {
1117            outgoing_tx,
1118            incoming_rx,
1119            shutdown_tx: None,
1120            session_id: None,
1121            resumption_state: Arc::new(Mutex::new(LiveSessionResumptionState::default())),
1122            go_away_time_left: Arc::new(Mutex::new(None)),
1123        };
1124        let err = session.send_text("hi").await.unwrap_err();
1125        assert!(matches!(err, Error::ChannelClosed));
1126    }
1127
1128    #[test]
1129    fn test_live_session_state_accessors() {
1130        let (outgoing_tx, _outgoing_rx) = mpsc::unbounded_channel();
1131        let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
1132        let state = Arc::new(Mutex::new(LiveSessionResumptionState {
1133            handle: Some("h".to_string()),
1134            resumable: Some(true),
1135            last_consumed_client_message_index: Some("7".to_string()),
1136        }));
1137        let go_away = Arc::new(Mutex::new(Some("10s".to_string())));
1138        let session = LiveSession {
1139            outgoing_tx,
1140            incoming_rx,
1141            shutdown_tx: None,
1142            session_id: None,
1143            resumption_state: state,
1144            go_away_time_left: go_away,
1145        };
1146        assert_eq!(session.resumption_handle().as_deref(), Some("h"));
1147        assert_eq!(session.last_go_away_time_left().as_deref(), Some("10s"));
1148        let state = session.resumption_state();
1149        assert_eq!(
1150            state.last_consumed_client_message_index.as_deref(),
1151            Some("7")
1152        );
1153    }
1154
1155    #[tokio::test]
1156    async fn test_connect_live_session_errors() {
1157        let inner = Arc::new(test_client_inner_with_api_key(
1158            Backend::GeminiApi,
1159            Some("key"),
1160        ));
1161        let config = LiveConnectConfig {
1162            http_options: Some(rust_genai_types::http::HttpOptions::default()),
1163            ..Default::default()
1164        };
1165        let err = connect_live_session(inner, "model".to_string(), config)
1166            .await
1167            .err()
1168            .unwrap();
1169        assert!(matches!(err, Error::InvalidConfig { .. }));
1170
1171        let inner = Arc::new(test_client_inner_with_api_key(
1172            Backend::VertexAi,
1173            Some("key"),
1174        ));
1175        let err = connect_live_session(inner, "model".to_string(), LiveConnectConfig::default())
1176            .await
1177            .err()
1178            .unwrap();
1179        assert!(matches!(err, Error::InvalidConfig { .. }));
1180
1181        let inner = Arc::new(test_client_inner_with_api_key(Backend::GeminiApi, None));
1182        let err = connect_live_session(inner, "model".to_string(), LiveConnectConfig::default())
1183            .await
1184            .err()
1185            .unwrap();
1186        assert!(matches!(err, Error::InvalidConfig { .. }));
1187    }
1188
1189    #[test]
1190    fn test_build_live_ws_url_ephemeral_requires_v1alpha() {
1191        let err = build_live_ws_url(
1192            "https://generativelanguage.googleapis.com/",
1193            "v1beta",
1194            "auth_tokens/abc",
1195        )
1196        .unwrap_err();
1197        assert!(matches!(err, Error::InvalidConfig { .. }));
1198    }
1199
1200    #[test]
1201    fn test_build_live_ws_url_invalid_base_url() {
1202        let err = build_live_ws_url("://bad-url", "v1beta", "test-key").unwrap_err();
1203        assert!(matches!(err, Error::InvalidConfig { .. }));
1204    }
1205}