1use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use base64::Engine as _;
8use futures::{SinkExt, StreamExt};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use thiserror::Error;
12use tokio::sync::mpsc;
13use tokio_tungstenite::{
14 connect_async,
15 tungstenite::{Message, http::Request},
16};
17use tracing::{debug, error, info, trace, warn};
18
19use crate::messages::ToolCallPart;
20
21#[derive(Debug, Error)]
22pub enum Error {
23 #[error("connection closed")]
24 ConnectionClosed,
25 #[error("serialization error: {0}")]
26 Serialization(String),
27 #[error("websocket error: {0}")]
28 WebSocket(String),
29 #[error("provider error: {0}")]
30 Provider(String),
31}
32
33impl From<serde_json::Error> for Error {
34 fn from(err: serde_json::Error) -> Self {
35 Self::Serialization(err.to_string())
36 }
37}
38
39impl From<tokio_tungstenite::tungstenite::Error> for Error {
40 fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
41 Self::WebSocket(err.to_string())
42 }
43}
44
45pub type Result<T> = std::result::Result<T, Error>;
46
47#[derive(Debug, Clone, Serialize)]
49#[serde(tag = "type", rename_all = "snake_case")]
50pub enum ClientEvent {
51 #[serde(rename = "session.update")]
53 SessionUpdate { session: SessionUpdatePayload },
54
55 #[serde(rename = "input_audio_buffer.append")]
57 InputAudioBufferAppend {
58 #[serde(skip_serializing_if = "Option::is_none")]
59 event_id: Option<String>,
60 audio: String, },
62
63 #[serde(rename = "conversation.item.commit")]
65 ConversationItemCommit {
66 #[serde(skip_serializing_if = "Option::is_none")]
67 event_id: Option<String>,
68 },
69
70 #[serde(rename = "input_audio_buffer.clear")]
72 InputAudioBufferClear {
73 #[serde(skip_serializing_if = "Option::is_none")]
74 event_id: Option<String>,
75 },
76
77 #[serde(rename = "conversation.item.create")]
79 ConversationItemCreate {
80 #[serde(skip_serializing_if = "Option::is_none")]
81 event_id: Option<String>,
82 item: ConversationItem,
83 },
84
85 #[serde(rename = "response.create")]
87 ResponseCreate {
88 #[serde(skip_serializing_if = "Option::is_none")]
89 event_id: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 response: Option<ResponseCreatePayload>,
92 },
93
94 #[serde(rename = "response.cancel")]
96 ResponseCancel {
97 #[serde(skip_serializing_if = "Option::is_none")]
98 event_id: Option<String>,
99 },
100}
101
102#[derive(Debug, Clone, Serialize)]
103pub struct SessionUpdatePayload {
104 #[serde(skip_serializing_if = "Option::is_none")]
105 pub instructions: Option<String>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 pub voice: Option<String>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 pub turn_detection: Option<TurnDetection>,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 pub tools: Option<Vec<GrokToolDefinition>>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 pub temperature: Option<f32>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 pub audio: Option<AudioConfig>,
116}
117
118#[derive(Debug, Clone, Serialize)]
119pub struct TurnDetection {
120 #[serde(rename = "type")]
121 pub detection_type: String, #[serde(skip_serializing_if = "Option::is_none")]
123 pub threshold: Option<f32>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub prefix_padding_ms: Option<u32>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 pub silence_duration_ms: Option<u32>,
128}
129
130impl Default for TurnDetection {
131 fn default() -> Self {
132 Self {
133 detection_type: "server_vad".to_string(),
134 threshold: Some(0.5),
135 prefix_padding_ms: Some(300),
136 silence_duration_ms: Some(200),
137 }
138 }
139}
140
141#[derive(Debug, Clone, Serialize)]
142pub struct AudioConfig {
143 pub input: AudioChannelConfig,
144 pub output: AudioChannelConfig,
145}
146
147#[derive(Debug, Clone, Serialize)]
148pub struct AudioChannelConfig {
149 pub format: AudioFormat,
150}
151
152#[derive(Debug, Clone, Serialize)]
153pub struct AudioFormat {
154 #[serde(rename = "type")]
155 pub format_type: String, #[serde(skip_serializing_if = "Option::is_none")]
157 pub rate: Option<u32>,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct GrokToolDefinition {
162 #[serde(rename = "type")]
163 pub tool_type: String, pub name: String,
165 #[serde(skip_serializing_if = "Option::is_none")]
166 pub description: Option<String>,
167 #[serde(skip_serializing_if = "Option::is_none")]
168 pub parameters: Option<Value>, }
170
171impl GrokToolDefinition {
172 pub fn function(
173 name: impl Into<String>,
174 description: impl Into<String>,
175 parameters: Value,
176 ) -> Self {
177 Self {
178 tool_type: "function".to_string(),
179 name: name.into(),
180 description: Some(description.into()),
181 parameters: Some(parameters),
182 }
183 }
184}
185
186impl From<&crate::tools::ToolDefinition> for GrokToolDefinition {
187 fn from(tool: &crate::tools::ToolDefinition) -> Self {
188 Self {
189 tool_type: "function".to_string(),
190 name: tool.name.clone(),
191 description: tool.description.clone(),
192 parameters: Some(tool.parameters_json_schema.clone()),
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize)]
198pub struct ConversationItem {
199 #[serde(rename = "type")]
200 pub item_type: String,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 pub id: Option<String>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 pub call_id: Option<String>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub output: Option<String>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub role: Option<String>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 pub content: Option<Vec<ContentPart>>,
211}
212
213impl ConversationItem {
214 pub fn function_call_output(call_id: String, output: String) -> Self {
216 Self {
217 item_type: "function_call_output".to_string(),
218 id: None,
219 call_id: Some(call_id),
220 output: Some(output),
221 role: None,
222 content: None,
223 }
224 }
225
226 pub fn user_text(text: impl Into<String>) -> Self {
228 Self {
229 item_type: "message".to_string(),
230 id: None,
231 call_id: None,
232 output: None,
233 role: Some("user".to_string()),
234 content: Some(vec![ContentPart {
235 content_type: "input_text".to_string(),
236 text: Some(text.into()),
237 audio: None,
238 }]),
239 }
240 }
241}
242
243#[derive(Debug, Clone, Serialize)]
244pub struct ContentPart {
245 #[serde(rename = "type")]
246 pub content_type: String,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub text: Option<String>,
249 #[serde(skip_serializing_if = "Option::is_none")]
250 pub audio: Option<String>,
251}
252
253#[derive(Debug, Clone, Serialize)]
254pub struct ResponseCreatePayload {
255 #[serde(skip_serializing_if = "Option::is_none")]
256 pub modalities: Option<Vec<String>>,
257}
258
259#[derive(Debug, Clone, Deserialize)]
261#[serde(tag = "type", rename_all = "snake_case")]
262pub enum ServerEvent {
263 #[serde(rename = "session.created")]
265 SessionCreated { session: SessionInfo },
266
267 #[serde(rename = "session.updated")]
269 SessionUpdated { session: SessionInfo },
270
271 #[serde(rename = "conversation.created")]
273 ConversationCreated {
274 event_id: String,
275 conversation: ConversationInfo,
276 #[serde(default)]
277 previous_item_id: Option<String>,
278 },
279
280 #[serde(rename = "response.audio.delta")]
282 ResponseAudioDelta {
283 event_id: String,
284 response_id: String,
285 item_id: String,
286 output_index: u32,
287 content_index: u32,
288 delta: String, },
290
291 #[serde(rename = "response.output_audio.delta")]
293 ResponseOutputAudioDelta {
294 event_id: String,
295 response_id: String,
296 item_id: String,
297 output_index: u32,
298 content_index: u32,
299 delta: String,
300 },
301
302 #[serde(rename = "response.function_call_arguments.delta")]
304 ResponseFunctionCallArgumentsDelta {
305 event_id: String,
306 response_id: String,
307 item_id: String,
308 output_index: u32,
309 call_id: String,
310 delta: String,
311 },
312
313 #[serde(rename = "response.function_call_arguments.done")]
315 ResponseFunctionCallArgumentsDone {
316 event_id: String,
317 response_id: String,
318 item_id: String,
319 output_index: u32,
320 call_id: String,
321 name: String,
322 arguments: String,
323 },
324
325 #[serde(rename = "response.done")]
327 ResponseDone {
328 event_id: String,
329 response_id: String,
330 #[serde(default)]
331 response: Option<ResponseInfo>,
332 },
333
334 #[serde(rename = "input_audio_buffer.speech_started")]
336 InputAudioBufferSpeechStarted {
337 event_id: String,
338 audio_start_ms: u64,
339 item_id: String,
340 },
341
342 #[serde(rename = "input_audio_buffer.speech_stopped")]
344 InputAudioBufferSpeechStopped {
345 event_id: String,
346 audio_end_ms: u64,
347 item_id: String,
348 },
349
350 #[serde(rename = "input_audio_buffer.committed")]
352 InputAudioBufferCommitted {
353 event_id: String,
354 item_id: String,
355 previous_item_id: Option<String>,
356 },
357
358 #[serde(rename = "conversation.item.input_audio_transcription.completed")]
360 InputAudioTranscriptionCompleted {
361 event_id: String,
362 item_id: String,
363 transcript: String,
364 content_index: u32,
365 status: String,
366 #[serde(default)]
367 previous_item_id: Option<String>,
368 },
369
370 #[serde(rename = "response.output_audio_transcript.delta")]
372 ResponseOutputAudioTranscriptDelta {
373 event_id: String,
374 item_id: String,
375 response_id: String,
376 delta: String,
377 content_index: u32,
378 output_index: u32,
379 #[serde(default)]
380 start_time: Option<f32>,
381 #[serde(default)]
382 previous_item_id: Option<String>,
383 },
384
385 #[serde(rename = "response.output_audio_transcript.done")]
387 ResponseOutputAudioTranscriptDone {
388 event_id: String,
389 item_id: String,
390 response_id: String,
391 transcript: String,
392 content_index: u32,
393 output_index: u32,
394 #[serde(default)]
395 previous_item_id: Option<String>,
396 },
397
398 #[serde(rename = "rate_limits.updated")]
400 RateLimitsUpdated {
401 event_id: String,
402 rate_limits: Vec<RateLimit>,
403 },
404
405 #[serde(rename = "error")]
407 Error { event_id: String, error: ErrorInfo },
408
409 #[serde(other)]
411 Unknown,
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use serde_json::{Value, json};
418
419 #[test]
420 fn session_update_serializes() {
421 let event = ClientEvent::SessionUpdate {
422 session: SessionUpdatePayload {
423 instructions: Some("be concise".to_string()),
424 voice: Some("alloy".to_string()),
425 turn_detection: Some(TurnDetection::default()),
426 tools: Some(vec![GrokToolDefinition::function(
427 "echo",
428 "echo back",
429 json!({"type": "object", "properties": {}}),
430 )]),
431 temperature: Some(0.3),
432 audio: Some(AudioConfig {
433 input: AudioChannelConfig {
434 format: AudioFormat {
435 format_type: "audio/pcm".to_string(),
436 rate: Some(16_000),
437 },
438 },
439 output: AudioChannelConfig {
440 format: AudioFormat {
441 format_type: "audio/pcm".to_string(),
442 rate: Some(16_000),
443 },
444 },
445 }),
446 },
447 };
448
449 let value = serde_json::to_value(event).expect("serialize");
450 assert_eq!(
451 value.get("type"),
452 Some(&Value::String("session.update".to_string()))
453 );
454 assert_eq!(
455 value
456 .get("session")
457 .and_then(|v| v.get("instructions"))
458 .and_then(|v| v.as_str()),
459 Some("be concise")
460 );
461 assert_eq!(
462 value
463 .get("session")
464 .and_then(|v| v.get("voice"))
465 .and_then(|v| v.as_str()),
466 Some("alloy")
467 );
468 }
469
470 #[test]
471 fn conversation_item_helpers_build_expected_shapes() {
472 let output = ConversationItem::function_call_output("call-1".to_string(), "ok".to_string());
473 let output_value = serde_json::to_value(output).expect("serialize output");
474 assert_eq!(
475 output_value.get("type"),
476 Some(&Value::String("function_call_output".to_string()))
477 );
478 assert_eq!(
479 output_value.get("call_id"),
480 Some(&Value::String("call-1".to_string()))
481 );
482 assert_eq!(
483 output_value.get("output"),
484 Some(&Value::String("ok".to_string()))
485 );
486
487 let user = ConversationItem::user_text("hello");
488 let user_value = serde_json::to_value(user).expect("serialize user");
489 assert_eq!(
490 user_value.get("type"),
491 Some(&Value::String("message".to_string()))
492 );
493 assert_eq!(
494 user_value.get("role"),
495 Some(&Value::String("user".to_string()))
496 );
497 let content = user_value
498 .get("content")
499 .and_then(|v| v.as_array())
500 .expect("content array");
501 assert_eq!(
502 content[0].get("type"),
503 Some(&Value::String("input_text".to_string()))
504 );
505 }
506
507 #[test]
508 fn tool_definition_from_tool() {
509 let tool = crate::tools::ToolDefinition::new(
510 "tool",
511 Some("desc".to_string()),
512 json!({"type": "object", "properties": {}}),
513 );
514 let def: GrokToolDefinition = (&tool).into();
515 assert_eq!(def.tool_type, "function");
516 assert_eq!(def.name, "tool");
517 assert_eq!(def.description.as_deref(), Some("desc"));
518 assert!(def.parameters.is_some());
519 }
520
521 #[test]
522 fn server_event_helpers_extract_audio_and_tool_calls() {
523 let audio_event = ServerEvent::ResponseAudioDelta {
524 event_id: "evt".to_string(),
525 response_id: "resp".to_string(),
526 item_id: "item".to_string(),
527 output_index: 0,
528 content_index: 0,
529 delta: "audio".to_string(),
530 };
531 assert_eq!(audio_event.audio_delta(), Some("audio"));
532 assert!(audio_event.function_call().is_none());
533
534 let output_audio_event = ServerEvent::ResponseOutputAudioDelta {
535 event_id: "evt".to_string(),
536 response_id: "resp".to_string(),
537 item_id: "item".to_string(),
538 output_index: 0,
539 content_index: 0,
540 delta: "audio2".to_string(),
541 };
542 assert_eq!(output_audio_event.audio_delta(), Some("audio2"));
543
544 let call_event = ServerEvent::ResponseFunctionCallArgumentsDone {
545 event_id: "evt".to_string(),
546 response_id: "resp".to_string(),
547 item_id: "item".to_string(),
548 output_index: 0,
549 call_id: "call".to_string(),
550 name: "tool".to_string(),
551 arguments: "{\"a\":1}".to_string(),
552 };
553 let call = call_event.function_call().expect("function call");
554 assert_eq!(call.call_id, "call");
555 assert_eq!(call.name, "tool");
556 }
557
558 #[test]
559 fn function_call_to_tool_call_part_parses_json_or_string() {
560 let call = FunctionCall {
561 call_id: "call-1".to_string(),
562 name: "tool".to_string(),
563 arguments: "{\"a\":1}".to_string(),
564 };
565 let part = call.to_tool_call_part();
566 assert_eq!(part.name, "tool");
567 assert_eq!(part.arguments, json!({"a": 1}));
568
569 let call = FunctionCall {
570 call_id: "call-2".to_string(),
571 name: "tool".to_string(),
572 arguments: "not-json".to_string(),
573 };
574 let part = call.to_tool_call_part();
575 assert_eq!(part.arguments, Value::String("not-json".to_string()));
576 }
577
578 #[test]
579 fn session_config_builders_populate_payload() {
580 let config = SessionConfig::new("hello")
581 .with_voice("Nova")
582 .with_temperature(0.4)
583 .with_audio_format("audio/pcm", Some(16_000))
584 .with_turn_detection(TurnDetection::default());
585 let payload = config.to_update_payload();
586 assert_eq!(payload.instructions.as_deref(), Some("hello"));
587 assert_eq!(payload.voice.as_deref(), Some("Nova"));
588 assert!(payload.tools.is_none());
589 assert_eq!(payload.temperature, Some(0.4));
590 let audio = payload.audio.expect("audio");
591 assert_eq!(audio.input.format.format_type, "audio/pcm");
592 assert_eq!(audio.input.format.rate, Some(16_000));
593
594 let tools = vec![GrokToolDefinition::function(
595 "echo",
596 "Echo back",
597 json!({"type": "object"}),
598 )];
599 let config = SessionConfig::default().with_tools(tools.clone());
600 let payload = config.to_update_payload();
601 assert!(payload.tools.is_some());
602 assert_eq!(payload.tools.unwrap().len(), tools.len());
603 }
604
605 #[tokio::test]
606 async fn grok_sender_emits_events() {
607 let (tx, mut rx) = mpsc::channel(10);
608 let sender = GrokSender { tx };
609
610 sender
611 .send_audio("audio".to_string())
612 .await
613 .expect("send audio");
614 match rx.recv().await.expect("audio event") {
615 ClientEvent::InputAudioBufferAppend { audio, .. } => {
616 assert_eq!(audio, "audio");
617 }
618 other => panic!("unexpected event: {other:?}"),
619 }
620
621 sender
622 .send_user_text("hello".to_string())
623 .await
624 .expect("send text");
625 match rx.recv().await.expect("user event") {
626 ClientEvent::ConversationItemCreate { item, .. } => {
627 assert_eq!(item.item_type, "message");
628 assert_eq!(item.role.as_deref(), Some("user"));
629 }
630 other => panic!("unexpected event: {other:?}"),
631 }
632
633 sender
634 .send_tool_result("call-1".to_string(), "ok".to_string())
635 .await
636 .expect("send tool result");
637 match rx.recv().await.expect("tool result") {
638 ClientEvent::ConversationItemCreate { item, .. } => {
639 assert_eq!(item.item_type, "function_call_output");
640 assert_eq!(item.call_id.as_deref(), Some("call-1"));
641 }
642 other => panic!("unexpected event: {other:?}"),
643 }
644 match rx.recv().await.expect("response create") {
645 ClientEvent::ResponseCreate { response, .. } => {
646 assert!(response.is_none());
647 }
648 other => panic!("unexpected event: {other:?}"),
649 }
650
651 sender
652 .request_response(Some(vec!["text".to_string()]))
653 .await
654 .expect("request response");
655 match rx.recv().await.expect("response create") {
656 ClientEvent::ResponseCreate { response, .. } => {
657 let response = response.expect("response payload");
658 assert_eq!(response.modalities, Some(vec!["text".to_string()]));
659 }
660 other => panic!("unexpected event: {other:?}"),
661 }
662
663 sender.cancel_response().await.expect("cancel response");
664 match rx.recv().await.expect("cancel event") {
665 ClientEvent::ResponseCancel { .. } => {}
666 other => panic!("unexpected event: {other:?}"),
667 }
668
669 sender.commit_audio().await.expect("commit audio");
670 match rx.recv().await.expect("commit event") {
671 ClientEvent::ConversationItemCommit { .. } => {}
672 other => panic!("unexpected event: {other:?}"),
673 }
674 }
675
676 #[test]
677 fn misc_helpers_cover_key_generation_and_host_extraction() {
678 let key = generate_ws_key();
679 let decoded = base64::engine::general_purpose::STANDARD
680 .decode(key.as_bytes())
681 .expect("decode");
682 assert_eq!(decoded.len(), 16);
683
684 assert_eq!(
685 extract_host("wss://api.x.ai/v1/realtime"),
686 "api.x.ai".to_string()
687 );
688 assert_eq!(
689 extract_host("ws://localhost:8080/socket"),
690 "localhost:8080".to_string()
691 );
692
693 let detection = TurnDetection::default();
694 assert_eq!(detection.detection_type, "server_vad");
695 assert_eq!(detection.threshold, Some(0.5));
696 }
697
698 #[test]
699 fn tool_definition_constructor_sets_fields() {
700 let def = GrokToolDefinition::function(
701 "tool",
702 "desc",
703 json!({"type": "object", "properties": {}}),
704 );
705 assert_eq!(def.tool_type, "function");
706 assert_eq!(def.name, "tool");
707 assert_eq!(def.description.as_deref(), Some("desc"));
708 assert!(def.parameters.is_some());
709 }
710}
711
712impl ServerEvent {
713 pub fn audio_delta(&self) -> Option<&str> {
715 match self {
716 Self::ResponseAudioDelta { delta, .. } => Some(delta),
717 Self::ResponseOutputAudioDelta { delta, .. } => Some(delta),
718 _ => None,
719 }
720 }
721
722 pub fn function_call(&self) -> Option<FunctionCall> {
724 match self {
725 Self::ResponseFunctionCallArgumentsDone {
726 call_id,
727 name,
728 arguments,
729 ..
730 } => Some(FunctionCall {
731 call_id: call_id.clone(),
732 name: name.clone(),
733 arguments: arguments.clone(),
734 }),
735 _ => None,
736 }
737 }
738}
739
740#[derive(Debug, Clone)]
741pub struct FunctionCall {
742 pub call_id: String,
743 pub name: String,
744 pub arguments: String,
745}
746
747impl FunctionCall {
748 pub fn to_tool_call_part(&self) -> ToolCallPart {
749 let args = serde_json::from_str::<Value>(&self.arguments)
750 .unwrap_or_else(|_| Value::String(self.arguments.clone()));
751 ToolCallPart {
752 id: self.call_id.clone(),
753 name: self.name.clone(),
754 arguments: args,
755 }
756 }
757}
758
759#[derive(Debug, Clone, Deserialize)]
760pub struct ConversationInfo {
761 pub id: String,
762 #[serde(default)]
763 pub object: Option<String>,
764}
765
766#[derive(Debug, Clone, Deserialize)]
767pub struct SessionInfo {
768 #[serde(default)]
769 pub id: Option<String>,
770 #[serde(default)]
771 pub model: Option<String>,
772 #[serde(default)]
773 pub voice: Option<String>,
774}
775
776#[derive(Debug, Clone, Deserialize)]
777pub struct ResponseInfo {
778 #[serde(default)]
779 pub id: Option<String>,
780 #[serde(default)]
781 pub status: Option<String>,
782}
783
784#[derive(Debug, Clone, Deserialize)]
785pub struct RateLimit {
786 pub name: String,
787 pub limit: u32,
788 pub remaining: u32,
789 pub reset_seconds: f32,
790}
791
792#[derive(Debug, Clone, Deserialize)]
793pub struct ErrorInfo {
794 #[serde(rename = "type")]
795 pub error_type: String,
796 pub code: Option<String>,
797 pub message: String,
798}
799
800#[derive(Debug, Clone)]
802pub struct SessionConfig {
803 pub instructions: String,
804 pub voice: String,
805 pub tools: Vec<GrokToolDefinition>,
806 pub temperature: f32,
807 pub audio_format: AudioFormat,
808 pub turn_detection: TurnDetection,
809}
810
811impl Default for SessionConfig {
812 fn default() -> Self {
813 Self {
814 instructions: "You are a helpful voice assistant.".to_string(),
815 voice: "Ara".to_string(),
816 tools: Vec::new(),
817 temperature: 0.8,
818 audio_format: AudioFormat {
819 format_type: "audio/pcmu".to_string(),
820 rate: None,
821 },
822 turn_detection: TurnDetection::default(),
823 }
824 }
825}
826
827impl SessionConfig {
828 pub fn new(instructions: impl Into<String>) -> Self {
829 Self {
830 instructions: instructions.into(),
831 ..Default::default()
832 }
833 }
834
835 pub fn with_voice(mut self, voice: impl Into<String>) -> Self {
836 self.voice = voice.into();
837 self
838 }
839
840 pub fn with_tools(mut self, tools: Vec<GrokToolDefinition>) -> Self {
841 self.tools = tools;
842 self
843 }
844
845 pub fn with_rustic_tools(mut self, tools: &[crate::tools::ToolDefinition]) -> Self {
846 self.tools = tools.iter().map(GrokToolDefinition::from).collect();
847 self
848 }
849
850 pub fn with_temperature(mut self, temperature: f32) -> Self {
851 self.temperature = temperature;
852 self
853 }
854
855 pub fn with_audio_format(mut self, format_type: impl Into<String>, rate: Option<u32>) -> Self {
856 self.audio_format = AudioFormat {
857 format_type: format_type.into(),
858 rate,
859 };
860 self
861 }
862
863 pub fn with_turn_detection(mut self, detection: TurnDetection) -> Self {
864 self.turn_detection = detection;
865 self
866 }
867
868 pub fn to_update_payload(&self) -> SessionUpdatePayload {
870 SessionUpdatePayload {
871 instructions: Some(self.instructions.clone()),
872 voice: Some(self.voice.clone()),
873 turn_detection: Some(self.turn_detection.clone()),
874 tools: if self.tools.is_empty() {
875 None
876 } else {
877 Some(self.tools.clone())
878 },
879 temperature: Some(self.temperature),
880 audio: Some(AudioConfig {
881 input: AudioChannelConfig {
882 format: self.audio_format.clone(),
883 },
884 output: AudioChannelConfig {
885 format: self.audio_format.clone(),
886 },
887 }),
888 }
889 }
890}
891
892#[derive(Clone)]
894pub struct GrokSender {
895 tx: mpsc::Sender<ClientEvent>,
896}
897
898impl GrokSender {
899 pub async fn send_audio(&self, audio_base64: String) -> Result<()> {
901 self.tx
902 .send(ClientEvent::InputAudioBufferAppend {
903 event_id: None,
904 audio: audio_base64,
905 })
906 .await
907 .map_err(|_| Error::ConnectionClosed)
908 }
909
910 pub async fn send_tool_result(&self, call_id: String, result: String) -> Result<()> {
912 self.tx
913 .send(ClientEvent::ConversationItemCreate {
914 event_id: None,
915 item: ConversationItem::function_call_output(call_id, result),
916 })
917 .await
918 .map_err(|_| Error::ConnectionClosed)?;
919
920 self.tx
921 .send(ClientEvent::ResponseCreate {
922 event_id: None,
923 response: None,
924 })
925 .await
926 .map_err(|_| Error::ConnectionClosed)
927 }
928
929 pub async fn send_user_text(&self, text: String) -> Result<()> {
931 self.tx
932 .send(ClientEvent::ConversationItemCreate {
933 event_id: None,
934 item: ConversationItem::user_text(text),
935 })
936 .await
937 .map_err(|_| Error::ConnectionClosed)
938 }
939
940 pub async fn request_response(&self, modalities: Option<Vec<String>>) -> Result<()> {
942 self.tx
943 .send(ClientEvent::ResponseCreate {
944 event_id: None,
945 response: Some(ResponseCreatePayload { modalities }),
946 })
947 .await
948 .map_err(|_| Error::ConnectionClosed)
949 }
950
951 pub async fn cancel_response(&self) -> Result<()> {
953 self.tx
954 .send(ClientEvent::ResponseCancel { event_id: None })
955 .await
956 .map_err(|_| Error::ConnectionClosed)
957 }
958
959 pub async fn commit_audio(&self) -> Result<()> {
961 self.tx
962 .send(ClientEvent::ConversationItemCommit { event_id: None })
963 .await
964 .map_err(|_| Error::ConnectionClosed)
965 }
966}
967
968pub struct GrokClient {
970 ws_url: String,
971 api_key: String,
972}
973
974impl GrokClient {
975 pub fn new(ws_url: String, api_key: String) -> Self {
976 Self { ws_url, api_key }
977 }
978
979 pub async fn connect(
985 &self,
986 session_config: SessionConfig,
987 ) -> Result<(GrokSender, mpsc::Receiver<ServerEvent>)> {
988 let request = Request::builder()
989 .uri(&self.ws_url)
990 .header("Authorization", format!("Bearer {}", self.api_key))
991 .header("Sec-WebSocket-Key", generate_ws_key())
992 .header("Sec-WebSocket-Version", "13")
993 .header("Connection", "Upgrade")
994 .header("Upgrade", "websocket")
995 .header("Host", extract_host(&self.ws_url))
996 .body(())
997 .map_err(|e| Error::Provider(format!("failed to build request: {e}")))?;
998
999 info!(url = %self.ws_url, "Connecting to Grok Realtime API");
1000
1001 let (ws_stream, _response) = connect_async(request)
1002 .await
1003 .map_err(|e| Error::Provider(format!("websocket connection failed: {e}")))?;
1004
1005 info!("Connected to Grok Realtime API");
1006
1007 let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
1008
1009 let (client_tx, mut client_rx) = mpsc::channel::<ClientEvent>(256);
1010 let (server_tx, server_rx) = mpsc::channel::<ServerEvent>(256);
1011
1012 let session_update = ClientEvent::SessionUpdate {
1013 session: session_config.to_update_payload(),
1014 };
1015 let msg = serde_json::to_string(&session_update)?;
1016 ws_sink
1017 .send(Message::Text(msg))
1018 .await
1019 .map_err(|e| Error::Provider(format!("failed to send session update: {e}")))?;
1020 debug!("Sent session.update");
1021
1022 tokio::spawn(async move {
1023 while let Some(event) = client_rx.recv().await {
1024 match serde_json::to_string(&event) {
1025 Ok(msg) => {
1026 if let Err(e) = ws_sink.send(Message::Text(msg)).await {
1027 error!(error = %e, "Failed to send to Grok WebSocket");
1028 break;
1029 }
1030 }
1031 Err(e) => {
1032 error!(error = %e, "Failed to serialize client event");
1033 }
1034 }
1035 }
1036 debug!("Grok sender task ended");
1037 });
1038
1039 tokio::spawn(async move {
1040 while let Some(msg_result) = ws_stream_rx.next().await {
1041 match msg_result {
1042 Ok(Message::Text(text)) => match serde_json::from_str::<Value>(&text) {
1043 Ok(value) => {
1044 let event_type = value
1045 .get("type")
1046 .and_then(|val| val.as_str())
1047 .unwrap_or("unknown");
1048 match serde_json::from_value::<ServerEvent>(value.clone()) {
1049 Ok(event) => {
1050 if matches!(event, ServerEvent::Unknown) {
1051 trace!(event_type = %event_type, raw = %text, "Unhandled Grok event");
1052 } else if event.audio_delta().is_none() {
1053 debug!(?event, "Received Grok event");
1054 }
1055 if server_tx.send(event).await.is_err() {
1056 debug!("Server event receiver dropped");
1057 break;
1058 }
1059 }
1060 Err(e) => {
1061 warn!(
1062 error = %e,
1063 event_type = %event_type,
1064 "Failed to parse Grok event"
1065 );
1066 trace!(raw = %text, "Grok event parse failure payload");
1067 }
1068 }
1069 }
1070 Err(e) => {
1071 warn!(error = %e, "Failed to parse Grok event");
1072 trace!(raw = %text, "Grok event parse failure payload");
1073 }
1074 },
1075 Ok(Message::Close(_)) => {
1076 info!("Grok WebSocket closed");
1077 break;
1078 }
1079 Ok(Message::Ping(data)) => {
1080 debug!("Received ping from Grok");
1081 let _ = data;
1082 }
1083 Ok(_) => {}
1084 Err(e) => {
1085 error!(error = %e, "Grok WebSocket error");
1086 break;
1087 }
1088 }
1089 }
1090 debug!("Grok receiver task ended");
1091 });
1092
1093 Ok((GrokSender { tx: client_tx }, server_rx))
1094 }
1095}
1096
1097fn generate_ws_key() -> String {
1098 let mut key = [0u8; 16];
1099 for (i, byte) in key.iter_mut().enumerate() {
1100 let now = SystemTime::now()
1101 .duration_since(UNIX_EPOCH)
1102 .unwrap_or(Duration::from_secs(0));
1103 *byte = (now.as_nanos() as u8).wrapping_add(i as u8);
1104 }
1105 base64::engine::general_purpose::STANDARD.encode(key)
1106}
1107
1108fn extract_host(url: &str) -> String {
1109 url.replace("wss://", "")
1110 .replace("ws://", "")
1111 .split('/')
1112 .next()
1113 .unwrap_or("api.x.ai")
1114 .to_string()
1115}