1use super::{SynthesisClient, SynthesisOption, SynthesisType};
2use crate::synthesis::SynthesisEvent;
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use futures::{
6 FutureExt, SinkExt, Stream, StreamExt, future,
7 stream::{self, BoxStream, SplitSink},
8};
9use serde::{Deserialize, Serialize};
10use serde_with::skip_serializing_none;
11use std::sync::Arc;
12use tokio::{
13 net::TcpStream,
14 sync::{Notify, mpsc},
15};
16use tokio_stream::wrappers::UnboundedReceiverStream;
17use tokio_tungstenite::{
18 MaybeTlsStream, WebSocketStream, connect_async,
19 tungstenite::{self, Message, client::IntoClientRequest},
20};
21use tracing::warn;
22use uuid::Uuid;
23type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
24type WsSink = SplitSink<WsStream, Message>;
25
26#[derive(Debug, Serialize)]
30struct Command {
31 header: CommandHeader,
32 payload: CommandPayload,
33}
34
35#[derive(Debug, Serialize)]
36#[serde(untagged)]
37enum CommandPayload {
38 Run(RunTaskPayload),
39 Continue(ContinueTaskPayload),
40 Finish(FinishTaskPayload),
41}
42
43impl Command {
44 fn run_task(option: &SynthesisOption, task_id: &str) -> Self {
45 let voice = option
46 .speaker
47 .clone()
48 .unwrap_or_else(|| "longyumi_v2".to_string());
49
50 let format = option.codec.as_deref().unwrap_or("pcm");
51
52 let sample_rate = option.samplerate.unwrap_or(16000) as u32;
53 let volume = option.volume.unwrap_or(50) as u32;
54 let rate = option.speed.unwrap_or(1.0);
55
56 Command {
57 header: CommandHeader {
58 action: "run-task".to_string(),
59 task_id: task_id.to_string(),
60 streaming: "duplex".to_string(),
61 },
62 payload: CommandPayload::Run(RunTaskPayload {
63 task_group: "audio".to_string(),
64 task: "tts".to_string(),
65 function: "SpeechSynthesizer".to_string(),
66 model: "cosyvoice-v2".to_string(),
67 parameters: RunTaskParameters {
68 text_type: "PlainText".to_string(),
69 voice,
70 format: Some(format.to_string()),
71 sample_rate: Some(sample_rate),
72 volume: Some(volume),
73 rate: Some(rate),
74 },
75 input: EmptyInput {},
76 }),
77 }
78 }
79
80 fn continue_task(task_id: &str, text: &str) -> Self {
81 Command {
82 header: CommandHeader {
83 action: "continue-task".to_string(),
84 task_id: task_id.to_string(),
85 streaming: "duplex".to_string(),
86 },
87 payload: CommandPayload::Continue(ContinueTaskPayload {
88 input: PayloadInput {
89 text: text.to_string(),
90 },
91 }),
92 }
93 }
94
95 fn finish_task(task_id: &str) -> Self {
96 Command {
97 header: CommandHeader {
98 action: "finish-task".to_string(),
99 task_id: task_id.to_string(),
100 streaming: "duplex".to_string(),
101 },
102 payload: CommandPayload::Finish(FinishTaskPayload {
103 input: EmptyInput {},
104 }),
105 }
106 }
107}
108
109#[derive(Debug, Serialize)]
110struct CommandHeader {
111 action: String,
112 task_id: String,
113 streaming: String,
114}
115
116#[derive(Debug, Serialize)]
117struct RunTaskPayload {
118 task_group: String,
119 task: String,
120 function: String,
121 model: String,
122 parameters: RunTaskParameters,
123 input: EmptyInput,
124}
125
126#[skip_serializing_none]
127#[derive(Debug, Serialize)]
128struct RunTaskParameters {
129 text_type: String,
130 voice: String,
131 format: Option<String>,
132 sample_rate: Option<u32>,
133 volume: Option<u32>,
134 rate: Option<f32>,
135}
136
137#[derive(Debug, Serialize)]
138struct ContinueTaskPayload {
139 input: PayloadInput,
140}
141
142#[derive(Debug, Serialize, Deserialize)]
143struct PayloadInput {
144 text: String,
145}
146
147#[derive(Debug, Serialize)]
148struct FinishTaskPayload {
149 input: EmptyInput,
150}
151
152#[derive(Debug, Serialize)]
153struct EmptyInput {}
154
155#[derive(Debug, Deserialize)]
157struct Event {
158 header: EventHeader,
159}
160
161#[allow(dead_code)]
162#[derive(Debug, Deserialize)]
163struct EventHeader {
164 task_id: String,
165 event: String,
166 error_code: Option<String>,
167 error_message: Option<String>,
168}
169
170async fn connect(task_id: String, option: SynthesisOption) -> Result<WsStream> {
171 let api_key = option
172 .secret_key
173 .as_ref()
174 .ok_or_else(|| anyhow!("Aliyun TTS: missing api key"))?;
175 let ws_url = option
176 .endpoint
177 .as_deref()
178 .unwrap_or("wss://dashscope.aliyuncs.com/api-ws/v1/inference");
179
180 let mut request = ws_url.into_client_request()?;
181 let headers = request.headers_mut();
182 headers.insert("Authorization", format!("Bearer {}", api_key).parse()?);
183 headers.insert("X-DashScope-DataInspection", "enable".parse()?);
184
185 let (mut ws_stream, _) = connect_async(request).await?;
186 let run_task_cmd = Command::run_task(&option, task_id.as_str());
187 let run_task_json = serde_json::to_string(&run_task_cmd)?;
188 ws_stream.send(Message::text(run_task_json)).await?;
189 while let Some(message) = ws_stream.next().await {
190 match message {
191 Ok(Message::Text(text)) => {
192 let event = serde_json::from_str::<Event>(&text)?;
193 match event.header.event.as_str() {
194 "task-started" => {
195 break;
196 }
197 "task-failed" => {
198 let error_code = event
199 .header
200 .error_code
201 .unwrap_or_else(|| "Unknown error code".to_string());
202 let error_msg = event
203 .header
204 .error_message
205 .unwrap_or_else(|| "Unknown error message".to_string());
206 return Err(anyhow!(
207 "Aliyun TTS Task: {} failed: {}, {}",
208 task_id,
209 error_code,
210 error_msg
211 ))?;
212 }
213 _ => {
214 warn!("Aliyun TTS Task: {} unexpected event: {:?}", task_id, event);
215 }
216 }
217 }
218 Ok(Message::Close(_)) => {
219 return Err(anyhow!("Aliyun TTS start failed: closed by server"));
220 }
221 Err(e) => {
222 return Err(anyhow!("Aliyun TTS start failed:: {}", e));
223 }
224 _ => {}
225 }
226 }
227 Ok(ws_stream)
228}
229
230fn event_stream<T>(ws_stream: T) -> impl Stream<Item = Result<SynthesisEvent>> + Send + 'static
231where
232 T: Stream<Item = Result<Message, tungstenite::Error>> + Send + 'static,
233{
234 let notify = Arc::new(Notify::new());
235 let notify_clone = notify.clone();
236 ws_stream
237 .take_until(notify.notified_owned())
238 .filter_map(move |message| {
239 let notify = notify_clone.clone();
240 async move {
241 match message {
242 Ok(Message::Binary(data)) => Some(Ok(SynthesisEvent::AudioChunk(data))),
243 Ok(Message::Text(text)) => {
244 let event: Event =
245 serde_json::from_str(&text).expect("Aliyun TTS API changed!");
246
247 match event.header.event.as_str() {
248 "task-finished" => {
249 notify.notify_one();
250 Some(Ok(SynthesisEvent::Finished))
251 }
252 "task-failed" => {
253 let error_code = event
254 .header
255 .error_code
256 .unwrap_or_else(|| "Unknown error code".to_string());
257 let error_msg = event
258 .header
259 .error_message
260 .unwrap_or_else(|| "Unknown error message".to_string());
261 notify.notify_one();
262 Some(Err(anyhow!(
263 "Aliyun TTS Task: {} failed: {}, {}",
264 event.header.task_id,
265 error_code,
266 error_msg
267 )))
268 }
269 _ => None,
270 }
271 }
272 Ok(Message::Close(_)) => {
273 notify.notify_one();
274 warn!("Aliyun TTS: closed by remote");
275 None
276 }
277 Err(e) => {
278 notify.notify_one();
279 Some(Err(anyhow!("Aliyun TTS: websocket error: {:?}", e)))
280 }
281 _ => None,
282 }
283 }
284 })
285}
286#[derive(Debug)]
287pub struct StreamingClient {
288 task_id: String,
289 option: SynthesisOption,
290 ws_sink: Option<WsSink>,
291}
292
293impl StreamingClient {
294 pub fn new(option: SynthesisOption) -> Self {
295 Self {
296 task_id: Uuid::new_v4().to_string(),
297 option,
298 ws_sink: None,
299 }
300 }
301}
302
303#[async_trait]
304impl SynthesisClient for StreamingClient {
305 fn provider(&self) -> SynthesisType {
306 SynthesisType::Aliyun
307 }
308
309 async fn start(
310 &mut self,
311 ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
312 let ws_stream = connect(self.task_id.clone(), self.option.clone()).await?;
313 let (ws_sink, ws_source) = ws_stream.split();
314 self.ws_sink.replace(ws_sink);
315 Ok(event_stream(ws_source).map(move |x| (None, x)).boxed())
316 }
317
318 async fn synthesize(
319 &mut self,
320 text: &str,
321 _cmd_seq: Option<usize>,
322 _option: Option<SynthesisOption>,
323 ) -> Result<()> {
324 if let Some(ws_sink) = self.ws_sink.as_mut() {
325 if !text.is_empty() {
326 let continue_task_cmd = Command::continue_task(self.task_id.as_str(), text);
327 let continue_task_json = serde_json::to_string(&continue_task_cmd)?;
328 ws_sink.send(Message::text(continue_task_json)).await?;
329 }
330 } else {
331 return Err(anyhow!("Aliyun TTS Task: not connected"));
332 }
333
334 Ok(())
335 }
336
337 async fn stop(&mut self) -> Result<()> {
338 if let Some(ws_sink) = self.ws_sink.as_mut() {
339 let finish_task_cmd = Command::finish_task(self.task_id.as_str());
340 let finish_task_json = serde_json::to_string(&finish_task_cmd)?;
341 ws_sink.send(Message::text(finish_task_json)).await?;
342 }
343
344 Ok(())
345 }
346}
347
348pub struct NonStreamingClient {
349 option: SynthesisOption,
350 tx: Option<mpsc::UnboundedSender<(String, Option<usize>, Option<SynthesisOption>)>>,
351}
352
353impl NonStreamingClient {
354 pub fn new(option: SynthesisOption) -> Self {
355 Self { option, tx: None }
356 }
357}
358
359#[async_trait]
360impl SynthesisClient for NonStreamingClient {
361 fn provider(&self) -> SynthesisType {
362 SynthesisType::Aliyun
363 }
364
365 async fn start(
366 &mut self,
367 ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
368 let (tx, rx) = mpsc::unbounded_channel();
369 self.tx.replace(tx);
370 let client_option = self.option.clone();
371 let max_concurrent_tasks = client_option.max_concurrent_tasks.unwrap_or(1);
372
373 let stream = UnboundedReceiverStream::new(rx)
374 .flat_map_unordered(max_concurrent_tasks, move |(text, cmd_seq, option)| {
375 let option = client_option.merge_with(option);
376 let task_id = Uuid::new_v4().to_string();
377 let text_clone = text.clone();
378 let task_id_clone = task_id.clone();
379 connect(task_id, option)
380 .then(async move |res| match res {
381 Ok(mut ws_stream) => {
382 let continue_task_cmd =
383 Command::continue_task(task_id_clone.as_str(), text_clone.as_str());
384 let continue_task_json = serde_json::to_string(&continue_task_cmd)
385 .expect("Aliyun TTS API changed!");
386 ws_stream.send(Message::text(continue_task_json)).await.ok();
387 let finish_task_cmd = Command::finish_task(task_id_clone.as_str());
388 let finish_task_json = serde_json::to_string(&finish_task_cmd)
389 .expect("Aliyun TTS API changed!");
390 ws_stream.send(Message::text(finish_task_json)).await.ok();
391 event_stream(ws_stream).boxed()
392 }
393 Err(e) => {
394 warn!("Aliyun TTS: websocket error: {:?}, {:?}", cmd_seq, e);
395 stream::once(future::ready(Err(e.into()))).boxed()
396 }
397 })
398 .flatten_stream()
399 .map(move |x| (cmd_seq, x))
400 .boxed()
401 })
402 .boxed();
403 Ok(stream)
404 }
405
406 async fn synthesize(
407 &mut self,
408 text: &str,
409 cmd_seq: Option<usize>,
410 option: Option<SynthesisOption>,
411 ) -> Result<()> {
412 if let Some(tx) = &self.tx {
413 tx.send((text.to_string(), cmd_seq, option))?;
414 } else {
415 return Err(anyhow!("Aliyun TTS Task: not connected"));
416 }
417 Ok(())
418 }
419
420 async fn stop(&mut self) -> Result<()> {
421 self.tx.take();
422 Ok(())
423 }
424}
425
426pub struct AliyunTtsClient;
427impl AliyunTtsClient {
428 pub fn create(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>> {
429 if streaming {
430 Ok(Box::new(StreamingClient::new(option.clone())))
431 } else {
432 Ok(Box::new(NonStreamingClient::new(option.clone())))
433 }
434 }
435}