1use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use base64::Engine as _;
8use futures::{SinkExt, StreamExt};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use thiserror::Error;
12use tokio::sync::mpsc;
13use tokio_tungstenite::{
14 connect_async,
15 tungstenite::{Message, http::Request},
16};
17use tracing::{debug, error, info, trace, warn};
18
19use crate::messages::ToolCallPart;
20
21#[derive(Debug, Error)]
22pub enum Error {
23 #[error("connection closed")]
24 ConnectionClosed,
25 #[error("serialization error: {0}")]
26 Serialization(String),
27 #[error("websocket error: {0}")]
28 WebSocket(String),
29 #[error("provider error: {0}")]
30 Provider(String),
31}
32
33impl From<serde_json::Error> for Error {
34 fn from(err: serde_json::Error) -> Self {
35 Self::Serialization(err.to_string())
36 }
37}
38
39impl From<tokio_tungstenite::tungstenite::Error> for Error {
40 fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
41 Self::WebSocket(err.to_string())
42 }
43}
44
45pub type Result<T> = std::result::Result<T, Error>;
46
47#[derive(Debug, Clone, Serialize)]
49#[serde(tag = "type", rename_all = "snake_case")]
50pub enum ClientEvent {
51 #[serde(rename = "session.update")]
53 SessionUpdate { session: SessionUpdatePayload },
54
55 #[serde(rename = "input_audio_buffer.append")]
57 InputAudioBufferAppend {
58 #[serde(skip_serializing_if = "Option::is_none")]
59 event_id: Option<String>,
60 audio: String, },
62
63 #[serde(rename = "conversation.item.commit")]
65 ConversationItemCommit {
66 #[serde(skip_serializing_if = "Option::is_none")]
67 event_id: Option<String>,
68 },
69
70 #[serde(rename = "input_audio_buffer.clear")]
72 InputAudioBufferClear {
73 #[serde(skip_serializing_if = "Option::is_none")]
74 event_id: Option<String>,
75 },
76
77 #[serde(rename = "conversation.item.create")]
79 ConversationItemCreate {
80 #[serde(skip_serializing_if = "Option::is_none")]
81 event_id: Option<String>,
82 item: ConversationItem,
83 },
84
85 #[serde(rename = "response.create")]
87 ResponseCreate {
88 #[serde(skip_serializing_if = "Option::is_none")]
89 event_id: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 response: Option<ResponseCreatePayload>,
92 },
93
94 #[serde(rename = "response.cancel")]
96 ResponseCancel {
97 #[serde(skip_serializing_if = "Option::is_none")]
98 event_id: Option<String>,
99 },
100}
101
102#[derive(Debug, Clone, Serialize)]
103pub struct SessionUpdatePayload {
104 #[serde(skip_serializing_if = "Option::is_none")]
105 pub instructions: Option<String>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 pub voice: Option<String>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 pub turn_detection: Option<TurnDetection>,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub tools: Option<Vec<GrokToolDefinition>>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 pub temperature: Option<f32>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 pub audio: Option<AudioConfig>,
116}
117
118#[derive(Debug, Clone, Serialize)]
119pub struct TurnDetection {
120 #[serde(rename = "type")]
121 pub detection_type: String, #[serde(skip_serializing_if = "Option::is_none")]
123 pub threshold: Option<f32>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub prefix_padding_ms: Option<u32>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 pub silence_duration_ms: Option<u32>,
128}
129
130impl Default for TurnDetection {
131 fn default() -> Self {
132 Self {
133 detection_type: "server_vad".to_string(),
134 threshold: Some(0.5),
135 prefix_padding_ms: Some(300),
136 silence_duration_ms: Some(200),
137 }
138 }
139}
140
141#[derive(Debug, Clone, Serialize)]
142pub struct AudioConfig {
143 pub input: AudioChannelConfig,
144 pub output: AudioChannelConfig,
145}
146
147#[derive(Debug, Clone, Serialize)]
148pub struct AudioChannelConfig {
149 pub format: AudioFormat,
150}
151
152#[derive(Debug, Clone, Serialize)]
153pub struct AudioFormat {
154 #[serde(rename = "type")]
155 pub format_type: String, #[serde(skip_serializing_if = "Option::is_none")]
157 pub rate: Option<u32>,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct GrokToolDefinition {
162 #[serde(rename = "type")]
163 pub tool_type: String, pub name: String,
165 #[serde(skip_serializing_if = "Option::is_none")]
166 pub description: Option<String>,
167 #[serde(skip_serializing_if = "Option::is_none")]
168 pub parameters: Option<Value>, }
170
171impl GrokToolDefinition {
172 pub fn function(
173 name: impl Into<String>,
174 description: impl Into<String>,
175 parameters: Value,
176 ) -> Self {
177 Self {
178 tool_type: "function".to_string(),
179 name: name.into(),
180 description: Some(description.into()),
181 parameters: Some(parameters),
182 }
183 }
184}
185
186impl From<&crate::tools::ToolDefinition> for GrokToolDefinition {
187 fn from(tool: &crate::tools::ToolDefinition) -> Self {
188 Self {
189 tool_type: "function".to_string(),
190 name: tool.name.clone(),
191 description: tool.description.clone(),
192 parameters: Some(tool.parameters_json_schema.clone()),
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize)]
198pub struct ConversationItem {
199 #[serde(rename = "type")]
200 pub item_type: String,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 pub id: Option<String>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 pub call_id: Option<String>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub output: Option<String>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub role: Option<String>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 pub content: Option<Vec<ContentPart>>,
211}
212
213impl ConversationItem {
214 pub fn function_call_output(call_id: String, output: String) -> Self {
216 Self {
217 item_type: "function_call_output".to_string(),
218 id: None,
219 call_id: Some(call_id),
220 output: Some(output),
221 role: None,
222 content: None,
223 }
224 }
225
226 pub fn user_text(text: impl Into<String>) -> Self {
228 Self {
229 item_type: "message".to_string(),
230 id: None,
231 call_id: None,
232 output: None,
233 role: Some("user".to_string()),
234 content: Some(vec![ContentPart {
235 content_type: "input_text".to_string(),
236 text: Some(text.into()),
237 audio: None,
238 }]),
239 }
240 }
241}
242
243#[derive(Debug, Clone, Serialize)]
244pub struct ContentPart {
245 #[serde(rename = "type")]
246 pub content_type: String,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub text: Option<String>,
249 #[serde(skip_serializing_if = "Option::is_none")]
250 pub audio: Option<String>,
251}
252
253#[derive(Debug, Clone, Serialize)]
254pub struct ResponseCreatePayload {
255 #[serde(skip_serializing_if = "Option::is_none")]
256 pub modalities: Option<Vec<String>>,
257}
258
259#[derive(Debug, Clone, Deserialize)]
261#[serde(tag = "type", rename_all = "snake_case")]
262pub enum ServerEvent {
263 #[serde(rename = "session.created")]
265 SessionCreated { session: SessionInfo },
266
267 #[serde(rename = "session.updated")]
269 SessionUpdated { session: SessionInfo },
270
271 #[serde(rename = "conversation.created")]
273 ConversationCreated {
274 event_id: String,
275 conversation: ConversationInfo,
276 #[serde(default)]
277 previous_item_id: Option<String>,
278 },
279
280 #[serde(rename = "response.audio.delta")]
282 ResponseAudioDelta {
283 event_id: String,
284 response_id: String,
285 item_id: String,
286 output_index: u32,
287 content_index: u32,
288 delta: String, },
290
291 #[serde(rename = "response.output_audio.delta")]
293 ResponseOutputAudioDelta {
294 event_id: String,
295 response_id: String,
296 item_id: String,
297 output_index: u32,
298 content_index: u32,
299 delta: String,
300 },
301
302 #[serde(rename = "response.function_call_arguments.delta")]
304 ResponseFunctionCallArgumentsDelta {
305 event_id: String,
306 response_id: String,
307 item_id: String,
308 output_index: u32,
309 call_id: String,
310 delta: String,
311 },
312
313 #[serde(rename = "response.function_call_arguments.done")]
315 ResponseFunctionCallArgumentsDone {
316 event_id: String,
317 response_id: String,
318 item_id: String,
319 output_index: u32,
320 call_id: String,
321 name: String,
322 arguments: String,
323 },
324
325 #[serde(rename = "response.done")]
327 ResponseDone {
328 event_id: String,
329 response_id: String,
330 #[serde(default)]
331 response: Option<ResponseInfo>,
332 },
333
334 #[serde(rename = "input_audio_buffer.speech_started")]
336 InputAudioBufferSpeechStarted {
337 event_id: String,
338 audio_start_ms: u64,
339 item_id: String,
340 },
341
342 #[serde(rename = "input_audio_buffer.speech_stopped")]
344 InputAudioBufferSpeechStopped {
345 event_id: String,
346 audio_end_ms: u64,
347 item_id: String,
348 },
349
350 #[serde(rename = "input_audio_buffer.committed")]
352 InputAudioBufferCommitted {
353 event_id: String,
354 item_id: String,
355 previous_item_id: Option<String>,
356 },
357
358 #[serde(rename = "conversation.item.input_audio_transcription.completed")]
360 InputAudioTranscriptionCompleted {
361 event_id: String,
362 item_id: String,
363 transcript: String,
364 content_index: u32,
365 status: String,
366 #[serde(default)]
367 previous_item_id: Option<String>,
368 },
369
370 #[serde(rename = "response.output_audio_transcript.delta")]
372 ResponseOutputAudioTranscriptDelta {
373 event_id: String,
374 item_id: String,
375 response_id: String,
376 delta: String,
377 content_index: u32,
378 output_index: u32,
379 #[serde(default)]
380 start_time: Option<f32>,
381 #[serde(default)]
382 previous_item_id: Option<String>,
383 },
384
385 #[serde(rename = "response.output_audio_transcript.done")]
387 ResponseOutputAudioTranscriptDone {
388 event_id: String,
389 item_id: String,
390 response_id: String,
391 transcript: String,
392 content_index: u32,
393 output_index: u32,
394 #[serde(default)]
395 previous_item_id: Option<String>,
396 },
397
398 #[serde(rename = "rate_limits.updated")]
400 RateLimitsUpdated {
401 event_id: String,
402 rate_limits: Vec<RateLimit>,
403 },
404
405 #[serde(rename = "error")]
407 Error { event_id: String, error: ErrorInfo },
408
409 #[serde(other)]
411 Unknown,
412}
413
414impl ServerEvent {
415 pub fn audio_delta(&self) -> Option<&str> {
417 match self {
418 Self::ResponseAudioDelta { delta, .. } => Some(delta),
419 Self::ResponseOutputAudioDelta { delta, .. } => Some(delta),
420 _ => None,
421 }
422 }
423
424 pub fn function_call(&self) -> Option<FunctionCall> {
426 match self {
427 Self::ResponseFunctionCallArgumentsDone {
428 call_id,
429 name,
430 arguments,
431 ..
432 } => Some(FunctionCall {
433 call_id: call_id.clone(),
434 name: name.clone(),
435 arguments: arguments.clone(),
436 }),
437 _ => None,
438 }
439 }
440}
441
442#[derive(Debug, Clone)]
443pub struct FunctionCall {
444 pub call_id: String,
445 pub name: String,
446 pub arguments: String,
447}
448
449impl FunctionCall {
450 pub fn to_tool_call_part(&self) -> ToolCallPart {
451 let args = serde_json::from_str::<Value>(&self.arguments)
452 .unwrap_or_else(|_| Value::String(self.arguments.clone()));
453 ToolCallPart {
454 id: self.call_id.clone(),
455 name: self.name.clone(),
456 arguments: args,
457 }
458 }
459}
460
461#[derive(Debug, Clone, Deserialize)]
462pub struct ConversationInfo {
463 pub id: String,
464 #[serde(default)]
465 pub object: Option<String>,
466}
467
468#[derive(Debug, Clone, Deserialize)]
469pub struct SessionInfo {
470 #[serde(default)]
471 pub id: Option<String>,
472 #[serde(default)]
473 pub model: Option<String>,
474 #[serde(default)]
475 pub voice: Option<String>,
476}
477
478#[derive(Debug, Clone, Deserialize)]
479pub struct ResponseInfo {
480 #[serde(default)]
481 pub id: Option<String>,
482 #[serde(default)]
483 pub status: Option<String>,
484}
485
486#[derive(Debug, Clone, Deserialize)]
487pub struct RateLimit {
488 pub name: String,
489 pub limit: u32,
490 pub remaining: u32,
491 pub reset_seconds: f32,
492}
493
494#[derive(Debug, Clone, Deserialize)]
495pub struct ErrorInfo {
496 #[serde(rename = "type")]
497 pub error_type: String,
498 pub code: Option<String>,
499 pub message: String,
500}
501
502#[derive(Debug, Clone)]
504pub struct SessionConfig {
505 pub instructions: String,
506 pub voice: String,
507 pub tools: Vec<GrokToolDefinition>,
508 pub temperature: f32,
509 pub audio_format: AudioFormat,
510 pub turn_detection: TurnDetection,
511}
512
513impl Default for SessionConfig {
514 fn default() -> Self {
515 Self {
516 instructions: "You are a helpful voice assistant.".to_string(),
517 voice: "Ara".to_string(),
518 tools: Vec::new(),
519 temperature: 0.8,
520 audio_format: AudioFormat {
521 format_type: "audio/pcmu".to_string(),
522 rate: None,
523 },
524 turn_detection: TurnDetection::default(),
525 }
526 }
527}
528
529impl SessionConfig {
530 pub fn new(instructions: impl Into<String>) -> Self {
531 Self {
532 instructions: instructions.into(),
533 ..Default::default()
534 }
535 }
536
537 pub fn with_voice(mut self, voice: impl Into<String>) -> Self {
538 self.voice = voice.into();
539 self
540 }
541
542 pub fn with_tools(mut self, tools: Vec<GrokToolDefinition>) -> Self {
543 self.tools = tools;
544 self
545 }
546
547 pub fn with_rustic_tools(mut self, tools: &[crate::tools::ToolDefinition]) -> Self {
548 self.tools = tools.iter().map(GrokToolDefinition::from).collect();
549 self
550 }
551
552 pub fn with_temperature(mut self, temperature: f32) -> Self {
553 self.temperature = temperature;
554 self
555 }
556
557 pub fn with_audio_format(mut self, format_type: impl Into<String>, rate: Option<u32>) -> Self {
558 self.audio_format = AudioFormat {
559 format_type: format_type.into(),
560 rate,
561 };
562 self
563 }
564
565 pub fn with_turn_detection(mut self, detection: TurnDetection) -> Self {
566 self.turn_detection = detection;
567 self
568 }
569
570 pub fn to_update_payload(&self) -> SessionUpdatePayload {
572 SessionUpdatePayload {
573 instructions: Some(self.instructions.clone()),
574 voice: Some(self.voice.clone()),
575 turn_detection: Some(self.turn_detection.clone()),
576 tools: if self.tools.is_empty() {
577 None
578 } else {
579 Some(self.tools.clone())
580 },
581 temperature: Some(self.temperature),
582 audio: Some(AudioConfig {
583 input: AudioChannelConfig {
584 format: self.audio_format.clone(),
585 },
586 output: AudioChannelConfig {
587 format: self.audio_format.clone(),
588 },
589 }),
590 }
591 }
592}
593
594#[derive(Clone)]
596pub struct GrokSender {
597 tx: mpsc::Sender<ClientEvent>,
598}
599
600impl GrokSender {
601 pub async fn send_audio(&self, audio_base64: String) -> Result<()> {
603 self.tx
604 .send(ClientEvent::InputAudioBufferAppend {
605 event_id: None,
606 audio: audio_base64,
607 })
608 .await
609 .map_err(|_| Error::ConnectionClosed)
610 }
611
612 pub async fn send_tool_result(&self, call_id: String, result: String) -> Result<()> {
614 self.tx
615 .send(ClientEvent::ConversationItemCreate {
616 event_id: None,
617 item: ConversationItem::function_call_output(call_id, result),
618 })
619 .await
620 .map_err(|_| Error::ConnectionClosed)?;
621
622 self.tx
623 .send(ClientEvent::ResponseCreate {
624 event_id: None,
625 response: None,
626 })
627 .await
628 .map_err(|_| Error::ConnectionClosed)
629 }
630
631 pub async fn send_user_text(&self, text: String) -> Result<()> {
633 self.tx
634 .send(ClientEvent::ConversationItemCreate {
635 event_id: None,
636 item: ConversationItem::user_text(text),
637 })
638 .await
639 .map_err(|_| Error::ConnectionClosed)
640 }
641
642 pub async fn request_response(&self, modalities: Option<Vec<String>>) -> Result<()> {
644 self.tx
645 .send(ClientEvent::ResponseCreate {
646 event_id: None,
647 response: Some(ResponseCreatePayload { modalities }),
648 })
649 .await
650 .map_err(|_| Error::ConnectionClosed)
651 }
652
653 pub async fn cancel_response(&self) -> Result<()> {
655 self.tx
656 .send(ClientEvent::ResponseCancel { event_id: None })
657 .await
658 .map_err(|_| Error::ConnectionClosed)
659 }
660
661 pub async fn commit_audio(&self) -> Result<()> {
663 self.tx
664 .send(ClientEvent::ConversationItemCommit { event_id: None })
665 .await
666 .map_err(|_| Error::ConnectionClosed)
667 }
668}
669
670pub struct GrokClient {
672 ws_url: String,
673 api_key: String,
674}
675
676impl GrokClient {
677 pub fn new(ws_url: String, api_key: String) -> Self {
678 Self { ws_url, api_key }
679 }
680
681 pub async fn connect(
687 &self,
688 session_config: SessionConfig,
689 ) -> Result<(GrokSender, mpsc::Receiver<ServerEvent>)> {
690 let request = Request::builder()
691 .uri(&self.ws_url)
692 .header("Authorization", format!("Bearer {}", self.api_key))
693 .header("Sec-WebSocket-Key", generate_ws_key())
694 .header("Sec-WebSocket-Version", "13")
695 .header("Connection", "Upgrade")
696 .header("Upgrade", "websocket")
697 .header("Host", extract_host(&self.ws_url))
698 .body(())
699 .map_err(|e| Error::Provider(format!("failed to build request: {e}")))?;
700
701 info!(url = %self.ws_url, "Connecting to Grok Realtime API");
702
703 let (ws_stream, _response) = connect_async(request)
704 .await
705 .map_err(|e| Error::Provider(format!("websocket connection failed: {e}")))?;
706
707 info!("Connected to Grok Realtime API");
708
709 let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
710
711 let (client_tx, mut client_rx) = mpsc::channel::<ClientEvent>(256);
712 let (server_tx, server_rx) = mpsc::channel::<ServerEvent>(256);
713
714 let session_update = ClientEvent::SessionUpdate {
715 session: session_config.to_update_payload(),
716 };
717 let msg = serde_json::to_string(&session_update)?;
718 ws_sink
719 .send(Message::Text(msg))
720 .await
721 .map_err(|e| Error::Provider(format!("failed to send session update: {e}")))?;
722 debug!("Sent session.update");
723
724 tokio::spawn(async move {
725 while let Some(event) = client_rx.recv().await {
726 match serde_json::to_string(&event) {
727 Ok(msg) => {
728 if let Err(e) = ws_sink.send(Message::Text(msg)).await {
729 error!(error = %e, "Failed to send to Grok WebSocket");
730 break;
731 }
732 }
733 Err(e) => {
734 error!(error = %e, "Failed to serialize client event");
735 }
736 }
737 }
738 debug!("Grok sender task ended");
739 });
740
741 tokio::spawn(async move {
742 while let Some(msg_result) = ws_stream_rx.next().await {
743 match msg_result {
744 Ok(Message::Text(text)) => match serde_json::from_str::<Value>(&text) {
745 Ok(value) => {
746 let event_type = value
747 .get("type")
748 .and_then(|val| val.as_str())
749 .unwrap_or("unknown");
750 match serde_json::from_value::<ServerEvent>(value.clone()) {
751 Ok(event) => {
752 if matches!(event, ServerEvent::Unknown) {
753 trace!(event_type = %event_type, raw = %text, "Unhandled Grok event");
754 } else if event.audio_delta().is_none() {
755 debug!(?event, "Received Grok event");
756 }
757 if server_tx.send(event).await.is_err() {
758 debug!("Server event receiver dropped");
759 break;
760 }
761 }
762 Err(e) => {
763 warn!(
764 error = %e,
765 event_type = %event_type,
766 "Failed to parse Grok event"
767 );
768 trace!(raw = %text, "Grok event parse failure payload");
769 }
770 }
771 }
772 Err(e) => {
773 warn!(error = %e, "Failed to parse Grok event");
774 trace!(raw = %text, "Grok event parse failure payload");
775 }
776 },
777 Ok(Message::Close(_)) => {
778 info!("Grok WebSocket closed");
779 break;
780 }
781 Ok(Message::Ping(data)) => {
782 debug!("Received ping from Grok");
783 let _ = data;
784 }
785 Ok(_) => {}
786 Err(e) => {
787 error!(error = %e, "Grok WebSocket error");
788 break;
789 }
790 }
791 }
792 debug!("Grok receiver task ended");
793 });
794
795 Ok((GrokSender { tx: client_tx }, server_rx))
796 }
797}
798
799fn generate_ws_key() -> String {
800 let mut key = [0u8; 16];
801 for (i, byte) in key.iter_mut().enumerate() {
802 let now = SystemTime::now()
803 .duration_since(UNIX_EPOCH)
804 .unwrap_or(Duration::from_secs(0));
805 *byte = (now.as_nanos() as u8).wrapping_add(i as u8);
806 }
807 base64::engine::general_purpose::STANDARD.encode(key)
808}
809
810fn extract_host(url: &str) -> String {
811 url.replace("wss://", "")
812 .replace("ws://", "")
813 .split('/')
814 .next()
815 .unwrap_or("api.x.ai")
816 .to_string()
817}