Skip to main content

vona_deepgram/
lib.rs

1use serde::{Deserialize, Serialize};
2use serde_json::{Value, json};
3use thiserror::Error;
4
5pub const DEFAULT_API_BASE: &str = "https://api.deepgram.com";
6pub const DEFAULT_STT_MODEL: &str = "flux-general-en";
7pub const DEFAULT_TTS_MODEL: &str = "aura-2-thalia-en";
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct DeepgramConfig {
11    pub api_base: String,
12    pub api_key: Option<String>,
13}
14
15impl Default for DeepgramConfig {
16    fn default() -> Self {
17        Self {
18            api_base: DEFAULT_API_BASE.to_string(),
19            api_key: None,
20        }
21    }
22}
23
24impl DeepgramConfig {
25    pub fn from_env() -> Self {
26        Self {
27            api_base: std::env::var("DEEPGRAM_API_BASE")
28                .unwrap_or_else(|_| DEFAULT_API_BASE.to_string()),
29            api_key: std::env::var("DEEPGRAM_API_KEY")
30                .ok()
31                .filter(|value| !value.is_empty()),
32        }
33    }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct DeepgramSttConfig {
38    pub base: DeepgramConfig,
39    pub model: String,
40    pub endpoint_version: String,
41    pub encoding: String,
42    pub sample_rate_hz: u32,
43    pub channels: u16,
44    pub interim_results: bool,
45}
46
47impl Default for DeepgramSttConfig {
48    fn default() -> Self {
49        Self {
50            base: DeepgramConfig::default(),
51            model: DEFAULT_STT_MODEL.to_string(),
52            endpoint_version: "v2".to_string(),
53            encoding: "linear16".to_string(),
54            sample_rate_hz: 16_000,
55            channels: 1,
56            interim_results: true,
57        }
58    }
59}
60
61impl DeepgramSttConfig {
62    pub fn websocket_url(&self) -> String {
63        let base = self
64            .base
65            .api_base
66            .trim_end_matches('/')
67            .replace("https://", "wss://")
68            .replace("http://", "ws://");
69        format!(
70            "{base}/{}/listen?model={}&encoding={}&sample_rate={}&channels={}&interim_results={}",
71            self.endpoint_version,
72            self.model,
73            self.encoding,
74            self.sample_rate_hz,
75            self.channels,
76            self.interim_results
77        )
78    }
79}
80
81#[derive(Debug, Clone, PartialEq, Eq)]
82pub struct DeepgramTtsConfig {
83    pub base: DeepgramConfig,
84    pub model: String,
85    pub encoding: String,
86    pub sample_rate_hz: u32,
87}
88
89impl Default for DeepgramTtsConfig {
90    fn default() -> Self {
91        Self {
92            base: DeepgramConfig::default(),
93            model: DEFAULT_TTS_MODEL.to_string(),
94            encoding: "linear16".to_string(),
95            sample_rate_hz: 24_000,
96        }
97    }
98}
99
100impl DeepgramTtsConfig {
101    pub fn websocket_url(&self) -> String {
102        let base = self
103            .base
104            .api_base
105            .trim_end_matches('/')
106            .replace("https://", "wss://")
107            .replace("http://", "ws://");
108        format!(
109            "{base}/v1/speak?model={}&encoding={}&sample_rate={}",
110            self.model, self.encoding, self.sample_rate_hz
111        )
112    }
113}
114
115#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
116pub struct DeepgramTtsMessage {
117    #[serde(flatten)]
118    pub payload: Value,
119}
120
121#[derive(Debug, Clone, Error, PartialEq, Eq)]
122pub enum DeepgramMappingError {
123    #[error("text cannot be empty")]
124    EmptyText,
125}
126
127pub fn tts_text_message(
128    text: impl Into<String>,
129) -> Result<DeepgramTtsMessage, DeepgramMappingError> {
130    let text = text.into();
131    if text.is_empty() {
132        return Err(DeepgramMappingError::EmptyText);
133    }
134    Ok(DeepgramTtsMessage {
135        payload: json!({ "type": "Speak", "text": text }),
136    })
137}
138
139pub fn tts_flush_message() -> DeepgramTtsMessage {
140    DeepgramTtsMessage {
141        payload: json!({ "type": "Flush" }),
142    }
143}
144
145pub fn tts_close_message() -> DeepgramTtsMessage {
146    DeepgramTtsMessage {
147        payload: json!({ "type": "Close" }),
148    }
149}
150
151pub fn transcript_from_listen_message(message: &Value) -> Option<String> {
152    message
153        .pointer("/channel/alternatives/0/transcript")
154        .and_then(Value::as_str)
155        .filter(|value| !value.is_empty())
156        .map(ToString::to_string)
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn stt_url_targets_listen_websocket() {
165        let cfg = DeepgramSttConfig {
166            base: DeepgramConfig {
167                api_base: "https://example.test".to_string(),
168                api_key: None,
169            },
170            ..DeepgramSttConfig::default()
171        };
172        assert_eq!(
173            cfg.websocket_url(),
174            "wss://example.test/v2/listen?model=flux-general-en&encoding=linear16&sample_rate=16000&channels=1&interim_results=true"
175        );
176    }
177
178    #[test]
179    fn tts_url_targets_speak_websocket() {
180        let cfg = DeepgramTtsConfig {
181            base: DeepgramConfig {
182                api_base: "https://example.test".to_string(),
183                api_key: None,
184            },
185            ..DeepgramTtsConfig::default()
186        };
187        assert_eq!(
188            cfg.websocket_url(),
189            "wss://example.test/v1/speak?model=aura-2-thalia-en&encoding=linear16&sample_rate=24000"
190        );
191    }
192
193    #[test]
194    fn transcript_parser_ignores_empty_transcripts() {
195        let message = json!({ "channel": { "alternatives": [{ "transcript": "" }] } });
196        assert_eq!(transcript_from_listen_message(&message), None);
197    }
198
199    #[test]
200    fn transcript_parser_reads_first_alternative() {
201        let message = json!({ "channel": { "alternatives": [{ "transcript": "hello" }] } });
202        assert_eq!(
203            transcript_from_listen_message(&message),
204            Some("hello".to_string())
205        );
206    }
207}