1use serde::{Deserialize, Serialize};
2use serde_json::{Value, json};
3use thiserror::Error;
4
5pub const DEFAULT_API_BASE: &str = "https://api.deepgram.com";
6pub const DEFAULT_STT_MODEL: &str = "flux-general-en";
7pub const DEFAULT_TTS_MODEL: &str = "aura-2-thalia-en";
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct DeepgramConfig {
11 pub api_base: String,
12 pub api_key: Option<String>,
13}
14
15impl Default for DeepgramConfig {
16 fn default() -> Self {
17 Self {
18 api_base: DEFAULT_API_BASE.to_string(),
19 api_key: None,
20 }
21 }
22}
23
24impl DeepgramConfig {
25 pub fn from_env() -> Self {
26 Self {
27 api_base: std::env::var("DEEPGRAM_API_BASE")
28 .unwrap_or_else(|_| DEFAULT_API_BASE.to_string()),
29 api_key: std::env::var("DEEPGRAM_API_KEY")
30 .ok()
31 .filter(|value| !value.is_empty()),
32 }
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct DeepgramSttConfig {
38 pub base: DeepgramConfig,
39 pub model: String,
40 pub endpoint_version: String,
41 pub encoding: String,
42 pub sample_rate_hz: u32,
43 pub channels: u16,
44 pub interim_results: bool,
45}
46
47impl Default for DeepgramSttConfig {
48 fn default() -> Self {
49 Self {
50 base: DeepgramConfig::default(),
51 model: DEFAULT_STT_MODEL.to_string(),
52 endpoint_version: "v2".to_string(),
53 encoding: "linear16".to_string(),
54 sample_rate_hz: 16_000,
55 channels: 1,
56 interim_results: true,
57 }
58 }
59}
60
61impl DeepgramSttConfig {
62 pub fn websocket_url(&self) -> String {
63 let base = self
64 .base
65 .api_base
66 .trim_end_matches('/')
67 .replace("https://", "wss://")
68 .replace("http://", "ws://");
69 format!(
70 "{base}/{}/listen?model={}&encoding={}&sample_rate={}&channels={}&interim_results={}",
71 self.endpoint_version,
72 self.model,
73 self.encoding,
74 self.sample_rate_hz,
75 self.channels,
76 self.interim_results
77 )
78 }
79}
80
81#[derive(Debug, Clone, PartialEq, Eq)]
82pub struct DeepgramTtsConfig {
83 pub base: DeepgramConfig,
84 pub model: String,
85 pub encoding: String,
86 pub sample_rate_hz: u32,
87}
88
89impl Default for DeepgramTtsConfig {
90 fn default() -> Self {
91 Self {
92 base: DeepgramConfig::default(),
93 model: DEFAULT_TTS_MODEL.to_string(),
94 encoding: "linear16".to_string(),
95 sample_rate_hz: 24_000,
96 }
97 }
98}
99
100impl DeepgramTtsConfig {
101 pub fn websocket_url(&self) -> String {
102 let base = self
103 .base
104 .api_base
105 .trim_end_matches('/')
106 .replace("https://", "wss://")
107 .replace("http://", "ws://");
108 format!(
109 "{base}/v1/speak?model={}&encoding={}&sample_rate={}",
110 self.model, self.encoding, self.sample_rate_hz
111 )
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
116pub struct DeepgramTtsMessage {
117 #[serde(flatten)]
118 pub payload: Value,
119}
120
121#[derive(Debug, Clone, Error, PartialEq, Eq)]
122pub enum DeepgramMappingError {
123 #[error("text cannot be empty")]
124 EmptyText,
125}
126
127pub fn tts_text_message(
128 text: impl Into<String>,
129) -> Result<DeepgramTtsMessage, DeepgramMappingError> {
130 let text = text.into();
131 if text.is_empty() {
132 return Err(DeepgramMappingError::EmptyText);
133 }
134 Ok(DeepgramTtsMessage {
135 payload: json!({ "type": "Speak", "text": text }),
136 })
137}
138
139pub fn tts_flush_message() -> DeepgramTtsMessage {
140 DeepgramTtsMessage {
141 payload: json!({ "type": "Flush" }),
142 }
143}
144
145pub fn tts_close_message() -> DeepgramTtsMessage {
146 DeepgramTtsMessage {
147 payload: json!({ "type": "Close" }),
148 }
149}
150
151pub fn transcript_from_listen_message(message: &Value) -> Option<String> {
152 message
153 .pointer("/channel/alternatives/0/transcript")
154 .and_then(Value::as_str)
155 .filter(|value| !value.is_empty())
156 .map(ToString::to_string)
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn stt_url_targets_listen_websocket() {
165 let cfg = DeepgramSttConfig {
166 base: DeepgramConfig {
167 api_base: "https://example.test".to_string(),
168 api_key: None,
169 },
170 ..DeepgramSttConfig::default()
171 };
172 assert_eq!(
173 cfg.websocket_url(),
174 "wss://example.test/v2/listen?model=flux-general-en&encoding=linear16&sample_rate=16000&channels=1&interim_results=true"
175 );
176 }
177
178 #[test]
179 fn tts_url_targets_speak_websocket() {
180 let cfg = DeepgramTtsConfig {
181 base: DeepgramConfig {
182 api_base: "https://example.test".to_string(),
183 api_key: None,
184 },
185 ..DeepgramTtsConfig::default()
186 };
187 assert_eq!(
188 cfg.websocket_url(),
189 "wss://example.test/v1/speak?model=aura-2-thalia-en&encoding=linear16&sample_rate=24000"
190 );
191 }
192
193 #[test]
194 fn transcript_parser_ignores_empty_transcripts() {
195 let message = json!({ "channel": { "alternatives": [{ "transcript": "" }] } });
196 assert_eq!(transcript_from_listen_message(&message), None);
197 }
198
199 #[test]
200 fn transcript_parser_reads_first_alternative() {
201 let message = json!({ "channel": { "alternatives": [{ "transcript": "hello" }] } });
202 assert_eq!(
203 transcript_from_listen_message(&message),
204 Some("hello".to_string())
205 );
206 }
207}