voice_engine/media/track/
webrtc.rs

1use super::track_codec::TrackCodec;
2use crate::{
3    IceServer,
4    event::{EventSender, SessionEvent},
5    media::AudioFrame,
6    media::{
7        codecs::CodecType,
8        negotiate::prefer_audio_codec,
9        processor::ProcessorChain,
10        track::{Track, TrackConfig, TrackId, TrackPacketSender},
11    },
12};
13use anyhow::Result;
14use async_trait::async_trait;
15use std::{sync::Arc, time::SystemTime};
16use tokio::time::sleep;
17use tokio::{select, sync::Mutex, time::Duration};
18use tokio_util::sync::CancellationToken;
19use tracing::{debug, info, warn};
20use webrtc::{
21    api::{
22        APIBuilder,
23        media_engine::{
24            MIME_TYPE_G722, MIME_TYPE_PCMA, MIME_TYPE_PCMU, MIME_TYPE_TELEPHONE_EVENT, MediaEngine,
25        },
26        setting_engine::SettingEngine,
27    },
28    ice_transport::{ice_candidate_type::RTCIceCandidateType, ice_server::RTCIceServer},
29    peer_connection::{
30        configuration::RTCConfiguration, peer_connection_state::RTCPeerConnectionState,
31        sdp::session_description::RTCSessionDescription,
32    },
33    rtp_transceiver::{
34        RTCRtpTransceiver,
35        rtp_codec::{RTCRtpCodecCapability, RTCRtpCodecParameters, RTPCodecType},
36        rtp_receiver::RTCRtpReceiver,
37    },
38    track::{track_local::TrackLocal, track_remote::TrackRemote},
39};
40use webrtc::{
41    peer_connection::RTCPeerConnection,
42    track::track_local::track_local_static_sample::TrackLocalStaticSample,
43};
44
45const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(15);
46// Configuration for integrating a webrtc crate track with our WebrtcTrack
47#[derive(Clone)]
48pub struct WebrtcTrackConfig {
49    pub track: Arc<TrackLocalStaticSample>,
50    pub payload_type: u8,
51}
52
53pub struct WebrtcTrack {
54    track_id: TrackId,
55    track_config: TrackConfig,
56    processor_chain: ProcessorChain,
57    packet_sender: Arc<Mutex<Option<TrackPacketSender>>>,
58    cancel_token: CancellationToken,
59    local_track: Option<Arc<TrackLocalStaticSample>>,
60    encoder: TrackCodec,
61    pub prefered_codec: Option<CodecType>,
62    ssrc: u32,
63    pub peer_connection: Option<Arc<RTCPeerConnection>>,
64    pub ice_servers: Option<Vec<IceServer>>,
65    pub external_ip: Option<String>,
66}
67
68impl WebrtcTrack {
69    pub fn create_audio_track(
70        codec: CodecType,
71        stream_id: Option<String>,
72    ) -> Arc<TrackLocalStaticSample> {
73        let stream_id = stream_id.unwrap_or("rustpbx-track".to_string());
74        Arc::new(TrackLocalStaticSample::new(
75            RTCRtpCodecCapability {
76                mime_type: codec.mime_type().to_string(),
77                clock_rate: codec.clock_rate(),
78                channels: codec.channels(),
79                ..Default::default()
80            },
81            "audio".to_string(),
82            stream_id,
83        ))
84    }
85    pub fn get_media_engine(prefered_codec: Option<CodecType>) -> Result<MediaEngine> {
86        let mut media_engine = MediaEngine::default();
87        for codec in vec![
88            #[cfg(feature = "opus")]
89            RTCRtpCodecParameters {
90                capability: RTCRtpCodecCapability {
91                    mime_type: "audio/opus".to_owned(),
92                    clock_rate: 48000,
93                    channels: 2,
94                    sdp_fmtp_line: "minptime=10".to_owned(),
95                    rtcp_feedback: vec![],
96                },
97                payload_type: 111,
98                ..Default::default()
99            },
100            RTCRtpCodecParameters {
101                capability: RTCRtpCodecCapability {
102                    mime_type: MIME_TYPE_G722.to_owned(),
103                    clock_rate: 8000,
104                    channels: 1,
105                    sdp_fmtp_line: "".to_owned(),
106                    rtcp_feedback: vec![],
107                },
108                payload_type: 9,
109                ..Default::default()
110            },
111            RTCRtpCodecParameters {
112                capability: RTCRtpCodecCapability {
113                    mime_type: MIME_TYPE_PCMU.to_owned(),
114                    clock_rate: 8000,
115                    channels: 1,
116                    sdp_fmtp_line: "".to_owned(),
117                    rtcp_feedback: vec![],
118                },
119                payload_type: 0,
120                ..Default::default()
121            },
122            RTCRtpCodecParameters {
123                capability: RTCRtpCodecCapability {
124                    mime_type: MIME_TYPE_PCMA.to_owned(),
125                    clock_rate: 8000,
126                    channels: 1,
127                    sdp_fmtp_line: "".to_owned(),
128                    rtcp_feedback: vec![],
129                },
130                payload_type: 8,
131                ..Default::default()
132            },
133            RTCRtpCodecParameters {
134                capability: RTCRtpCodecCapability {
135                    mime_type: MIME_TYPE_TELEPHONE_EVENT.to_owned(),
136                    clock_rate: 8000,
137                    channels: 1,
138                    sdp_fmtp_line: "".to_owned(),
139                    rtcp_feedback: vec![],
140                },
141                payload_type: 101,
142                ..Default::default()
143            },
144        ] {
145            if let Some(prefered_codec) = prefered_codec {
146                if codec.capability.mime_type == prefered_codec.mime_type() {
147                    media_engine.register_codec(codec, RTPCodecType::Audio)?;
148                }
149            } else {
150                media_engine.register_codec(codec, RTPCodecType::Audio)?;
151            }
152        }
153        Ok(media_engine)
154    }
155
156    pub fn new(
157        cancel_token: CancellationToken,
158        id: TrackId,
159        track_config: TrackConfig,
160        ice_servers: Option<Vec<IceServer>>,
161    ) -> Self {
162        let processor_chain = ProcessorChain::new(track_config.samplerate);
163        Self {
164            track_id: id,
165            track_config,
166            processor_chain,
167            packet_sender: Arc::new(Mutex::new(None)),
168            cancel_token,
169            local_track: None,
170            encoder: TrackCodec::new(),
171            prefered_codec: None,
172            ssrc: 0,
173            peer_connection: None,
174            ice_servers,
175            external_ip: None,
176        }
177    }
178
179    pub fn with_external_ip(mut self, external_ip: String) -> Self {
180        self.external_ip = Some(external_ip);
181        self
182    }
183
184    pub fn with_ssrc(mut self, ssrc: u32) -> Self {
185        self.ssrc = ssrc;
186        self
187    }
188
189    pub fn with_prefered_codec(mut self, codec: Option<CodecType>) -> Self {
190        self.prefered_codec = codec;
191        self
192    }
193
194    async fn create(&mut self) -> Result<()> {
195        let media_engine = Self::get_media_engine(self.prefered_codec)?;
196        let mut setting_engine = SettingEngine::default();
197
198        if let Some(ref external_ip) = self.external_ip {
199            setting_engine.set_nat_1to1_ips(vec![external_ip.clone()], RTCIceCandidateType::Srflx);
200        }
201        let api = APIBuilder::new()
202            .with_setting_engine(setting_engine)
203            .with_media_engine(media_engine)
204            .build();
205
206        let ice_servers = if let Some(ice_servers) = &self.ice_servers {
207            ice_servers
208                .iter()
209                .map(|s| RTCIceServer {
210                    urls: s.urls.clone(),
211                    username: s.username.clone().unwrap_or_default(),
212                    credential: s.credential.clone().unwrap_or_default(),
213                    ..Default::default()
214                })
215                .collect()
216        } else {
217            vec![RTCIceServer {
218                urls: vec!["stun:stun.l.google.com:19302".to_string()],
219                ..Default::default()
220            }]
221        };
222        let config = RTCConfiguration {
223            ice_servers,
224            ..Default::default()
225        };
226
227        let cancel_token = self.cancel_token.clone();
228        let peer_connection = Arc::new(api.new_peer_connection(config).await?);
229        self.peer_connection = Some(peer_connection.clone());
230        let peer_connection_clone = peer_connection.clone();
231
232        let cancel_token_clone = cancel_token.clone();
233        let track_id = self.track_id.clone();
234        peer_connection.on_peer_connection_state_change(Box::new(
235            move |s: RTCPeerConnectionState| {
236                debug!(track_id, "peer connection state changed: {}", s);
237                let cancel_token = cancel_token.clone();
238                let peer_connection_clone = peer_connection_clone.clone();
239                let track_id_clone = track_id.clone();
240                Box::pin(async move {
241                    match s {
242                        RTCPeerConnectionState::Connected => {}
243                        RTCPeerConnectionState::Disconnected
244                        | RTCPeerConnectionState::Closed
245                        | RTCPeerConnectionState::Failed => {
246                            info!(
247                                track_id = track_id_clone,
248                                "peer connection is {}, try to close", s
249                            );
250                            cancel_token.cancel();
251                            peer_connection_clone.close().await.ok();
252                        }
253                        _ => {}
254                    }
255                })
256            },
257        ));
258        let packet_sender = self.packet_sender.clone();
259        let track_id_clone = self.track_id.clone();
260        let processor_chain = self.processor_chain.clone();
261        peer_connection.on_track(Box::new(
262            move |track: Arc<TrackRemote>,
263                  _receiver: Arc<RTCRtpReceiver>,
264                  _transceiver: Arc<RTCRtpTransceiver>| {
265                let track_id_clone = track_id_clone.clone();
266                let packet_sender_clone = packet_sender.clone();
267                let processor_chain = processor_chain.clone();
268                let track_samplerate = match track.codec().payload_type {
269                    9 => 16000,   // G722
270                    111 => 48000, // Opus
271                    _ => 8000,    // PCMU, PCMA, TELEPHONE_EVENT
272                };
273                info!(
274                    track_id=track_id_clone,
275                    "on_track received: {} samplerate: {}",
276                    track.codec().capability.mime_type,
277                    track_samplerate,
278                );
279                let cancel_token_clone = cancel_token_clone.clone();
280                Box::pin(async move {
281                    loop {
282                        select! {
283                            _ = cancel_token_clone.cancelled() => {
284                                info!(track_id=track_id_clone, "track cancelled");
285                                break;
286                            }
287                            Ok((packet, _)) = track.read_rtp() => {
288                                let packet_sender = packet_sender_clone.lock().await;
289                            if let Some(sender) = packet_sender.as_ref() {
290                                let mut frame = AudioFrame {
291                                    track_id: track_id_clone.clone(),
292                                    samples: crate::media::Samples::RTP {
293                                        payload_type: packet.header.payload_type,
294                                        payload: packet.payload.to_vec(),
295                                        sequence_number: packet.header.sequence_number,
296                                    },
297                                    timestamp: crate::media::get_timestamp(),
298                                    sample_rate: track_samplerate,
299                                    ..Default::default()
300                                };
301                                if let Err(e) = processor_chain.process_frame(&mut frame) {
302                                    warn!(track_id=track_id_clone,"Failed to process frame: {}", e);
303                                    break;
304                                }
305                                match sender.send(frame) {
306                                    Ok(_) => {}
307                                    Err(e) => {
308                                        warn!(track_id=track_id_clone,"Failed to send packet: {}", e);
309                                        break;
310                                        }
311                                    }
312                                }
313                            }
314                        }
315                    }
316                })
317            },
318        ));
319
320        #[cfg(feature = "opus")]
321        let codec = self.prefered_codec.clone().unwrap_or(CodecType::Opus);
322
323        #[cfg(not(feature = "opus"))]
324        let codec = self.prefered_codec.clone().unwrap_or(CodecType::G722);
325
326        let track = Self::create_audio_track(codec, Some(self.track_id.clone()));
327        peer_connection
328            .add_track(Arc::clone(&track) as Arc<dyn TrackLocal + Send + Sync>)
329            .await?;
330        self.local_track = Some(track.clone());
331        self.track_config.codec = codec;
332
333        Ok(())
334    }
335
336    pub async fn setup_with_offer(
337        &mut self,
338        offer: String,
339        timeout: Option<Duration>,
340    ) -> Result<RTCSessionDescription> {
341        let remote_desc = RTCSessionDescription::offer(offer)?;
342        if self.prefered_codec.is_none() {
343            let codec = match prefer_audio_codec(&remote_desc.unmarshal()?) {
344                Some(codec) => codec,
345                None => {
346                    return Err(anyhow::anyhow!("No codec found"));
347                }
348            };
349            self.prefered_codec = Some(codec);
350        }
351        self.create().await?;
352
353        let peer_connection = self
354            .peer_connection
355            .as_ref()
356            .ok_or_else(|| anyhow::anyhow!("Peer connection is not created"))?;
357
358        peer_connection.set_remote_description(remote_desc).await?;
359
360        let answer = peer_connection.create_answer(None).await?;
361        let mut gather_complete = peer_connection.gathering_complete_promise().await;
362        peer_connection.set_local_description(answer).await?;
363        select! {
364            _ = gather_complete.recv() => {
365                info!(track_id = self.track_id,"ICE candidate received");
366            }
367            _ = sleep(timeout.unwrap_or(HANDSHAKE_TIMEOUT)) => {
368                warn!(track_id = self.track_id,"wait candidate timeout");
369            }
370        }
371
372        let answer = peer_connection
373            .local_description()
374            .await
375            .ok_or(anyhow::anyhow!("Failed to get local description"))?;
376
377        info!(
378            track_id = self.track_id,
379            codec = ?self.prefered_codec,
380            "set remote description and create answer success"
381        );
382        Ok(answer)
383    }
384
385    pub async fn local_description(&mut self) -> Result<String> {
386        if self.peer_connection.is_none() {
387            self.create().await?;
388            if let Some(peer_connection) = &self.peer_connection {
389                let offer = peer_connection.create_offer(None).await?;
390                peer_connection.set_local_description(offer).await?;
391                peer_connection
392                    .gathering_complete_promise()
393                    .await
394                    .recv()
395                    .await;
396            }
397        }
398        let peer_connection = self
399            .peer_connection
400            .as_ref()
401            .ok_or_else(|| anyhow::anyhow!("Peer connection is not created"))?;
402
403        peer_connection
404            .local_description()
405            .await
406            .ok_or(anyhow::anyhow!("Failed to get local description"))
407            .map(|desc| desc.sdp)
408    }
409}
410
411#[async_trait]
412impl Track for WebrtcTrack {
413    fn ssrc(&self) -> u32 {
414        self.ssrc
415    }
416    fn id(&self) -> &TrackId {
417        &self.track_id
418    }
419    fn config(&self) -> &TrackConfig {
420        &self.track_config
421    }
422    fn processor_chain(&mut self) -> &mut ProcessorChain {
423        &mut self.processor_chain
424    }
425
426    async fn handshake(&mut self, offer: String, timeout: Option<Duration>) -> Result<String> {
427        self.setup_with_offer(offer, timeout)
428            .await
429            .map(|answer| answer.sdp)
430    }
431
432    async fn update_remote_description(&mut self, answer: &String) -> Result<()> {
433        let peer_connection = self
434            .peer_connection
435            .as_ref()
436            .ok_or_else(|| anyhow::anyhow!("Peer connection is not created"))?;
437        let remote_desc = RTCSessionDescription::answer(answer.clone())?;
438        peer_connection.set_remote_description(remote_desc).await?;
439        Ok(())
440    }
441
442    async fn start(
443        &self,
444        event_sender: EventSender,
445        packet_sender: TrackPacketSender,
446    ) -> Result<()> {
447        // Store the packet sender
448        *self.packet_sender.lock().await = Some(packet_sender.clone());
449        let token_clone = self.cancel_token.clone();
450        let event_sender_clone = event_sender.clone();
451        let track_id = self.track_id.clone();
452        let start_time = crate::media::get_timestamp();
453        let ssrc = self.ssrc;
454        tokio::spawn(async move {
455            token_clone.cancelled().await;
456            let _ = event_sender_clone.send(SessionEvent::TrackEnd {
457                track_id,
458                timestamp: crate::media::get_timestamp(),
459                duration: crate::media::get_timestamp() - start_time,
460                ssrc,
461                play_id: None,
462            });
463        });
464
465        Ok(())
466    }
467
468    async fn stop(&self) -> Result<()> {
469        // Cancel all processing
470        self.cancel_token.cancel();
471        Ok(())
472    }
473
474    async fn send_packet(&self, packet: &AudioFrame) -> Result<()> {
475        if self.local_track.is_none() {
476            return Ok(());
477        }
478        let local_track = match self.local_track.as_ref() {
479            Some(track) => track,
480            None => {
481                return Ok(()); // no local track, ignore
482            }
483        };
484
485        let payload_type = self.track_config.codec.payload_type();
486        let (_payload_type, payload) = self.encoder.encode(payload_type, packet.clone());
487        if payload.is_empty() {
488            return Ok(());
489        }
490
491        let sample = webrtc::media::Sample {
492            data: payload.into(),
493            duration: Duration::from_millis(self.track_config.ptime.as_millis() as u64),
494            timestamp: SystemTime::now(),
495            packet_timestamp: packet.timestamp as u32,
496            ..Default::default()
497        };
498        match local_track.write_sample(&sample).await {
499            Ok(_) => {}
500            Err(e) => {
501                warn!("failed to send sample: {}", e);
502                return Err(anyhow::anyhow!("Failed to send sample: {}", e));
503            }
504        }
505        Ok(())
506    }
507}