upgrade2webrtc/
client.rs

1use std::{
2    fmt::{Debug, Display},
3    path::Path,
4    sync::Arc,
5};
6
7use tokio::{
8    io::{AsyncRead, AsyncWrite, BufStream},
9    net::{TcpStream, ToSocketAddrs},
10};
11pub use tokio_rustls::rustls::ServerName;
12use tokio_rustls::{client::TlsStream, TlsConnector};
13use webrtc::{
14    api::{
15        interceptor_registry::register_default_interceptors, media_engine::MediaEngine, APIBuilder,
16        API,
17    },
18    data_channel::{data_channel_init::RTCDataChannelInit, RTCDataChannel},
19    ice_transport::{ice_candidate::RTCIceCandidate, ice_server::RTCIceServer},
20    interceptor::registry::Registry,
21    peer_connection::{configuration::RTCConfiguration, RTCPeerConnection},
22};
23
24use crate::{
25    tls::{new_tls_connector, TlsInitError},
26    transport::{IDUpgradeTransport, RecvError, StreamTransport, UpgradeTransport},
27    RTCMessage, STUN_SERVERS,
28};
29
30pub struct UpgradeWebRTCClient<C: UpgradeTransport> {
31    client: C,
32    api: API,
33    config: RTCConfiguration,
34}
35
36impl<C: Debug + UpgradeTransport> Debug for UpgradeWebRTCClient<C> {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("UpgradeWebRTCClient")
39            .field("client", &self.client)
40            .finish()
41    }
42}
43
44pub struct PeerAndChannels<'a> {
45    pub peer: RTCPeerConnection,
46    pub channels: Vec<(&'a str, Arc<RTCDataChannel>)>,
47}
48
49#[derive(Debug)]
50pub enum ClientError<DE> {
51    WebRTCError(webrtc::Error),
52    IOError(std::io::Error),
53    DeserializeError(DE),
54    UnexpectedMessage,
55}
56
57impl<DE> From<webrtc::Error> for ClientError<DE> {
58    fn from(value: webrtc::Error) -> Self {
59        ClientError::WebRTCError(value)
60    }
61}
62
63impl<DE> From<std::io::Error> for ClientError<DE> {
64    fn from(value: std::io::Error) -> Self {
65        ClientError::IOError(value)
66    }
67}
68
69impl<DE> From<RecvError<DE>> for ClientError<DE> {
70    fn from(value: RecvError<DE>) -> Self {
71        match value {
72            RecvError::DeserializeError(e) => Self::DeserializeError(e),
73            RecvError::IOError(e) => Self::IOError(e),
74        }
75    }
76}
77
78impl<DE: Display> Display for ClientError<DE> {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match self {
81            ClientError::WebRTCError(e) => write!(f, "{e}"),
82            ClientError::IOError(e) => write!(f, "{e}"),
83            ClientError::DeserializeError(e) => write!(f, "{e}"),
84            ClientError::UnexpectedMessage => write!(f, "Unexpected WebRTC SDP type"),
85        }
86    }
87}
88
89impl<DE: Display + Debug> std::error::Error for ClientError<DE> {}
90
91impl<C: UpgradeTransport> UpgradeWebRTCClient<C> {
92    pub fn new(client: C) -> Self {
93        let mut m = MediaEngine::default();
94        m.register_default_codecs()
95            .expect("Default codecs should have registered safely");
96
97        let mut registry = Registry::new();
98
99        // Use the default set of Interceptors
100        registry = register_default_interceptors(registry, &mut m)
101            .expect("Default interceptors should have registered safely");
102
103        Self {
104            client,
105            api: APIBuilder::new()
106                .with_media_engine(m)
107                .with_interceptor_registry(registry)
108                .build(),
109            config: RTCConfiguration {
110                ice_servers: vec![RTCIceServer {
111                    urls: STUN_SERVERS.map(Into::into).to_vec(),
112                    ..Default::default()
113                }],
114                ..Default::default()
115            },
116        }
117    }
118
119    pub async fn upgrade<'a>(
120        &mut self,
121        channel_configs: impl IntoIterator<Item = (&'a str, RTCDataChannelInit)>,
122    ) -> Result<PeerAndChannels<'a>, ClientError<C::DeserializationError>> {
123        let peer = self.api.new_peer_connection(self.config.clone()).await?;
124        let mut channels = vec![];
125
126        for (label, option) in channel_configs {
127            channels.push((label, peer.create_data_channel(label, Some(option)).await?));
128        }
129
130        let offer = peer.create_offer(None).await?;
131        self.client.send_obj(&offer).await?;
132        peer.set_local_description(offer).await?;
133        let mut ices = vec![];
134        let answer = loop {
135            let msg: RTCMessage = self.client.recv_obj().await?;
136            match msg {
137                RTCMessage::SDPAnswer(x) => break x,
138                RTCMessage::ICE(x) => ices.push(x),
139            }
140        };
141        peer.set_remote_description(answer).await?;
142        for ice in ices {
143            peer.add_ice_candidate(ice).await?;
144        }
145
146        let (ice_sender, mut ice_receiver) = tokio::sync::mpsc::channel(3);
147
148        peer.on_ice_candidate(Box::new(move |c: Option<RTCIceCandidate>| {
149            let ice_sender = ice_sender.clone();
150            Box::pin(async move {
151                let _ = ice_sender.send(c).await;
152            })
153        }));
154
155        let mut done_sending_ice = false;
156        let mut done_receiving_ice = false;
157
158        loop {
159            tokio::select! {
160                ice_to_send = ice_receiver.recv() => {
161                    let ice_to_send = ice_to_send.unwrap();
162                    self.client.send_obj(&ice_to_send).await?;
163                    if ice_to_send.is_none() {
164                        done_sending_ice = true;
165                        if done_receiving_ice {
166                            break
167                        }
168                    };
169                }
170                received_msg = self.client.recv_obj::<Option<RTCMessage>>() => {
171                    let received_msg = received_msg?;
172                    let received_ice = match received_msg {
173                        Some(RTCMessage::ICE(x)) => Some(x),
174                        None => None,
175                        _ => return Err(ClientError::UnexpectedMessage)
176                    };
177                    let Some(received_ice) = received_ice else {
178                        done_receiving_ice = true;
179                        if done_sending_ice {
180                            break
181                        }
182                        continue
183                    };
184                    peer.add_ice_candidate(received_ice).await?;
185                }
186            }
187        }
188
189        Ok(PeerAndChannels { peer, channels })
190    }
191}
192
193impl<C> UpgradeWebRTCClient<StreamTransport<C>>
194where
195    C: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static + IDUpgradeTransport,
196{
197    pub async fn add_tls(
198        self,
199        domain: ServerName,
200        connector: &TlsConnector,
201    ) -> std::io::Result<UpgradeWebRTCClient<StreamTransport<TlsStream<C>>>> {
202        let stream = connector.connect(domain, self.client.stream).await?;
203        Ok(UpgradeWebRTCClient {
204            client: StreamTransport::from(stream),
205            api: self.api,
206            config: self.config,
207        })
208    }
209
210    pub async fn add_tls_from_config(
211        self,
212        domain: ServerName,
213        root_cert_path: Option<impl AsRef<Path>>,
214    ) -> Result<UpgradeWebRTCClient<StreamTransport<TlsStream<C>>>, TlsInitError> {
215        let connector = new_tls_connector(root_cert_path)?;
216        self.add_tls(domain, &connector).await.map_err(Into::into)
217    }
218}
219
220pub async fn client_new_tcp(
221    addr: impl ToSocketAddrs,
222) -> std::io::Result<UpgradeWebRTCClient<StreamTransport<BufStream<TcpStream>>>> {
223    Ok(UpgradeWebRTCClient::new(
224        BufStream::new(TcpStream::connect(addr).await?).into(),
225    ))
226}
227
228#[cfg(feature = "local_sockets")]
229pub async fn client_new_local_socket<'a>(
230    addr: impl interprocess::local_socket::ToLocalSocketName<'a>,
231) -> std::io::Result<
232    UpgradeWebRTCClient<
233        StreamTransport<
234            BufStream<
235                tokio_util::compat::Compat<interprocess::local_socket::tokio::LocalSocketStream>,
236            >,
237        >,
238    >,
239> {
240    use interprocess::local_socket::tokio::LocalSocketStream;
241    use tokio_util::compat::FuturesAsyncWriteCompatExt;
242
243    Ok(UpgradeWebRTCClient::new(
244        BufStream::new(LocalSocketStream::connect(addr).await?.compat_write()).into(),
245    ))
246}