Skip to main content

xai_rust/models/
voice.rs

1//! Voice/Realtime API types.
2
3use serde::{Deserialize, Serialize};
4
5use super::tool::Tool;
6
7/// Voice options for the realtime API.
8#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "lowercase")]
10pub enum Voice {
11    /// Female, warm and friendly (default).
12    #[default]
13    Ara,
14    /// Male, confident and clear.
15    Rex,
16    /// Neutral, smooth and balanced.
17    Sal,
18    /// Female, energetic and upbeat.
19    Eve,
20    /// Male, authoritative and strong.
21    Leo,
22}
23
24/// Audio format for realtime API.
25#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
26#[serde(rename_all = "lowercase")]
27pub enum AudioFormat {
28    /// PCM 16-bit (configurable sample rate).
29    #[default]
30    Pcm16,
31    /// G.711 μ-law (for telephony).
32    #[serde(rename = "g711_ulaw")]
33    G711Ulaw,
34    /// G.711 A-law (international telephony).
35    #[serde(rename = "g711_alaw")]
36    G711Alaw,
37}
38
39/// Session configuration for realtime API.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct SessionConfig {
42    /// The model to use.
43    pub model: String,
44    /// Voice to use for responses.
45    #[serde(default)]
46    pub voice: Voice,
47    /// Input audio format.
48    #[serde(default)]
49    pub input_audio_format: AudioFormat,
50    /// Output audio format.
51    #[serde(default)]
52    pub output_audio_format: AudioFormat,
53    /// System instructions.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub instructions: Option<String>,
56    /// Tools available to the model.
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub tools: Option<Vec<Tool>>,
59    /// Input audio transcription settings.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub input_audio_transcription: Option<AudioTranscriptionConfig>,
62    /// Turn detection settings.
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub turn_detection: Option<TurnDetectionConfig>,
65}
66
67impl SessionConfig {
68    /// Create a new session configuration.
69    pub fn new(model: impl Into<String>) -> Self {
70        Self {
71            model: model.into(),
72            voice: Voice::default(),
73            input_audio_format: AudioFormat::default(),
74            output_audio_format: AudioFormat::default(),
75            instructions: None,
76            tools: None,
77            input_audio_transcription: None,
78            turn_detection: None,
79        }
80    }
81
82    /// Set the voice.
83    pub fn voice(mut self, voice: Voice) -> Self {
84        self.voice = voice;
85        self
86    }
87
88    /// Set the input audio format.
89    pub fn input_format(mut self, format: AudioFormat) -> Self {
90        self.input_audio_format = format;
91        self
92    }
93
94    /// Set the output audio format.
95    pub fn output_format(mut self, format: AudioFormat) -> Self {
96        self.output_audio_format = format;
97        self
98    }
99
100    /// Set system instructions.
101    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
102        self.instructions = Some(instructions.into());
103        self
104    }
105
106    /// Add tools.
107    pub fn tools(mut self, tools: Vec<Tool>) -> Self {
108        self.tools = Some(tools);
109        self
110    }
111}
112
113/// Audio transcription configuration.
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct AudioTranscriptionConfig {
116    /// Whether to enable transcription.
117    #[serde(default)]
118    pub enabled: bool,
119}
120
121/// Turn detection configuration.
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct TurnDetectionConfig {
124    /// Type of turn detection.
125    #[serde(rename = "type")]
126    pub detection_type: String,
127    /// Silence threshold in milliseconds.
128    #[serde(skip_serializing_if = "Option::is_none")]
129    pub silence_duration_ms: Option<u32>,
130}
131
132/// Type of a conversation item.
133#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
134#[serde(rename_all = "snake_case")]
135pub enum ConversationItemType {
136    /// A message item.
137    Message,
138    /// A function call item.
139    FunctionCall,
140    /// A function call output item.
141    FunctionCallOutput,
142}
143
144/// A conversation item.
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ConversationItem {
147    /// Item ID.
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub id: Option<String>,
150    /// Item type.
151    #[serde(rename = "type")]
152    pub item_type: ConversationItemType,
153    /// Role (user, assistant).
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub role: Option<String>,
156    /// Content parts.
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub content: Option<Vec<ConversationContent>>,
159}
160
161/// Type of content in a conversation item.
162#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
163#[serde(rename_all = "snake_case")]
164pub enum ConversationContentType {
165    /// Text content.
166    Text,
167    /// Audio content.
168    Audio,
169    /// Input text content.
170    InputText,
171    /// Input audio content.
172    InputAudio,
173}
174
175/// Content in a conversation item.
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct ConversationContent {
178    /// Content type.
179    #[serde(rename = "type")]
180    pub content_type: ConversationContentType,
181    /// Text content.
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub text: Option<String>,
184    /// Audio content (base64).
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub audio: Option<String>,
187    /// Transcript.
188    #[serde(skip_serializing_if = "Option::is_none")]
189    pub transcript: Option<String>,
190}
191
192/// Client message for realtime WebSocket.
193#[derive(Debug, Clone, Serialize)]
194#[serde(tag = "type", rename_all = "snake_case")]
195pub enum RealtimeClientMessage {
196    /// Update session configuration.
197    SessionUpdate {
198        /// The session configuration.
199        session: SessionConfig,
200    },
201    /// Append audio to input buffer.
202    InputAudioBufferAppend {
203        /// Base64-encoded audio data.
204        audio: String,
205    },
206    /// Commit the audio buffer.
207    InputAudioBufferCommit {},
208    /// Clear the audio buffer.
209    InputAudioBufferClear {},
210    /// Create a conversation item.
211    ConversationItemCreate {
212        /// The item to create.
213        item: ConversationItem,
214    },
215    /// Request a response.
216    ResponseCreate {
217        /// Optional response configuration.
218        #[serde(skip_serializing_if = "Option::is_none")]
219        response: Option<ResponseConfig>,
220    },
221    /// Cancel the current response.
222    ResponseCancel {},
223}
224
225/// Response configuration.
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct ResponseConfig {
228    /// Modalities to include.
229    #[serde(skip_serializing_if = "Option::is_none")]
230    pub modalities: Option<Vec<String>>,
231    /// Instructions for this response.
232    #[serde(skip_serializing_if = "Option::is_none")]
233    pub instructions: Option<String>,
234}
235
236/// Server message from realtime WebSocket.
237#[derive(Debug, Clone, Deserialize)]
238#[serde(tag = "type", rename_all = "snake_case")]
239pub enum RealtimeServerMessage {
240    /// Session created.
241    SessionCreated {
242        /// The session configuration.
243        session: SessionConfig,
244    },
245    /// Session updated.
246    SessionUpdated {
247        /// The updated session configuration.
248        session: SessionConfig,
249    },
250    /// Conversation item created.
251    ConversationItemCreated {
252        /// The created item.
253        item: ConversationItem,
254    },
255    /// Input audio buffer committed.
256    InputAudioBufferCommitted {
257        /// The item ID.
258        item_id: String,
259    },
260    /// Input audio buffer cleared.
261    InputAudioBufferCleared {},
262    /// Input audio buffer speech started.
263    InputAudioBufferSpeechStarted {
264        /// Audio start time in milliseconds.
265        audio_start_ms: u32,
266    },
267    /// Input audio buffer speech stopped.
268    InputAudioBufferSpeechStopped {
269        /// Audio end time in milliseconds.
270        audio_end_ms: u32,
271    },
272    /// Response created.
273    ResponseCreated {
274        /// The response.
275        response: RealtimeResponse,
276    },
277    /// Response audio delta.
278    ResponseAudioDelta {
279        /// Response ID.
280        response_id: String,
281        /// Item ID.
282        item_id: String,
283        /// Base64-encoded audio delta.
284        delta: String,
285    },
286    /// Response audio transcript delta.
287    ResponseAudioTranscriptDelta {
288        /// Response ID.
289        response_id: String,
290        /// Item ID.
291        item_id: String,
292        /// Transcript delta.
293        delta: String,
294    },
295    /// Response text delta.
296    ResponseTextDelta {
297        /// Response ID.
298        response_id: String,
299        /// Item ID.
300        item_id: String,
301        /// Text delta.
302        delta: String,
303    },
304    /// Response done.
305    ResponseDone {
306        /// The completed response.
307        response: RealtimeResponse,
308    },
309    /// Error occurred.
310    Error {
311        /// Error details.
312        error: RealtimeError,
313    },
314    /// Rate limits updated.
315    RateLimitsUpdated {
316        /// Rate limit information.
317        rate_limits: Vec<RateLimit>,
318    },
319}
320
321/// Realtime response.
322#[derive(Debug, Clone, Deserialize)]
323pub struct RealtimeResponse {
324    /// Response ID.
325    pub id: String,
326    /// Response status.
327    #[serde(default)]
328    pub status: String,
329    /// Output items.
330    #[serde(default)]
331    pub output: Vec<ConversationItem>,
332}
333
334/// Realtime error.
335#[derive(Debug, Clone, Deserialize)]
336pub struct RealtimeError {
337    /// Error type.
338    #[serde(rename = "type")]
339    pub error_type: String,
340    /// Error code.
341    #[serde(default)]
342    pub code: Option<String>,
343    /// Error message.
344    pub message: String,
345}
346
347impl std::fmt::Display for RealtimeError {
348    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349        write!(f, "{}: {}", self.error_type, self.message)
350    }
351}
352
353/// Rate limit information.
354#[derive(Debug, Clone, Deserialize)]
355pub struct RateLimit {
356    /// Rate limit name.
357    pub name: String,
358    /// Limit value.
359    pub limit: u32,
360    /// Remaining value.
361    pub remaining: u32,
362    /// Reset time.
363    #[serde(default)]
364    pub reset_seconds: Option<f64>,
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use serde_json::json;
371
372    // ── ConversationItemType serde roundtrips ──────────────────────────
373
374    #[test]
375    fn conversation_item_type_roundtrip_all() {
376        for (variant, expected) in [
377            (ConversationItemType::Message, "message"),
378            (ConversationItemType::FunctionCall, "function_call"),
379            (
380                ConversationItemType::FunctionCallOutput,
381                "function_call_output",
382            ),
383        ] {
384            let json_val = serde_json::to_value(variant).unwrap();
385            assert_eq!(json_val, json!(expected));
386
387            let back: ConversationItemType = serde_json::from_value(json_val).unwrap();
388            assert_eq!(back, variant);
389        }
390    }
391
392    #[test]
393    fn conversation_item_type_rejects_unknown() {
394        let result = serde_json::from_str::<ConversationItemType>(r#""unknown_type""#);
395        assert!(result.is_err());
396    }
397
398    // ── ConversationContentType serde roundtrips ──────────────────────
399
400    #[test]
401    fn conversation_content_type_roundtrip_all() {
402        for (variant, expected) in [
403            (ConversationContentType::Text, "text"),
404            (ConversationContentType::Audio, "audio"),
405            (ConversationContentType::InputText, "input_text"),
406            (ConversationContentType::InputAudio, "input_audio"),
407        ] {
408            let json_val = serde_json::to_value(variant).unwrap();
409            assert_eq!(json_val, json!(expected));
410
411            let back: ConversationContentType = serde_json::from_value(json_val).unwrap();
412            assert_eq!(back, variant);
413        }
414    }
415
416    #[test]
417    fn conversation_content_type_rejects_unknown() {
418        let result = serde_json::from_str::<ConversationContentType>(r#""video""#);
419        assert!(result.is_err());
420    }
421
422    // ── ConversationItem serde roundtrip ──────────────────────────────
423
424    #[test]
425    fn conversation_item_message_roundtrip() {
426        let item = ConversationItem {
427            id: Some("item_1".to_string()),
428            item_type: ConversationItemType::Message,
429            role: Some("user".to_string()),
430            content: Some(vec![ConversationContent {
431                content_type: ConversationContentType::InputText,
432                text: Some("Hello".to_string()),
433                audio: None,
434                transcript: None,
435            }]),
436        };
437
438        let json_val = serde_json::to_value(&item).unwrap();
439        assert_eq!(json_val["type"], "message");
440        assert_eq!(json_val["role"], "user");
441        assert_eq!(json_val["content"][0]["type"], "input_text");
442        assert_eq!(json_val["content"][0]["text"], "Hello");
443
444        let back: ConversationItem = serde_json::from_value(json_val).unwrap();
445        assert_eq!(back.item_type, ConversationItemType::Message);
446        assert_eq!(back.id.as_deref(), Some("item_1"));
447    }
448
449    #[test]
450    fn conversation_item_function_call_roundtrip() {
451        let item = ConversationItem {
452            id: Some("fc_1".to_string()),
453            item_type: ConversationItemType::FunctionCall,
454            role: None,
455            content: None,
456        };
457
458        let json_val = serde_json::to_value(&item).unwrap();
459        assert_eq!(json_val["type"], "function_call");
460        assert!(json_val.get("role").is_none());
461        assert!(json_val.get("content").is_none());
462
463        let back: ConversationItem = serde_json::from_value(json_val).unwrap();
464        assert_eq!(back.item_type, ConversationItemType::FunctionCall);
465    }
466
467    // ── ConversationContent serde roundtrip ───────────────────────────
468
469    #[test]
470    fn conversation_content_audio_roundtrip() {
471        let cc = ConversationContent {
472            content_type: ConversationContentType::Audio,
473            text: None,
474            audio: Some("base64data".to_string()),
475            transcript: Some("transcribed text".to_string()),
476        };
477
478        let json_val = serde_json::to_value(&cc).unwrap();
479        assert_eq!(json_val["type"], "audio");
480        assert_eq!(json_val["audio"], "base64data");
481        assert_eq!(json_val["transcript"], "transcribed text");
482
483        let back: ConversationContent = serde_json::from_value(json_val).unwrap();
484        assert_eq!(back.content_type, ConversationContentType::Audio);
485        assert_eq!(back.audio.as_deref(), Some("base64data"));
486    }
487
488    // ── Voice enum ────────────────────────────────────────────────────
489
490    #[test]
491    fn voice_roundtrip_all() {
492        for (variant, expected) in [
493            (Voice::Ara, "ara"),
494            (Voice::Rex, "rex"),
495            (Voice::Sal, "sal"),
496            (Voice::Eve, "eve"),
497            (Voice::Leo, "leo"),
498        ] {
499            let json_val = serde_json::to_value(variant).unwrap();
500            assert_eq!(json_val, json!(expected));
501
502            let back: Voice = serde_json::from_value(json_val).unwrap();
503            assert_eq!(back, variant);
504        }
505    }
506
507    #[test]
508    fn voice_default_is_ara() {
509        assert_eq!(Voice::default(), Voice::Ara);
510    }
511
512    // ── AudioFormat enum ──────────────────────────────────────────────
513
514    #[test]
515    fn audio_format_roundtrip_all() {
516        for (variant, expected) in [
517            (AudioFormat::Pcm16, "pcm16"),
518            (AudioFormat::G711Ulaw, "g711_ulaw"),
519            (AudioFormat::G711Alaw, "g711_alaw"),
520        ] {
521            let json_val = serde_json::to_value(variant).unwrap();
522            assert_eq!(json_val, json!(expected));
523
524            let back: AudioFormat = serde_json::from_value(json_val).unwrap();
525            assert_eq!(back, variant);
526        }
527    }
528
529    #[test]
530    fn audio_format_default_is_pcm16() {
531        assert_eq!(AudioFormat::default(), AudioFormat::Pcm16);
532    }
533
534    // ── SessionConfig ─────────────────────────────────────────────────
535
536    #[test]
537    fn session_config_builder_pattern() {
538        let config = SessionConfig::new("grok-4")
539            .voice(Voice::Eve)
540            .input_format(AudioFormat::G711Ulaw)
541            .output_format(AudioFormat::G711Alaw)
542            .instructions("Be helpful");
543        assert_eq!(config.model, "grok-4");
544        assert_eq!(config.voice, Voice::Eve);
545        assert_eq!(config.input_audio_format, AudioFormat::G711Ulaw);
546        assert_eq!(config.output_audio_format, AudioFormat::G711Alaw);
547        assert_eq!(config.instructions.as_deref(), Some("Be helpful"));
548    }
549
550    #[test]
551    fn session_config_roundtrip() {
552        let config = SessionConfig::new("grok-4")
553            .voice(Voice::Rex)
554            .instructions("Test instructions");
555
556        let json_val = serde_json::to_value(&config).unwrap();
557        assert_eq!(json_val["model"], "grok-4");
558        assert_eq!(json_val["voice"], "rex");
559        assert_eq!(json_val["instructions"], "Test instructions");
560
561        let back: SessionConfig = serde_json::from_value(json_val).unwrap();
562        assert_eq!(back.model, "grok-4");
563        assert_eq!(back.voice, Voice::Rex);
564    }
565
566    // ── RealtimeClientMessage serialization ──────────────────────────
567
568    #[test]
569    fn realtime_client_message_session_update_serialize() {
570        let msg = RealtimeClientMessage::SessionUpdate {
571            session: SessionConfig::new("grok-4"),
572        };
573        let json_val = serde_json::to_value(&msg).unwrap();
574        assert_eq!(json_val["type"], "session_update");
575        assert_eq!(json_val["session"]["model"], "grok-4");
576    }
577
578    #[test]
579    fn realtime_client_message_audio_append_serialize() {
580        let msg = RealtimeClientMessage::InputAudioBufferAppend {
581            audio: "base64data".to_string(),
582        };
583        let json_val = serde_json::to_value(&msg).unwrap();
584        assert_eq!(json_val["type"], "input_audio_buffer_append");
585        assert_eq!(json_val["audio"], "base64data");
586    }
587
588    #[test]
589    fn realtime_client_message_response_create_serialize() {
590        let msg = RealtimeClientMessage::ResponseCreate { response: None };
591        let json_val = serde_json::to_value(&msg).unwrap();
592        assert_eq!(json_val["type"], "response_create");
593    }
594
595    // ── RealtimeServerMessage deserialization ────────────────────────
596
597    #[test]
598    fn realtime_server_message_session_created() {
599        let json_val = json!({
600            "type": "session_created",
601            "session": {
602                "model": "grok-4",
603                "voice": "ara",
604                "input_audio_format": "pcm16",
605                "output_audio_format": "pcm16"
606            }
607        });
608        let msg: RealtimeServerMessage = serde_json::from_value(json_val).unwrap();
609        assert!(matches!(msg, RealtimeServerMessage::SessionCreated { .. }));
610    }
611
612    #[test]
613    fn realtime_server_message_error() {
614        let json_val = json!({
615            "type": "error",
616            "error": {
617                "type": "invalid_request",
618                "message": "Bad request"
619            }
620        });
621        let msg: RealtimeServerMessage = serde_json::from_value(json_val).unwrap();
622        if let RealtimeServerMessage::Error { error } = msg {
623            assert_eq!(error.error_type, "invalid_request");
624            assert_eq!(error.message, "Bad request");
625            assert_eq!(format!("{error}"), "invalid_request: Bad request");
626        } else {
627            panic!("Expected Error variant");
628        }
629    }
630
631    #[test]
632    fn realtime_server_message_response_audio_delta() {
633        let json_val = json!({
634            "type": "response_audio_delta",
635            "response_id": "resp_1",
636            "item_id": "item_1",
637            "delta": "YXVkaW8="
638        });
639        let msg: RealtimeServerMessage = serde_json::from_value(json_val).unwrap();
640        if let RealtimeServerMessage::ResponseAudioDelta { delta, .. } = msg {
641            assert_eq!(delta, "YXVkaW8=");
642        } else {
643            panic!("Expected ResponseAudioDelta variant");
644        }
645    }
646}