voice_engine/media/
engine.rs

1use super::{
2    asr_processor::AsrProcessor,
3    denoiser::NoiseReducer,
4    processor::Processor,
5    track::{
6        Track,
7        tts::{SynthesisHandle, TtsTrack},
8    },
9    vad::{VADOption, VadProcessor, VadType},
10};
11use crate::{
12    CallOption, EouOption,
13    event::EventSender,
14    media::TrackId,
15    synthesis::{
16        AliyunTtsClient, DeepegramTtsClient, SynthesisClient, SynthesisOption, SynthesisType,
17        TencentCloudTtsBasicClient, TencentCloudTtsClient, VoiceApiTtsClient,
18    },
19    transcription::{
20        AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
21        TranscriptionOption, TranscriptionType, VoiceApiAsrClientBuilder,
22    },
23};
24use anyhow::Result;
25use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
26use tokio::sync::mpsc;
27use tokio_util::sync::CancellationToken;
28
29pub type FnCreateVadProcessor = fn(
30    token: CancellationToken,
31    event_sender: EventSender,
32    option: VADOption,
33) -> Result<Box<dyn Processor>>;
34
35pub type FnCreateEouProcessor = fn(
36    token: CancellationToken,
37    event_sender: EventSender,
38    option: EouOption,
39) -> Result<Box<dyn Processor>>;
40
41pub type FnCreateAsrClient = Box<
42    dyn Fn(
43            TrackId,
44            CancellationToken,
45            TranscriptionOption,
46            EventSender,
47        ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
48        + Send
49        + Sync,
50>;
51pub type FnCreateTtsClient =
52    fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
53
54// Define hook types
55pub type CreateProcessorsHook = Box<
56    dyn Fn(
57            Arc<StreamEngine>,
58            &dyn Track,
59            CancellationToken,
60            EventSender,
61            CallOption,
62        ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
63        + Send
64        + Sync,
65>;
66
67pub struct StreamEngine {
68    vad_creators: HashMap<VadType, FnCreateVadProcessor>,
69    eou_creators: HashMap<String, FnCreateEouProcessor>,
70    asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
71    tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
72    create_processors_hook: Arc<CreateProcessorsHook>,
73}
74
75impl Default for StreamEngine {
76    fn default() -> Self {
77        let mut engine = Self::new();
78        #[cfg(feature = "vad_silero")]
79        engine.register_vad(VadType::Silero, VadProcessor::create_silero);
80        #[cfg(feature = "vad_webrtc")]
81        engine.register_vad(VadType::WebRTC, VadProcessor::create_webrtc);
82        #[cfg(feature = "vad_ten")]
83        engine.register_vad(VadType::Ten, VadProcessor::create_ten);
84        engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
85
86        engine.register_asr(
87            TranscriptionType::TencentCloud,
88            Box::new(TencentCloudAsrClientBuilder::create),
89        );
90        engine.register_asr(
91            TranscriptionType::VoiceApi,
92            Box::new(VoiceApiAsrClientBuilder::create),
93        );
94        engine.register_asr(
95            TranscriptionType::Aliyun,
96            Box::new(AliyunAsrClientBuilder::create),
97        );
98        engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
99        engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
100        engine.register_tts(SynthesisType::VoiceApi, VoiceApiTtsClient::create);
101        engine.register_tts(
102            SynthesisType::Other("tencent_basic".to_string()),
103            TencentCloudTtsBasicClient::create,
104        );
105        engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
106        engine
107    }
108}
109
110impl StreamEngine {
111    pub fn new() -> Self {
112        Self {
113            vad_creators: HashMap::new(),
114            asr_creators: HashMap::new(),
115            tts_creators: HashMap::new(),
116            eou_creators: HashMap::new(),
117            create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
118        }
119    }
120
121    pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
122        self.vad_creators.insert(vad_type, creator);
123        self
124    }
125
126    pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
127        self.eou_creators.insert(name, creator);
128        self
129    }
130
131    pub fn register_asr(
132        &mut self,
133        asr_type: TranscriptionType,
134        creator: FnCreateAsrClient,
135    ) -> &mut Self {
136        self.asr_creators.insert(asr_type, creator);
137        self
138    }
139
140    pub fn register_tts(
141        &mut self,
142        tts_type: SynthesisType,
143        creator: FnCreateTtsClient,
144    ) -> &mut Self {
145        self.tts_creators.insert(tts_type, creator);
146        self
147    }
148
149    pub fn create_vad_processor(
150        &self,
151        token: CancellationToken,
152        event_sender: EventSender,
153        option: VADOption,
154    ) -> Result<Box<dyn Processor>> {
155        let creator = self.vad_creators.get(&option.r#type);
156        if let Some(creator) = creator {
157            creator(token, event_sender, option)
158        } else {
159            Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
160        }
161    }
162    pub fn create_eou_processor(
163        &self,
164        token: CancellationToken,
165        event_sender: EventSender,
166        option: EouOption,
167    ) -> Result<Box<dyn Processor>> {
168        let creator = self
169            .eou_creators
170            .get(&option.r#type.clone().unwrap_or_default());
171        if let Some(creator) = creator {
172            creator(token, event_sender, option)
173        } else {
174            Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
175        }
176    }
177
178    pub async fn create_asr_processor(
179        &self,
180        track_id: TrackId,
181        cancel_token: CancellationToken,
182        option: TranscriptionOption,
183        event_sender: EventSender,
184    ) -> Result<Box<dyn Processor>> {
185        let asr_client = match option.provider {
186            Some(ref provider) => {
187                let creator = self.asr_creators.get(&provider);
188                if let Some(creator) = creator {
189                    creator(track_id, cancel_token, option, event_sender).await?
190                } else {
191                    return Err(anyhow::anyhow!("ASR type not found: {}", provider));
192                }
193            }
194            None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
195        };
196        Ok(Box::new(AsrProcessor { asr_client }))
197    }
198
199    pub async fn create_tts_client(
200        &self,
201        streaming: bool,
202        tts_option: &SynthesisOption,
203    ) -> Result<Box<dyn SynthesisClient>> {
204        match tts_option.provider {
205            Some(ref provider) => {
206                let creator = self.tts_creators.get(&provider);
207                if let Some(creator) = creator {
208                    creator(streaming, tts_option)
209                } else {
210                    Err(anyhow::anyhow!("TTS type not found: {}", provider))
211                }
212            }
213            None => Err(anyhow::anyhow!(
214                "TTS type not found: {:?}",
215                tts_option.provider
216            )),
217        }
218    }
219
220    pub async fn create_processors(
221        engine: Arc<StreamEngine>,
222        track: &dyn Track,
223        cancel_token: CancellationToken,
224        event_sender: EventSender,
225        option: &CallOption,
226    ) -> Result<Vec<Box<dyn Processor>>> {
227        (engine.clone().create_processors_hook)(
228            engine,
229            track,
230            cancel_token,
231            event_sender,
232            option.clone(),
233        )
234        .await
235    }
236
237    pub async fn create_tts_track(
238        engine: Arc<StreamEngine>,
239        cancel_token: CancellationToken,
240        session_id: String,
241        track_id: TrackId,
242        ssrc: u32,
243        play_id: Option<String>,
244        streaming: bool,
245        tts_option: &SynthesisOption,
246    ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
247        let (tx, rx) = mpsc::unbounded_channel();
248        let new_handle = SynthesisHandle::new(tx, play_id.clone());
249        let tts_client = engine.create_tts_client(streaming, tts_option).await?;
250        let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
251            .with_ssrc(ssrc)
252            .with_cancel_token(cancel_token);
253        Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
254    }
255
256    pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
257        self.create_processors_hook = Arc::new(Box::new(hook_fn));
258        self
259    }
260
261    fn default_create_procesors_hook(
262        engine: Arc<StreamEngine>,
263        track: &dyn Track,
264        cancel_token: CancellationToken,
265        event_sender: EventSender,
266        option: CallOption,
267    ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
268        let track_id = track.id().clone();
269        let samplerate = track.config().samplerate as usize;
270        Box::pin(async move {
271            let mut processors = vec![];
272            match option.denoise {
273                Some(true) => {
274                    let noise_reducer = NoiseReducer::new(samplerate)?;
275                    processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
276                }
277                _ => {}
278            }
279            match option.vad {
280                Some(ref option) => {
281                    let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
282                        cancel_token.child_token(),
283                        event_sender.clone(),
284                        option.to_owned(),
285                    )?;
286                    processors.push(vad_processor);
287                }
288                None => {}
289            }
290            match option.asr {
291                Some(ref option) => {
292                    let asr_processor = engine
293                        .create_asr_processor(
294                            track_id,
295                            cancel_token.child_token(),
296                            option.to_owned(),
297                            event_sender.clone(),
298                        )
299                        .await?;
300                    processors.push(asr_processor);
301                }
302                None => {}
303            }
304            match option.eou {
305                Some(ref option) => {
306                    let eou_processor = engine.create_eou_processor(
307                        cancel_token.child_token(),
308                        event_sender.clone(),
309                        option.to_owned(),
310                    )?;
311                    processors.push(eou_processor);
312                }
313                None => {}
314            }
315
316            Ok(processors)
317        })
318    }
319}