1use futures_util::{SinkExt, StreamExt};
4use tokio_tungstenite::{
5 connect_async,
6 tungstenite::{client::IntoClientRequest, Message as WsMessage},
7};
8
9use crate::client::XaiClient;
10use crate::models::tool::Tool;
11use crate::models::voice::{
12 AudioFormat, ConversationItem, RealtimeClientMessage, RealtimeServerMessage, SessionConfig,
13 Voice,
14};
15use crate::{Error, Result};
16
17#[derive(Debug, Clone)]
19pub struct RealtimeApi {
20 client: XaiClient,
21}
22
23impl RealtimeApi {
24 pub(crate) fn new(client: XaiClient) -> Self {
25 Self { client }
26 }
27
28 pub fn connect(&self, model: impl Into<String>) -> RealtimeSessionBuilder {
49 RealtimeSessionBuilder::new(self.client.clone(), model.into())
50 }
51
52 pub fn resume(&self, config: SessionConfig) -> RealtimeSessionBuilder {
74 RealtimeSessionBuilder::from_config(self.client.clone(), config)
75 }
76}
77
78#[derive(Debug)]
80pub struct RealtimeSessionBuilder {
81 client: XaiClient,
82 config: SessionConfig,
83}
84
85impl RealtimeSessionBuilder {
86 fn new(client: XaiClient, model: String) -> Self {
87 Self::from_config(client, SessionConfig::new(model))
88 }
89
90 fn from_config(client: XaiClient, config: SessionConfig) -> Self {
91 Self { client, config }
92 }
93
94 pub fn voice(mut self, voice: Voice) -> Self {
96 self.config.voice = voice;
97 self
98 }
99
100 pub fn audio_format(mut self, format: AudioFormat) -> Self {
102 self.config.input_audio_format = format;
103 self.config.output_audio_format = format;
104 self
105 }
106
107 pub fn input_format(mut self, format: AudioFormat) -> Self {
109 self.config.input_audio_format = format;
110 self
111 }
112
113 pub fn output_format(mut self, format: AudioFormat) -> Self {
115 self.config.output_audio_format = format;
116 self
117 }
118
119 pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
121 self.config.instructions = Some(instructions.into());
122 self
123 }
124
125 pub fn tools(mut self, tools: Vec<Tool>) -> Self {
127 self.config.tools = Some(tools);
128 self
129 }
130
131 pub async fn start(self) -> Result<RealtimeSession> {
136 let mut parsed = url::Url::parse(self.client.base_url())?;
138 let scheme = match parsed.scheme() {
139 "https" => "wss",
140 "http" => "ws",
141 s => {
142 return Err(Error::InvalidRequest(format!(
143 "Unsupported URL scheme: {}",
144 s
145 )))
146 }
147 };
148 parsed
149 .set_scheme(scheme)
150 .map_err(|_| Error::InvalidRequest("Failed to set WebSocket scheme".to_string()))?;
151 parsed
152 .path_segments_mut()
153 .map_err(|_| Error::InvalidRequest("Cannot-be-a-base URL".to_string()))?
154 .push("realtime");
155 parsed
156 .query_pairs_mut()
157 .append_pair("model", &self.config.model);
158 let ws_url = parsed.to_string();
159
160 let mut request = ws_url
162 .into_client_request()
163 .map_err(|e| Error::InvalidRequest(e.to_string()))?;
164 request.headers_mut().insert(
165 "Authorization",
166 http::HeaderValue::from_str(&format!("Bearer {}", self.client.api_key()))
167 .map_err(|e| Error::InvalidRequest(e.to_string()))?,
168 );
169 request.headers_mut().insert(
170 "Sec-WebSocket-Protocol",
171 http::HeaderValue::from_static("realtime"),
172 );
173
174 let (ws_stream, _) = connect_async(request).await?;
175 let (write, read) = ws_stream.split();
176
177 let mut session = RealtimeSession {
178 client: self.client.clone(),
179 config: self.config.clone(),
180 write: Box::new(write),
181 read: Box::new(read),
182 };
183
184 session.update_session(self.config).await?;
187
188 Ok(session)
189 }
190}
191
192pub struct RealtimeSession {
194 client: XaiClient,
195 config: SessionConfig,
196 write: Box<
197 dyn futures_util::Sink<WsMessage, Error = tokio_tungstenite::tungstenite::Error>
198 + Send
199 + Unpin,
200 >,
201 read: Box<
202 dyn futures_util::Stream<
203 Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>,
204 > + Send
205 + Unpin,
206 >,
207}
208
209impl RealtimeSession {
210 pub fn config(&self) -> &SessionConfig {
212 &self.config
213 }
214
215 pub fn reconnect_builder(&self) -> RealtimeSessionBuilder {
219 RealtimeSessionBuilder::from_config(self.client.clone(), self.config.clone())
220 }
221
222 pub async fn reconnect(&self) -> Result<RealtimeSession> {
227 self.reconnect_builder().start().await
228 }
229
230 pub async fn reconnect_and_replay(
234 &self,
235 items: impl IntoIterator<Item = ConversationItem>,
236 ) -> Result<RealtimeSession> {
237 let mut session = self.reconnect().await?;
238 for item in items {
239 session.create_item(item).await?;
240 }
241 Ok(session)
242 }
243
244 pub async fn send(&mut self, message: RealtimeClientMessage) -> Result<()> {
246 let json = serde_json::to_string(&message)?;
247 self.write.send(WsMessage::Text(json.into())).await?;
248 Ok(())
249 }
250
251 pub async fn receive(&mut self) -> Result<Option<RealtimeServerMessage>> {
253 while let Some(result) = self.read.next().await {
254 match result? {
255 WsMessage::Text(text) => {
256 let message: RealtimeServerMessage = serde_json::from_str(&text)?;
257 return Ok(Some(message));
258 }
259 WsMessage::Binary(_) => {
260 continue;
263 }
264 WsMessage::Close(_) => return Ok(None),
265 WsMessage::Ping(_) | WsMessage::Pong(_) | WsMessage::Frame(_) => continue,
266 }
267 }
268 Ok(None)
269 }
270
271 pub async fn receive_raw(&mut self) -> Result<Option<RawRealtimeMessage>> {
273 while let Some(result) = self.read.next().await {
274 match result? {
275 WsMessage::Text(text) => {
276 let message: RealtimeServerMessage = serde_json::from_str(&text)?;
277 return Ok(Some(RawRealtimeMessage::Event(message)));
278 }
279 WsMessage::Binary(data) => {
280 return Ok(Some(RawRealtimeMessage::Audio(data.to_vec())));
281 }
282 WsMessage::Close(_) => return Ok(None),
283 WsMessage::Ping(_) | WsMessage::Pong(_) | WsMessage::Frame(_) => continue,
284 }
285 }
286 Ok(None)
287 }
288
289 pub async fn update_session(&mut self, config: SessionConfig) -> Result<()> {
291 self.config = config.clone();
292 self.send(RealtimeClientMessage::SessionUpdate { session: config })
293 .await
294 }
295
296 pub async fn append_audio(&mut self, audio_base64: impl Into<String>) -> Result<()> {
298 self.send(RealtimeClientMessage::InputAudioBufferAppend {
299 audio: audio_base64.into(),
300 })
301 .await
302 }
303
304 pub async fn commit_audio(&mut self) -> Result<()> {
306 self.send(RealtimeClientMessage::InputAudioBufferCommit {})
307 .await
308 }
309
310 pub async fn clear_audio(&mut self) -> Result<()> {
312 self.send(RealtimeClientMessage::InputAudioBufferClear {})
313 .await
314 }
315
316 pub async fn create_item(&mut self, item: ConversationItem) -> Result<()> {
318 self.send(RealtimeClientMessage::ConversationItemCreate { item })
319 .await
320 }
321
322 pub async fn create_response(&mut self) -> Result<()> {
324 self.send(RealtimeClientMessage::ResponseCreate { response: None })
325 .await
326 }
327
328 pub async fn cancel_response(&mut self) -> Result<()> {
330 self.send(RealtimeClientMessage::ResponseCancel {}).await
331 }
332
333 pub async fn close(mut self) -> Result<()> {
335 self.write.close().await?;
336 Ok(())
337 }
338}
339
340impl std::fmt::Debug for RealtimeSession {
341 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 f.debug_struct("RealtimeSession")
343 .field("config", &self.config)
344 .finish_non_exhaustive()
345 }
346}
347
348#[derive(Debug)]
351pub enum RawRealtimeMessage {
352 Event(RealtimeServerMessage),
354 Audio(Vec<u8>),
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::models::voice::{
362 ConversationContent, ConversationContentType, ConversationItemType,
363 };
364 use futures_util::{SinkExt, StreamExt};
365 use std::time::Duration;
366 use tokio::net::TcpListener;
367 use tokio::time::timeout;
368 use tokio_tungstenite::{
369 accept_hdr_async,
370 tungstenite::{
371 handshake::server::{Request, Response},
372 Message as WsMessage,
373 },
374 };
375
376 async fn spawn_capture_server(
377 expected_messages: usize,
378 ) -> (String, tokio::task::JoinHandle<Vec<String>>) {
379 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
380 let addr = listener.local_addr().unwrap();
381
382 let handle = tokio::spawn(async move {
383 let (stream, _) = listener.accept().await.unwrap();
384 let ws_stream =
385 accept_hdr_async(stream, |request: &Request, mut response: Response| {
386 let has_realtime = request
387 .headers()
388 .get("Sec-WebSocket-Protocol")
389 .and_then(|v| v.to_str().ok())
390 .map(|raw| raw.split(',').any(|v| v.trim() == "realtime"))
391 .unwrap_or(false);
392
393 if has_realtime {
394 response.headers_mut().insert(
395 "Sec-WebSocket-Protocol",
396 http::HeaderValue::from_static("realtime"),
397 );
398 }
399
400 Ok(response)
401 })
402 .await
403 .unwrap();
404 let (_write, mut read) = ws_stream.split();
405
406 let mut messages = Vec::new();
407 while let Some(frame) = read.next().await {
408 match frame.unwrap() {
409 WsMessage::Text(text) => {
410 messages.push(text.to_string());
411 if messages.len() >= expected_messages {
412 break;
413 }
414 }
415 WsMessage::Close(_) => break,
416 WsMessage::Ping(_)
417 | WsMessage::Pong(_)
418 | WsMessage::Binary(_)
419 | WsMessage::Frame(_) => {}
420 }
421 }
422 messages
423 });
424
425 (format!("http://{}", addr), handle)
426 }
427
428 async fn spawn_multi_capture_server(
429 expected_messages_per_connection: Vec<usize>,
430 ) -> (String, tokio::task::JoinHandle<Vec<Vec<String>>>) {
431 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
432 let addr = listener.local_addr().unwrap();
433
434 let handle = tokio::spawn(async move {
435 let mut all_connections = Vec::new();
436
437 for expected_messages in expected_messages_per_connection {
438 let (stream, _) = listener.accept().await.unwrap();
439 let ws_stream =
440 accept_hdr_async(stream, |request: &Request, mut response: Response| {
441 let has_realtime = request
442 .headers()
443 .get("Sec-WebSocket-Protocol")
444 .and_then(|v| v.to_str().ok())
445 .map(|raw| raw.split(',').any(|v| v.trim() == "realtime"))
446 .unwrap_or(false);
447
448 if has_realtime {
449 response.headers_mut().insert(
450 "Sec-WebSocket-Protocol",
451 http::HeaderValue::from_static("realtime"),
452 );
453 }
454
455 Ok(response)
456 })
457 .await
458 .unwrap();
459 let (_write, mut read) = ws_stream.split();
460
461 let mut messages = Vec::new();
462 while let Some(frame) = read.next().await {
463 match frame.unwrap() {
464 WsMessage::Text(text) => {
465 messages.push(text.to_string());
466 if messages.len() >= expected_messages {
467 break;
468 }
469 }
470 WsMessage::Close(_) => break,
471 WsMessage::Ping(_)
472 | WsMessage::Pong(_)
473 | WsMessage::Binary(_)
474 | WsMessage::Frame(_) => {}
475 }
476 }
477
478 all_connections.push(messages);
479 }
480
481 all_connections
482 });
483
484 (format!("http://{}", addr), handle)
485 }
486
487 async fn spawn_response_server(
488 frames: Vec<WsMessage>,
489 ) -> (String, tokio::task::JoinHandle<()>) {
490 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
491 let addr = listener.local_addr().unwrap();
492
493 let handle = tokio::spawn(async move {
494 let (stream, _) = listener.accept().await.unwrap();
495 let ws_stream =
496 accept_hdr_async(stream, |request: &Request, mut response: Response| {
497 let has_realtime = request
498 .headers()
499 .get("Sec-WebSocket-Protocol")
500 .and_then(|v| v.to_str().ok())
501 .map(|raw| raw.split(',').any(|v| v.trim() == "realtime"))
502 .unwrap_or(false);
503
504 if has_realtime {
505 response.headers_mut().insert(
506 "Sec-WebSocket-Protocol",
507 http::HeaderValue::from_static("realtime"),
508 );
509 }
510
511 Ok(response)
512 })
513 .await
514 .unwrap();
515 let (mut write, mut read) = ws_stream.split();
516 for frame in frames {
517 write.send(frame).await.unwrap();
518 }
519
520 while let Some(frame) = read.next().await {
521 if matches!(frame, Ok(WsMessage::Close(_))) {
522 break;
523 }
524 }
525 });
526
527 (format!("http://{}", addr), handle)
528 }
529
530 #[tokio::test]
531 async fn start_sends_session_update_with_builder_config() {
532 let (base_url, server_handle) = spawn_capture_server(1).await;
533
534 let client = XaiClient::builder()
535 .api_key("test-key")
536 .base_url(base_url)
537 .build()
538 .unwrap();
539
540 let session = timeout(
541 Duration::from_secs(2),
542 client
543 .realtime()
544 .connect("grok-4")
545 .voice(Voice::Rex)
546 .audio_format(AudioFormat::G711Ulaw)
547 .instructions("Be brief")
548 .start(),
549 )
550 .await
551 .expect("start() should not wait on incoming messages")
552 .unwrap();
553
554 assert_eq!(session.config().voice, Voice::Rex);
555 assert_eq!(session.config().instructions.as_deref(), Some("Be brief"));
556
557 let frames = timeout(Duration::from_secs(2), server_handle)
558 .await
559 .unwrap()
560 .unwrap();
561 assert_eq!(frames.len(), 1);
562
563 let msg: serde_json::Value = serde_json::from_str(&frames[0]).unwrap();
564 assert_eq!(msg["type"], "session_update");
565 assert_eq!(msg["session"]["model"], "grok-4");
566 assert_eq!(msg["session"]["voice"], "rex");
567 assert_eq!(msg["session"]["input_audio_format"], "g711_ulaw");
568 assert_eq!(msg["session"]["output_audio_format"], "g711_ulaw");
569 assert_eq!(msg["session"]["instructions"], "Be brief");
570
571 session.close().await.unwrap();
572 }
573
574 #[tokio::test]
575 async fn start_sends_session_update_with_format_and_tools_config() {
576 let (base_url, server_handle) = spawn_capture_server(1).await;
577
578 let client = XaiClient::builder()
579 .api_key("test-key")
580 .base_url(base_url)
581 .build()
582 .unwrap();
583
584 let session = timeout(
585 Duration::from_secs(2),
586 client
587 .realtime()
588 .connect("grok-4")
589 .voice(Voice::Ara)
590 .input_format(AudioFormat::G711Alaw)
591 .output_format(AudioFormat::G711Ulaw)
592 .tools(vec![Tool::web_search(), Tool::x_search()])
593 .instructions("Use web and x search tools")
594 .start(),
595 )
596 .await
597 .expect("start() should not wait on incoming messages")
598 .unwrap();
599
600 assert_eq!(session.config().voice, Voice::Ara);
601 assert_eq!(
602 session.config().instructions.as_deref(),
603 Some("Use web and x search tools")
604 );
605
606 let frames = timeout(Duration::from_secs(2), server_handle)
607 .await
608 .unwrap()
609 .unwrap();
610 assert_eq!(frames.len(), 1);
611
612 let msg: serde_json::Value = serde_json::from_str(&frames[0]).unwrap();
613 assert_eq!(msg["type"], "session_update");
614 assert_eq!(msg["session"]["model"], "grok-4");
615 assert_eq!(msg["session"]["voice"], "ara");
616 assert_eq!(msg["session"]["input_audio_format"], "g711_alaw");
617 assert_eq!(msg["session"]["output_audio_format"], "g711_ulaw");
618
619 let tools = msg["session"]["tools"].as_array().unwrap();
620 assert_eq!(tools.len(), 2);
621 assert_eq!(tools[0]["type"], "web_search");
622 assert_eq!(tools[1]["type"], "x_search");
623
624 session.close().await.unwrap();
625 }
626
627 #[tokio::test]
628 async fn update_session_updates_local_config_and_sends_update() {
629 let (base_url, server_handle) = spawn_capture_server(2).await;
630
631 let client = XaiClient::builder()
632 .api_key("test-key")
633 .base_url(base_url)
634 .build()
635 .unwrap();
636
637 let mut session = client
638 .realtime()
639 .connect("grok-4")
640 .voice(Voice::Ara)
641 .start()
642 .await
643 .unwrap();
644
645 let updated = SessionConfig::new("grok-4")
646 .voice(Voice::Leo)
647 .input_format(AudioFormat::G711Alaw)
648 .output_format(AudioFormat::G711Alaw)
649 .instructions("Updated instructions");
650
651 session.update_session(updated.clone()).await.unwrap();
652
653 assert_eq!(session.config().voice, Voice::Leo);
654 assert_eq!(session.config().input_audio_format, AudioFormat::G711Alaw);
655 assert_eq!(session.config().output_audio_format, AudioFormat::G711Alaw);
656 assert_eq!(
657 session.config().instructions.as_deref(),
658 Some("Updated instructions")
659 );
660
661 session.close().await.unwrap();
662
663 let frames = timeout(Duration::from_secs(2), server_handle)
664 .await
665 .unwrap()
666 .unwrap();
667 assert_eq!(frames.len(), 2);
668
669 let second: serde_json::Value = serde_json::from_str(&frames[1]).unwrap();
670 assert_eq!(second["type"], "session_update");
671 assert_eq!(second["session"]["voice"], "leo");
672 assert_eq!(second["session"]["input_audio_format"], "g711_alaw");
673 assert_eq!(second["session"]["output_audio_format"], "g711_alaw");
674 assert_eq!(second["session"]["instructions"], "Updated instructions");
675 }
676
677 #[tokio::test]
678 async fn resume_starts_session_with_existing_config() {
679 let (base_url, server_handle) = spawn_capture_server(1).await;
680
681 let client = XaiClient::builder()
682 .api_key("test-key")
683 .base_url(base_url)
684 .build()
685 .unwrap();
686
687 let config = SessionConfig::new("grok-4")
688 .voice(Voice::Leo)
689 .input_format(AudioFormat::G711Alaw)
690 .output_format(AudioFormat::G711Ulaw)
691 .instructions("Resume config");
692
693 let session = client
694 .realtime()
695 .resume(config.clone())
696 .start()
697 .await
698 .unwrap();
699
700 assert_eq!(session.config().model, "grok-4");
701 assert_eq!(session.config().voice, Voice::Leo);
702 assert_eq!(session.config().input_audio_format, AudioFormat::G711Alaw);
703 assert_eq!(session.config().output_audio_format, AudioFormat::G711Ulaw);
704 assert_eq!(
705 session.config().instructions.as_deref(),
706 Some("Resume config")
707 );
708
709 session.close().await.unwrap();
710
711 let frames = timeout(Duration::from_secs(2), server_handle)
712 .await
713 .unwrap()
714 .unwrap();
715 assert_eq!(frames.len(), 1);
716
717 let msg: serde_json::Value = serde_json::from_str(&frames[0]).unwrap();
718 assert_eq!(msg["type"], "session_update");
719 assert_eq!(msg["session"]["model"], "grok-4");
720 assert_eq!(msg["session"]["voice"], "leo");
721 assert_eq!(msg["session"]["input_audio_format"], "g711_alaw");
722 assert_eq!(msg["session"]["output_audio_format"], "g711_ulaw");
723 assert_eq!(msg["session"]["instructions"], "Resume config");
724 }
725
726 #[tokio::test]
727 async fn reconnect_reuses_current_session_config() {
728 let (base_url, server_handle) = spawn_multi_capture_server(vec![1, 1]).await;
729
730 let client = XaiClient::builder()
731 .api_key("test-key")
732 .base_url(base_url)
733 .build()
734 .unwrap();
735
736 let session = client
737 .realtime()
738 .connect("grok-4")
739 .voice(Voice::Rex)
740 .instructions("Reconnect me")
741 .start()
742 .await
743 .unwrap();
744
745 let reconnected = session.reconnect().await.unwrap();
746
747 assert_eq!(reconnected.config().model, "grok-4");
748 assert_eq!(reconnected.config().voice, Voice::Rex);
749 assert_eq!(
750 reconnected.config().instructions.as_deref(),
751 Some("Reconnect me")
752 );
753
754 session.close().await.unwrap();
755 reconnected.close().await.unwrap();
756
757 let frames_by_connection = timeout(Duration::from_secs(2), server_handle)
758 .await
759 .unwrap()
760 .unwrap();
761 assert_eq!(frames_by_connection.len(), 2);
762 assert_eq!(frames_by_connection[0].len(), 1);
763 assert_eq!(frames_by_connection[1].len(), 1);
764
765 let first: serde_json::Value = serde_json::from_str(&frames_by_connection[0][0]).unwrap();
766 let second: serde_json::Value = serde_json::from_str(&frames_by_connection[1][0]).unwrap();
767 assert_eq!(first["type"], "session_update");
768 assert_eq!(second["type"], "session_update");
769 assert_eq!(second["session"]["model"], "grok-4");
770 assert_eq!(second["session"]["voice"], "rex");
771 assert_eq!(second["session"]["instructions"], "Reconnect me");
772 }
773
774 #[tokio::test]
775 async fn reconnect_and_replay_sends_conversation_items() {
776 let (base_url, server_handle) = spawn_multi_capture_server(vec![1, 3]).await;
777
778 let client = XaiClient::builder()
779 .api_key("test-key")
780 .base_url(base_url)
781 .build()
782 .unwrap();
783
784 let session = client.realtime().connect("grok-4").start().await.unwrap();
785
786 let items = vec![
787 ConversationItem {
788 id: Some("item-1".to_string()),
789 item_type: ConversationItemType::Message,
790 role: Some("user".to_string()),
791 content: Some(vec![ConversationContent {
792 content_type: ConversationContentType::InputText,
793 text: Some("hello".to_string()),
794 audio: None,
795 transcript: None,
796 }]),
797 },
798 ConversationItem {
799 id: Some("item-2".to_string()),
800 item_type: ConversationItemType::Message,
801 role: Some("assistant".to_string()),
802 content: Some(vec![ConversationContent {
803 content_type: ConversationContentType::Text,
804 text: Some("world".to_string()),
805 audio: None,
806 transcript: None,
807 }]),
808 },
809 ];
810
811 let resumed = session.reconnect_and_replay(items).await.unwrap();
812
813 session.close().await.unwrap();
814 resumed.close().await.unwrap();
815
816 let frames_by_connection = timeout(Duration::from_secs(2), server_handle)
817 .await
818 .unwrap()
819 .unwrap();
820 assert_eq!(frames_by_connection.len(), 2);
821 assert_eq!(frames_by_connection[1].len(), 3);
822
823 let second_connection = &frames_by_connection[1];
824 let first: serde_json::Value = serde_json::from_str(&second_connection[0]).unwrap();
825 let second: serde_json::Value = serde_json::from_str(&second_connection[1]).unwrap();
826 let third: serde_json::Value = serde_json::from_str(&second_connection[2]).unwrap();
827
828 assert_eq!(first["type"], "session_update");
829 assert_eq!(second["type"], "conversation_item_create");
830 assert_eq!(second["item"]["id"], "item-1");
831 assert_eq!(second["item"]["content"][0]["text"], "hello");
832 assert_eq!(third["type"], "conversation_item_create");
833 assert_eq!(third["item"]["id"], "item-2");
834 assert_eq!(third["item"]["content"][0]["text"], "world");
835 }
836
837 #[tokio::test]
838 async fn receive_skips_binary_frames_and_returns_event() {
839 let binary = WsMessage::Binary(vec![0x10, 0x20].into());
840 let event = WsMessage::Text(
841 r#"{"type":"session_updated","session":{"model":"grok-4","voice":"rex","input_audio_format":"pcm16","output_audio_format":"pcm16"}}"#
842 .to_string()
843 .into(),
844 );
845
846 let (base_url, server_handle) = spawn_response_server(vec![binary, event]).await;
847
848 let client = XaiClient::builder()
849 .api_key("test-key")
850 .base_url(base_url)
851 .build()
852 .unwrap();
853
854 let mut session = client
855 .realtime()
856 .connect("grok-4")
857 .voice(Voice::Rex)
858 .start()
859 .await
860 .unwrap();
861
862 let event = session
863 .receive()
864 .await
865 .expect("event should arrive after binary frame");
866 assert!(matches!(
867 event,
868 Some(RealtimeServerMessage::SessionUpdated { .. })
869 ));
870
871 session.close().await.unwrap();
872 server_handle.await.unwrap();
873 }
874
875 #[tokio::test]
876 async fn receive_returns_none_on_close() {
877 let close = WsMessage::Close(None);
878
879 let (base_url, server_handle) = spawn_response_server(vec![close]).await;
880
881 let client = XaiClient::builder()
882 .api_key("test-key")
883 .base_url(base_url)
884 .build()
885 .unwrap();
886
887 let mut session = client.realtime().connect("grok-4").start().await.unwrap();
888
889 let event = session.receive().await.expect("close should be observed");
890 assert!(event.is_none());
891
892 session.close().await.unwrap();
893 server_handle.await.unwrap();
894 }
895
896 #[tokio::test]
897 async fn receive_skips_control_frames_until_event() {
898 let ping = WsMessage::Ping(vec![0x01].into());
899 let pong = WsMessage::Pong(vec![0x02].into());
900 let event = WsMessage::Text(
901 r#"{"type":"session_updated","session":{"model":"grok-4","voice":"rex","input_audio_format":"pcm16","output_audio_format":"pcm16"}}"#
902 .to_string()
903 .into(),
904 );
905
906 let (base_url, server_handle) = spawn_response_server(vec![ping, pong, event]).await;
907
908 let client = XaiClient::builder()
909 .api_key("test-key")
910 .base_url(base_url)
911 .build()
912 .unwrap();
913
914 let mut session = client
915 .realtime()
916 .connect("grok-4")
917 .voice(Voice::Rex)
918 .start()
919 .await
920 .unwrap();
921
922 let event = session
923 .receive()
924 .await
925 .expect("event should arrive after control frames");
926 assert!(matches!(
927 event,
928 Some(RealtimeServerMessage::SessionUpdated { .. })
929 ));
930
931 session.close().await.unwrap();
932 server_handle.await.unwrap();
933 }
934
935 #[tokio::test]
936 async fn receive_raw_returns_none_on_close() {
937 let close = WsMessage::Close(None);
938
939 let (base_url, server_handle) = spawn_response_server(vec![close]).await;
940
941 let client = XaiClient::builder()
942 .api_key("test-key")
943 .base_url(base_url)
944 .build()
945 .unwrap();
946
947 let mut session = client.realtime().connect("grok-4").start().await.unwrap();
948
949 let event = session
950 .receive_raw()
951 .await
952 .expect("close should be observed");
953 assert!(event.is_none());
954
955 session.close().await.unwrap();
956 server_handle.await.unwrap();
957 }
958
959 #[tokio::test]
960 async fn receive_raw_skips_control_frames_until_audio() {
961 let ping = WsMessage::Ping(vec![0x01].into());
962 let audio = WsMessage::Binary(vec![9, 8, 7].into());
963
964 let (base_url, server_handle) = spawn_response_server(vec![ping, audio]).await;
965
966 let client = XaiClient::builder()
967 .api_key("test-key")
968 .base_url(base_url)
969 .build()
970 .unwrap();
971
972 let mut session = client.realtime().connect("grok-4").start().await.unwrap();
973
974 let first = session
975 .receive_raw()
976 .await
977 .expect("audio after control frame")
978 .expect("audio is present");
979 assert!(matches!(first, RawRealtimeMessage::Audio(_)));
980 match first {
981 RawRealtimeMessage::Audio(bytes) => assert_eq!(bytes, vec![9, 8, 7]),
982 RawRealtimeMessage::Event(_) => unreachable!("expected audio"),
983 }
984
985 session.close().await.unwrap();
986 server_handle.await.unwrap();
987 }
988
989 #[tokio::test]
990 async fn start_rejects_unsupported_base_url_scheme() {
991 let client = XaiClient::builder()
992 .api_key("test-key")
993 .base_url("ftp://localhost")
994 .build()
995 .unwrap();
996
997 let err = client
998 .realtime()
999 .connect("grok-4")
1000 .start()
1001 .await
1002 .unwrap_err();
1003
1004 match err {
1005 Error::InvalidRequest(message) => {
1006 assert_eq!(message, "Unsupported URL scheme: ftp")
1007 }
1008 _ => panic!("expected unsupported scheme error"),
1009 }
1010 }
1011
1012 #[tokio::test]
1013 async fn receive_raw_supports_event_and_audio() {
1014 let event = WsMessage::Text(
1015 r#"{"type":"response_audio_delta","response_id":"resp","item_id":"item","delta":"AQID"}"#
1016 .to_string()
1017 .into(),
1018 );
1019 let audio = WsMessage::Binary(vec![9, 8, 7].into());
1020
1021 let (base_url, server_handle) = spawn_response_server(vec![event, audio]).await;
1022
1023 let client = XaiClient::builder()
1024 .api_key("test-key")
1025 .base_url(base_url)
1026 .build()
1027 .unwrap();
1028
1029 let mut session = client.realtime().connect("grok-4").start().await.unwrap();
1030
1031 let first = session
1032 .receive_raw()
1033 .await
1034 .expect("received event")
1035 .expect("event is present");
1036 match first {
1037 RawRealtimeMessage::Event(message) => {
1038 assert!(matches!(
1039 message,
1040 RealtimeServerMessage::ResponseAudioDelta { .. }
1041 ))
1042 }
1043 RawRealtimeMessage::Audio(_) => panic!("expected event first"),
1044 }
1045
1046 let second = session
1047 .receive_raw()
1048 .await
1049 .expect("received audio")
1050 .expect("audio is present");
1051 match second {
1052 RawRealtimeMessage::Audio(bytes) => assert_eq!(bytes, vec![9, 8, 7]),
1053 RawRealtimeMessage::Event(_) => panic!("expected audio second"),
1054 }
1055
1056 session.close().await.unwrap();
1057 server_handle.await.unwrap();
1058 }
1059}