voice_engine/media/track/
websocket.rs

1use super::{Track, TrackConfig, TrackPacketSender, track_codec::TrackCodec};
2use crate::{
3    event::{EventSender, SessionEvent},
4    media::AudioFrame,
5    media::Samples,
6    media::TrackId,
7    media::{codecs::bytes_to_samples, processor::ProcessorChain},
8};
9use anyhow::Result;
10use async_trait::async_trait;
11use bytes::Bytes;
12use std::{sync::Mutex, time::Duration};
13use tokio::select;
14use tokio_util::sync::CancellationToken;
15use tracing::{info, warn};
16
17pub type WebsocketBytesSender = tokio::sync::mpsc::UnboundedSender<Bytes>;
18pub type WebsocketBytesReceiver = tokio::sync::mpsc::UnboundedReceiver<Bytes>;
19
20pub struct WebsocketTrack {
21    track_id: TrackId,
22    config: TrackConfig,
23    cancel_token: CancellationToken,
24    processor_chain: ProcessorChain,
25    rx: Mutex<Option<WebsocketBytesReceiver>>,
26    encoder: TrackCodec,
27    payload_type: u8,
28    event_sender: EventSender,
29    ssrc: u32,
30}
31
32impl WebsocketTrack {
33    pub fn new(
34        cancel_token: CancellationToken,
35        track_id: TrackId,
36        track_config: TrackConfig,
37        event_sender: EventSender,
38        audio_receiver: WebsocketBytesReceiver,
39        codec: Option<String>,
40        ssrc: u32,
41    ) -> Self {
42        let processor_chain = ProcessorChain::new(track_config.samplerate);
43        let payload_type = match codec.unwrap_or("pcm".to_string()).to_lowercase().as_str() {
44            "pcmu" => 0,
45            "pcma" => 8,
46            "g722" => 9,
47            _ => u8::MAX, // PCM
48        };
49        Self {
50            track_id,
51            config: track_config,
52            cancel_token,
53            processor_chain,
54            rx: Mutex::new(Some(audio_receiver)),
55            encoder: TrackCodec::new(),
56            payload_type,
57            event_sender,
58            ssrc,
59        }
60    }
61}
62
63#[async_trait]
64impl Track for WebsocketTrack {
65    fn ssrc(&self) -> u32 {
66        self.ssrc
67    }
68    fn id(&self) -> &TrackId {
69        &self.track_id
70    }
71    fn config(&self) -> &TrackConfig {
72        &self.config
73    }
74    fn processor_chain(&mut self) -> &mut ProcessorChain {
75        &mut self.processor_chain
76    }
77
78    async fn handshake(&mut self, _offer: String, _timeout: Option<Duration>) -> Result<String> {
79        Ok("".to_string())
80    }
81    async fn update_remote_description(&mut self, _answer: &String) -> Result<()> {
82        Ok(())
83    }
84
85    async fn start(
86        &self,
87        event_sender: EventSender,
88        packet_sender: TrackPacketSender,
89    ) -> Result<()> {
90        let track_id = self.track_id.clone();
91        let token = self.cancel_token.clone();
92        let mut audio_from_ws = match self.rx.lock().unwrap().take() {
93            Some(rx) => rx,
94            None => {
95                warn!(track_id, "no audio from ws");
96                return Ok(());
97            }
98        };
99        let sample_rate = self.config.samplerate;
100        let payload_type = self.payload_type;
101        let start_time = crate::media::get_timestamp();
102        let ssrc = self.ssrc;
103        tokio::spawn(async move {
104            let track_id_clone = track_id.clone();
105            let audio_from_ws_loop = async move {
106                let mut sequence_number = 0;
107                while let Some(bytes) = audio_from_ws.recv().await {
108                    sequence_number += 1;
109
110                    let samples = match payload_type {
111                        u8::MAX => Samples::PCM {
112                            samples: bytes_to_samples(&bytes.to_vec()),
113                        },
114                        _ => Samples::RTP {
115                            sequence_number,
116                            payload_type,
117                            payload: bytes.to_vec(),
118                        },
119                    };
120
121                    let packet = AudioFrame {
122                        track_id: track_id_clone.clone(),
123                        samples,
124                        timestamp: crate::media::get_timestamp(),
125                        sample_rate,
126                    };
127                    match packet_sender.send(packet) {
128                        Ok(_) => (),
129                        Err(e) => {
130                            warn!("error sending packet: {}", e);
131                            break;
132                        }
133                    }
134                }
135            };
136
137            select! {
138                _ = token.cancelled() => {
139                    info!("RTC process cancelled");
140                },
141                _ = audio_from_ws_loop => {
142                    info!("audio_from_ws_loop");
143                }
144            };
145
146            event_sender
147                .send(SessionEvent::TrackEnd {
148                    track_id,
149                    timestamp: crate::media::get_timestamp(),
150                    duration: crate::media::get_timestamp() - start_time,
151                    ssrc,
152                    play_id: None,
153                })
154                .ok();
155        });
156        Ok(())
157    }
158
159    async fn stop(&self) -> Result<()> {
160        self.cancel_token.cancel();
161        Ok(())
162    }
163
164    async fn send_packet(&self, packet: &AudioFrame) -> Result<()> {
165        let (_, payload) = self.encoder.encode(self.payload_type, packet.clone());
166        if payload.is_empty() {
167            return Ok(());
168        }
169        self.event_sender
170            .send(SessionEvent::Binary {
171                track_id: self.track_id.clone(),
172                timestamp: crate::media::get_timestamp(),
173                data: payload,
174            })
175            .map(|_| ())
176            .map_err(|_| anyhow::anyhow!("error sending binary event"))
177    }
178}