voice_engine/media/track/
media_pass.rs

1use crate::{
2    event::{EventSender, SessionEvent},
3    media::AudioFrame,
4    media::Samples,
5    media::TrackId,
6    media::{
7        codecs::{bytes_to_samples, resample::resample_mono, samples_to_bytes},
8        processor::{Processor, ProcessorChain},
9        track::{Track, TrackConfig, TrackPacketSender},
10    },
11};
12use anyhow::Result;
13use async_trait::async_trait;
14use bytes::{Bytes, BytesMut};
15use futures::{SinkExt, StreamExt, stream::SplitSink};
16use serde::{Deserialize, Serialize};
17use std::{
18    sync::Arc,
19    sync::atomic::{AtomicU64, Ordering},
20    time::Duration,
21};
22use tokio::{net::TcpStream, sync::Mutex};
23use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
24use tokio_util::sync::CancellationToken;
25use tracing::{error, info, warn};
26
27type WsConn = WebSocketStream<MaybeTlsStream<TcpStream>>;
28type WsSink = SplitSink<WsConn, Message>;
29
30#[derive(Debug, Deserialize, Serialize, Clone)]
31#[serde(rename_all = "camelCase")]
32pub struct MediaPassOption {
33    url: String,              // websocket url, e.g. ws://localhost:8080/
34    input_sample_rate: u32,   // sample rate of audio receiving from websocket
35    output_sample_rate: u32,  // sample rate of audio sending to websocket server
36    packet_size: Option<u32>, // packet size send to websocket server, default is 3200
37    ptime: Option<u32>, // packet time in milliseconds, if set, buffer and emit at fixed intervals
38}
39
40impl MediaPassOption {
41    pub fn new(
42        url: String,
43        input_sample_rate: u32,
44        output_sample_rate: u32,
45        packet_size: Option<u32>,
46        ptime: Option<u32>,
47    ) -> Self {
48        Self {
49            url,
50            input_sample_rate,
51            output_sample_rate,
52            packet_size,
53            ptime,
54        }
55    }
56}
57
58pub struct MediaPassTrack {
59    session_id: String,
60    track_id: TrackId,
61    cancel_token: CancellationToken,
62    config: TrackConfig, // input sample rate is here, it is 0 if the ptime is None
63    url: String,
64    output_sample_rate: u32, // output sample rate
65    packet_size: u32,
66    // buffer the rtp/webrtc packets, send to websocket server with packet size
67    buffer: Mutex<BytesMut>,
68    ws_sink: Arc<Mutex<Option<WsSink>>>,
69    bytes_sent: Arc<AtomicU64>, // bytes sent to websocket
70    ssrc: u32,
71    processor_chain: ProcessorChain,
72}
73
74impl MediaPassTrack {
75    pub fn new(
76        session_id: String,
77        ssrc: u32,
78        track_id: TrackId,
79        cancel_token: CancellationToken,
80        option: MediaPassOption,
81    ) -> Self {
82        let sample_rate = option.output_sample_rate;
83        let mut config = TrackConfig::default();
84        config = config.with_sample_rate(option.input_sample_rate);
85        config = config.with_ptime(Duration::from_millis(option.ptime.unwrap_or(0) as u64));
86        // for 16000Hz, 20ms ptime, 3200 is 5 packets
87        let packet_size = option.packet_size.unwrap_or(3200);
88        let buffer: Mutex<BytesMut> = Mutex::new(BytesMut::with_capacity(packet_size as usize * 2));
89        Self {
90            session_id,
91            track_id,
92            cancel_token,
93            config,
94            url: option.url,
95            output_sample_rate: sample_rate,
96            packet_size,
97            buffer,
98            ssrc,
99            ws_sink: Arc::new(Mutex::new(None)),
100            bytes_sent: Arc::new(AtomicU64::new(0)),
101            // dummy processor chain, will ignore processor
102            processor_chain: ProcessorChain::new(sample_rate),
103        }
104    }
105}
106
107#[async_trait]
108impl Track for MediaPassTrack {
109    fn ssrc(&self) -> u32 {
110        self.ssrc
111    }
112    fn id(&self) -> &TrackId {
113        &self.track_id
114    }
115    fn config(&self) -> &TrackConfig {
116        &self.config
117    }
118    fn processor_chain(&mut self) -> &mut ProcessorChain {
119        warn!(track_id = %self.track_id, "ignore processor for media pass track");
120        &mut self.processor_chain
121    }
122    fn insert_processor(&mut self, _: Box<dyn Processor>) {
123        warn!(track_id = %self.track_id, "ignore processor for media pass track");
124    }
125    fn append_processor(&mut self, _: Box<dyn Processor>) {
126        warn!(track_id = %self.track_id, "ignore processor for media pass track");
127    }
128    async fn handshake(&mut self, _: String, _: Option<Duration>) -> Result<String> {
129        Ok("".to_string())
130    }
131    async fn update_remote_description(&mut self, _answer: &String) -> Result<()> {
132        Ok(())
133    }
134
135    async fn start(
136        &self,
137        event_sender: EventSender,
138        packet_sender: TrackPacketSender,
139    ) -> Result<()> {
140        let mut url = url::Url::parse(&self.url)?;
141        {
142            let mut query = url.query_pairs_mut();
143            query.append_pair("sample_rate", self.output_sample_rate.to_string().as_str());
144            query.append_pair("packet_size", self.packet_size.to_string().as_str());
145        }
146        info!(
147            session_id = %self.session_id,
148            track_id = %self.track_id,
149            input_sample_rate = self.config.samplerate,
150            output_sample_rate = self.output_sample_rate,
151            packet_size = self.packet_size,
152            ptime_ms = self.config.ptime.as_millis(),
153            "Media pass track starting"
154        );
155        let input_sample_rate = self.config.samplerate;
156        let output_sample_rate = self.output_sample_rate;
157        let (ws_stream, _) = tokio_tungstenite::connect_async(url.as_str()).await?;
158        let (ws_sink, mut ws_source) = ws_stream.split();
159        *self.ws_sink.lock().await = Some(ws_sink);
160        let ws_sink = self.ws_sink.clone();
161        let bytes_sent = self.bytes_sent.clone();
162        let session_id = self.session_id.clone();
163        let ssrc = self.ssrc;
164        let track_id = self.track_id.clone();
165        let start_time = crate::media::get_timestamp();
166        let cancel_token = self.cancel_token.clone();
167        let ptime = self.config.ptime;
168        let ptime_ms = ptime.as_millis() as u32;
169        tokio::spawn(async move {
170            let mut bytes_received = 0u64;
171            let mut bytes_emitted = 0u64;
172
173            // ptimer is polled only if ptime > 0
174            let capacity = input_sample_rate as usize * ptime_ms as usize / 500;
175            let (mut ptimer, mut samples, mut buffer) = if ptime_ms > 0 {
176                (
177                    tokio::time::interval(Duration::from_millis(ptime_ms as u64)),
178                    vec![0u8; capacity],
179                    BytesMut::with_capacity(8 * 1024),
180                )
181            } else {
182                (
183                    tokio::time::interval(Duration::MAX),
184                    Vec::new(),
185                    BytesMut::new(),
186                )
187            };
188
189            loop {
190                tokio::select! {
191                    biased;
192                    _ = cancel_token.cancelled() => {
193                        info!(session_id, "Media pass track cancelled");
194                        break;
195                    }
196                    _ = ptimer.tick(), if ptime_ms > 0 => {
197                        // Fill samples buffer from audio buffer
198                        samples.fill(0);
199                        let mut i = 0;
200
201                        // Fill samples until it's full or there's no more data
202                        while i < capacity && buffer.len() > 0 {
203                            let remaining = capacity - i;
204                            let available = buffer.len();
205                            let len = usize::min(remaining, available);
206                            let cut = buffer.split_to(len);
207                            samples[i..i+len].copy_from_slice(&cut);
208                            i += len;
209                        }
210
211                        // Create frame (will have zeros if not enough data)
212                        let samples_vec = bytes_to_samples(&samples[..]);
213                        let frame = AudioFrame {
214                            track_id: track_id.clone(),
215                            samples: Samples::PCM { samples: samples_vec.clone() },
216                            timestamp: crate::media::get_timestamp(),
217                            sample_rate: input_sample_rate,
218                        };
219
220                        if let Ok(_) = packet_sender.send(frame) {
221                            // only count the actual bytes filled from buffer
222                            bytes_emitted += i as u64;
223                        } else {
224                            warn!(
225                                track_id,
226                                "packet sender closed, stopping emit loop"
227                            );
228                            break;
229                        }
230                    }
231                    msg = ws_source.next() => {
232                        match msg {
233                            Some(Ok(Message::Binary(data))) => {
234                                bytes_received += data.len() as u64;
235                                if ptime_ms > 0 {
236                                    // buffer if ptime was set
237                                    buffer.reserve(data.len());
238                                    buffer.extend_from_slice(&data);
239                                } else {
240                                    // send immediately if ptime was not set
241                                    let samples_vec = bytes_to_samples(&data);
242                                    let frame = AudioFrame {
243                                        track_id: track_id.clone(),
244                                        samples: Samples::PCM { samples: samples_vec.clone() },
245                                        timestamp: crate::media::get_timestamp(),
246                                        sample_rate: input_sample_rate,
247                                    };
248
249                                    if let Ok(_) = packet_sender.send(frame) {
250                                        bytes_emitted += data.len() as u64;
251                                    } else {
252                                        warn!(
253                                            track_id,
254                                            "packet sender closed, stopping emit loop"
255                                        );
256                                        break;
257                                    }
258                                }
259                            }
260                            Some(Ok(Message::Close(res))) => {
261                                warn!(
262                                    track_id,
263                                    close_reason = ?res,
264                                    bytes_received,
265                                    "Media pass track closed by remote"
266                                );
267                                break;
268                            }
269                            Some(Err(e)) => {
270                                error!(
271                                    track_id,
272                                    error = %e,
273                                    bytes_received,
274                                    "Media pass track WebSocket error"
275                                );
276                                let error = SessionEvent::Error {
277                                    track_id: track_id.clone(),
278                                    timestamp: crate::media::get_timestamp(),
279                                    sender: format!("media_pass: {}", url),
280                                    error: format!("Media pass track error: {:?}", e),
281                                    code: None,
282                                };
283                                event_sender.send(error).ok();
284                                break;
285                            }
286                            None => {
287                                info!(
288                                    track_id,
289                                    bytes_received,
290                                    "Media pass track WebSocket stream ended"
291                                );
292                                break;
293                            }
294                            _ => {}
295                        }
296                    }
297                }
298
299                // Break if packet sender is closed
300                if packet_sender.is_closed() {
301                    break;
302                }
303            }
304
305            if let Some(mut ws_sink) = ws_sink.lock().await.take() {
306                ws_sink.close().await.ok();
307            };
308
309            let duration = crate::media::get_timestamp() - start_time;
310            let bytes_sent_to_ws = bytes_sent.load(Ordering::Relaxed);
311            info!(
312                session_id,
313                duration,
314                input_sample_rate,
315                output_sample_rate,
316                bytes_received,
317                bytes_emitted,
318                bytes_sent_to_ws,
319                "Media pass track ended"
320            );
321
322            event_sender
323                .send(SessionEvent::TrackEnd {
324                    track_id,
325                    timestamp: crate::media::get_timestamp(),
326                    duration,
327                    ssrc,
328                    play_id: None,
329                })
330                .ok();
331        });
332        Ok(())
333    }
334
335    async fn stop(&self) -> Result<()> {
336        if let Some(mut ws_sink) = self.ws_sink.lock().await.take() {
337            ws_sink.close().await.ok();
338        }
339        self.cancel_token.cancel();
340        Ok(())
341    }
342
343    async fn send_packet(&self, packet: &AudioFrame) -> Result<()> {
344        if let Some(ws_sink) = self.ws_sink.lock().await.as_mut() {
345            if let Samples::PCM { samples } = &packet.samples {
346                let mut buffer = self.buffer.lock().await;
347                let bytes = samples_to_bytes(samples.as_slice());
348                buffer.reserve(bytes.len());
349                buffer.extend_from_slice(bytes.as_slice());
350                while buffer.len() >= self.packet_size as usize {
351                    let bytes = buffer.split_to(self.packet_size as usize).freeze();
352                    let bytes = if packet.sample_rate == self.output_sample_rate {
353                        bytes
354                    } else {
355                        let samples = bytes_to_samples(&bytes);
356                        let resample =
357                            resample_mono(&samples, packet.sample_rate, self.output_sample_rate);
358                        let bytes = samples_to_bytes(resample.as_slice());
359                        Bytes::copy_from_slice(bytes.as_slice())
360                    };
361                    let bytes_len = bytes.len();
362                    ws_sink.send(Message::Binary(bytes)).await?;
363                    self.bytes_sent
364                        .fetch_add(bytes_len as u64, Ordering::Relaxed);
365                }
366            }
367        }
368        Ok(())
369    }
370}