voice_engine/transcription/
mod.rs1use crate::event::SessionEvent;
2use crate::media::AudioFrame;
3use crate::media::Sample;
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use tokio::sync::mpsc;
9use tokio_util::sync::CancellationToken;
10use tracing::debug;
11
12mod aliyun;
13mod tencent_cloud;
14mod voiceapi;
15
16pub use aliyun::AliyunAsrClient;
17pub use aliyun::AliyunAsrClientBuilder;
18pub use tencent_cloud::TencentCloudAsrClient;
19pub use tencent_cloud::TencentCloudAsrClientBuilder;
20pub use voiceapi::VoiceApiAsrClient;
21pub use voiceapi::VoiceApiAsrClientBuilder;
22
23pub async fn handle_wait_for_answer_with_audio_drop(
25 event_rx: Option<crate::event::EventReceiver>,
26 audio_rx: &mut mpsc::UnboundedReceiver<Vec<u8>>,
27 token: &CancellationToken,
28) {
29 tokio::select! {
30 _ = token.cancelled() => {
31 debug!("Cancelled before answer");
32 }
33 _ = async {
35 while (audio_rx.recv().await).is_some() {}
36 } => {}
37 _ = async {
38 if let Some(mut rx) = event_rx {
39 while let Ok(event) = rx.recv().await {
40 if let SessionEvent::Answer { .. } = event {
41 debug!("Received answer event, starting transcription");
42 break;
43 }
44 }
45 }
46 } => {
47 debug!("Wait for answer completed");
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Hash, Eq, PartialEq)]
53pub enum TranscriptionType {
54 #[serde(rename = "tencent")]
55 TencentCloud,
56 #[serde(rename = "voiceapi")]
57 VoiceApi,
58 #[serde(rename = "aliyun")]
59 Aliyun,
60 Other(String),
61}
62
63#[derive(Debug, Clone, Deserialize, Serialize, Default)]
64#[serde(rename_all = "camelCase")]
65#[serde(default)]
66pub struct TranscriptionOption {
67 pub provider: Option<TranscriptionType>,
68 pub language: Option<String>,
69 pub app_id: Option<String>,
70 pub secret_id: Option<String>,
71 pub secret_key: Option<String>,
72 pub model_type: Option<String>,
73 pub buffer_size: Option<usize>,
74 pub samplerate: Option<u32>,
75 pub endpoint: Option<String>,
76 pub extra: Option<HashMap<String, String>>,
77 pub start_when_answer: Option<bool>,
78}
79
80impl std::fmt::Display for TranscriptionType {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 TranscriptionType::TencentCloud => write!(f, "tencent"),
84 TranscriptionType::VoiceApi => write!(f, "voiceapi"),
85 TranscriptionType::Aliyun => write!(f, "aliyun"),
86 TranscriptionType::Other(provider) => write!(f, "{}", provider),
87 }
88 }
89}
90
91impl<'de> Deserialize<'de> for TranscriptionType {
92 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
93 where
94 D: serde::Deserializer<'de>,
95 {
96 let value = String::deserialize(deserializer)?;
97 match value.as_str() {
98 "tencent" => Ok(TranscriptionType::TencentCloud),
99 "voiceapi" => Ok(TranscriptionType::VoiceApi),
100 "aliyun" => Ok(TranscriptionType::Aliyun),
101 _ => Ok(TranscriptionType::Other(value)),
102 }
103 }
104}
105
106impl TranscriptionOption {
107 pub fn check_default(&mut self) {
108 match self.provider {
109 Some(TranscriptionType::TencentCloud) => {
110 if self.app_id.is_none() {
111 self.app_id = std::env::var("TENCENT_APPID").ok();
112 }
113 if self.secret_id.is_none() {
114 self.secret_id = std::env::var("TENCENT_SECRET_ID").ok();
115 }
116 if self.secret_key.is_none() {
117 self.secret_key = std::env::var("TENCENT_SECRET_KEY").ok();
118 }
119 }
120 Some(TranscriptionType::VoiceApi) => {
121 if self.endpoint.is_none() {
123 self.endpoint = std::env::var("VOICEAPI_ENDPOINT").ok();
124 }
125 }
126 Some(TranscriptionType::Aliyun) => {
127 if self.secret_key.is_none() {
128 self.secret_key = std::env::var("DASHSCOPE_API_KEY").ok();
129 }
130 }
131 _ => {}
132 }
133 }
134}
135pub type TranscriptionSender = mpsc::UnboundedSender<AudioFrame>;
136pub type TranscriptionReceiver = mpsc::UnboundedReceiver<AudioFrame>;
137
138#[async_trait]
140pub trait TranscriptionClient: Send + Sync {
141 fn send_audio(&self, samples: &[Sample]) -> Result<()>;
142}
143
144#[cfg(test)]
145mod tests;