voice_engine/synthesis/
aliyun.rs

1use super::{SynthesisClient, SynthesisOption, SynthesisType};
2use crate::synthesis::SynthesisEvent;
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use futures::{
6    FutureExt, SinkExt, Stream, StreamExt, future,
7    stream::{self, BoxStream, SplitSink},
8};
9use serde::{Deserialize, Serialize};
10use serde_with::skip_serializing_none;
11use std::sync::Arc;
12use tokio::{
13    net::TcpStream,
14    sync::{Notify, mpsc},
15};
16use tokio_stream::wrappers::UnboundedReceiverStream;
17use tokio_tungstenite::{
18    MaybeTlsStream, WebSocketStream, connect_async,
19    tungstenite::{self, Message, client::IntoClientRequest},
20};
21use tracing::warn;
22use uuid::Uuid;
23type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
24type WsSink = SplitSink<WsStream, Message>;
25
26/// Aliyun CosyVoice WebSocket API Client
27/// https://help.aliyun.com/zh/model-studio/cosyvoice-websocket-api
28
29#[derive(Debug, Serialize)]
30struct Command {
31    header: CommandHeader,
32    payload: CommandPayload,
33}
34
35#[derive(Debug, Serialize)]
36#[serde(untagged)]
37enum CommandPayload {
38    Run(RunTaskPayload),
39    Continue(ContinueTaskPayload),
40    Finish(FinishTaskPayload),
41}
42
43impl Command {
44    fn run_task(option: &SynthesisOption, task_id: &str) -> Self {
45        let voice = option
46            .speaker
47            .clone()
48            .unwrap_or_else(|| "longyumi_v2".to_string());
49
50        let format = option.codec.as_deref().unwrap_or("pcm");
51
52        let sample_rate = option.samplerate.unwrap_or(16000) as u32;
53        let volume = option.volume.unwrap_or(50) as u32;
54        let rate = option.speed.unwrap_or(1.0);
55
56        Command {
57            header: CommandHeader {
58                action: "run-task".to_string(),
59                task_id: task_id.to_string(),
60                streaming: "duplex".to_string(),
61            },
62            payload: CommandPayload::Run(RunTaskPayload {
63                task_group: "audio".to_string(),
64                task: "tts".to_string(),
65                function: "SpeechSynthesizer".to_string(),
66                model: "cosyvoice-v2".to_string(),
67                parameters: RunTaskParameters {
68                    text_type: "PlainText".to_string(),
69                    voice,
70                    format: Some(format.to_string()),
71                    sample_rate: Some(sample_rate),
72                    volume: Some(volume),
73                    rate: Some(rate),
74                },
75                input: EmptyInput {},
76            }),
77        }
78    }
79
80    fn continue_task(task_id: &str, text: &str) -> Self {
81        Command {
82            header: CommandHeader {
83                action: "continue-task".to_string(),
84                task_id: task_id.to_string(),
85                streaming: "duplex".to_string(),
86            },
87            payload: CommandPayload::Continue(ContinueTaskPayload {
88                input: PayloadInput {
89                    text: text.to_string(),
90                },
91            }),
92        }
93    }
94
95    fn finish_task(task_id: &str) -> Self {
96        Command {
97            header: CommandHeader {
98                action: "finish-task".to_string(),
99                task_id: task_id.to_string(),
100                streaming: "duplex".to_string(),
101            },
102            payload: CommandPayload::Finish(FinishTaskPayload {
103                input: EmptyInput {},
104            }),
105        }
106    }
107}
108
109#[derive(Debug, Serialize)]
110struct CommandHeader {
111    action: String,
112    task_id: String,
113    streaming: String,
114}
115
116#[derive(Debug, Serialize)]
117struct RunTaskPayload {
118    task_group: String,
119    task: String,
120    function: String,
121    model: String,
122    parameters: RunTaskParameters,
123    input: EmptyInput,
124}
125
126#[skip_serializing_none]
127#[derive(Debug, Serialize)]
128struct RunTaskParameters {
129    text_type: String,
130    voice: String,
131    format: Option<String>,
132    sample_rate: Option<u32>,
133    volume: Option<u32>,
134    rate: Option<f32>,
135}
136
137#[derive(Debug, Serialize)]
138struct ContinueTaskPayload {
139    input: PayloadInput,
140}
141
142#[derive(Debug, Serialize, Deserialize)]
143struct PayloadInput {
144    text: String,
145}
146
147#[derive(Debug, Serialize)]
148struct FinishTaskPayload {
149    input: EmptyInput,
150}
151
152#[derive(Debug, Serialize)]
153struct EmptyInput {}
154
155/// WebSocket event response structure
156#[derive(Debug, Deserialize)]
157struct Event {
158    header: EventHeader,
159}
160
161#[allow(dead_code)]
162#[derive(Debug, Deserialize)]
163struct EventHeader {
164    task_id: String,
165    event: String,
166    error_code: Option<String>,
167    error_message: Option<String>,
168}
169
170async fn connect(task_id: String, option: SynthesisOption) -> Result<WsStream> {
171    let api_key = option
172        .secret_key
173        .as_ref()
174        .ok_or_else(|| anyhow!("Aliyun TTS: missing api key"))?;
175    let ws_url = option
176        .endpoint
177        .as_deref()
178        .unwrap_or("wss://dashscope.aliyuncs.com/api-ws/v1/inference");
179
180    let mut request = ws_url.into_client_request()?;
181    let headers = request.headers_mut();
182    headers.insert("Authorization", format!("Bearer {}", api_key).parse()?);
183    headers.insert("X-DashScope-DataInspection", "enable".parse()?);
184
185    let (mut ws_stream, _) = connect_async(request).await?;
186    let run_task_cmd = Command::run_task(&option, task_id.as_str());
187    let run_task_json = serde_json::to_string(&run_task_cmd)?;
188    ws_stream.send(Message::text(run_task_json)).await?;
189    while let Some(message) = ws_stream.next().await {
190        match message {
191            Ok(Message::Text(text)) => {
192                let event = serde_json::from_str::<Event>(&text)?;
193                match event.header.event.as_str() {
194                    "task-started" => {
195                        break;
196                    }
197                    "task-failed" => {
198                        let error_code = event
199                            .header
200                            .error_code
201                            .unwrap_or_else(|| "Unknown error code".to_string());
202                        let error_msg = event
203                            .header
204                            .error_message
205                            .unwrap_or_else(|| "Unknown error message".to_string());
206                        return Err(anyhow!(
207                            "Aliyun TTS Task: {} failed: {}, {}",
208                            task_id,
209                            error_code,
210                            error_msg
211                        ))?;
212                    }
213                    _ => {
214                        warn!("Aliyun TTS Task: {} unexpected event: {:?}", task_id, event);
215                    }
216                }
217            }
218            Ok(Message::Close(_)) => {
219                return Err(anyhow!("Aliyun TTS start failed: closed by server"));
220            }
221            Err(e) => {
222                return Err(anyhow!("Aliyun TTS start failed:: {}", e));
223            }
224            _ => {}
225        }
226    }
227    Ok(ws_stream)
228}
229
230fn event_stream<T>(ws_stream: T) -> impl Stream<Item = Result<SynthesisEvent>> + Send + 'static
231where
232    T: Stream<Item = Result<Message, tungstenite::Error>> + Send + 'static,
233{
234    let notify = Arc::new(Notify::new());
235    let notify_clone = notify.clone();
236    ws_stream
237        .take_until(notify.notified_owned())
238        .filter_map(move |message| {
239            let notify = notify_clone.clone();
240            async move {
241                match message {
242                    Ok(Message::Binary(data)) => Some(Ok(SynthesisEvent::AudioChunk(data))),
243                    Ok(Message::Text(text)) => {
244                        let event: Event =
245                            serde_json::from_str(&text).expect("Aliyun TTS API changed!");
246
247                        match event.header.event.as_str() {
248                            "task-finished" => {
249                                notify.notify_one();
250                                Some(Ok(SynthesisEvent::Finished))
251                            }
252                            "task-failed" => {
253                                let error_code = event
254                                    .header
255                                    .error_code
256                                    .unwrap_or_else(|| "Unknown error code".to_string());
257                                let error_msg = event
258                                    .header
259                                    .error_message
260                                    .unwrap_or_else(|| "Unknown error message".to_string());
261                                notify.notify_one();
262                                Some(Err(anyhow!(
263                                    "Aliyun TTS Task: {} failed: {}, {}",
264                                    event.header.task_id,
265                                    error_code,
266                                    error_msg
267                                )))
268                            }
269                            _ => None,
270                        }
271                    }
272                    Ok(Message::Close(_)) => {
273                        notify.notify_one();
274                        warn!("Aliyun TTS: closed by remote");
275                        None
276                    }
277                    Err(e) => {
278                        notify.notify_one();
279                        Some(Err(anyhow!("Aliyun TTS: websocket error: {:?}", e)))
280                    }
281                    _ => None,
282                }
283            }
284        })
285}
286#[derive(Debug)]
287pub struct StreamingClient {
288    task_id: String,
289    option: SynthesisOption,
290    ws_sink: Option<WsSink>,
291}
292
293impl StreamingClient {
294    pub fn new(option: SynthesisOption) -> Self {
295        Self {
296            task_id: Uuid::new_v4().to_string(),
297            option,
298            ws_sink: None,
299        }
300    }
301}
302
303#[async_trait]
304impl SynthesisClient for StreamingClient {
305    fn provider(&self) -> SynthesisType {
306        SynthesisType::Aliyun
307    }
308
309    async fn start(
310        &mut self,
311    ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
312        let ws_stream = connect(self.task_id.clone(), self.option.clone()).await?;
313        let (ws_sink, ws_source) = ws_stream.split();
314        self.ws_sink.replace(ws_sink);
315        Ok(event_stream(ws_source).map(move |x| (None, x)).boxed())
316    }
317
318    async fn synthesize(
319        &mut self,
320        text: &str,
321        _cmd_seq: Option<usize>,
322        _option: Option<SynthesisOption>,
323    ) -> Result<()> {
324        if let Some(ws_sink) = self.ws_sink.as_mut() {
325            if !text.is_empty() {
326                let continue_task_cmd = Command::continue_task(self.task_id.as_str(), text);
327                let continue_task_json = serde_json::to_string(&continue_task_cmd)?;
328                ws_sink.send(Message::text(continue_task_json)).await?;
329            }
330        } else {
331            return Err(anyhow!("Aliyun TTS Task: not connected"));
332        }
333
334        Ok(())
335    }
336
337    async fn stop(&mut self) -> Result<()> {
338        if let Some(ws_sink) = self.ws_sink.as_mut() {
339            let finish_task_cmd = Command::finish_task(self.task_id.as_str());
340            let finish_task_json = serde_json::to_string(&finish_task_cmd)?;
341            ws_sink.send(Message::text(finish_task_json)).await?;
342        }
343
344        Ok(())
345    }
346}
347
348pub struct NonStreamingClient {
349    option: SynthesisOption,
350    tx: Option<mpsc::UnboundedSender<(String, Option<usize>, Option<SynthesisOption>)>>,
351}
352
353impl NonStreamingClient {
354    pub fn new(option: SynthesisOption) -> Self {
355        Self { option, tx: None }
356    }
357}
358
359#[async_trait]
360impl SynthesisClient for NonStreamingClient {
361    fn provider(&self) -> SynthesisType {
362        SynthesisType::Aliyun
363    }
364
365    async fn start(
366        &mut self,
367    ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
368        let (tx, rx) = mpsc::unbounded_channel();
369        self.tx.replace(tx);
370        let client_option = self.option.clone();
371        let max_concurrent_tasks = client_option.max_concurrent_tasks.unwrap_or(1);
372
373        let stream = UnboundedReceiverStream::new(rx)
374            .flat_map_unordered(max_concurrent_tasks, move |(text, cmd_seq, option)| {
375                let option = client_option.merge_with(option);
376                let task_id = Uuid::new_v4().to_string();
377                let text_clone = text.clone();
378                let task_id_clone = task_id.clone();
379                connect(task_id, option)
380                    .then(async move |res| match res {
381                        Ok(mut ws_stream) => {
382                            let continue_task_cmd =
383                                Command::continue_task(task_id_clone.as_str(), text_clone.as_str());
384                            let continue_task_json = serde_json::to_string(&continue_task_cmd)
385                                .expect("Aliyun TTS API changed!");
386                            ws_stream.send(Message::text(continue_task_json)).await.ok();
387                            let finish_task_cmd = Command::finish_task(task_id_clone.as_str());
388                            let finish_task_json = serde_json::to_string(&finish_task_cmd)
389                                .expect("Aliyun TTS API changed!");
390                            ws_stream.send(Message::text(finish_task_json)).await.ok();
391                            event_stream(ws_stream).boxed()
392                        }
393                        Err(e) => {
394                            warn!("Aliyun TTS: websocket error: {:?}, {:?}", cmd_seq, e);
395                            stream::once(future::ready(Err(e.into()))).boxed()
396                        }
397                    })
398                    .flatten_stream()
399                    .map(move |x| (cmd_seq, x))
400                    .boxed()
401            })
402            .boxed();
403        Ok(stream)
404    }
405
406    async fn synthesize(
407        &mut self,
408        text: &str,
409        cmd_seq: Option<usize>,
410        option: Option<SynthesisOption>,
411    ) -> Result<()> {
412        if let Some(tx) = &self.tx {
413            tx.send((text.to_string(), cmd_seq, option))?;
414        } else {
415            return Err(anyhow!("Aliyun TTS Task: not connected"));
416        }
417        Ok(())
418    }
419
420    async fn stop(&mut self) -> Result<()> {
421        self.tx.take();
422        Ok(())
423    }
424}
425
426pub struct AliyunTtsClient;
427impl AliyunTtsClient {
428    pub fn create(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>> {
429        if streaming {
430            Ok(Box::new(StreamingClient::new(option.clone())))
431        } else {
432            Ok(Box::new(NonStreamingClient::new(option.clone())))
433        }
434    }
435}