Skip to main content

tauri_plugin_tts/
models.rs

1use serde::{Deserialize, Serialize};
2use std::borrow::Cow;
3use ts_rs::TS;
4
5/// Maximum text length in bytes (10KB)
6pub const MAX_TEXT_LENGTH: usize = 10_000;
7/// Maximum voice ID length
8pub const MAX_VOICE_ID_LENGTH: usize = 256;
9/// Maximum language code length
10pub const MAX_LANGUAGE_LENGTH: usize = 35;
11
12#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, PartialEq, Eq, TS)]
13#[ts(export, export_to = "../guest-js/bindings/")]
14#[serde(rename_all = "lowercase")]
15pub enum QueueMode {
16    /// Flush any pending speech and start speaking immediately (default)
17    #[default]
18    Flush,
19    /// Add to queue and speak after current speech finishes
20    Add,
21}
22
23#[derive(Debug, Clone, Deserialize, Serialize, TS)]
24#[ts(export, export_to = "../guest-js/bindings/")]
25#[serde(rename_all = "camelCase")]
26pub struct SpeakOptions {
27    /// The text to speak (max 10,000 characters)
28    pub text: String,
29    /// The language/locale code (e.g., "en-US", "pt-BR", "ja-JP")
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub language: Option<String>,
32    /// Specific voice ID to use (from getVoices). Takes priority over language
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub voice_id: Option<String>,
35    /// Speech rate (0.1 to 4.0, where 1.0 = normal)
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub rate: Option<f32>,
38    /// Pitch (0.5 to 2.0, where 1.0 = normal)
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub pitch: Option<f32>,
41    /// Volume (0.0 to 1.0, where 1.0 = full volume)
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub volume: Option<f32>,
44    /// Queue mode: "flush" (default) or "add"
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub queue_mode: Option<QueueMode>,
47}
48
49#[derive(Debug, Clone, Deserialize, Serialize, TS)]
50#[ts(export, export_to = "../guest-js/bindings/")]
51#[serde(rename_all = "camelCase")]
52pub struct PreviewVoiceOptions {
53    /// Voice ID to preview
54    pub voice_id: String,
55    /// Optional custom sample text (uses default if not provided)
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub text: Option<String>,
58}
59
60#[derive(Debug, Deserialize, Serialize)]
61#[serde(rename_all = "camelCase")]
62pub struct SpeakRequest {
63    /// The text to speak
64    pub text: String,
65    /// The language/locale code (e.g., "en-US", "pt-BR", "ja-JP")
66    #[serde(default)]
67    pub language: Option<String>,
68    /// Voice ID to use (from getVoices)
69    #[serde(default)]
70    pub voice_id: Option<String>,
71    /// Speech rate (0.1 to 4.0, where 1.0 = normal, 2.0 = double, 0.5 = half)
72    #[serde(default = "default_rate")]
73    pub rate: f32,
74    /// Pitch (0.5 = low, 1.0 = normal, 2.0 = high)
75    #[serde(default = "default_pitch")]
76    pub pitch: f32,
77    /// Volume (0.0 = silent, 1.0 = full volume)
78    #[serde(default = "default_volume")]
79    pub volume: f32,
80    /// Queue mode: "flush" (default) or "add"
81    #[serde(default)]
82    pub queue_mode: QueueMode,
83}
84
85fn default_rate() -> f32 {
86    1.0
87}
88fn default_pitch() -> f32 {
89    1.0
90}
91fn default_volume() -> f32 {
92    1.0
93}
94
95#[derive(Debug, Clone, thiserror::Error)]
96pub enum ValidationError {
97    #[error("Text cannot be empty")]
98    EmptyText,
99    #[error("Text too long: {len} bytes (max: {max})")]
100    TextTooLong { len: usize, max: usize },
101    #[error("Voice ID too long: {len} chars (max: {max})")]
102    VoiceIdTooLong { len: usize, max: usize },
103    #[error("Language code too long: {len} chars (max: {max})")]
104    LanguageTooLong { len: usize, max: usize },
105}
106
107#[derive(Debug, Clone)]
108pub struct ValidatedSpeakRequest {
109    pub text: String,
110    pub language: Option<String>,
111    pub voice_id: Option<String>,
112    pub rate: f32,
113    pub pitch: f32,
114    pub volume: f32,
115    pub queue_mode: QueueMode,
116}
117
118impl SpeakRequest {
119    pub fn validate(&self) -> Result<ValidatedSpeakRequest, ValidationError> {
120        // Text validation
121        if self.text.is_empty() {
122            return Err(ValidationError::EmptyText);
123        }
124        if self.text.len() > MAX_TEXT_LENGTH {
125            return Err(ValidationError::TextTooLong {
126                len: self.text.len(),
127                max: MAX_TEXT_LENGTH,
128            });
129        }
130
131        // Language validation (if provided)
132        let sanitized_language = self
133            .language
134            .as_ref()
135            .map(|lang| Self::validate_language(lang))
136            .transpose()?;
137
138        Ok(ValidatedSpeakRequest {
139            text: self.text.clone(),
140            language: sanitized_language,
141            voice_id: self.voice_id.clone(),
142            rate: self.rate.clamp(0.1, 4.0),
143            pitch: self.pitch.clamp(0.5, 2.0),
144            volume: self.volume.clamp(0.0, 1.0),
145            queue_mode: self.queue_mode,
146        })
147    }
148
149    fn validate_language(lang: &str) -> Result<String, ValidationError> {
150        if lang.len() > MAX_LANGUAGE_LENGTH {
151            return Err(ValidationError::LanguageTooLong {
152                len: lang.len(),
153                max: MAX_LANGUAGE_LENGTH,
154            });
155        }
156        Ok(lang.to_string())
157    }
158}
159
160#[derive(Debug, Clone, Default, Deserialize, Serialize)]
161#[serde(rename_all = "camelCase")]
162pub struct SpeakResponse {
163    /// Whether speech was successfully initiated
164    pub success: bool,
165    /// Optional warning message (e.g., voice not found, using fallback)
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub warning: Option<String>,
168}
169
170#[derive(Debug, Clone, Default, Deserialize, Serialize)]
171#[serde(rename_all = "camelCase")]
172pub struct StopResponse {
173    pub success: bool,
174}
175
176#[derive(Debug, Clone, Deserialize, Serialize, TS)]
177#[ts(export, export_to = "../guest-js/bindings/")]
178#[serde(rename_all = "camelCase")]
179pub struct Voice {
180    /// Unique identifier for the voice
181    pub id: String,
182    /// Display name of the voice
183    pub name: String,
184    /// Language code (e.g., "en-US")
185    pub language: String,
186}
187
188#[derive(Debug, Deserialize, Serialize)]
189#[serde(rename_all = "camelCase")]
190pub struct GetVoicesRequest {
191    /// Optional language filter
192    #[serde(default)]
193    pub language: Option<String>,
194}
195
196#[derive(Debug, Clone, Default, Deserialize, Serialize)]
197#[serde(rename_all = "camelCase")]
198pub struct GetVoicesResponse {
199    pub voices: Vec<Voice>,
200}
201
202#[derive(Debug, Clone, Default, Deserialize, Serialize)]
203#[serde(rename_all = "camelCase")]
204pub struct IsSpeakingResponse {
205    pub speaking: bool,
206}
207
208#[derive(Debug, Clone, Default, Deserialize, Serialize)]
209#[serde(rename_all = "camelCase")]
210pub struct IsInitializedResponse {
211    /// Whether the TTS engine is initialized and ready
212    pub initialized: bool,
213    /// Number of available voices (0 if not initialized)
214    pub voice_count: u32,
215}
216
217#[derive(Debug, Clone, Default, Deserialize, Serialize, TS)]
218#[ts(export, export_to = "../guest-js/bindings/")]
219#[serde(rename_all = "camelCase")]
220pub struct PauseResumeResponse {
221    pub success: bool,
222    /// Reason for failure (if success is false)
223    #[serde(skip_serializing_if = "Option::is_none")]
224    pub reason: Option<String>,
225}
226
227#[derive(Debug, Deserialize, Serialize)]
228#[serde(rename_all = "camelCase")]
229pub struct PreviewVoiceRequest {
230    /// Voice ID to preview
231    pub voice_id: String,
232    /// Optional custom sample text (uses default if not provided)
233    #[serde(default)]
234    pub text: Option<String>,
235}
236
237impl PreviewVoiceRequest {
238    pub const DEFAULT_SAMPLE_TEXT: &'static str =
239        "Hello! This is a sample of how this voice sounds.";
240
241    pub fn sample_text(&self) -> Cow<'_, str> {
242        match &self.text {
243            Some(text) => Cow::Borrowed(text.as_str()),
244            None => Cow::Borrowed(Self::DEFAULT_SAMPLE_TEXT),
245        }
246    }
247
248    pub fn validate(&self) -> Result<(), ValidationError> {
249        // Validate custom text if provided
250        if let Some(ref text) = self.text {
251            if text.is_empty() {
252                return Err(ValidationError::EmptyText);
253            }
254            if text.len() > MAX_TEXT_LENGTH {
255                return Err(ValidationError::TextTooLong {
256                    len: text.len(),
257                    max: MAX_TEXT_LENGTH,
258                });
259            }
260        }
261
262        Ok(())
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_speak_request_defaults() {
272        let json = r#"{"text": "Hello world"}"#;
273        let request: SpeakRequest = serde_json::from_str(json).unwrap();
274
275        assert_eq!(request.text, "Hello world");
276        assert!(request.language.is_none());
277        assert!(request.voice_id.is_none());
278        assert_eq!(request.rate, 1.0);
279        assert_eq!(request.pitch, 1.0);
280        assert_eq!(request.volume, 1.0);
281    }
282
283    #[test]
284    fn test_speak_request_full() {
285        let json = r#"{
286            "text": "Olá",
287            "language": "pt-BR",
288            "voiceId": "com.apple.voice.enhanced.pt-BR",
289            "rate": 0.8,
290            "pitch": 1.2,
291            "volume": 0.9
292        }"#;
293
294        let request: SpeakRequest = serde_json::from_str(json).unwrap();
295        assert_eq!(request.text, "Olá");
296        assert_eq!(request.language, Some("pt-BR".to_string()));
297        assert_eq!(
298            request.voice_id,
299            Some("com.apple.voice.enhanced.pt-BR".to_string())
300        );
301        assert_eq!(request.rate, 0.8);
302        assert_eq!(request.pitch, 1.2);
303        assert_eq!(request.volume, 0.9);
304    }
305
306    #[test]
307    fn test_voice_serialization() {
308        let voice = Voice {
309            id: "test-voice".to_string(),
310            name: "Test Voice".to_string(),
311            language: "en-US".to_string(),
312        };
313
314        let json = serde_json::to_string(&voice).unwrap();
315        assert!(json.contains("\"id\":\"test-voice\""));
316        assert!(json.contains("\"name\":\"Test Voice\""));
317        assert!(json.contains("\"language\":\"en-US\""));
318    }
319
320    #[test]
321    fn test_get_voices_request_optional_language() {
322        let json1 = r#"{}"#;
323        let request1: GetVoicesRequest = serde_json::from_str(json1).unwrap();
324        assert!(request1.language.is_none());
325
326        let json2 = r#"{"language": "en"}"#;
327        let request2: GetVoicesRequest = serde_json::from_str(json2).unwrap();
328        assert_eq!(request2.language, Some("en".to_string()));
329    }
330
331    #[test]
332    fn test_validation_empty_text() {
333        let request = SpeakRequest {
334            text: "".to_string(),
335            language: None,
336            voice_id: None,
337            rate: 1.0,
338            pitch: 1.0,
339            volume: 1.0,
340            queue_mode: QueueMode::Flush,
341        };
342
343        let result = request.validate();
344        assert!(result.is_err());
345        assert!(matches!(result.unwrap_err(), ValidationError::EmptyText));
346    }
347
348    #[test]
349    fn test_validation_text_too_long() {
350        let long_text = "x".repeat(MAX_TEXT_LENGTH + 1);
351        let request = SpeakRequest {
352            text: long_text,
353            language: None,
354            voice_id: None,
355            rate: 1.0,
356            pitch: 1.0,
357            volume: 1.0,
358            queue_mode: QueueMode::Flush,
359        };
360
361        let result = request.validate();
362        assert!(result.is_err());
363        assert!(matches!(
364            result.unwrap_err(),
365            ValidationError::TextTooLong { .. }
366        ));
367    }
368
369    #[test]
370    fn test_validation_valid_voice_id() {
371        let request = SpeakRequest {
372            text: "Hello".to_string(),
373            language: None,
374            voice_id: Some("com.apple.voice.enhanced.en-US".to_string()),
375            rate: 1.0,
376            pitch: 1.0,
377            volume: 1.0,
378            queue_mode: QueueMode::Flush,
379        };
380
381        let result = request.validate();
382        assert!(result.is_ok());
383        assert_eq!(
384            result.unwrap().voice_id,
385            Some("com.apple.voice.enhanced.en-US".to_string())
386        );
387    }
388    
389    #[test]
390    fn test_validation_voice_id_too_long() {
391        let long_voice_id = "x".repeat(MAX_VOICE_ID_LENGTH + 1);
392        let request = SpeakRequest {
393            text: "Hello".to_string(),
394            language: None,
395            voice_id: Some(long_voice_id),
396            rate: 1.0,
397            pitch: 1.0,
398            volume: 1.0,
399            queue_mode: QueueMode::Flush,
400        };
401
402        let result = request.validate();
403        assert!(result.is_err());
404        assert!(matches!(
405            result.unwrap_err(),
406            ValidationError::VoiceIdTooLong { .. }
407        ));
408    }
409
410    #[test]
411    fn test_validation_rate_clamping() {
412        let request = SpeakRequest {
413            text: "Hello".to_string(),
414            language: None,
415            voice_id: None,
416            rate: 999.0,
417            pitch: 1.0,
418            volume: 1.0,
419            queue_mode: QueueMode::Flush,
420        };
421
422        let result = request.validate();
423        assert!(result.is_ok());
424        let validated = result.unwrap();
425        assert_eq!(validated.rate, 4.0); // Clamped to max
426    }
427
428    #[test]
429    fn test_validation_pitch_clamping() {
430        let request = SpeakRequest {
431            text: "Hello".to_string(),
432            language: None,
433            voice_id: None,
434            rate: 1.0,
435            pitch: 0.1,
436            volume: 1.0,
437            queue_mode: QueueMode::Flush,
438        };
439
440        let result = request.validate();
441        assert!(result.is_ok());
442        let validated = result.unwrap();
443        assert_eq!(validated.pitch, 0.5); // Clamped to min
444    }
445
446    #[test]
447    fn test_validation_volume_clamping() {
448        let request = SpeakRequest {
449            text: "Hello".to_string(),
450            language: None,
451            voice_id: None,
452            rate: 1.0,
453            pitch: 1.0,
454            volume: 5.0,
455            queue_mode: QueueMode::Flush,
456        };
457
458        let result = request.validate();
459        assert!(result.is_ok());
460        let validated = result.unwrap();
461        assert_eq!(validated.volume, 1.0); // Clamped to max
462    }
463
464    #[test]
465    fn test_preview_voice_validation() {
466        // Valid preview
467        let valid = PreviewVoiceRequest {
468            voice_id: "valid-voice_123".to_string(),
469            text: None,
470        };
471        assert!(valid.validate().is_ok());
472
473        // Invalid voice_id
474        let invalid = PreviewVoiceRequest {
475            voice_id: "invalid<script>".to_string(),
476            text: None,
477        };
478        assert!(invalid.validate().is_err());
479    }
480
481    #[test]
482    fn test_preview_voice_sample_text() {
483        let without_text = PreviewVoiceRequest {
484            voice_id: "voice".to_string(),
485            text: None,
486        };
487        assert_eq!(
488            without_text.sample_text(),
489            PreviewVoiceRequest::DEFAULT_SAMPLE_TEXT
490        );
491
492        let with_text = PreviewVoiceRequest {
493            voice_id: "voice".to_string(),
494            text: Some("Custom sample".to_string()),
495        };
496        assert_eq!(with_text.sample_text(), "Custom sample");
497    }
498}