1use std::sync::{Arc, Mutex};
4
5use futures_util::{SinkExt, StreamExt};
6use reqwest::Url;
7use tokio::sync::{mpsc, oneshot};
8use tokio::time::{timeout, Duration};
9use tokio_tungstenite::connect_async;
10use tokio_tungstenite::tungstenite::client::IntoClientRequest;
11use tokio_tungstenite::tungstenite::http::{HeaderMap, HeaderValue};
12use tokio_tungstenite::tungstenite::Message;
13
14use rust_genai_types::config::GenerationConfig;
15use rust_genai_types::content::{Blob, Content};
16use rust_genai_types::live_types::{
17 AudioTranscriptionConfig, ContextWindowCompressionConfig, LiveClientContent, LiveClientMessage,
18 LiveClientRealtimeInput, LiveClientSetup, LiveConnectConfig, LiveSendClientContentParameters,
19 LiveSendRealtimeInputParameters, LiveSendToolResponseParameters, LiveServerMessage,
20 SessionResumptionConfig,
21};
22use rust_genai_types::tool::Tool;
23
24use crate::client::{Backend, ClientInner};
25use crate::error::{Error, Result};
26use crate::live_music::LiveMusic;
27
28#[derive(Clone)]
29pub struct Live {
30 pub(crate) inner: Arc<ClientInner>,
31}
32
33impl Live {
34 pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
35 Self { inner }
36 }
37
38 pub async fn connect(
43 &self,
44 model: impl Into<String>,
45 config: LiveConnectConfig,
46 ) -> Result<LiveSession> {
47 Box::pin(
48 LiveSessionBuilder::new(self.inner.clone(), model.into())
49 .with_config(config)
50 .connect(),
51 )
52 .await
53 }
54
55 #[must_use]
57 pub fn builder(&self, model: impl Into<String>) -> LiveSessionBuilder {
58 LiveSessionBuilder::new(self.inner.clone(), model.into())
59 }
60
61 #[must_use]
63 pub fn music(&self) -> LiveMusic {
64 LiveMusic::new(self.inner.clone())
65 }
66}
67
68pub struct LiveSessionBuilder {
69 inner: Arc<ClientInner>,
70 model: String,
71 config: LiveConnectConfig,
72}
73
74impl LiveSessionBuilder {
75 pub(crate) fn new(inner: Arc<ClientInner>, model: String) -> Self {
76 Self {
77 inner,
78 model,
79 config: LiveConnectConfig::default(),
80 }
81 }
82
83 #[must_use]
85 pub fn with_config(mut self, config: LiveConnectConfig) -> Self {
86 self.config = config;
87 self
88 }
89
90 #[must_use]
92 pub fn with_system_instruction(mut self, instruction: impl Into<String>) -> Self {
93 self.config.system_instruction = Some(Content::text(instruction));
94 self
95 }
96
97 #[must_use]
99 pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
100 self.config.tools = Some(tools);
101 self
102 }
103
104 #[must_use]
106 pub fn with_generation_config(mut self, config: GenerationConfig) -> Self {
107 self.config.generation_config = Some(config);
108 self
109 }
110
111 #[must_use]
113 pub fn with_session_resumption(mut self) -> Self {
114 self.config.session_resumption = Some(SessionResumptionConfig {
115 handle: None,
116 transparent: None,
117 });
118 self
119 }
120
121 #[must_use]
123 pub fn with_session_resumption_handle(mut self, handle: impl Into<String>) -> Self {
124 self.config.session_resumption = Some(SessionResumptionConfig {
125 handle: Some(handle.into()),
126 transparent: None,
127 });
128 self
129 }
130
131 #[must_use]
133 pub fn with_context_window_compression(
134 mut self,
135 config: ContextWindowCompressionConfig,
136 ) -> Self {
137 self.config.context_window_compression = Some(config);
138 self
139 }
140
141 #[must_use]
143 pub const fn with_input_audio_transcription(
144 mut self,
145 config: AudioTranscriptionConfig,
146 ) -> Self {
147 self.config.input_audio_transcription = Some(config);
148 self
149 }
150
151 #[must_use]
153 pub const fn with_output_audio_transcription(
154 mut self,
155 config: AudioTranscriptionConfig,
156 ) -> Self {
157 self.config.output_audio_transcription = Some(config);
158 self
159 }
160
161 pub async fn connect(self) -> Result<LiveSession> {
166 connect_live_session(self.inner, self.model, self.config).await
167 }
168}
169
170pub struct LiveSession {
172 outgoing_tx: mpsc::UnboundedSender<LiveClientMessage>,
173 incoming_rx: mpsc::UnboundedReceiver<Result<LiveServerMessage>>,
174 shutdown_tx: Option<oneshot::Sender<()>>,
175 pub session_id: Option<String>,
176 resumption_state: Arc<Mutex<LiveSessionResumptionState>>,
177 go_away_time_left: Arc<Mutex<Option<String>>>,
178}
179
180#[derive(Debug, Clone, Default)]
181pub struct LiveSessionResumptionState {
182 pub handle: Option<String>,
183 pub resumable: Option<bool>,
184 pub last_consumed_client_message_index: Option<String>,
185}
186
187impl LiveSession {
188 pub async fn send_text(&self, text: impl Into<String>) -> Result<()> {
193 let message = LiveClientMessage {
194 setup: None,
195 client_content: Some(LiveClientContent {
196 turns: Some(vec![Content::text(text)]),
197 turn_complete: Some(true),
198 }),
199 realtime_input: None,
200 tool_response: None,
201 };
202 self.send_async(message).await
203 }
204
205 pub async fn send_audio(&self, data: Vec<u8>, mime_type: impl Into<String>) -> Result<()> {
210 let message = LiveClientMessage {
211 setup: None,
212 client_content: None,
213 realtime_input: Some(LiveClientRealtimeInput {
214 media_chunks: None,
215 audio: Some(Blob {
216 mime_type: mime_type.into(),
217 data,
218 display_name: None,
219 }),
220 audio_stream_end: None,
221 video: None,
222 text: None,
223 activity_start: None,
224 activity_end: None,
225 }),
226 tool_response: None,
227 };
228 self.send_async(message).await
229 }
230
231 pub async fn send_client_content(&self, params: LiveSendClientContentParameters) -> Result<()> {
236 let message = LiveClientMessage {
237 setup: None,
238 client_content: Some(LiveClientContent {
239 turns: params.turns,
240 turn_complete: params.turn_complete,
241 }),
242 realtime_input: None,
243 tool_response: None,
244 };
245 self.send_async(message).await
246 }
247
248 pub async fn send_realtime_input(&self, params: LiveSendRealtimeInputParameters) -> Result<()> {
253 let message = LiveClientMessage {
254 setup: None,
255 client_content: None,
256 realtime_input: Some(LiveClientRealtimeInput {
257 media_chunks: params.media.map(|media| vec![media]),
258 audio: params.audio,
259 audio_stream_end: params.audio_stream_end,
260 video: params.video,
261 text: params.text,
262 activity_start: params.activity_start,
263 activity_end: params.activity_end,
264 }),
265 tool_response: None,
266 };
267 self.send_async(message).await
268 }
269
270 pub async fn send_tool_response(&self, params: LiveSendToolResponseParameters) -> Result<()> {
275 let message = LiveClientMessage {
276 setup: None,
277 client_content: None,
278 realtime_input: None,
279 tool_response: Some(rust_genai_types::live_types::LiveClientToolResponse {
280 function_responses: params.function_responses,
281 }),
282 };
283 self.send_async(message).await
284 }
285
286 pub async fn receive(&mut self) -> Option<Result<LiveServerMessage>> {
288 self.incoming_rx.recv().await
289 }
290
291 pub async fn close(mut self) -> Result<()> {
296 if let Some(tx) = self.shutdown_tx.take() {
297 let _ = tx.send(());
298 }
299 tokio::task::yield_now().await;
300 Ok(())
301 }
302
303 pub fn resumption_state(&self) -> LiveSessionResumptionState {
305 self.resumption_state
306 .lock()
307 .unwrap_or_else(std::sync::PoisonError::into_inner)
308 .clone()
309 }
310
311 pub fn resumption_handle(&self) -> Option<String> {
313 self.resumption_state
314 .lock()
315 .unwrap_or_else(std::sync::PoisonError::into_inner)
316 .handle
317 .clone()
318 }
319
320 pub fn last_go_away_time_left(&self) -> Option<String> {
322 self.go_away_time_left
323 .lock()
324 .unwrap_or_else(std::sync::PoisonError::into_inner)
325 .clone()
326 }
327
328 fn send(&self, message: LiveClientMessage) -> Result<()> {
329 self.outgoing_tx
330 .send(message)
331 .map_err(|_| Error::ChannelClosed)?;
332 Ok(())
333 }
334
335 async fn send_async(&self, message: LiveClientMessage) -> Result<()> {
336 self.send(message)?;
337 tokio::task::yield_now().await;
338 Ok(())
339 }
340}
341
342async fn connect_live_session(
343 inner: Arc<ClientInner>,
344 model: String,
345 config: LiveConnectConfig,
346) -> Result<LiveSession> {
347 if config.http_options.is_some() {
348 return Err(Error::InvalidConfig {
349 message: "LiveConnectConfig.http_options is not supported yet".into(),
350 });
351 }
352
353 if inner.config.backend == Backend::VertexAi {
354 return Err(Error::InvalidConfig {
355 message: "Live API for Vertex AI is not supported yet".into(),
356 });
357 }
358
359 let api_key = inner
360 .config
361 .api_key
362 .as_ref()
363 .ok_or_else(|| Error::InvalidConfig {
364 message: "API key required for Live API".into(),
365 })?;
366
367 let (url, headers) = build_live_ws_url(
368 &inner.api_client.base_url,
369 &inner.api_client.api_version,
370 api_key,
371 )?;
372
373 let setup_timeout_ms = inner.config.http_options.timeout.unwrap_or(30_000);
374 let request = build_ws_request(&url, &headers)?;
375 let (ws_stream, _) = timeout(
376 Duration::from_millis(setup_timeout_ms),
377 connect_async(request),
378 )
379 .await
380 .map_err(|_| Error::Timeout {
381 message: format!("Timed out connecting to Live API after {setup_timeout_ms}ms"),
382 })??;
383 let (mut write, mut read) = ws_stream.split();
384
385 let setup = build_live_setup(&model, &config);
386 let setup_message = LiveClientMessage {
387 setup: Some(setup),
388 client_content: None,
389 realtime_input: None,
390 tool_response: None,
391 };
392 let payload = serde_json::to_string(&setup_message)?;
393 write.send(Message::Text(payload.into())).await?;
394
395 let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
396 let (outgoing_tx, outgoing_rx) = mpsc::unbounded_channel();
397 let (shutdown_tx, shutdown_rx) = oneshot::channel();
398 let resumption_state = Arc::new(Mutex::new(LiveSessionResumptionState::default()));
399 let go_away_time_left = Arc::new(Mutex::new(None));
400
401 let session_id = timeout(Duration::from_millis(setup_timeout_ms), async {
402 loop {
403 match read.next().await {
404 Some(Ok(message)) => match message {
405 Message::Close(frame) => {
406 return Err(Error::Parse {
407 message: format!("WebSocket closed before setup_complete: {frame:?}"),
408 })
409 }
410 _ => {
411 if let Some(msg) = parse_server_message(message)? {
412 if let Some(setup) = msg.setup_complete.as_ref() {
413 return Ok(setup.session_id.clone());
414 }
415 }
416 }
417 },
418 Some(Err(err)) => return Err(Error::WebSocket { source: err }),
419 None => {
420 return Err(Error::Parse {
421 message: "WebSocket closed before setup_complete".into(),
422 })
423 }
424 }
425 }
426 })
427 .await
428 .map_err(|_| Error::Timeout {
429 message: format!(
430 "Timed out waiting for Live API setup_complete after {setup_timeout_ms}ms"
431 ),
432 })??;
433
434 tokio::spawn(message_loop(
435 write,
436 read,
437 outgoing_rx,
438 incoming_tx,
439 shutdown_rx,
440 resumption_state.clone(),
441 go_away_time_left.clone(),
442 ));
443
444 Ok(LiveSession {
445 outgoing_tx,
446 incoming_rx,
447 shutdown_tx: Some(shutdown_tx),
448 session_id,
449 resumption_state,
450 go_away_time_left,
451 })
452}
453
454fn build_live_setup(model: &str, config: &LiveConnectConfig) -> LiveClientSetup {
455 let model = normalize_model_name(model);
456 let generation_config = merge_generation_config(config);
457
458 LiveClientSetup {
459 model: Some(model),
460 generation_config,
461 system_instruction: config.system_instruction.clone(),
462 tools: config.tools.clone(),
463 realtime_input_config: config.realtime_input_config.clone(),
464 session_resumption: config.session_resumption.clone(),
465 context_window_compression: config.context_window_compression.clone(),
466 input_audio_transcription: config.input_audio_transcription.clone(),
467 output_audio_transcription: config.output_audio_transcription.clone(),
468 proactivity: config.proactivity.clone(),
469 explicit_vad_signal: config.explicit_vad_signal,
470 }
471}
472
473fn merge_generation_config(config: &LiveConnectConfig) -> Option<GenerationConfig> {
474 let mut generation_config = config.generation_config.clone().unwrap_or_default();
475 let updated = config.generation_config.is_some()
476 || config.response_modalities.is_some()
477 || config.temperature.is_some()
478 || config.top_p.is_some()
479 || config.top_k.is_some()
480 || config.max_output_tokens.is_some()
481 || config.media_resolution.is_some()
482 || config.seed.is_some()
483 || config.speech_config.is_some()
484 || config.thinking_config.is_some()
485 || config.enable_affective_dialog.is_some();
486
487 if let Some(value) = config.response_modalities.clone() {
488 generation_config.response_modalities = Some(value);
489 }
490 if let Some(value) = config.temperature {
491 generation_config.temperature = Some(value);
492 }
493 if let Some(value) = config.top_p {
494 generation_config.top_p = Some(value);
495 }
496 if let Some(value) = config.top_k {
497 let top_k_value = i16::try_from(value).unwrap_or_else(|_| {
498 if value > i32::from(i16::MAX) {
499 i16::MAX
500 } else {
501 i16::MIN
502 }
503 });
504 generation_config.top_k = Some(f32::from(top_k_value));
505 }
506 if let Some(value) = config.max_output_tokens {
507 generation_config.max_output_tokens = Some(value);
508 }
509 if let Some(value) = config.media_resolution {
510 generation_config.media_resolution = Some(value);
511 }
512 if let Some(value) = config.seed {
513 generation_config.seed = Some(value);
514 }
515 if let Some(value) = config.speech_config.clone() {
516 generation_config.speech_config = Some(value);
517 }
518 if let Some(value) = config.thinking_config.clone() {
519 generation_config.thinking_config = Some(value);
520 }
521 if let Some(value) = config.enable_affective_dialog {
522 generation_config.enable_affective_dialog = Some(value);
523 }
524
525 updated.then_some(generation_config)
526}
527
528fn build_ws_request(
529 url: &Url,
530 headers: &HeaderMap,
531) -> Result<tokio_tungstenite::tungstenite::http::Request<()>> {
532 let mut request = url
533 .as_str()
534 .into_client_request()
535 .map_err(|err| Error::Parse {
536 message: err.to_string(),
537 })?;
538 {
539 let request_headers = request.headers_mut();
540 for (key, value) in headers {
541 request_headers.insert(key, value.clone());
542 }
543 }
544 Ok(request)
545}
546
547fn build_live_ws_url(base_url: &str, api_version: &str, api_key: &str) -> Result<(Url, HeaderMap)> {
548 if api_key.starts_with("auth_tokens/") && api_version != "v1alpha" {
549 return Err(Error::InvalidConfig {
550 message: "Ephemeral tokens require v1alpha for Live API".into(),
551 });
552 }
553 let mut url = Url::parse(base_url).map_err(|err| Error::InvalidConfig {
554 message: err.to_string(),
555 })?;
556
557 let scheme = match url.scheme() {
558 "http" | "ws" => "ws",
559 _ => "wss",
560 };
561 url.set_scheme(scheme).map_err(|()| Error::InvalidConfig {
562 message: "Invalid base_url scheme".into(),
563 })?;
564
565 let base_path = url.path().trim_end_matches('/');
566 let method = if api_key.starts_with("auth_tokens/") {
567 "BidiGenerateContentConstrained"
568 } else {
569 "BidiGenerateContent"
570 };
571 let path = format!(
572 "{base_path}/ws/google.ai.generativelanguage.{api_version}.GenerativeService.{method}"
573 );
574 url.set_path(&path);
575
576 let mut headers = HeaderMap::new();
577 if api_key.starts_with("auth_tokens/") {
578 headers.insert(
579 "authorization",
580 HeaderValue::from_str(&format!("Token {api_key}")).map_err(|_| {
581 Error::InvalidConfig {
582 message: "Invalid ephemeral token".into(),
583 }
584 })?,
585 );
586 } else {
587 headers.insert(
588 "x-goog-api-key",
589 HeaderValue::from_str(api_key).map_err(|_| Error::InvalidConfig {
590 message: "Invalid API key".into(),
591 })?,
592 );
593 }
594
595 Ok((url, headers))
596}
597
598fn normalize_model_name(model: &str) -> String {
599 if model.starts_with("models/") {
600 model.to_string()
601 } else {
602 format!("models/{model}")
603 }
604}
605
606fn parse_server_message(message: Message) -> Result<Option<LiveServerMessage>> {
607 match message {
608 Message::Text(text) => {
609 let msg = serde_json::from_str::<LiveServerMessage>(&text)?;
610 Ok(Some(msg))
611 }
612 Message::Binary(data) => {
613 let msg = serde_json::from_slice::<LiveServerMessage>(&data)?;
614 Ok(Some(msg))
615 }
616 Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => Ok(None),
617 }
618}
619
620fn update_resumption_state(
621 state: &Arc<Mutex<LiveSessionResumptionState>>,
622 message: &LiveServerMessage,
623) {
624 if let Some(update) = message.session_resumption_update.as_ref() {
625 let mut guard = state
626 .lock()
627 .unwrap_or_else(std::sync::PoisonError::into_inner);
628 if update.new_handle.is_some() || update.resumable.is_some() {
629 guard.handle.clone_from(&update.new_handle);
630 }
631 if update.resumable.is_some() {
632 guard.resumable = update.resumable;
633 }
634 if update.last_consumed_client_message_index.is_some() {
635 guard
636 .last_consumed_client_message_index
637 .clone_from(&update.last_consumed_client_message_index);
638 }
639 }
640}
641
642fn update_go_away(state: &Arc<Mutex<Option<String>>>, message: &LiveServerMessage) {
643 if let Some(go_away) = message.go_away.as_ref() {
644 let mut guard = state
645 .lock()
646 .unwrap_or_else(std::sync::PoisonError::into_inner);
647 guard.clone_from(&go_away.time_left);
648 }
649}
650
651async fn message_loop(
652 mut write: futures_util::stream::SplitSink<
653 tokio_tungstenite::WebSocketStream<
654 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
655 >,
656 Message,
657 >,
658 mut read: futures_util::stream::SplitStream<
659 tokio_tungstenite::WebSocketStream<
660 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
661 >,
662 >,
663 mut outgoing_rx: mpsc::UnboundedReceiver<LiveClientMessage>,
664 incoming_tx: mpsc::UnboundedSender<Result<LiveServerMessage>>,
665 mut shutdown_rx: oneshot::Receiver<()>,
666 resumption_state: Arc<Mutex<LiveSessionResumptionState>>,
667 go_away_time_left: Arc<Mutex<Option<String>>>,
668) {
669 loop {
670 tokio::select! {
671 Some(message) = outgoing_rx.recv() => {
672 match serde_json::to_string(&message) {
673 Ok(payload) => {
674 if write.send(Message::Text(payload.into())).await.is_err() {
675 let _ = incoming_tx.send(Err(Error::ChannelClosed));
676 break;
677 }
678 }
679 Err(err) => {
680 let _ = incoming_tx.send(Err(Error::Serialization { source: err }));
681 }
682 }
683 }
684 message = read.next() => {
685 match message {
686 Some(Ok(message)) => {
687 match message {
688 Message::Ping(payload) => {
689 let _ = write.send(Message::Pong(payload)).await;
690 }
691 Message::Close(_) => break,
692 other => match parse_server_message(other) {
693 Ok(Some(parsed)) => {
694 update_resumption_state(&resumption_state, &parsed);
695 update_go_away(&go_away_time_left, &parsed);
696 let _ = incoming_tx.send(Ok(parsed));
697 }
698 Ok(None) => {}
699 Err(err) => {
700 let _ = incoming_tx.send(Err(err));
701 }
702 },
703 }
704 }
705 Some(Err(err)) => {
706 let _ = incoming_tx.send(Err(Error::WebSocket { source: err }));
707 break;
708 }
709 None => break,
710 }
711 }
712 _ = &mut shutdown_rx => {
713 let _ = write.send(Message::Close(None)).await;
714 break;
715 }
716 }
717 }
718}
719
720#[cfg(test)]
721mod tests {
722 use super::*;
723 use crate::test_support::test_client_inner_with_api_key;
724 use rust_genai_types::config::{SpeechConfig, ThinkingConfig};
725 use rust_genai_types::enums::{MediaResolution, Modality};
726 use rust_genai_types::live_types::{
727 LiveServerGoAway, LiveServerMessage, LiveServerSessionResumptionUpdate,
728 };
729 use tokio_tungstenite::tungstenite::Message;
730
731 #[test]
732 fn test_build_live_ws_url() {
733 let (url, headers) = build_live_ws_url(
734 "https://generativelanguage.googleapis.com/",
735 "v1beta",
736 "test-key",
737 )
738 .unwrap();
739 assert!(url.as_str().starts_with("wss://"));
740 assert_eq!(
741 url.as_str(),
742 "wss://generativelanguage.googleapis.com/ws/google.ai.generativelanguage.v1beta.GenerativeService.BidiGenerateContent"
743 );
744 assert!(headers.contains_key("x-goog-api-key"));
745 }
746
747 #[test]
748 fn test_build_live_ws_url_with_ephemeral_token() {
749 let (_url, headers) = build_live_ws_url(
750 "https://generativelanguage.googleapis.com/",
751 "v1alpha",
752 "auth_tokens/abc",
753 )
754 .unwrap();
755 assert!(headers.contains_key("authorization"));
756 assert!(!headers.contains_key("x-goog-api-key"));
757 }
758
759 #[test]
760 fn test_build_live_ws_url_invalid_key() {
761 let err = build_live_ws_url(
762 "https://generativelanguage.googleapis.com/",
763 "v1beta",
764 "bad\nkey",
765 )
766 .unwrap_err();
767 assert!(matches!(err, Error::InvalidConfig { .. }));
768 }
769
770 #[test]
771 fn test_merge_generation_config() {
772 let config = LiveConnectConfig {
773 response_modalities: Some(vec![Modality::Text]),
774 temperature: Some(0.7),
775 ..LiveConnectConfig::default()
776 };
777 let generation = merge_generation_config(&config).unwrap();
778 assert_eq!(generation.response_modalities.unwrap().len(), 1);
779 assert_eq!(generation.temperature, Some(0.7));
780 }
781
782 #[test]
783 fn test_build_live_setup_and_ws_request() {
784 let config = LiveConnectConfig {
785 response_modalities: Some(vec![Modality::Text]),
786 temperature: Some(0.5),
787 ..LiveConnectConfig::default()
788 };
789 let setup = build_live_setup("gemini-2.0-flash", &config);
790 assert_eq!(setup.model.as_deref(), Some("models/gemini-2.0-flash"));
791 assert!(setup.generation_config.is_some());
792
793 let (url, headers) =
794 build_live_ws_url("https://example.com/", "v1beta", "test-key").unwrap();
795 let request = build_ws_request(&url, &headers).unwrap();
796 assert!(request.headers().contains_key("x-goog-api-key"));
797 }
798
799 #[test]
800 fn test_live_builder_and_music_accessors() {
801 let inner = Arc::new(test_client_inner_with_api_key(
802 Backend::GeminiApi,
803 Some("key"),
804 ));
805 let live = Live::new(inner);
806 let builder = live.builder("gemini-2.0-flash");
807 assert_eq!(builder.model, "gemini-2.0-flash");
808 let _music = live.music();
809 }
810
811 #[test]
812 fn test_merge_generation_config_all_fields() {
813 let config = LiveConnectConfig {
814 response_modalities: Some(vec![Modality::Text]),
815 temperature: Some(0.7),
816 top_p: Some(0.9),
817 top_k: Some(32),
818 max_output_tokens: Some(256),
819 media_resolution: Some(MediaResolution::MediaResolutionHigh),
820 seed: Some(42),
821 speech_config: Some(SpeechConfig::default()),
822 thinking_config: Some(ThinkingConfig::default()),
823 enable_affective_dialog: Some(true),
824 ..LiveConnectConfig::default()
825 };
826 let generation = merge_generation_config(&config).unwrap();
827 assert_eq!(generation.top_p, Some(0.9));
828 assert_eq!(generation.top_k, Some(32.0));
829 assert_eq!(generation.max_output_tokens, Some(256));
830 assert_eq!(generation.seed, Some(42));
831 assert!(generation.speech_config.is_some());
832 assert!(generation.thinking_config.is_some());
833 assert_eq!(generation.enable_affective_dialog, Some(true));
834 }
835
836 #[test]
837 fn test_build_ws_request_invalid_scheme() {
838 let url = Url::parse("file:///tmp/socket").unwrap();
839 let err = build_ws_request(&url, &HeaderMap::new()).unwrap_err();
840 assert!(matches!(err, Error::Parse { .. }));
841 }
842
843 #[test]
844 fn test_build_live_ws_url_scheme_variants_and_invalid_token() {
845 let (url, _) = build_live_ws_url("ws://example.com/", "v1beta", "test-key").unwrap();
846 assert!(url.as_str().starts_with("ws://"));
847 let (url, _) = build_live_ws_url("wss://example.com/", "v1beta", "test-key").unwrap();
848 assert!(url.as_str().starts_with("wss://"));
849
850 let err =
851 build_live_ws_url("https://example.com/", "v1alpha", "auth_tokens/bad\n").unwrap_err();
852 assert!(matches!(err, Error::InvalidConfig { .. }));
853 }
854
855 #[test]
856 fn test_normalize_model_name_with_prefix() {
857 assert_eq!(
858 normalize_model_name("models/gemini-2.0-flash"),
859 "models/gemini-2.0-flash"
860 );
861 }
862
863 #[test]
864 fn test_poisoned_mutex_accessors() {
865 let state = Arc::new(Mutex::new(LiveSessionResumptionState {
866 handle: Some("handle".into()),
867 resumable: Some(true),
868 last_consumed_client_message_index: Some("idx".into()),
869 }));
870 let go_away = Arc::new(Mutex::new(Some("5s".into())));
871 let state_clone = Arc::clone(&state);
872 let go_away_clone = Arc::clone(&go_away);
873 let _ = std::thread::spawn(move || {
874 let _guard = state_clone.lock().unwrap();
875 let _guard2 = go_away_clone.lock().unwrap();
876 panic!("poison");
877 })
878 .join();
879
880 let (outgoing_tx, _outgoing_rx) = mpsc::unbounded_channel();
881 let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
882 let session = LiveSession {
883 outgoing_tx,
884 incoming_rx,
885 shutdown_tx: None,
886 session_id: None,
887 resumption_state: state,
888 go_away_time_left: go_away,
889 };
890 assert_eq!(session.resumption_handle().as_deref(), Some("handle"));
891 assert_eq!(session.last_go_away_time_left().as_deref(), Some("5s"));
892 let state = session.resumption_state();
893 assert_eq!(
894 state.last_consumed_client_message_index.as_deref(),
895 Some("idx")
896 );
897 }
898
899 #[test]
900 fn test_parse_message_and_state_updates() {
901 let message = Message::Text(
902 serde_json::to_string(&LiveServerMessage {
903 session_resumption_update: Some(LiveServerSessionResumptionUpdate {
904 new_handle: Some("handle".to_string()),
905 resumable: Some(true),
906 last_consumed_client_message_index: Some("1".to_string()),
907 }),
908 go_away: Some(LiveServerGoAway {
909 time_left: Some("5s".to_string()),
910 }),
911 ..LiveServerMessage {
912 setup_complete: None,
913 server_content: None,
914 tool_call: None,
915 tool_call_cancellation: None,
916 usage_metadata: None,
917 voice_activity_detection_signal: None,
918 session_resumption_update: None,
919 go_away: None,
920 }
921 })
922 .unwrap()
923 .into(),
924 );
925
926 let parsed = parse_server_message(message).unwrap().unwrap();
927 let state = Arc::new(Mutex::new(LiveSessionResumptionState::default()));
928 update_resumption_state(&state, &parsed);
929 let guard = state.lock().unwrap();
930 assert_eq!(guard.handle.as_deref(), Some("handle"));
931 assert_eq!(guard.resumable, Some(true));
932 drop(guard);
933
934 let go_away = Arc::new(Mutex::new(None));
935 update_go_away(&go_away, &parsed);
936 assert_eq!(*go_away.lock().unwrap(), Some("5s".to_string()));
937
938 let bin_message = Message::Binary(
939 serde_json::to_vec(&LiveServerMessage {
940 setup_complete: None,
941 server_content: None,
942 tool_call: None,
943 tool_call_cancellation: None,
944 usage_metadata: None,
945 go_away: None,
946 session_resumption_update: None,
947 voice_activity_detection_signal: None,
948 })
949 .unwrap()
950 .into(),
951 );
952 assert!(parse_server_message(bin_message).unwrap().is_some());
953 }
954
955 #[test]
956 fn test_parse_server_message_variants() {
957 assert!(parse_server_message(Message::Ping(vec![1].into()))
958 .unwrap()
959 .is_none());
960 assert!(parse_server_message(Message::Close(None))
961 .unwrap()
962 .is_none());
963 assert!(parse_server_message(Message::Text("not-json".into())).is_err());
964 }
965
966 #[test]
967 fn test_update_state_with_partial_resumption_update() {
968 let message = LiveServerMessage {
969 session_resumption_update: Some(LiveServerSessionResumptionUpdate {
970 new_handle: None,
971 resumable: None,
972 last_consumed_client_message_index: Some("2".to_string()),
973 }),
974 setup_complete: None,
975 server_content: None,
976 tool_call: None,
977 tool_call_cancellation: None,
978 usage_metadata: None,
979 voice_activity_detection_signal: None,
980 go_away: None,
981 };
982 let state = Arc::new(Mutex::new(LiveSessionResumptionState {
983 handle: Some("keep".into()),
984 resumable: Some(false),
985 last_consumed_client_message_index: None,
986 }));
987 update_resumption_state(&state, &message);
988 let guard = state.lock().unwrap();
989 assert_eq!(guard.handle.as_deref(), Some("keep"));
990 assert_eq!(guard.resumable, Some(false));
991 assert_eq!(
992 guard.last_consumed_client_message_index.as_deref(),
993 Some("2")
994 );
995 drop(guard);
996
997 let go_away = Arc::new(Mutex::new(Some("stay".to_string())));
998 update_go_away(&go_away, &message);
999 assert_eq!(*go_away.lock().unwrap(), Some("stay".to_string()));
1000 }
1001
1002 #[test]
1003 fn test_live_builder_config_chain() {
1004 let inner = Arc::new(test_client_inner_with_api_key(
1005 Backend::GeminiApi,
1006 Some("key"),
1007 ));
1008 let builder = LiveSessionBuilder::new(inner, "gemini-2.0-flash".to_string())
1009 .with_system_instruction("sys")
1010 .with_tools(vec![Tool::default()])
1011 .with_generation_config(GenerationConfig::default())
1012 .with_session_resumption()
1013 .with_context_window_compression(ContextWindowCompressionConfig {
1014 trigger_tokens: None,
1015 sliding_window: None,
1016 })
1017 .with_input_audio_transcription(AudioTranscriptionConfig::default())
1018 .with_output_audio_transcription(AudioTranscriptionConfig::default());
1019
1020 assert_eq!(builder.model, "gemini-2.0-flash");
1021 assert!(builder.config.system_instruction.is_some());
1022 assert!(builder.config.tools.is_some());
1023 assert!(builder.config.generation_config.is_some());
1024 assert!(builder.config.session_resumption.is_some());
1025 assert!(builder.config.context_window_compression.is_some());
1026 assert!(builder.config.input_audio_transcription.is_some());
1027 assert!(builder.config.output_audio_transcription.is_some());
1028
1029 let builder = builder.with_session_resumption_handle("handle");
1030 assert_eq!(
1031 builder
1032 .config
1033 .session_resumption
1034 .as_ref()
1035 .and_then(|cfg| cfg.handle.as_deref()),
1036 Some("handle")
1037 );
1038 }
1039
1040 #[tokio::test]
1041 async fn test_live_session_send_and_close() {
1042 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel();
1043 let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
1044 let (shutdown_tx, shutdown_rx) = oneshot::channel();
1045 let session = LiveSession {
1046 outgoing_tx,
1047 incoming_rx,
1048 shutdown_tx: Some(shutdown_tx),
1049 session_id: Some("session".to_string()),
1050 resumption_state: Arc::new(Mutex::new(LiveSessionResumptionState::default())),
1051 go_away_time_left: Arc::new(Mutex::new(None)),
1052 };
1053
1054 session.send_text("hi").await.unwrap();
1055 let msg = outgoing_rx.recv().await.unwrap();
1056 assert!(msg.client_content.is_some());
1057 assert!(msg.realtime_input.is_none());
1058
1059 session
1060 .send_audio(vec![1, 2, 3], "audio/pcm")
1061 .await
1062 .unwrap();
1063 let msg = outgoing_rx.recv().await.unwrap();
1064 assert!(msg.realtime_input.as_ref().unwrap().audio.is_some());
1065
1066 session
1067 .send_client_content(LiveSendClientContentParameters {
1068 turns: Some(vec![Content::text("turn")]),
1069 turn_complete: Some(false),
1070 })
1071 .await
1072 .unwrap();
1073 let msg = outgoing_rx.recv().await.unwrap();
1074 assert!(msg.client_content.is_some());
1075
1076 session
1077 .send_realtime_input(LiveSendRealtimeInputParameters {
1078 media: Some(Blob {
1079 mime_type: "audio/pcm".to_string(),
1080 data: vec![9],
1081 display_name: None,
1082 }),
1083 audio: None,
1084 audio_stream_end: Some(true),
1085 video: None,
1086 text: Some("rt".to_string()),
1087 activity_start: None,
1088 activity_end: None,
1089 })
1090 .await
1091 .unwrap();
1092 let msg = outgoing_rx.recv().await.unwrap();
1093 assert!(msg.realtime_input.is_some());
1094
1095 session
1096 .send_tool_response(LiveSendToolResponseParameters {
1097 function_responses: None,
1098 })
1099 .await
1100 .unwrap();
1101 let msg = outgoing_rx.recv().await.unwrap();
1102 assert!(msg.tool_response.is_some());
1103
1104 session.close().await.unwrap();
1105 assert!(shutdown_rx.await.is_ok());
1106 }
1107
1108 #[tokio::test]
1109 async fn test_live_session_send_channel_closed() {
1110 let (outgoing_tx, outgoing_rx) = mpsc::unbounded_channel();
1111 drop(outgoing_rx);
1112 let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
1113 let session = LiveSession {
1114 outgoing_tx,
1115 incoming_rx,
1116 shutdown_tx: None,
1117 session_id: None,
1118 resumption_state: Arc::new(Mutex::new(LiveSessionResumptionState::default())),
1119 go_away_time_left: Arc::new(Mutex::new(None)),
1120 };
1121 let err = session.send_text("hi").await.unwrap_err();
1122 assert!(matches!(err, Error::ChannelClosed));
1123 }
1124
1125 #[test]
1126 fn test_live_session_state_accessors() {
1127 let (outgoing_tx, _outgoing_rx) = mpsc::unbounded_channel();
1128 let (_incoming_tx, incoming_rx) = mpsc::unbounded_channel();
1129 let state = Arc::new(Mutex::new(LiveSessionResumptionState {
1130 handle: Some("h".to_string()),
1131 resumable: Some(true),
1132 last_consumed_client_message_index: Some("7".to_string()),
1133 }));
1134 let go_away = Arc::new(Mutex::new(Some("10s".to_string())));
1135 let session = LiveSession {
1136 outgoing_tx,
1137 incoming_rx,
1138 shutdown_tx: None,
1139 session_id: None,
1140 resumption_state: state,
1141 go_away_time_left: go_away,
1142 };
1143 assert_eq!(session.resumption_handle().as_deref(), Some("h"));
1144 assert_eq!(session.last_go_away_time_left().as_deref(), Some("10s"));
1145 let state = session.resumption_state();
1146 assert_eq!(
1147 state.last_consumed_client_message_index.as_deref(),
1148 Some("7")
1149 );
1150 }
1151
1152 #[tokio::test]
1153 async fn test_connect_live_session_errors() {
1154 let inner = Arc::new(test_client_inner_with_api_key(
1155 Backend::GeminiApi,
1156 Some("key"),
1157 ));
1158 let config = LiveConnectConfig {
1159 http_options: Some(rust_genai_types::http::HttpOptions::default()),
1160 ..Default::default()
1161 };
1162 let err = connect_live_session(inner, "model".to_string(), config)
1163 .await
1164 .err()
1165 .unwrap();
1166 assert!(matches!(err, Error::InvalidConfig { .. }));
1167
1168 let inner = Arc::new(test_client_inner_with_api_key(
1169 Backend::VertexAi,
1170 Some("key"),
1171 ));
1172 let err = connect_live_session(inner, "model".to_string(), LiveConnectConfig::default())
1173 .await
1174 .err()
1175 .unwrap();
1176 assert!(matches!(err, Error::InvalidConfig { .. }));
1177
1178 let inner = Arc::new(test_client_inner_with_api_key(Backend::GeminiApi, None));
1179 let err = connect_live_session(inner, "model".to_string(), LiveConnectConfig::default())
1180 .await
1181 .err()
1182 .unwrap();
1183 assert!(matches!(err, Error::InvalidConfig { .. }));
1184 }
1185
1186 #[test]
1187 fn test_build_live_ws_url_ephemeral_requires_v1alpha() {
1188 let err = build_live_ws_url(
1189 "https://generativelanguage.googleapis.com/",
1190 "v1beta",
1191 "auth_tokens/abc",
1192 )
1193 .unwrap_err();
1194 assert!(matches!(err, Error::InvalidConfig { .. }));
1195 }
1196
1197 #[test]
1198 fn test_build_live_ws_url_invalid_base_url() {
1199 let err = build_live_ws_url("://bad-url", "v1beta", "test-key").unwrap_err();
1200 assert!(matches!(err, Error::InvalidConfig { .. }));
1201 }
1202}