Skip to main content

webrtc_streaming_actix/
lib.rs

1use sctp::EndpointConfig;
2pub use actix_web::{App, HttpServer, web};
3use clap::{arg, command, Parser};
4use log::error;
5use opentelemetry::{/*global,*/ KeyValue};
6use opentelemetry_sdk::metrics::{PeriodicReader, SdkMeterProvider};
7use opentelemetry_stdout::MetricsExporterBuilder;
8use opentelemetry_sdk::{Resource, runtime};
9
10use std::net::{IpAddr, Ipv4Addr};
11use std::str::FromStr;
12use tera::Tera;
13use wg::WaitGroup;
14use std::cell::RefCell;
15use std::collections::HashMap;
16use std::io::{Error, ErrorKind};
17use std::net::{SocketAddr, UdpSocket};
18use std::rc::Rc;
19use std::sync::mpsc::{Receiver, SyncSender};
20use std::sync::{Arc, mpsc};
21use std::time::{Duration, Instant};
22use bytes::{Bytes, BytesMut};
23
24use std::io::Write;
25use actix_web::{HttpRequest, HttpResponse};
26
27use retty::channel::{InboundPipeline, Pipeline};
28use retty::transport::{TaggedBytesMut, TransportContext};
29use sfu::{DtlsHandler, ExceptionHandler, GatewayHandler, InterceptorHandler, SrtpHandler, StunHandler, DataChannelHandler, DemuxerHandler, RTCSessionDescription};
30
31pub mod util;
32pub mod messages;
33pub mod interceptors;
34pub mod types;
35pub mod metrics;
36
37use dtls::config;
38use dtls::extension::extension_use_srtp::SrtpProtectionProfile;
39use sfu::{RTCCertificate, SctpHandler, ServerConfig, ServerStates};
40
41
42
43#[derive(Default, Debug, Copy, Clone, clap::ValueEnum)]
44pub enum Level {
45    Error,
46    Warn,
47    #[default]
48    Info,
49    Debug,
50    Trace,
51}
52
53impl From<Level> for log::LevelFilter {
54    fn from(level: Level) -> Self {
55        match level {
56            Level::Error => log::LevelFilter::Error,
57            Level::Warn => log::LevelFilter::Warn,
58            Level::Info => log::LevelFilter::Info,
59            Level::Debug => log::LevelFilter::Debug,
60            Level::Trace => log::LevelFilter::Trace,
61        }
62    }
63}
64
65#[derive(Parser)]
66#[command(name = "SFU Server")]
67#[command(author = "Rusty Rain <y@ngr.tc>")]
68#[command(version = "0.1.0")]
69#[command(about = "An example of SFU Server", long_about = None)]
70pub struct Cli {
71    #[arg(long, default_value_t = format!("127.0.0.1"))]
72    pub host: String,
73    #[arg(short, long, default_value_t = 8080)]
74    pub signal_port: u16,
75    #[arg(long, default_value_t = 3478)]
76    pub media_port_min: u16,
77    #[arg(long, default_value_t = 3495)]
78    pub media_port_max: u16,
79
80    #[arg(short, long)]
81    pub force_local_loop: bool,
82    #[arg(short, long)]
83    pub debug: bool,
84    #[arg(short, long, default_value_t = Level::Info)]
85    #[clap(value_enum)]
86    pub level: Level,
87}
88
89pub fn init_meter_provider(
90    mut _stop_rx: async_broadcast::Receiver<()>,
91    _wait_group: WaitGroup,
92) -> SdkMeterProvider {
93    let exporter = MetricsExporterBuilder::default()
94        .with_encoder(|writer, data| {
95            Ok(serde_json::to_writer_pretty(writer, &data).unwrap())
96        })
97        .build();
98    let reader = PeriodicReader::builder(exporter, runtime::TokioCurrentThread)
99        .with_interval(Duration::from_secs(30))
100        .build();
101    let meter_provider = SdkMeterProvider::builder()
102        .with_reader(reader)
103        .with_resource(Resource::new(vec![KeyValue::new("chat", "metrics")]))
104        .build();
105
106    meter_provider
107}
108
109#[actix_web::main]
110async fn main() -> anyhow::Result<()> {
111
112    let (_stop_tx, _stop_rx) = crossbeam_channel::bounded::<()>(1);
113    let cli = Cli::parse();
114    if cli.debug {
115        env_logger::Builder::new()
116            .format(|buf, record| {
117                writeln!(
118                    buf,
119                    "{}:{} [{}] {} - {}",
120                    record.file().unwrap_or("unknown"),
121                    record.line().unwrap_or(0),
122                    record.level(),
123                    chrono::Local::now().format("%H:%M:%S.%6f"),
124                    record.args()
125                )
126            })
127            .filter(None, cli.level.into())
128            .init();
129    }
130    
131    // Figure out some public IP address, since Firefox will not accept 127.0.0.1 for WebRTC traffic.
132    let host_addr = if cli.host == "127.0.0.1" && !cli.force_local_loop {
133        util::select_host_address()
134    } else {
135        IpAddr::from_str(&cli.host)?
136    };
137
138    let _media_ports: Vec<u16> = (cli.media_port_min..=cli.media_port_max).collect();
139    let (_stop_tx, stop_rx) = crossbeam_channel::bounded::<()>(1);
140    let mut media_port_thread_map = HashMap::new();
141
142
143    let key_pair = rcgen::KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?;
144    let certificates = vec![RTCCertificate::from_key_pair(key_pair)?];
145    let dtls_handshake_config = Arc::new(
146        config::ConfigBuilder::default()
147            .with_certificates(
148                certificates
149                    .iter()
150                    .map(|c| c.dtls_certificate.clone())
151                    .collect(),
152            )
153            .with_srtp_protection_profiles(vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80])
154            .with_extended_master_secret(config::ExtendedMasterSecretType::Require)
155            .build(false, None)?,
156    );
157    let sctp_endpoint_config = Arc::new(EndpointConfig::default());
158    let sctp_server_config = Arc::new(sctp::ServerConfig::default());
159    let server_config = Arc::new(
160        ServerConfig::new(certificates)
161            .with_dtls_handshake_config(dtls_handshake_config)
162            .with_sctp_endpoint_config(sctp_endpoint_config)
163            .with_sctp_server_config(sctp_server_config)
164            .with_idle_timeout(Duration::from_secs(30)),
165    );
166    let (_stop_meter_tx, stop_meter_rx) = async_broadcast::broadcast::<()>(1);
167    let wait_group = WaitGroup::new();
168    let meter_provider = init_meter_provider(stop_meter_rx, wait_group.clone());
169    let media_ports: Vec<u16> = (cli.media_port_min..=cli.media_port_max).collect();
170    for port in media_ports {
171        let worker = wait_group.add(1);
172        let stop_rx = stop_rx.clone();
173        let (signaling_tx, signaling_rx) = mpsc::sync_channel(1);
174
175        media_port_thread_map.insert(port, signaling_tx);
176        // Spin up a UDP socket for the RTC. All WebRTC traffic is going to be multiplexed over this single
177        // server socket. Clients are identified via their respective remote (UDP) socket address.
178        let socket = UdpSocket::bind(format!("{host_addr}:{port}"))
179            .expect(&format!("binding to {host_addr}:{port}"));
180
181        let server_config = server_config.clone();
182        let meter_provider = meter_provider.clone();
183        std::thread::spawn(move || {
184            if let Err(err) = sync_run(stop_rx, socket, signaling_rx, server_config, meter_provider)
185            {
186                eprintln!("run_sfu got error: {}", err);
187            }
188            worker.done();
189        });
190    }
191
192    let media_port_thread_map = Arc::new(media_port_thread_map);
193    let signal_port = cli.signal_port;
194    let host_addr = if cli.host == "127.0.0.1" && !cli.force_local_loop {
195        util::select_host_address()
196    } else {
197        IpAddr::from_str(&cli.host)?
198    };
199
200    println!("Connect a browser to https://{}:{}", host_addr, signal_port);
201
202    HttpServer::new(move || {
203        let tera = Tera::new("templates/**/*").unwrap();
204        App::new()
205            .service(actix_files::Files::new("/static", "./src/static/"))
206            .app_data(web::Data::new(tera))
207            .app_data(web::Data::new(media_port_thread_map.clone()))
208            .service(web::resource("/").route(web::get().to(index)))
209            .service(web::resource("/{path}/{session_id}/{endpoint_id}")
210                .route(web::get().to(web_request))
211                .route(web::post().to(web_request))
212            )
213    })
214    .bind(SocketAddr::new(IpAddr::from(Ipv4Addr::new(0,0,0,0)), 80))?
215    .run()
216    .await?;
217
218    println!("Wait for Signaling Server and Media Server Gracefully Shutdown...");
219    wait_group.wait();
220
221    Ok(())
222}
223
224pub enum SignalingProtocolMessage {
225    Ok {
226        session_id: u64,
227        endpoint_id: u64,
228    },
229    Err {
230        session_id: u64,
231        endpoint_id: u64,
232        reason: Bytes,
233    },
234    Offer {
235        session_id: u64,
236        endpoint_id: u64,
237        offer_sdp: Bytes,
238    },
239    Answer {
240        session_id: u64,
241        endpoint_id: u64,
242        answer_sdp: Bytes,
243    },
244    Leave {
245        session_id: u64,
246        endpoint_id: u64,
247    },
248}
249
250pub struct SignalingMessage {
251    pub request: SignalingProtocolMessage,
252    pub response_tx: SyncSender<SignalingProtocolMessage>,
253}
254
255pub async fn index(tera: web::Data<Tera>) -> HttpResponse {
256    let rendered = tera.render("chat.html.tera", &tera::Context::new()).unwrap();
257    HttpResponse::Ok().content_type("text/html").body(rendered)
258}
259
260
261fn build_pipeline(local_addr: SocketAddr, server_states: Rc<RefCell<ServerStates>>) -> Rc<Pipeline<TaggedBytesMut, TaggedBytesMut>> {
262    let pipeline: Pipeline<TaggedBytesMut, TaggedBytesMut> = Pipeline::new();
263
264    let demuxer_handler = DemuxerHandler::new();
265    let stun_handler = StunHandler::new();
266    // DTLS
267    let dtls_handler = DtlsHandler::new(local_addr, Rc::clone(&server_states));
268    let sctp_handler = SctpHandler::new(local_addr, Rc::clone(&server_states));
269    let data_channel_handler = DataChannelHandler::new();
270    // SRTP
271    let srtp_handler = SrtpHandler::new(Rc::clone(&server_states));
272    let interceptor_handler = InterceptorHandler::new(Rc::clone(&server_states));
273    // Gateway
274    let gateway_handler = GatewayHandler::new(Rc::clone(&server_states));
275    let exception_handler = ExceptionHandler::new();
276
277    pipeline.add_back(demuxer_handler);
278    pipeline.add_back(stun_handler);
279    // DTLS
280    pipeline.add_back(dtls_handler);
281    pipeline.add_back(sctp_handler);
282    pipeline.add_back(data_channel_handler);
283    // SRTP
284    pipeline.add_back(srtp_handler);
285    pipeline.add_back(interceptor_handler);
286    // Gateway
287    pipeline.add_back(gateway_handler);
288    pipeline.add_back(exception_handler);
289
290    pipeline.finalize()
291}
292
293fn write_socket_output(socket: &UdpSocket, pipeline: &Rc<Pipeline<TaggedBytesMut, TaggedBytesMut>>) -> anyhow::Result<()> {
294    while let Some(transmit) = pipeline.poll_transmit() {
295        socket.send_to(&transmit.message, transmit.transport.peer_addr)?;
296    }
297
298    Ok(())
299}
300
301fn handle_offer_message(
302    server_states: &Rc<RefCell<ServerStates>>,
303    session_id: u64,
304    endpoint_id: u64,
305    offer: Bytes,
306    response_tx: SyncSender<SignalingProtocolMessage>,
307) -> anyhow::Result<()> {
308    let try_handle = || -> anyhow::Result<Bytes> {
309        let offer_str = String::from_utf8(offer.to_vec())?;
310        log::info!(
311            "handle_offer_message: {}/{}/{}",
312            session_id,
313            endpoint_id,
314            offer_str,
315        );
316        let mut server_states = server_states.borrow_mut();
317
318        let offer_sdp = serde_json::from_str::<RTCSessionDescription>(&offer_str)?;
319        let answer = server_states.accept_offer(session_id, endpoint_id, None, offer_sdp)?;
320        let answer_str = serde_json::to_string(&answer)?;
321        log::info!("generate answer sdp: {}", answer_str);
322        Ok(Bytes::from(answer_str))
323    };
324
325    match try_handle() {
326        Ok(answer_sdp) => Ok(response_tx
327            .send(SignalingProtocolMessage::Answer {
328                session_id,
329                endpoint_id,
330                answer_sdp,
331            })
332            .map_err(|_| {
333                Error::new(
334                    ErrorKind::Other,
335                    "failed to send back signaling message response".to_string(),
336                )
337            })?),
338        Err(err) => Ok(response_tx
339            .send(SignalingProtocolMessage::Err {
340                session_id,
341                endpoint_id,
342                reason: Bytes::from(err.to_string()),
343            })
344            .map_err(|_| {
345                Error::new(
346                    ErrorKind::Other,
347                    "failed to send back signaling message response".to_string(),
348                )
349            })?),
350    }
351}
352
353fn handle_leave_message(_server_states: &Rc<RefCell<ServerStates>>,  session_id: u64, endpoint_id: u64,  response_tx: SyncSender<SignalingProtocolMessage>) -> anyhow::Result<()> {
354    let try_handle = || -> anyhow::Result<()> {
355        log::info!("handle_leave_message: {}/{}", session_id, endpoint_id,);
356        Ok(())
357    };
358
359    match try_handle() {
360        Ok(_) => Ok(response_tx
361            .send(SignalingProtocolMessage::Ok {
362                session_id,
363                endpoint_id,
364            })
365            .map_err(|_| {
366                Error::new(
367                    ErrorKind::Other,
368                    "failed to send back signaling message response".to_string(),
369                )
370            })?),
371        Err(err) => Ok(response_tx
372            .send(SignalingProtocolMessage::Err {
373                session_id,
374                endpoint_id,
375                reason: Bytes::from(err.to_string()),
376            })
377            .map_err(|_| {
378                Error::new(
379                    ErrorKind::Other,
380                    "failed to send back signaling message response".to_string(),
381                )
382            })?),
383    }
384}
385
386pub fn handle_signaling_message(
387    server_states: &Rc<RefCell<ServerStates>>,
388    signaling_msg: SignalingMessage,
389) -> anyhow::Result<()> {
390    match signaling_msg.request {
391        SignalingProtocolMessage::Offer {
392            session_id,
393            endpoint_id,
394            offer_sdp,
395        } => handle_offer_message(
396            server_states,
397            session_id,
398            endpoint_id,
399            offer_sdp,
400            signaling_msg.response_tx,
401        ),
402        SignalingProtocolMessage::Leave {
403            session_id,
404            endpoint_id,
405        } => handle_leave_message(
406            server_states,
407            session_id,
408            endpoint_id,
409            signaling_msg.response_tx,
410        ),
411        SignalingProtocolMessage::Ok {
412            session_id,
413            endpoint_id,
414        }
415        | SignalingProtocolMessage::Err {
416            session_id,
417            endpoint_id,
418            reason: _,
419        }
420        | SignalingProtocolMessage::Answer {
421            session_id,
422            endpoint_id,
423            answer_sdp: _,
424        } => Ok(signaling_msg
425            .response_tx
426            .send(SignalingProtocolMessage::Err {
427                session_id,
428                endpoint_id,
429                reason: Bytes::from("Invalid Request"),
430            })
431            .map_err(|_| {
432                Error::new(
433                    ErrorKind::Other,
434                    "failed to send back signaling message response".to_string(),
435                )
436            })?),
437    }
438}
439
440pub fn sync_run(
441    stop_rx: crossbeam_channel::Receiver<()>,
442    socket: UdpSocket,
443    rx: Receiver<SignalingMessage>,
444    server_config: Arc<ServerConfig>,
445    _meter_provider: SdkMeterProvider,
446) -> anyhow::Result<()> {
447    let server_states = Rc::new(RefCell::new(ServerStates::new(
448        server_config,
449        socket.local_addr()?,
450    )?));
451
452    println!("listening {}...", socket.local_addr()?);
453
454    let pipeline = build_pipeline(socket.local_addr()?, server_states.clone());
455
456    let mut buf = vec![0; 2000];
457
458    pipeline.transport_active();
459    loop {
460        match stop_rx.try_recv() {
461            Ok(_) => break,
462            Err(err) => {
463                if err.is_disconnected() {
464                    break;
465                }
466            }
467        };
468
469        write_socket_output(&socket, &pipeline)?;
470
471        // Spawn new incoming signal message from the signaling server thread.
472        if let Ok(signal_message) = rx.try_recv() {
473            if let Err(err) = handle_signaling_message(&server_states, signal_message) {
474                error!("handle_signaling_message got error:{}", err);
475                continue;
476            }
477        }
478
479        // Poll clients until they return timeout
480        let mut eto = Instant::now() + Duration::from_millis(100);
481        pipeline.poll_timeout(&mut eto);
482
483        let delay_from_now = eto
484            .checked_duration_since(Instant::now())
485            .unwrap_or(Duration::from_secs(0));
486        if delay_from_now.is_zero() {
487            pipeline.handle_timeout(Instant::now());
488            continue;
489        }
490
491        socket
492            .set_read_timeout(Some(delay_from_now))
493            .expect("setting socket read timeout");
494
495        if let Some(input) = read_socket_input(&socket, &mut buf) {
496            pipeline.read(input);
497        }
498
499        // Drive time forward in all clients.
500        pipeline.handle_timeout(Instant::now());
501    }
502    pipeline.transport_inactive();
503
504    println!(
505        "media server on {} is gracefully down",
506        socket.local_addr()?
507    );
508    Ok(())
509}
510
511
512
513fn read_socket_input(socket: &UdpSocket, buf: &mut [u8]) -> Option<TaggedBytesMut> {
514    match socket.recv_from(buf) {
515        Ok((n, peer_addr)) => {
516            return Some(TaggedBytesMut {
517                now: Instant::now(),
518                transport: TransportContext {
519                    local_addr: socket.local_addr().unwrap(),
520                    peer_addr,
521                    ecn: None,
522                },
523                message: BytesMut::from(&buf[..n]),
524            });
525        }
526
527        Err(e) => match e.kind() {
528            // Expected error for set_read_timeout(). One for windows, one for the rest.
529            ErrorKind::WouldBlock | ErrorKind::TimedOut => None,
530            _ => panic!("UdpSocket read failed: {e:?}"),
531        },
532    }
533}
534
535
536
537pub async fn web_request(
538    req: HttpRequest,
539    bytes: web::Bytes,
540    path: web::Path<(String, u64, u64)>,
541    tera: web::Data<Tera>,
542    media_port_thread_map: web::Data<Arc<HashMap<u16, SyncSender<SignalingMessage>>>>,
543) -> HttpResponse {
544    let (path, session_id, endpoint_id) = path.into_inner();
545
546    if req.method() == actix_web::http::Method::GET {
547        let rendered = tera.render("chat.html.tera", &tera::Context::new()).unwrap();
548        HttpResponse::Ok().content_type("text/html").body(rendered)
549    } else if req.method() == actix_web::http::Method::POST {
550        let mut sorted_ports: Vec<u16> = media_port_thread_map.keys().copied().collect();
551        sorted_ports.sort();
552        assert!(!sorted_ports.is_empty());
553        let port = sorted_ports[(session_id as usize) % sorted_ports.len()];
554        let tx = media_port_thread_map.get(&port);
555
556        if let Some(tx  ) = tx {
557            let offer_sdp = bytes.to_vec();
558
559            let (response_tx, response_rx) = mpsc::sync_channel(1);
560            tx.send(SignalingMessage {
561                request: SignalingProtocolMessage::Offer {
562                    session_id,
563                    endpoint_id,
564                    offer_sdp: Bytes::from(offer_sdp),
565                },
566                response_tx,
567            })
568                .expect("to send SignalingMessage instance");
569
570            let response = response_rx.recv().expect("receive answer offer");
571            match response {
572                SignalingProtocolMessage::Answer {
573                    session_id: _,
574                    endpoint_id: _,
575                    answer_sdp,
576                } => HttpResponse::Ok()
577                    .content_type("application/json")
578                    .body(answer_sdp),
579                _ => HttpResponse::NotFound().finish(),
580            }
581        } else {
582            HttpResponse::NotAcceptable().finish()
583        }
584    } else {
585        HttpResponse::MethodNotAllowed().finish()
586    }
587}