voice_engine/transcription/
mod.rs

1use crate::event::SessionEvent;
2use crate::media::AudioFrame;
3use crate::media::Sample;
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use tokio::sync::mpsc;
9use tokio_util::sync::CancellationToken;
10use tracing::debug;
11
12mod aliyun;
13mod tencent_cloud;
14mod voiceapi;
15
16pub use aliyun::AliyunAsrClient;
17pub use aliyun::AliyunAsrClientBuilder;
18pub use tencent_cloud::TencentCloudAsrClient;
19pub use tencent_cloud::TencentCloudAsrClientBuilder;
20pub use voiceapi::VoiceApiAsrClient;
21pub use voiceapi::VoiceApiAsrClientBuilder;
22
23/// Common helper function for handling wait_for_answer logic with audio dropping
24pub async fn handle_wait_for_answer_with_audio_drop(
25    event_rx: Option<crate::event::EventReceiver>,
26    audio_rx: &mut mpsc::UnboundedReceiver<Vec<u8>>,
27    token: &CancellationToken,
28) {
29    tokio::select! {
30        _ = token.cancelled() => {
31            debug!("Cancelled before answer");
32        }
33        // drop audio if not started after answer
34        _ = async {
35            while (audio_rx.recv().await).is_some() {}
36        } => {}
37        _ = async {
38            if let Some(mut rx) = event_rx {
39                while let Ok(event) = rx.recv().await {
40                    if let SessionEvent::Answer { .. } = event {
41                        debug!("Received answer event, starting transcription");
42                        break;
43                    }
44                }
45            }
46        } => {
47            debug!("Wait for answer completed");
48        }
49    }
50}
51
52#[derive(Debug, Clone, Serialize, Hash, Eq, PartialEq)]
53pub enum TranscriptionType {
54    #[serde(rename = "tencent")]
55    TencentCloud,
56    #[serde(rename = "voiceapi")]
57    VoiceApi,
58    #[serde(rename = "aliyun")]
59    Aliyun,
60    Other(String),
61}
62
63#[derive(Debug, Clone, Deserialize, Serialize, Default)]
64#[serde(rename_all = "camelCase")]
65#[serde(default)]
66pub struct TranscriptionOption {
67    pub provider: Option<TranscriptionType>,
68    pub language: Option<String>,
69    pub app_id: Option<String>,
70    pub secret_id: Option<String>,
71    pub secret_key: Option<String>,
72    pub model_type: Option<String>,
73    pub buffer_size: Option<usize>,
74    pub samplerate: Option<u32>,
75    pub endpoint: Option<String>,
76    pub extra: Option<HashMap<String, String>>,
77    pub start_when_answer: Option<bool>,
78}
79
80impl std::fmt::Display for TranscriptionType {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            TranscriptionType::TencentCloud => write!(f, "tencent"),
84            TranscriptionType::VoiceApi => write!(f, "voiceapi"),
85            TranscriptionType::Aliyun => write!(f, "aliyun"),
86            TranscriptionType::Other(provider) => write!(f, "{}", provider),
87        }
88    }
89}
90
91impl<'de> Deserialize<'de> for TranscriptionType {
92    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
93    where
94        D: serde::Deserializer<'de>,
95    {
96        let value = String::deserialize(deserializer)?;
97        match value.as_str() {
98            "tencent" => Ok(TranscriptionType::TencentCloud),
99            "voiceapi" => Ok(TranscriptionType::VoiceApi),
100            "aliyun" => Ok(TranscriptionType::Aliyun),
101            _ => Ok(TranscriptionType::Other(value)),
102        }
103    }
104}
105
106impl TranscriptionOption {
107    pub fn check_default(&mut self) {
108        match self.provider {
109            Some(TranscriptionType::TencentCloud) => {
110                if self.app_id.is_none() {
111                    self.app_id = std::env::var("TENCENT_APPID").ok();
112                }
113                if self.secret_id.is_none() {
114                    self.secret_id = std::env::var("TENCENT_SECRET_ID").ok();
115                }
116                if self.secret_key.is_none() {
117                    self.secret_key = std::env::var("TENCENT_SECRET_KEY").ok();
118                }
119            }
120            Some(TranscriptionType::VoiceApi) => {
121                // Set the host from environment variable if not already set
122                if self.endpoint.is_none() {
123                    self.endpoint = std::env::var("VOICEAPI_ENDPOINT").ok();
124                }
125            }
126            Some(TranscriptionType::Aliyun) => {
127                if self.secret_key.is_none() {
128                    self.secret_key = std::env::var("DASHSCOPE_API_KEY").ok();
129                }
130            }
131            _ => {}
132        }
133    }
134}
135pub type TranscriptionSender = mpsc::UnboundedSender<AudioFrame>;
136pub type TranscriptionReceiver = mpsc::UnboundedReceiver<AudioFrame>;
137
138// Unified transcription client trait with async_trait support
139#[async_trait]
140pub trait TranscriptionClient: Send + Sync {
141    fn send_audio(&self, samples: &[Sample]) -> Result<()>;
142}
143
144#[cfg(test)]
145mod tests;