1use super::{
2 asr_processor::AsrProcessor,
3 denoiser::NoiseReducer,
4 processor::Processor,
5 track::{
6 Track,
7 tts::{SynthesisHandle, TtsTrack},
8 },
9 vad::{VADOption, VadProcessor, VadType},
10};
11use crate::{
12 CallOption, EouOption,
13 event::EventSender,
14 media::TrackId,
15 synthesis::{
16 AliyunTtsClient, DeepegramTtsClient, SynthesisClient, SynthesisOption, SynthesisType,
17 TencentCloudTtsBasicClient, TencentCloudTtsClient, VoiceApiTtsClient,
18 },
19 transcription::{
20 AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
21 TranscriptionOption, TranscriptionType, VoiceApiAsrClientBuilder,
22 },
23};
24use anyhow::Result;
25use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
26use tokio::sync::mpsc;
27use tokio_util::sync::CancellationToken;
28
29pub type FnCreateVadProcessor = fn(
30 token: CancellationToken,
31 event_sender: EventSender,
32 option: VADOption,
33) -> Result<Box<dyn Processor>>;
34
35pub type FnCreateEouProcessor = fn(
36 token: CancellationToken,
37 event_sender: EventSender,
38 option: EouOption,
39) -> Result<Box<dyn Processor>>;
40
41pub type FnCreateAsrClient = Box<
42 dyn Fn(
43 TrackId,
44 CancellationToken,
45 TranscriptionOption,
46 EventSender,
47 ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
48 + Send
49 + Sync,
50>;
51pub type FnCreateTtsClient =
52 fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
53
54pub type CreateProcessorsHook = Box<
56 dyn Fn(
57 Arc<StreamEngine>,
58 &dyn Track,
59 CancellationToken,
60 EventSender,
61 CallOption,
62 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
63 + Send
64 + Sync,
65>;
66
67pub struct StreamEngine {
68 vad_creators: HashMap<VadType, FnCreateVadProcessor>,
69 eou_creators: HashMap<String, FnCreateEouProcessor>,
70 asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
71 tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
72 create_processors_hook: Arc<CreateProcessorsHook>,
73}
74
75impl Default for StreamEngine {
76 fn default() -> Self {
77 let mut engine = Self::new();
78 #[cfg(feature = "vad_silero")]
79 engine.register_vad(VadType::Silero, VadProcessor::create_silero);
80 #[cfg(feature = "vad_webrtc")]
81 engine.register_vad(VadType::WebRTC, VadProcessor::create_webrtc);
82 #[cfg(feature = "vad_ten")]
83 engine.register_vad(VadType::Ten, VadProcessor::create_ten);
84 engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
85
86 engine.register_asr(
87 TranscriptionType::TencentCloud,
88 Box::new(TencentCloudAsrClientBuilder::create),
89 );
90 engine.register_asr(
91 TranscriptionType::VoiceApi,
92 Box::new(VoiceApiAsrClientBuilder::create),
93 );
94 engine.register_asr(
95 TranscriptionType::Aliyun,
96 Box::new(AliyunAsrClientBuilder::create),
97 );
98 engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
99 engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
100 engine.register_tts(SynthesisType::VoiceApi, VoiceApiTtsClient::create);
101 engine.register_tts(
102 SynthesisType::Other("tencent_basic".to_string()),
103 TencentCloudTtsBasicClient::create,
104 );
105 engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
106 engine
107 }
108}
109
110impl StreamEngine {
111 pub fn new() -> Self {
112 Self {
113 vad_creators: HashMap::new(),
114 asr_creators: HashMap::new(),
115 tts_creators: HashMap::new(),
116 eou_creators: HashMap::new(),
117 create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
118 }
119 }
120
121 pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
122 self.vad_creators.insert(vad_type, creator);
123 self
124 }
125
126 pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
127 self.eou_creators.insert(name, creator);
128 self
129 }
130
131 pub fn register_asr(
132 &mut self,
133 asr_type: TranscriptionType,
134 creator: FnCreateAsrClient,
135 ) -> &mut Self {
136 self.asr_creators.insert(asr_type, creator);
137 self
138 }
139
140 pub fn register_tts(
141 &mut self,
142 tts_type: SynthesisType,
143 creator: FnCreateTtsClient,
144 ) -> &mut Self {
145 self.tts_creators.insert(tts_type, creator);
146 self
147 }
148
149 pub fn create_vad_processor(
150 &self,
151 token: CancellationToken,
152 event_sender: EventSender,
153 option: VADOption,
154 ) -> Result<Box<dyn Processor>> {
155 let creator = self.vad_creators.get(&option.r#type);
156 if let Some(creator) = creator {
157 creator(token, event_sender, option)
158 } else {
159 Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
160 }
161 }
162 pub fn create_eou_processor(
163 &self,
164 token: CancellationToken,
165 event_sender: EventSender,
166 option: EouOption,
167 ) -> Result<Box<dyn Processor>> {
168 let creator = self
169 .eou_creators
170 .get(&option.r#type.clone().unwrap_or_default());
171 if let Some(creator) = creator {
172 creator(token, event_sender, option)
173 } else {
174 Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
175 }
176 }
177
178 pub async fn create_asr_processor(
179 &self,
180 track_id: TrackId,
181 cancel_token: CancellationToken,
182 option: TranscriptionOption,
183 event_sender: EventSender,
184 ) -> Result<Box<dyn Processor>> {
185 let asr_client = match option.provider {
186 Some(ref provider) => {
187 let creator = self.asr_creators.get(&provider);
188 if let Some(creator) = creator {
189 creator(track_id, cancel_token, option, event_sender).await?
190 } else {
191 return Err(anyhow::anyhow!("ASR type not found: {}", provider));
192 }
193 }
194 None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
195 };
196 Ok(Box::new(AsrProcessor { asr_client }))
197 }
198
199 pub async fn create_tts_client(
200 &self,
201 streaming: bool,
202 tts_option: &SynthesisOption,
203 ) -> Result<Box<dyn SynthesisClient>> {
204 match tts_option.provider {
205 Some(ref provider) => {
206 let creator = self.tts_creators.get(&provider);
207 if let Some(creator) = creator {
208 creator(streaming, tts_option)
209 } else {
210 Err(anyhow::anyhow!("TTS type not found: {}", provider))
211 }
212 }
213 None => Err(anyhow::anyhow!(
214 "TTS type not found: {:?}",
215 tts_option.provider
216 )),
217 }
218 }
219
220 pub async fn create_processors(
221 engine: Arc<StreamEngine>,
222 track: &dyn Track,
223 cancel_token: CancellationToken,
224 event_sender: EventSender,
225 option: &CallOption,
226 ) -> Result<Vec<Box<dyn Processor>>> {
227 (engine.clone().create_processors_hook)(
228 engine,
229 track,
230 cancel_token,
231 event_sender,
232 option.clone(),
233 )
234 .await
235 }
236
237 pub async fn create_tts_track(
238 engine: Arc<StreamEngine>,
239 cancel_token: CancellationToken,
240 session_id: String,
241 track_id: TrackId,
242 ssrc: u32,
243 play_id: Option<String>,
244 streaming: bool,
245 tts_option: &SynthesisOption,
246 ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
247 let (tx, rx) = mpsc::unbounded_channel();
248 let new_handle = SynthesisHandle::new(tx, play_id.clone());
249 let tts_client = engine.create_tts_client(streaming, tts_option).await?;
250 let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
251 .with_ssrc(ssrc)
252 .with_cancel_token(cancel_token);
253 Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
254 }
255
256 pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
257 self.create_processors_hook = Arc::new(Box::new(hook_fn));
258 self
259 }
260
261 fn default_create_procesors_hook(
262 engine: Arc<StreamEngine>,
263 track: &dyn Track,
264 cancel_token: CancellationToken,
265 event_sender: EventSender,
266 option: CallOption,
267 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
268 let track_id = track.id().clone();
269 let samplerate = track.config().samplerate as usize;
270 Box::pin(async move {
271 let mut processors = vec![];
272 match option.denoise {
273 Some(true) => {
274 let noise_reducer = NoiseReducer::new(samplerate)?;
275 processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
276 }
277 _ => {}
278 }
279 match option.vad {
280 Some(ref option) => {
281 let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
282 cancel_token.child_token(),
283 event_sender.clone(),
284 option.to_owned(),
285 )?;
286 processors.push(vad_processor);
287 }
288 None => {}
289 }
290 match option.asr {
291 Some(ref option) => {
292 let asr_processor = engine
293 .create_asr_processor(
294 track_id,
295 cancel_token.child_token(),
296 option.to_owned(),
297 event_sender.clone(),
298 )
299 .await?;
300 processors.push(asr_processor);
301 }
302 None => {}
303 }
304 match option.eou {
305 Some(ref option) => {
306 let eou_processor = engine.create_eou_processor(
307 cancel_token.child_token(),
308 event_sender.clone(),
309 option.to_owned(),
310 )?;
311 processors.push(eou_processor);
312 }
313 None => {}
314 }
315
316 Ok(processors)
317 })
318 }
319}