tauri_plugin_tts/
models.rs

1use serde::{Deserialize, Serialize};
2use std::borrow::Cow;
3use ts_rs::TS;
4
5/// Maximum text length in bytes (10KB)
6pub const MAX_TEXT_LENGTH: usize = 10_000;
7/// Maximum voice ID length
8pub const MAX_VOICE_ID_LENGTH: usize = 256;
9/// Maximum language code length
10pub const MAX_LANGUAGE_LENGTH: usize = 35;
11
12#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, PartialEq, Eq, TS)]
13#[ts(export, export_to = "../guest-js/bindings/")]
14#[serde(rename_all = "lowercase")]
15pub enum QueueMode {
16    /// Flush any pending speech and start speaking immediately (default)
17    #[default]
18    Flush,
19    /// Add to queue and speak after current speech finishes
20    Add,
21}
22
23#[derive(Debug, Clone, Deserialize, Serialize, TS)]
24#[ts(export, export_to = "../guest-js/bindings/")]
25#[serde(rename_all = "camelCase")]
26pub struct SpeakOptions {
27    /// The text to speak (max 10,000 characters)
28    pub text: String,
29    /// The language/locale code (e.g., "en-US", "pt-BR", "ja-JP")
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub language: Option<String>,
32    /// Specific voice ID to use (from getVoices). Takes priority over language
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub voice_id: Option<String>,
35    /// Speech rate (0.1 to 4.0, where 1.0 = normal)
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub rate: Option<f32>,
38    /// Pitch (0.5 to 2.0, where 1.0 = normal)
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub pitch: Option<f32>,
41    /// Volume (0.0 to 1.0, where 1.0 = full volume)
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub volume: Option<f32>,
44    /// Queue mode: "flush" (default) or "add"
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub queue_mode: Option<QueueMode>,
47}
48
49#[derive(Debug, Clone, Deserialize, Serialize, TS)]
50#[ts(export, export_to = "../guest-js/bindings/")]
51#[serde(rename_all = "camelCase")]
52pub struct PreviewVoiceOptions {
53    /// Voice ID to preview
54    pub voice_id: String,
55    /// Optional custom sample text (uses default if not provided)
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub text: Option<String>,
58}
59
60#[derive(Debug, Deserialize, Serialize)]
61#[serde(rename_all = "camelCase")]
62pub struct SpeakRequest {
63    /// The text to speak
64    pub text: String,
65    /// The language/locale code (e.g., "en-US", "pt-BR", "ja-JP")
66    #[serde(default)]
67    pub language: Option<String>,
68    /// Voice ID to use (from getVoices)
69    #[serde(default)]
70    pub voice_id: Option<String>,
71    /// Speech rate (0.1 to 4.0, where 1.0 = normal, 2.0 = double, 0.5 = half)
72    #[serde(default = "default_rate")]
73    pub rate: f32,
74    /// Pitch (0.5 = low, 1.0 = normal, 2.0 = high)
75    #[serde(default = "default_pitch")]
76    pub pitch: f32,
77    /// Volume (0.0 = silent, 1.0 = full volume)
78    #[serde(default = "default_volume")]
79    pub volume: f32,
80    /// Queue mode: "flush" (default) or "add"
81    #[serde(default)]
82    pub queue_mode: QueueMode,
83}
84
85fn default_rate() -> f32 {
86    1.0
87}
88fn default_pitch() -> f32 {
89    1.0
90}
91fn default_volume() -> f32 {
92    1.0
93}
94
95#[derive(Debug, Clone, thiserror::Error)]
96pub enum ValidationError {
97    #[error("Text cannot be empty")]
98    EmptyText,
99    #[error("Text too long: {len} bytes (max: {max})")]
100    TextTooLong { len: usize, max: usize },
101    #[error("Voice ID too long: {len} chars (max: {max})")]
102    VoiceIdTooLong { len: usize, max: usize },
103    #[error("Invalid voice ID format - only alphanumeric, dots, underscores, and hyphens allowed")]
104    InvalidVoiceId,
105    #[error("Language code too long: {len} chars (max: {max})")]
106    LanguageTooLong { len: usize, max: usize },
107}
108
109#[derive(Debug, Clone)]
110pub struct ValidatedSpeakRequest {
111    pub text: String,
112    pub language: Option<String>,
113    pub voice_id: Option<String>,
114    pub rate: f32,
115    pub pitch: f32,
116    pub volume: f32,
117    pub queue_mode: QueueMode,
118}
119
120impl SpeakRequest {
121    pub fn validate(&self) -> Result<ValidatedSpeakRequest, ValidationError> {
122        // Text validation
123        if self.text.is_empty() {
124            return Err(ValidationError::EmptyText);
125        }
126        if self.text.len() > MAX_TEXT_LENGTH {
127            return Err(ValidationError::TextTooLong {
128                len: self.text.len(),
129                max: MAX_TEXT_LENGTH,
130            });
131        }
132
133        // Voice ID validation (if provided)
134        let sanitized_voice_id = self
135            .voice_id
136            .as_ref()
137            .map(|id| Self::validate_voice_id(id))
138            .transpose()?;
139
140        // Language validation (if provided)
141        let sanitized_language = self
142            .language
143            .as_ref()
144            .map(|lang| Self::validate_language(lang))
145            .transpose()?;
146
147        Ok(ValidatedSpeakRequest {
148            text: self.text.clone(),
149            language: sanitized_language,
150            voice_id: sanitized_voice_id,
151            rate: self.rate.clamp(0.1, 4.0),
152            pitch: self.pitch.clamp(0.5, 2.0),
153            volume: self.volume.clamp(0.0, 1.0),
154            queue_mode: self.queue_mode,
155        })
156    }
157
158    fn validate_voice_id(id: &str) -> Result<String, ValidationError> {
159        if id.len() > MAX_VOICE_ID_LENGTH {
160            return Err(ValidationError::VoiceIdTooLong {
161                len: id.len(),
162                max: MAX_VOICE_ID_LENGTH,
163            });
164        }
165        if !id
166            .chars()
167            .all(|c| c.is_alphanumeric() || c == '.' || c == '_' || c == '-')
168        {
169            return Err(ValidationError::InvalidVoiceId);
170        }
171        Ok(id.to_string())
172    }
173
174    fn validate_language(lang: &str) -> Result<String, ValidationError> {
175        if lang.len() > MAX_LANGUAGE_LENGTH {
176            return Err(ValidationError::LanguageTooLong {
177                len: lang.len(),
178                max: MAX_LANGUAGE_LENGTH,
179            });
180        }
181        Ok(lang.to_string())
182    }
183}
184
185#[derive(Debug, Clone, Default, Deserialize, Serialize)]
186#[serde(rename_all = "camelCase")]
187pub struct SpeakResponse {
188    /// Whether speech was successfully initiated
189    pub success: bool,
190    /// Optional warning message (e.g., voice not found, using fallback)
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub warning: Option<String>,
193}
194
195#[derive(Debug, Clone, Default, Deserialize, Serialize)]
196#[serde(rename_all = "camelCase")]
197pub struct StopResponse {
198    pub success: bool,
199}
200
201#[derive(Debug, Clone, Deserialize, Serialize, TS)]
202#[ts(export, export_to = "../guest-js/bindings/")]
203#[serde(rename_all = "camelCase")]
204pub struct Voice {
205    /// Unique identifier for the voice
206    pub id: String,
207    /// Display name of the voice
208    pub name: String,
209    /// Language code (e.g., "en-US")
210    pub language: String,
211}
212
213#[derive(Debug, Deserialize, Serialize)]
214#[serde(rename_all = "camelCase")]
215pub struct GetVoicesRequest {
216    /// Optional language filter
217    #[serde(default)]
218    pub language: Option<String>,
219}
220
221#[derive(Debug, Clone, Default, Deserialize, Serialize)]
222#[serde(rename_all = "camelCase")]
223pub struct GetVoicesResponse {
224    pub voices: Vec<Voice>,
225}
226
227#[derive(Debug, Clone, Default, Deserialize, Serialize)]
228#[serde(rename_all = "camelCase")]
229pub struct IsSpeakingResponse {
230    pub speaking: bool,
231}
232
233#[derive(Debug, Clone, Default, Deserialize, Serialize)]
234#[serde(rename_all = "camelCase")]
235pub struct IsInitializedResponse {
236    /// Whether the TTS engine is initialized and ready
237    pub initialized: bool,
238    /// Number of available voices (0 if not initialized)
239    pub voice_count: u32,
240}
241
242#[derive(Debug, Clone, Default, Deserialize, Serialize, TS)]
243#[ts(export, export_to = "../guest-js/bindings/")]
244#[serde(rename_all = "camelCase")]
245pub struct PauseResumeResponse {
246    pub success: bool,
247    /// Reason for failure (if success is false)
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub reason: Option<String>,
250}
251
252#[derive(Debug, Deserialize, Serialize)]
253#[serde(rename_all = "camelCase")]
254pub struct PreviewVoiceRequest {
255    /// Voice ID to preview
256    pub voice_id: String,
257    /// Optional custom sample text (uses default if not provided)
258    #[serde(default)]
259    pub text: Option<String>,
260}
261
262impl PreviewVoiceRequest {
263    pub const DEFAULT_SAMPLE_TEXT: &'static str =
264        "Hello! This is a sample of how this voice sounds.";
265
266    pub fn sample_text(&self) -> Cow<'_, str> {
267        match &self.text {
268            Some(text) => Cow::Borrowed(text.as_str()),
269            None => Cow::Borrowed(Self::DEFAULT_SAMPLE_TEXT),
270        }
271    }
272
273    pub fn validate(&self) -> Result<(), ValidationError> {
274        // Validate voice_id
275        SpeakRequest::validate_voice_id(&self.voice_id)?;
276
277        // Validate custom text if provided
278        if let Some(ref text) = self.text {
279            if text.is_empty() {
280                return Err(ValidationError::EmptyText);
281            }
282            if text.len() > MAX_TEXT_LENGTH {
283                return Err(ValidationError::TextTooLong {
284                    len: text.len(),
285                    max: MAX_TEXT_LENGTH,
286                });
287            }
288        }
289
290        Ok(())
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_speak_request_defaults() {
300        let json = r#"{"text": "Hello world"}"#;
301        let request: SpeakRequest = serde_json::from_str(json).unwrap();
302
303        assert_eq!(request.text, "Hello world");
304        assert!(request.language.is_none());
305        assert!(request.voice_id.is_none());
306        assert_eq!(request.rate, 1.0);
307        assert_eq!(request.pitch, 1.0);
308        assert_eq!(request.volume, 1.0);
309    }
310
311    #[test]
312    fn test_speak_request_full() {
313        let json = r#"{
314            "text": "Olá",
315            "language": "pt-BR",
316            "voiceId": "com.apple.voice.enhanced.pt-BR",
317            "rate": 0.8,
318            "pitch": 1.2,
319            "volume": 0.9
320        }"#;
321
322        let request: SpeakRequest = serde_json::from_str(json).unwrap();
323        assert_eq!(request.text, "Olá");
324        assert_eq!(request.language, Some("pt-BR".to_string()));
325        assert_eq!(
326            request.voice_id,
327            Some("com.apple.voice.enhanced.pt-BR".to_string())
328        );
329        assert_eq!(request.rate, 0.8);
330        assert_eq!(request.pitch, 1.2);
331        assert_eq!(request.volume, 0.9);
332    }
333
334    #[test]
335    fn test_voice_serialization() {
336        let voice = Voice {
337            id: "test-voice".to_string(),
338            name: "Test Voice".to_string(),
339            language: "en-US".to_string(),
340        };
341
342        let json = serde_json::to_string(&voice).unwrap();
343        assert!(json.contains("\"id\":\"test-voice\""));
344        assert!(json.contains("\"name\":\"Test Voice\""));
345        assert!(json.contains("\"language\":\"en-US\""));
346    }
347
348    #[test]
349    fn test_get_voices_request_optional_language() {
350        let json1 = r#"{}"#;
351        let request1: GetVoicesRequest = serde_json::from_str(json1).unwrap();
352        assert!(request1.language.is_none());
353
354        let json2 = r#"{"language": "en"}"#;
355        let request2: GetVoicesRequest = serde_json::from_str(json2).unwrap();
356        assert_eq!(request2.language, Some("en".to_string()));
357    }
358
359    #[test]
360    fn test_validation_empty_text() {
361        let request = SpeakRequest {
362            text: "".to_string(),
363            language: None,
364            voice_id: None,
365            rate: 1.0,
366            pitch: 1.0,
367            volume: 1.0,
368            queue_mode: QueueMode::Flush,
369        };
370
371        let result = request.validate();
372        assert!(result.is_err());
373        assert!(matches!(result.unwrap_err(), ValidationError::EmptyText));
374    }
375
376    #[test]
377    fn test_validation_text_too_long() {
378        let long_text = "x".repeat(MAX_TEXT_LENGTH + 1);
379        let request = SpeakRequest {
380            text: long_text,
381            language: None,
382            voice_id: None,
383            rate: 1.0,
384            pitch: 1.0,
385            volume: 1.0,
386            queue_mode: QueueMode::Flush,
387        };
388
389        let result = request.validate();
390        assert!(result.is_err());
391        assert!(matches!(
392            result.unwrap_err(),
393            ValidationError::TextTooLong { .. }
394        ));
395    }
396
397    #[test]
398    fn test_validation_valid_voice_id() {
399        let request = SpeakRequest {
400            text: "Hello".to_string(),
401            language: None,
402            voice_id: Some("com.apple.voice.enhanced.en-US".to_string()),
403            rate: 1.0,
404            pitch: 1.0,
405            volume: 1.0,
406            queue_mode: QueueMode::Flush,
407        };
408
409        let result = request.validate();
410        assert!(result.is_ok());
411        assert_eq!(
412            result.unwrap().voice_id,
413            Some("com.apple.voice.enhanced.en-US".to_string())
414        );
415    }
416
417    #[test]
418    fn test_validation_invalid_voice_id_special_chars() {
419        let request = SpeakRequest {
420            text: "Hello".to_string(),
421            language: None,
422            voice_id: Some("voice'; DROP TABLE--".to_string()),
423            rate: 1.0,
424            pitch: 1.0,
425            volume: 1.0,
426            queue_mode: QueueMode::Flush,
427        };
428
429        let result = request.validate();
430        assert!(result.is_err());
431        assert!(matches!(
432            result.unwrap_err(),
433            ValidationError::InvalidVoiceId
434        ));
435    }
436
437    #[test]
438    fn test_validation_voice_id_too_long() {
439        let long_voice_id = "x".repeat(MAX_VOICE_ID_LENGTH + 1);
440        let request = SpeakRequest {
441            text: "Hello".to_string(),
442            language: None,
443            voice_id: Some(long_voice_id),
444            rate: 1.0,
445            pitch: 1.0,
446            volume: 1.0,
447            queue_mode: QueueMode::Flush,
448        };
449
450        let result = request.validate();
451        assert!(result.is_err());
452        assert!(matches!(
453            result.unwrap_err(),
454            ValidationError::VoiceIdTooLong { .. }
455        ));
456    }
457
458    #[test]
459    fn test_validation_rate_clamping() {
460        let request = SpeakRequest {
461            text: "Hello".to_string(),
462            language: None,
463            voice_id: None,
464            rate: 999.0,
465            pitch: 1.0,
466            volume: 1.0,
467            queue_mode: QueueMode::Flush,
468        };
469
470        let result = request.validate();
471        assert!(result.is_ok());
472        let validated = result.unwrap();
473        assert_eq!(validated.rate, 4.0); // Clamped to max
474    }
475
476    #[test]
477    fn test_validation_pitch_clamping() {
478        let request = SpeakRequest {
479            text: "Hello".to_string(),
480            language: None,
481            voice_id: None,
482            rate: 1.0,
483            pitch: 0.1,
484            volume: 1.0,
485            queue_mode: QueueMode::Flush,
486        };
487
488        let result = request.validate();
489        assert!(result.is_ok());
490        let validated = result.unwrap();
491        assert_eq!(validated.pitch, 0.5); // Clamped to min
492    }
493
494    #[test]
495    fn test_validation_volume_clamping() {
496        let request = SpeakRequest {
497            text: "Hello".to_string(),
498            language: None,
499            voice_id: None,
500            rate: 1.0,
501            pitch: 1.0,
502            volume: 5.0,
503            queue_mode: QueueMode::Flush,
504        };
505
506        let result = request.validate();
507        assert!(result.is_ok());
508        let validated = result.unwrap();
509        assert_eq!(validated.volume, 1.0); // Clamped to max
510    }
511
512    #[test]
513    fn test_preview_voice_validation() {
514        // Valid preview
515        let valid = PreviewVoiceRequest {
516            voice_id: "valid-voice_123".to_string(),
517            text: None,
518        };
519        assert!(valid.validate().is_ok());
520
521        // Invalid voice_id
522        let invalid = PreviewVoiceRequest {
523            voice_id: "invalid<script>".to_string(),
524            text: None,
525        };
526        assert!(invalid.validate().is_err());
527    }
528
529    #[test]
530    fn test_preview_voice_sample_text() {
531        let without_text = PreviewVoiceRequest {
532            voice_id: "voice".to_string(),
533            text: None,
534        };
535        assert_eq!(
536            without_text.sample_text(),
537            PreviewVoiceRequest::DEFAULT_SAMPLE_TEXT
538        );
539
540        let with_text = PreviewVoiceRequest {
541            voice_id: "voice".to_string(),
542            text: Some("Custom sample".to_string()),
543        };
544        assert_eq!(with_text.sample_text(), "Custom sample");
545    }
546}