1use crate::{
2 event::{EventSender, SessionEvent},
3 media::AudioFrame,
4 media::Samples,
5 media::{
6 cache,
7 codecs::bytes_to_samples,
8 processor::ProcessorChain,
9 track::{Track, TrackConfig, TrackId, TrackPacketSender},
10 },
11 synthesis::{
12 Subtitle, SynthesisClient, SynthesisCommand, SynthesisCommandReceiver,
13 SynthesisCommandSender, SynthesisEvent, bytes_size_to_duration,
14 },
15};
16use anyhow::{Result, anyhow};
17use async_trait::async_trait;
18use base64::{Engine, prelude::BASE64_STANDARD};
19use bytes::{Bytes, BytesMut};
20use futures::StreamExt;
21use std::{
22 collections::{HashMap, VecDeque},
23 sync::{
24 Arc,
25 atomic::{AtomicBool, Ordering},
26 },
27};
28use tokio::{
29 sync::{Mutex, mpsc},
30 time::{Duration, Instant},
31};
32use tokio_util::sync::CancellationToken;
33use tracing::{debug, error, info, warn};
34
35pub struct SynthesisHandle {
36 pub play_id: Option<String>,
37 pub command_tx: SynthesisCommandSender,
38}
39
40struct EmitEntry {
41 chunks: VecDeque<Bytes>,
42 finished: bool,
43 finish_at: Instant,
44}
45
46struct Metadata {
47 cache_key: String,
48 text: String,
49 first_chunk: bool,
50 chunks: Vec<Bytes>,
51 subtitles: Vec<Subtitle>,
52 total_bytes: usize,
53 emitted_bytes: usize,
54 recv_time: u64,
55}
56
57impl Default for Metadata {
58 fn default() -> Self {
59 Self {
60 cache_key: String::new(),
61 text: String::new(),
62 chunks: Vec::new(),
63 first_chunk: true,
64 subtitles: Vec::new(),
65 total_bytes: 0,
66 emitted_bytes: 0,
67 recv_time: 0,
68 }
69 }
70}
71
72struct TtsTask {
74 ssrc: u32,
75 play_id: Option<String>,
76 track_id: TrackId,
77 session_id: String,
78 client: Box<dyn SynthesisClient>,
79 command_rx: SynthesisCommandReceiver,
80 packet_sender: TrackPacketSender,
81 event_sender: EventSender,
82 cancel_token: CancellationToken,
83 processor_chain: ProcessorChain,
84 cache_enabled: bool,
85 sample_rate: u32,
86 ptime: Duration,
87 cache_buffer: BytesMut,
88 emit_q: VecDeque<EmitEntry>,
89 metadatas: HashMap<usize, Metadata>,
91 cur_seq: usize,
93 streaming: bool,
94 graceful: Arc<AtomicBool>,
95}
96
97impl TtsTask {
98 async fn run(mut self) -> Result<()> {
99 let mut stream;
100 match self.client.start().await {
101 Ok(s) => stream = s,
102 Err(e) => {
103 error!(
104 session_id = %self.session_id,
105 track_id = %self.track_id,
106 play_id = ?self.play_id,
107 provider = %self.client.provider(),
108 error = %e,
109 "failed to start tts task"
110 );
111 return Err(e);
112 }
113 };
114
115 info!(
116 session_id = %self.session_id,
117 track_id = %self.track_id,
118 play_id = ?self.play_id,
119 streaming = self.streaming,
120 provider = %self.client.provider(),
121 "tts task started"
122 );
123 let start_time = crate::media::get_timestamp();
124 let mut cmd_seq = if self.streaming { None } else { Some(0) };
126 let mut cmd_finished = false;
127 let mut tts_finished = false;
128 let mut cancel_received = false;
129 let sample_rate = self.sample_rate;
130 let packet_duration_ms = self.ptime.as_millis();
131 let capacity = sample_rate as usize * packet_duration_ms as usize / 500;
133 let mut ptimer = tokio::time::interval(self.ptime);
134 let mut samples = vec![0u8; capacity];
136 while !cmd_finished || !tts_finished || !self.emit_q.is_empty() {
138 tokio::select! {
139 biased;
140 _ = self.cancel_token.cancelled(), if !cancel_received => {
141 cancel_received = true;
142 let graceful = self.graceful.load(Ordering::Relaxed);
143 let emitted_bytes = self.metadatas.get(&self.cur_seq).map(|entry| entry.emitted_bytes).unwrap_or(0);
144 let total_bytes = self.metadatas.get(&self.cur_seq).map(|entry| entry.total_bytes).unwrap_or(0);
145 debug!(
146 session_id = %self.session_id,
147 track_id = %self.track_id,
148 play_id = ?self.play_id,
149 cur_seq = self.cur_seq,
150 emitted_bytes,
151 total_bytes,
152 graceful,
153 streaming = self.streaming,
154 "tts task cancelled"
155 );
156 self.handle_interrupt();
157 if self.streaming || !graceful || emitted_bytes == 0 {
161 break;
162 }
163
164 cmd_finished = true;
166 self.client.stop().await?;
167 }
168 _ = ptimer.tick() => {
169 samples.fill(0);
170 let mut i = 0;
171 while i < capacity && !self.emit_q.is_empty(){
173 let first_entry = &mut self.emit_q[0];
175
176 while i < capacity && !first_entry.chunks.is_empty() {
178 let first_chunk = &mut first_entry.chunks[0];
179 let remaining = capacity - i;
180 let available = first_chunk.len();
181 let len = usize::min(remaining, available);
182 let cut = first_chunk.split_to(len);
183 samples[i..i+len].copy_from_slice(&cut);
184 i += len;
185 self.metadatas.get_mut(&self.cur_seq).map(|entry| {
186 entry.emitted_bytes += len;
187 });
188 if first_chunk.is_empty() {
189 first_entry.chunks.pop_front();
190 }
191 }
192
193 if first_entry.chunks.is_empty(){
194 let elapsed = first_entry.finish_at.elapsed();
195 if self.streaming && cmd_finished && (tts_finished || elapsed > Duration::from_secs(10)) {
196 debug!(
197 session_id = %self.session_id,
198 track_id = %self.track_id,
199 play_id = ?self.play_id,
200 tts_finished,
201 elapsed_ms = elapsed.as_millis(),
202 "tts streaming finished"
203 );
204 tts_finished = true;
205 self.emit_q.clear();
206 continue;
207 }
208
209 if !self.streaming && (first_entry.finished || elapsed > Duration::from_secs(3))
210 {
211 debug!(
212 session_id = %self.session_id,
213 track_id = %self.track_id,
214 play_id = ?self.play_id,
215 cur_seq = self.cur_seq,
216 entry_finished = first_entry.finished,
217 elapsed_ms = elapsed.as_millis(),
218 "tts entry finished"
219 );
220
221 self.emit_q.pop_front();
222 self.cur_seq += 1;
223
224 if self.graceful.load(Ordering::Relaxed) {
226 self.emit_q.clear();
227 }
228
229 continue;
231 }
232
233 break;
235 }
236 }
237
238 if i == 0 && self.cur_seq == 0 && self.metadatas.get(&self.cur_seq).map(|entry| entry.emitted_bytes).unwrap_or(0) == 0 {
240 continue;
241 }
242
243 let samples = Samples::PCM{
244 samples: bytes_to_samples(&samples[..]),
245 };
246
247 let mut frame = AudioFrame {
248 track_id: self.track_id.clone(),
249 samples,
250 timestamp: crate::media::get_timestamp(),
251 sample_rate,
252 };
253
254 if let Err(e) = self.processor_chain.process_frame(&mut frame) {
255 warn!(
256 session_id = %self.session_id,
257 track_id = %self.track_id,
258 play_id = ?self.play_id,
259 error = %e,
260 "error processing audio frame"
261 );
262 break;
263 }
264
265 if let Err(_) = self.packet_sender.send(frame) {
266 warn!(
267 session_id = %self.session_id,
268 track_id = %self.track_id,
269 play_id = ?self.play_id,
270 "track packet sender closed, stopping task"
271 );
272 break;
273 }
274 }
275 cmd = self.command_rx.recv(), if !cmd_finished => {
276 if let Some(cmd) = cmd.as_ref() {
277 self.handle_cmd(cmd, cmd_seq).await;
278 cmd_seq.as_mut().map(|seq| *seq += 1);
279 }
280
281 if cmd.is_none() || cmd.as_ref().map(|c| c.end_of_stream).unwrap_or(false) {
283 debug!(
284 session_id = %self.session_id,
285 track_id = %self.track_id,
286 play_id = ?self.play_id,
287 cmd_seq = ?cmd_seq,
288 end_of_stream = cmd.as_ref().map(|c| c.end_of_stream).unwrap_or(true),
289 "tts command finished"
290 );
291 cmd_finished = true;
292 self.client.stop().await?;
293 }
294 }
295 item = stream.next(), if !tts_finished => {
296 if let Some((cmd_seq, res)) = item {
297 self.handle_event(cmd_seq, res).await
298 }else{
299 debug!(
300 session_id = %self.session_id,
301 track_id = %self.track_id,
302 play_id = ?self.play_id,
303 "tts event stream finished"
304 );
305 tts_finished = true;
306 }
307 }
308 }
309 }
310
311 let (emitted_bytes, total_bytes) = self.metadatas.values().fold((0, 0), |(a, b), entry| {
312 (a + entry.emitted_bytes, b + entry.total_bytes)
313 });
314
315 let duration_ms = (crate::media::get_timestamp() - start_time) as f64 / 1000.0;
316 info!(
317 session_id = %self.session_id,
318 track_id = %self.track_id,
319 play_id = ?self.play_id,
320 cur_seq = self.cur_seq,
321 cmd_seq = ?cmd_seq,
322 cmd_finished,
323 tts_finished,
324 streaming = self.streaming,
325 emitted_bytes,
326 total_bytes,
327 duration_ms,
328 provider = %self.client.provider(),
329 "tts task finished"
330 );
331
332 self.event_sender
333 .send(SessionEvent::TrackEnd {
334 track_id: self.track_id.clone(),
335 timestamp: crate::media::get_timestamp(),
336 duration: crate::media::get_timestamp() - start_time,
337 ssrc: self.ssrc,
338 play_id: self.play_id.clone(),
339 })
340 .ok();
341 Ok(())
342 }
343
344 async fn handle_cmd(&mut self, cmd: &SynthesisCommand, cmd_seq: Option<usize>) {
345 let session_id = self.session_id.clone();
346 let track_id = self.track_id.clone();
347 let play_id = self.play_id.clone();
348 let streaming = self.streaming;
349 debug!(
350 session_id = %session_id,
351 track_id = %self.track_id,
352 play_id = ?play_id,
353 cmd_seq = ?cmd_seq,
354 text_preview = %cmd.text.chars().take(20).collect::<String>(),
355 text_length = cmd.text.len(),
356 base64 = cmd.base64,
357 end_of_stream = cmd.end_of_stream,
358 "tts track: received command"
359 );
360 let text = &cmd.text;
361
362 let assume_seq = cmd_seq.unwrap_or(0);
364 let meta_entry = self.metadatas.entry(assume_seq).or_default();
365 meta_entry.text = text.clone();
366 meta_entry.recv_time = crate::media::get_timestamp();
367
368 let emit_entry = self.get_emit_entry_mut(assume_seq);
369
370 if text.is_empty() {
374 if !streaming {
375 emit_entry.map(|entry| entry.finished = true);
376 }
377 return;
378 }
379
380 if cmd.base64 {
381 match BASE64_STANDARD.decode(text) {
382 Ok(bytes) => {
383 emit_entry.map(|entry| {
384 entry.chunks.push_back(Bytes::from(bytes));
385 entry.finished = true;
386 });
387 }
388 Err(e) => {
389 warn!(
390 session_id = %session_id,
391 track_id = %track_id,
392 play_id = ?play_id,
393 cmd_seq = ?cmd_seq,
394 error = %e,
395 "failed to decode base64 text"
396 );
397 emit_entry.map(|entry| entry.finished = true);
398 }
399 }
400 return;
401 }
402
403 if self.cache_enabled && self.handle_cache(&cmd, assume_seq).await {
404 return;
405 }
406
407 if let Err(e) = self
408 .client
409 .synthesize(&text, cmd_seq, Some(cmd.option.clone()))
410 .await
411 {
412 warn!(
413 session_id = %session_id,
414 track_id = %track_id,
415 play_id = ?play_id,
416 cmd_seq = ?cmd_seq,
417 text_length = text.len(),
418 provider = %self.client.provider(),
419 error = %e,
420 "failed to synthesize text"
421 );
422 }
423 }
424
425 async fn handle_cache(&mut self, cmd: &SynthesisCommand, cmd_seq: usize) -> bool {
427 let cache_key = cache::generate_cache_key(
428 &format!("tts:{}{}", self.client.provider(), cmd.text),
429 self.sample_rate,
430 cmd.option.speaker.as_ref(),
431 cmd.option.speed,
432 );
433
434 self.metadatas.get_mut(&cmd_seq).map(|entry| {
436 entry.cache_key = cache_key.clone();
437 });
438
439 if cache::is_cached(&cache_key).await.unwrap_or_default() {
440 match cache::retrieve_from_cache_with_buffer(&cache_key, &mut self.cache_buffer).await {
441 Ok(()) => {
442 debug!(
443 session_id = %self.session_id,
444 track_id = %self.track_id,
445 play_id = ?self.play_id,
446 cmd_seq,
447 cache_key = %cache_key,
448 text_preview = %cmd.text.chars().take(20).collect::<String>(),
449 "using cached audio"
450 );
451 let bytes = self.cache_buffer.split().freeze();
452 let len = bytes.len();
453
454 self.get_emit_entry_mut(cmd_seq).map(|entry| {
455 entry.chunks.push_back(bytes);
456 entry.finished = true;
457 });
458
459 self.event_sender
460 .send(SessionEvent::Metrics {
461 timestamp: crate::media::get_timestamp(),
462 key: format!("completed.tts.{}", self.client.provider()),
463 data: serde_json::json!({
464 "speaker": cmd.option.speaker,
465 "playId": self.play_id,
466 "cmdSeq": cmd_seq,
467 "length": len,
468 "cached": true,
469 }),
470 duration: 0,
471 })
472 .ok();
473 return true;
474 }
475 Err(e) => {
476 warn!(
477 session_id = %self.session_id,
478 track_id = %self.track_id,
479 play_id = ?self.play_id,
480 cmd_seq,
481 cache_key = %cache_key,
482 error = %e,
483 "error retrieving cached audio"
484 );
485 }
486 }
487 }
488 false
489 }
490
491 async fn handle_event(&mut self, cmd_seq: Option<usize>, event: Result<SynthesisEvent>) {
492 let assume_seq = cmd_seq.unwrap_or(0);
493 match event {
494 Ok(SynthesisEvent::AudioChunk(mut chunk)) => {
495 let entry = self.metadatas.entry(assume_seq).or_default();
496
497 if entry.first_chunk {
498 if chunk.len() > 44 && chunk[..4] == [0x52, 0x49, 0x46, 0x46] {
500 let _ = chunk.split_to(44);
501 }
502 entry.first_chunk = false;
503 }
504
505 entry.total_bytes += chunk.len();
506
507 if self.cache_enabled {
509 entry.chunks.push(chunk.clone());
510 }
511
512 let duration = Duration::from_millis(bytes_size_to_duration(
513 chunk.len(),
514 self.sample_rate,
515 ) as u64);
516 self.get_emit_entry_mut(assume_seq).map(|entry| {
517 entry.chunks.push_back(chunk.clone());
518 entry.finish_at += duration;
519 });
520 }
521 Ok(SynthesisEvent::Subtitles(subtitles)) => {
522 self.metadatas.get_mut(&assume_seq).map(|entry| {
523 entry.subtitles.extend(subtitles);
524 });
525 }
526 Ok(SynthesisEvent::Finished) => {
527 let entry = self.metadatas.entry(assume_seq).or_default();
528 debug!(
529 session_id = %self.session_id,
530 track_id = %self.track_id,
531 play_id = ?self.play_id,
532 cmd_seq = ?cmd_seq,
533 streaming = self.streaming,
534 total_bytes = entry.total_bytes,
535 "tts synthesis completed for command sequence"
536 );
537 self.event_sender
538 .send(SessionEvent::Metrics {
539 timestamp: crate::media::get_timestamp(),
540 key: format!("completed.tts.{}", self.client.provider()),
541 data: serde_json::json!({
542 "playId": self.play_id,
543 "cmdSeq": cmd_seq,
544 "length": entry.total_bytes,
545 "cached": false,
546 }),
547 duration: (crate::media::get_timestamp() - entry.recv_time) as u32,
548 })
549 .ok();
550
551 if self.streaming {
553 return;
554 }
555
556 if self.cache_enabled
558 && !cache::is_cached(&entry.cache_key).await.unwrap_or_default()
559 {
560 if let Err(e) =
561 cache::store_in_cache_vectored(&entry.cache_key, &entry.chunks).await
562 {
563 warn!(
564 session_id = %self.session_id,
565 track_id = %self.track_id,
566 play_id = ?self.play_id,
567 cmd_seq = ?cmd_seq,
568 cache_key = %entry.cache_key,
569 error = %e,
570 "failed to store audio in cache"
571 );
572 } else {
573 debug!(
574 session_id = %self.session_id,
575 track_id = %self.track_id,
576 play_id = ?self.play_id,
577 cmd_seq = ?cmd_seq,
578 cache_key = %entry.cache_key,
579 total_bytes = entry.total_bytes,
580 "stored audio in cache"
581 );
582 }
583 entry.chunks.clear();
584 }
585
586 self.get_emit_entry_mut(assume_seq)
587 .map(|entry| entry.finished = true);
588 }
589 Err(e) => {
590 warn!(
591 session_id = %self.session_id,
592 track_id = %self.track_id,
593 play_id = ?self.play_id,
594 cmd_seq = ?cmd_seq,
595 error = %e,
596 "tts synthesis event error"
597 );
598 self.get_emit_entry_mut(assume_seq)
600 .map(|entry| entry.finished = true);
601 }
602 }
603 }
604
605 fn get_emit_entry_mut(&mut self, cmd_seq: usize) -> Option<&mut EmitEntry> {
608 if cmd_seq < self.cur_seq {
610 debug!(
611 session_id = %self.session_id,
612 track_id = %self.track_id,
613 play_id = ?self.play_id,
614 cmd_seq,
615 cur_seq = self.cur_seq,
616 "ignoring timeout tts result"
617 );
618 return None;
619 }
620
621 let i = cmd_seq - self.cur_seq;
623 if i >= self.emit_q.len() {
624 self.emit_q.resize_with(i + 1, || EmitEntry {
625 chunks: VecDeque::new(),
626 finished: false,
627 finish_at: Instant::now(),
628 });
629 }
630 Some(&mut self.emit_q[i])
631 }
632
633 fn handle_interrupt(&self) {
634 if let Some(entry) = self.metadatas.get(&self.cur_seq) {
635 let current = bytes_size_to_duration(entry.emitted_bytes, self.sample_rate);
636 let total_duration = bytes_size_to_duration(entry.total_bytes, self.sample_rate);
637 let text = entry.text.clone();
638 let mut position = None;
639
640 for subtitle in entry.subtitles.iter().rev() {
641 if subtitle.begin_time < current {
642 position = Some(subtitle.begin_index);
643 break;
644 }
645 }
646
647 let interruption = SessionEvent::Interruption {
648 track_id: self.track_id.clone(),
649 timestamp: crate::media::get_timestamp(),
650 play_id: self.play_id.clone(),
651 subtitle: Some(text),
652 position,
653 total_duration,
654 current,
655 };
656 self.event_sender.send(interruption).ok();
657 }
658 }
659}
660
661pub struct TtsTrack {
662 track_id: TrackId,
663 session_id: String,
664 streaming: bool,
665 play_id: Option<String>,
666 processor_chain: ProcessorChain,
667 config: TrackConfig,
668 cancel_token: CancellationToken,
669 use_cache: bool,
670 command_rx: Mutex<Option<SynthesisCommandReceiver>>,
671 client: Mutex<Option<Box<dyn SynthesisClient>>>,
672 ssrc: u32,
673 graceful: Arc<AtomicBool>,
674}
675
676impl SynthesisHandle {
677 pub fn new(command_tx: SynthesisCommandSender, play_id: Option<String>) -> Self {
678 Self {
679 play_id,
680 command_tx,
681 }
682 }
683 pub fn try_send(
684 &self,
685 cmd: SynthesisCommand,
686 ) -> Result<(), mpsc::error::SendError<SynthesisCommand>> {
687 if self.play_id == cmd.play_id {
688 self.command_tx.send(cmd)
689 } else {
690 Err(mpsc::error::SendError(cmd))
691 }
692 }
693}
694
695impl TtsTrack {
696 pub fn new(
697 track_id: TrackId,
698 session_id: String,
699 streaming: bool,
700 play_id: Option<String>,
701 command_rx: SynthesisCommandReceiver,
702 client: Box<dyn SynthesisClient>,
703 ) -> Self {
704 let config = TrackConfig::default();
705 Self {
706 track_id,
707 session_id,
708 streaming,
709 play_id,
710 processor_chain: ProcessorChain::new(config.samplerate),
711 config,
712 cancel_token: CancellationToken::new(),
713 command_rx: Mutex::new(Some(command_rx)),
714 use_cache: true,
715 client: Mutex::new(Some(client)),
716 graceful: Arc::new(AtomicBool::new(false)),
717 ssrc: 0,
718 }
719 }
720 pub fn with_ssrc(mut self, ssrc: u32) -> Self {
721 self.ssrc = ssrc;
722 self
723 }
724 pub fn with_config(mut self, config: TrackConfig) -> Self {
725 self.config = config;
726 self
727 }
728
729 pub fn with_cancel_token(mut self, cancel_token: CancellationToken) -> Self {
730 self.cancel_token = cancel_token;
731 self
732 }
733
734 pub fn with_sample_rate(mut self, sample_rate: u32) -> Self {
735 self.config = self.config.with_sample_rate(sample_rate);
736 self.processor_chain = ProcessorChain::new(sample_rate);
737 self
738 }
739
740 pub fn with_ptime(mut self, ptime: Duration) -> Self {
741 self.config = self.config.with_ptime(ptime);
742 self
743 }
744
745 pub fn with_cache_enabled(mut self, use_cache: bool) -> Self {
746 self.use_cache = use_cache;
747 self
748 }
749}
750
751#[async_trait]
752impl Track for TtsTrack {
753 fn ssrc(&self) -> u32 {
754 self.ssrc
755 }
756 fn id(&self) -> &TrackId {
757 &self.track_id
758 }
759 fn config(&self) -> &TrackConfig {
760 &self.config
761 }
762 fn processor_chain(&mut self) -> &mut ProcessorChain {
763 &mut self.processor_chain
764 }
765
766 async fn handshake(&mut self, _offer: String, _timeout: Option<Duration>) -> Result<String> {
767 Ok("".to_string())
768 }
769 async fn update_remote_description(&mut self, _answer: &String) -> Result<()> {
770 Ok(())
771 }
772
773 async fn start(
774 &self,
775 event_sender: EventSender,
776 packet_sender: TrackPacketSender,
777 ) -> Result<()> {
778 let client = self
779 .client
780 .lock()
781 .await
782 .take()
783 .ok_or(anyhow!("Client not found"))?;
784 let command_rx = self
785 .command_rx
786 .lock()
787 .await
788 .take()
789 .ok_or(anyhow!("Command receiver not found"))?;
790
791 let task = TtsTask {
792 play_id: self.play_id.clone(),
793 track_id: self.track_id.clone(),
794 session_id: self.session_id.clone(),
795 client,
796 command_rx,
797 event_sender,
798 packet_sender,
799 cancel_token: self.cancel_token.clone(),
800 processor_chain: self.processor_chain.clone(),
801 cache_enabled: self.use_cache && !self.streaming,
802 sample_rate: self.config.samplerate,
803 ptime: self.config.ptime,
804 cache_buffer: BytesMut::new(),
805 emit_q: VecDeque::new(),
806 metadatas: HashMap::new(),
807 cur_seq: 0,
808 streaming: self.streaming,
809 graceful: self.graceful.clone(),
810 ssrc: self.ssrc,
811 };
812 debug!(
813 session_id = %self.session_id,
814 track_id = %self.track_id,
815 play_id = ?self.play_id,
816 streaming = self.streaming,
817 "spawning tts task"
818 );
819 tokio::spawn(async move { task.run().await });
820 Ok(())
821 }
822
823 async fn stop(&self) -> Result<()> {
824 self.cancel_token.cancel();
825 Ok(())
826 }
827
828 async fn stop_graceful(&self) -> Result<()> {
829 self.graceful.store(true, Ordering::Relaxed);
830 self.stop().await
831 }
832
833 async fn send_packet(&self, _packet: &AudioFrame) -> Result<()> {
834 Ok(())
835 }
836}