saorsa_webrtc/
media.rs

1//! Media stream management for WebRTC
2//!
3//! This module handles audio, video, and screen share media streams.
4
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use thiserror::Error;
8use tokio::sync::broadcast;
9use webrtc::track::track_local::track_local_static_sample::TrackLocalStaticSample;
10use webrtc::rtp_transceiver::rtp_codec::RTCRtpCodecCapability;
11use crate::types::MediaType;
12
13/// Media-related errors
14#[derive(Error, Debug)]
15pub enum MediaError {
16    /// Device not found
17    #[error("Device not found: {0}")]
18    DeviceNotFound(String),
19
20    /// Stream error
21    #[error("Stream error: {0}")]
22    StreamError(String),
23
24    /// Configuration error
25    #[error("Configuration error: {0}")]
26    ConfigError(String),
27}
28
29/// Media events
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub enum MediaEvent {
32    /// Device connected
33    DeviceConnected {
34        /// Device identifier
35        device_id: String,
36    },
37    /// Device disconnected
38    DeviceDisconnected {
39        /// Device identifier
40        device_id: String,
41    },
42    /// Stream started
43    StreamStarted {
44        /// Stream identifier
45        stream_id: String,
46    },
47    /// Stream stopped
48    StreamStopped {
49        /// Stream identifier
50        stream_id: String,
51    },
52}
53
54/// Audio device
55#[derive(Debug, Clone)]
56pub struct AudioDevice {
57    /// Device identifier
58    pub id: String,
59    /// Device name
60    pub name: String,
61}
62
63/// Video device
64#[derive(Debug, Clone)]
65pub struct VideoDevice {
66    /// Device identifier
67    pub id: String,
68    /// Device name
69    pub name: String,
70}
71
72/// Audio track
73#[derive(Debug, Clone)]
74pub struct AudioTrack {
75    /// Track identifier
76    pub id: String,
77}
78
79/// Video track
80#[derive(Debug, Clone)]
81pub struct VideoTrack {
82    /// Track identifier
83    pub id: String,
84}
85
86/// WebRTC media track wrapper
87#[derive(Debug, Clone)]
88pub struct WebRtcTrack {
89    /// Local WebRTC track
90    pub track: Arc<TrackLocalStaticSample>,
91    /// Track type
92    pub track_type: MediaType,
93    /// Track ID
94    pub id: String,
95}
96
97/// Media stream
98#[derive(Debug, Clone)]
99pub struct MediaStream {
100    /// Stream identifier
101    pub id: String,
102}
103
104/// Media stream manager
105pub struct MediaStreamManager {
106    event_sender: broadcast::Sender<MediaEvent>,
107    #[allow(dead_code)]
108    audio_devices: Vec<AudioDevice>,
109    #[allow(dead_code)]
110    video_devices: Vec<VideoDevice>,
111    webrtc_tracks: Vec<WebRtcTrack>,
112}
113
114impl MediaStreamManager {
115    /// Create new media stream manager
116    #[must_use]
117    pub fn new() -> Self {
118        let (event_sender, _) = broadcast::channel(100);
119        Self {
120            event_sender,
121            audio_devices: Vec::new(),
122            video_devices: Vec::new(),
123            webrtc_tracks: Vec::new(),
124        }
125    }
126
127    /// Initialize media devices
128    ///
129    /// # Errors
130    ///
131    /// Returns error if device initialization fails
132    pub async fn initialize(&self) -> Result<(), MediaError> {
133        // For now, add some fake devices for testing
134        // In a real implementation, this would enumerate actual hardware devices
135        let audio_device = AudioDevice {
136            id: "default-audio".to_string(),
137            name: "Default Audio Device".to_string(),
138        };
139
140        let video_device = VideoDevice {
141            id: "default-video".to_string(),
142            name: "Default Video Device".to_string(),
143        };
144
145        // Emit device connected events
146        let _ = self.event_sender.send(MediaEvent::DeviceConnected {
147            device_id: audio_device.id.clone(),
148        });
149
150        let _ = self.event_sender.send(MediaEvent::DeviceConnected {
151            device_id: video_device.id.clone(),
152        });
153
154        Ok(())
155    }
156
157    /// Get available audio devices
158    #[must_use]
159    pub fn get_audio_devices(&self) -> &[AudioDevice] {
160        // Return empty for now, as we can't enumerate real devices easily
161        // In a real implementation, this would return actual devices
162        &[]
163    }
164
165    /// Get available video devices
166    #[must_use]
167    pub fn get_video_devices(&self) -> &[VideoDevice] {
168        // Return empty for now
169        &[]
170    }
171
172    /// Create a new audio track
173    ///
174    /// # Errors
175    ///
176    /// Returns error if track creation fails
177    pub async fn create_audio_track(&mut self) -> Result<&WebRtcTrack, MediaError> {
178        let track_id = format!("audio-{}", self.webrtc_tracks.len());
179
180        let codec = RTCRtpCodecCapability {
181            mime_type: "audio/opus".to_string(),
182            clock_rate: 48000,
183            channels: 2,
184            sdp_fmtp_line: "".to_string(),
185            rtcp_feedback: vec![],
186        };
187
188        let track = Arc::new(TrackLocalStaticSample::new(
189            codec,
190            track_id.clone(),
191            "audio".to_string(),
192        ));
193
194        let webrtc_track = WebRtcTrack {
195            track,
196            track_type: MediaType::Audio,
197            id: track_id,
198        };
199
200        self.webrtc_tracks.push(webrtc_track);
201        self.webrtc_tracks
202            .last()
203            .ok_or(MediaError::StreamError(
204                "Failed to get last track after push".to_string(),
205            ))
206    }
207
208    /// Create a new video track
209    ///
210    /// # Errors
211    ///
212    /// Returns error if track creation fails
213    pub async fn create_video_track(&mut self) -> Result<&WebRtcTrack, MediaError> {
214        let track_id = format!("video-{}", self.webrtc_tracks.len());
215
216        let codec = RTCRtpCodecCapability {
217            mime_type: "video/VP8".to_string(),
218            clock_rate: 90000,
219            channels: 0,
220            sdp_fmtp_line: "".to_string(),
221            rtcp_feedback: vec![],
222        };
223
224        let track = Arc::new(TrackLocalStaticSample::new(
225            codec,
226            track_id.clone(),
227            "video".to_string(),
228        ));
229
230        let webrtc_track = WebRtcTrack {
231            track,
232            track_type: MediaType::Video,
233            id: track_id,
234        };
235
236        self.webrtc_tracks.push(webrtc_track);
237        self.webrtc_tracks
238            .last()
239            .ok_or(MediaError::StreamError(
240                "Failed to get last track after push".to_string(),
241            ))
242    }
243
244    /// Get all WebRTC tracks
245    #[must_use]
246    pub fn get_webrtc_tracks(&self) -> &[WebRtcTrack] {
247        &self.webrtc_tracks
248    }
249
250    /// Subscribe to media events
251    #[must_use]
252    pub fn subscribe_events(&self) -> broadcast::Receiver<MediaEvent> {
253        self.event_sender.subscribe()
254    }
255
256    /// Remove a track by ID
257    ///
258    /// Returns true if the track was found and removed
259    pub fn remove_track(&mut self, track_id: &str) -> bool {
260        if let Some(pos) = self.webrtc_tracks.iter().position(|t| t.id == track_id) {
261            self.webrtc_tracks.remove(pos);
262            tracing::debug!("Removed track: {}", track_id);
263            true
264        } else {
265            tracing::warn!("Track not found for removal: {}", track_id);
266            false
267        }
268    }
269}
270
271impl Default for MediaStreamManager {
272    fn default() -> Self {
273        Self::new()
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[tokio::test]
282    async fn test_media_stream_manager_initialize() {
283        let manager = MediaStreamManager::new();
284
285        let result = manager.initialize().await;
286        assert!(result.is_ok());
287
288        // Check that events were sent
289        let _events = manager.subscribe_events();
290        // Note: In a real test, we'd need to handle the async nature,
291        // but for now this is a basic structure test
292    }
293
294    #[tokio::test]
295    async fn test_media_stream_manager_get_devices() {
296        let manager = MediaStreamManager::new();
297
298        let audio_devices = manager.get_audio_devices();
299        assert!(audio_devices.is_empty());
300
301        let video_devices = manager.get_video_devices();
302        assert!(video_devices.is_empty());
303    }
304
305    #[tokio::test]
306    async fn test_media_stream_manager_create_audio_track() {
307        let mut manager = MediaStreamManager::new();
308
309        let track = manager.create_audio_track().await.unwrap();
310        assert_eq!(track.track_type, MediaType::Audio);
311        assert!(track.id.starts_with("audio-"));
312
313        let tracks = manager.get_webrtc_tracks();
314        assert_eq!(tracks.len(), 1);
315        assert_eq!(tracks[0].track_type, MediaType::Audio);
316    }
317
318    #[tokio::test]
319    async fn test_media_stream_manager_create_video_track() {
320        let mut manager = MediaStreamManager::new();
321
322        let track = manager.create_video_track().await.unwrap();
323        assert_eq!(track.track_type, MediaType::Video);
324        assert!(track.id.starts_with("video-"));
325
326        let tracks = manager.get_webrtc_tracks();
327        assert_eq!(tracks.len(), 1);
328        assert_eq!(tracks[0].track_type, MediaType::Video);
329    }
330
331    #[tokio::test]
332    async fn test_media_stream_manager_multiple_tracks() {
333        let mut manager = MediaStreamManager::new();
334
335        manager.create_audio_track().await.unwrap();
336        manager.create_video_track().await.unwrap();
337
338        let tracks = manager.get_webrtc_tracks();
339        assert_eq!(tracks.len(), 2);
340
341        // Check track IDs are different
342        assert_ne!(tracks[0].id, tracks[1].id);
343
344        // Check that we have one audio and one video track
345        let audio_count = tracks.iter().filter(|t| t.track_type == MediaType::Audio).count();
346        let video_count = tracks.iter().filter(|t| t.track_type == MediaType::Video).count();
347
348        assert_eq!(audio_count, 1);
349        assert_eq!(video_count, 1);
350    }
351}