webrtc_unreliable_client/
socket.rs

1use std::{sync::Arc, time::Duration};
2
3use anyhow::{Error, Result};
4use bytes::Bytes;
5use log::warn;
6use reqwest::{Client as HttpClient, Response};
7use tinyjson::JsonValue;
8use tokio::{
9    sync::{mpsc, oneshot},
10    time::sleep,
11};
12
13use crate::webrtc::{
14    data_channel::internal::data_channel::DataChannel,
15    peer_connection::{sdp::session_description::RTCSessionDescription, RTCPeerConnection},
16};
17
18use super::addr_cell::AddrCell;
19
20const MESSAGE_SIZE: usize = 1500;
21
22pub struct Socket {
23    addr_cell: AddrCell,
24    to_server_receiver: mpsc::UnboundedReceiver<Box<[u8]>>,
25    to_server_disconnect_receiver: mpsc::Receiver<()>,
26    to_client_sender: mpsc::UnboundedSender<Box<[u8]>>,
27    to_client_id_sender: oneshot::Sender<Result<String, u16>>,
28}
29
30pub struct SocketIo {
31    pub addr_cell: AddrCell,
32    pub to_server_sender: mpsc::UnboundedSender<Box<[u8]>>,
33    pub to_server_disconnect_sender: mpsc::Sender<()>,
34    pub to_client_receiver: mpsc::UnboundedReceiver<Box<[u8]>>,
35    pub to_client_id_receiver: oneshot::Receiver<Result<String, u16>>,
36}
37
38impl Socket {
39    pub fn new() -> (Self, SocketIo) {
40        let addr_cell = AddrCell::default();
41        let (to_server_sender, to_server_receiver) = mpsc::unbounded_channel();
42        let (to_server_disconnect_sender, to_server_disconnect_receiver) = mpsc::channel(1);
43        let (to_client_sender, to_client_receiver) = mpsc::unbounded_channel();
44        let (to_client_id_sender, to_client_id_receiver) = oneshot::channel();
45
46        (
47            Self {
48                addr_cell: addr_cell.clone(),
49                to_server_receiver,
50                to_server_disconnect_receiver,
51                to_client_sender,
52                to_client_id_sender,
53            },
54            SocketIo {
55                addr_cell,
56                to_server_sender,
57                to_server_disconnect_sender,
58                to_client_receiver,
59                to_client_id_receiver,
60            },
61        )
62    }
63
64    pub async fn connect(
65        self,
66        server_url: &str,
67        auth_bytes_opt: Option<Vec<u8>>,
68        auth_headers_opt: Option<Vec<(String, String)>>,
69    ) {
70        let Self {
71            addr_cell,
72            to_server_receiver,
73            to_server_disconnect_receiver,
74            to_client_sender,
75            to_client_id_sender,
76        } = self;
77
78        // create a new RTCPeerConnection
79        let peer_connection = RTCPeerConnection::new().await;
80
81        let label = "data";
82        let protocol = "";
83
84        // create a datachannel with label 'data'
85        let data_channel = peer_connection
86            .create_data_channel(label, protocol)
87            .await
88            .expect("cannot create data channel");
89
90        // datachannel on_error callback
91        data_channel
92            .on_error(Box::new(move |error| {
93                println!("data channel error: {:?}", error);
94                Box::pin(async {})
95            }))
96            .await;
97
98        // datachannel on_open callback
99        let peer_connection_ref = Arc::clone(&peer_connection);
100        let data_channel_ref = Arc::clone(&data_channel);
101        data_channel
102            .on_open(Box::new(move || {
103                let peer_connection_ref_2 = Arc::clone(&peer_connection_ref);
104                let data_channel_ref_2 = Arc::clone(&data_channel_ref);
105                Box::pin(async move {
106                    let detached_data_channel = data_channel_ref_2
107                        .detach()
108                        .await
109                        .expect("data channel detach got error");
110
111                    // Handle reading from the data channel
112                    let peer_connection_ref_3 = Arc::clone(&peer_connection_ref_2);
113                    let peer_connection_ref_4 = Arc::clone(&peer_connection_ref_2);
114
115                    let detached_data_channel_1 = Arc::clone(&detached_data_channel);
116                    let detached_data_channel_2 = Arc::clone(&detached_data_channel);
117                    tokio::spawn(async move {
118                        let _loop_result =
119                            read_loop(detached_data_channel_1, to_client_sender).await;
120
121                        // do nothing with result, just close thread
122                        peer_connection_ref_3.internal.close().await;
123                    });
124
125                    // Handle writing to the data channel
126                    tokio::spawn(async move {
127                        let detached_data_channel_3 = Arc::clone(&detached_data_channel_2);
128                        let _loop_result = write_loop(
129                            detached_data_channel_3,
130                            to_server_receiver,
131                            to_server_disconnect_receiver,
132                        )
133                        .await;
134
135                        // do nothing with result, just close thread
136                        detached_data_channel_2.close().await;
137
138                        peer_connection_ref_4.internal.close().await;
139                    });
140                })
141            }))
142            .await;
143
144        // create an offer to send to the server
145        let offer = peer_connection
146            .create_offer()
147            .await
148            .expect("cannot create offer");
149
150        // sets the LocalDescription, and starts our UDP listeners
151        peer_connection
152            .set_local_description(offer)
153            .await
154            .expect("cannot set local description");
155
156        // send a request to server to initiate connection (signaling, essentially)
157        let http_client = HttpClient::new();
158
159        let sdp = peer_connection.local_description().await.unwrap().sdp;
160
161        let sdp_len = sdp.len();
162
163        // wait to receive a response from server
164        let response: Response = loop {
165            let mut request = http_client
166                .post(server_url)
167                .header("Content-Length", sdp_len)
168                .body(sdp.clone());
169            if let Some(auth_bytes) = auth_bytes_opt.clone() {
170                let base64_encoded = base64::encode(auth_bytes);
171                request = request.header("Authorization", &base64_encoded);
172            }
173            if let Some(auth_headers) = auth_headers_opt.clone() {
174                for (key, value) in auth_headers {
175                    request = request.header(key, value);
176                }
177            }
178
179            match request.send().await {
180                Ok(resp) => {
181                    break resp;
182                }
183                Err(err) => {
184                    warn!("Could not send request, original error: {:?}", err);
185                    sleep(Duration::from_secs(1)).await;
186                }
187            };
188        };
189
190        if !response.status().is_success() {
191            let status_code = response.status().as_u16();
192            to_client_id_sender.send(Err(status_code)).unwrap();
193            return;
194        }
195
196        // get the body of the response as a string
197        let response_string = match response.text().await {
198            Ok(response_string) => response_string,
199            Err(_err) => {
200                // error reading response?
201                to_client_id_sender.send(Err(500)).unwrap();
202                return;
203            }
204        };
205
206        // parse session from server response
207        let session_response_result = get_session_response(response_string.as_str());
208        let session_response = match session_response_result {
209            Ok(session_response) => session_response,
210            Err(_err) => {
211                // parsing error?
212                to_client_id_sender.send(Err(500)).unwrap();
213                return;
214            }
215        };
216
217        // send the id token to the client
218        // info!("Sending id token to client: {:?}", auth_header);
219        to_client_id_sender
220            .send(Ok(session_response.id_token))
221            .unwrap();
222
223        // apply the server's response as the remote description
224        let session_description =
225            RTCSessionDescription::answer(session_response.answer.sdp).unwrap();
226
227        peer_connection
228            .set_remote_description(session_description)
229            .await
230            .expect("cannot set remote description");
231
232        addr_cell
233            .receive_candidate(session_response.candidate.candidate.as_str())
234            .await;
235
236        // add ice candidate to connection
237        if let Err(error) = peer_connection
238            .add_ice_candidate(session_response.candidate.candidate)
239            .await
240        {
241            panic!("Error during add_ice_candidate: {:?}", error);
242        }
243    }
244}
245
246// read_loop shows how to read from the datachannel directly
247async fn read_loop(
248    data_channel: Arc<DataChannel>,
249    to_client_sender: mpsc::UnboundedSender<Box<[u8]>>,
250) -> Result<()> {
251    let mut buffer = vec![0u8; MESSAGE_SIZE];
252    loop {
253        let message_length = match data_channel.read(&mut buffer).await {
254            Ok(length) => length,
255            Err(_err) => {
256                //println!("Datachannel closed; Exit the read_loop: {}", err);
257                return Ok(());
258            }
259        };
260
261        match to_client_sender.send(buffer[..message_length].into()) {
262            Ok(_) => {}
263            Err(e) => {
264                return Err(Error::new(e));
265            }
266        }
267    }
268}
269
270// write_loop shows how to write to the datachannel directly
271async fn write_loop(
272    data_channel: Arc<DataChannel>,
273    mut to_server_receiver: mpsc::UnboundedReceiver<Box<[u8]>>,
274    mut to_server_disconnect_receiver: mpsc::Receiver<()>,
275) -> Result<()> {
276    loop {
277        tokio::select! {
278            _ = to_server_disconnect_receiver.recv() => {
279                return Ok(());
280            }
281            result = to_server_receiver.recv() => {
282                if let Some(mut write_message) = result {
283                    let taken_message = std::mem::take(&mut write_message);
284                    let message_bytes = Bytes::from(taken_message);
285                    if let Err(e) = data_channel.write(&message_bytes).await {
286                        return Err(Error::new(e));
287                    }
288                } else {
289                    return Ok(());
290                }
291            }
292        }
293    }
294}
295
296#[derive(Clone)]
297pub(crate) struct SessionAnswer {
298    pub(crate) sdp: String,
299}
300
301pub(crate) struct SessionCandidate {
302    pub(crate) candidate: String,
303}
304
305pub(crate) struct JsSessionResponse {
306    pub(crate) id_token: String,
307    pub(crate) answer: SessionAnswer,
308    pub(crate) candidate: SessionCandidate,
309}
310
311fn get_session_response(input: &str) -> Result<JsSessionResponse, String> {
312    // info!("{}", input);
313    let Ok(json_obj): Result<JsonValue, _> = input.parse() else {
314        return Err("Could not parse response JSON".to_string());
315    };
316
317    let sdp_opt: Option<&String> = json_obj["sdp"]["answer"]["sdp"].get();
318    let sdp: String = sdp_opt.unwrap().clone();
319
320    let candidate_opt: Option<&String> = json_obj["sdp"]["candidate"]["candidate"].get();
321    let candidate: String = candidate_opt.unwrap().clone();
322
323    let id_token_opt: Option<&String> = json_obj["id"].get();
324    let id_token: String = id_token_opt.unwrap().clone();
325
326    Ok(JsSessionResponse {
327        id_token,
328        answer: SessionAnswer { sdp },
329        candidate: SessionCandidate { candidate },
330    })
331}