1use std::{
13 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
14 task::Poll, io::Write,
15};
16
17use anyhow::{anyhow, Context};
18use base64::{engine::general_purpose::STANDARD, Engine};
19use bytes::Bytes;
20use flate2::{Compression, write::{ZlibEncoder, ZlibDecoder}};
21use futures::{Sink, SinkExt, Stream, StreamExt};
22use pin_project::pin_project;
23use serde::{Deserialize, Serialize};
24use tokio::{net::UdpSocket, select};
25use turnclient::{
26 ChannelUsage, MessageFromTurnServer, MessageToTurnServer, TurnClient, TurnClientBuilder, ExportedParameters,
27};
28
29#[derive(Serialize, Deserialize)]
30struct Data {
31 turn_server: SocketAddr,
32 username: String,
33 password: String,
34 realm: String,
35 nonce: String,
36 mobility_ticket: Vec<u8>,
37 counterpart: SocketAddr,
38}
39
40impl Data {
41 pub fn new(turn_server: SocketAddr, username: String, password: String, state: ExportedParameters, counterpart: SocketAddr) -> Data {
42 Data {
43 turn_server,
44 username,
45 password,
46 realm: state.realm,
47 nonce: state.nonce,
48 mobility_ticket: state.mobility_ticket,
49 counterpart,
50 }
51 }
52 pub fn serialize(&self) -> String {
53 let q = bincode::serialize(self).unwrap();
54 let mut e = ZlibEncoder::new(Vec::new(), Compression::default());
55 e.write_all(&q).unwrap();
56 STANDARD.encode(e.finish().unwrap())
57 }
58 pub fn deserialize(x: &str) -> anyhow::Result<Data> {
59 let z = STANDARD.decode(x)?;
60 let mut d = ZlibDecoder::new(Vec::new());
61 d.write_all(&z)?;
62 let b = d.finish()?;
63 Ok(bincode::deserialize(&b)?)
64 }
65}
66
67pub async fn tie(
69 turn_server: SocketAddr,
70 username: String,
71 password: String,
72) -> anyhow::Result<(String, String)> {
73 let mut t1 = TurnClientBuilder::new(turn_server, username.clone(), password.clone());
74 let mut t2 = TurnClientBuilder::new(turn_server, username.clone(), password.clone());
75 t1.enable_mobility = true;
76 t2.enable_mobility = true;
77
78 let neutral_sockaddr = match turn_server {
79 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
80 SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
81 };
82 let u1 = UdpSocket::bind(neutral_sockaddr).await?;
83 let u2 = UdpSocket::bind(neutral_sockaddr).await?;
84 u1.connect(turn_server).await?;
85 u2.connect(turn_server).await?;
86
87 let mut c1 = t1.build_and_send_request(u1);
88 let mut c2 = t2.build_and_send_request(u2);
89
90 let mut addr1 = None::<SocketAddr>;
94 let mut addr2 = None::<SocketAddr>;
95 let mut ready1 = false;
96 let mut ready2 = false;
97
98 let mut perm_requests_sent = false;
99
100 loop {
101 let (msg, first): (Option<Result<MessageFromTurnServer, _>>, bool) = select! {
102 msg = c1.next() => {
103 (msg, true)
104 }
105 msg = c2.next() => {
106 (msg, false)
107 }
108 };
109 let msg = msg.context(anyhow!("Sudden end of TURN client incoming messages"))??;
110 match msg {
111 MessageFromTurnServer::AllocationGranted {
112 relay_address,
113 mobility,
114 ..
115 } => {
116 if !mobility {
117 anyhow::bail!("No RFC 8016 mobility received from TURN server");
118 }
119 if first {
120 addr1 = Some(relay_address);
121 } else {
122 addr2 = Some(relay_address);
123 }
124 }
125 MessageFromTurnServer::RedirectedToAlternateServer(alt) => {
126 anyhow::bail!(
127 "We are being redirected to {alt}. This is not supported by turntie."
128 );
129 }
130 MessageFromTurnServer::PermissionCreated(addr) => {
131 if first && Some(addr) == addr2 {
132 ready1 = true;
133 } else if !first && Some(addr) == addr1 {
134 ready2 = true;
135 } else {
136 anyhow::bail!(
137 "Unexpected granted permission. Something is wrong with the code?"
138 );
139 }
140 }
141 MessageFromTurnServer::PermissionNotCreated(_) => {
142 anyhow::bail!("Failed to create permission on TURN server")
143 }
144 MessageFromTurnServer::Disconnected => anyhow::bail!("Disconnected from TURN server"),
145 _ => (),
146 }
147
148 if addr1.is_some() && addr2.is_some() && !perm_requests_sent {
149 perm_requests_sent = true;
150
151 c1.send(MessageToTurnServer::AddPermission(
153 addr2.unwrap(),
154 ChannelUsage::JustPermission,
155 ))
156 .await?;
157 c2.send(MessageToTurnServer::AddPermission(
158 addr1.unwrap(),
159 ChannelUsage::JustPermission,
160 ))
161 .await?;
162 }
163
164 if ready1 && ready2 {
165 break;
166 }
167 }
168
169 let params1 = c1.export_state();
170 let params2 = c2.export_state();
171
172 let spec1 = Data::new(turn_server, username.clone(), password.clone(), params1, addr2.unwrap());
173 let spec2 = Data::new(turn_server, username, password, params2, addr1.unwrap());
174
175 Ok((spec1.serialize(), spec2.serialize()))
176}
177
178pub async fn connect(specifier: &str) -> anyhow::Result<TurnTie> {
183 let data = Data::deserialize(specifier)?;
184
185 let mut t = TurnClientBuilder::new(data.turn_server, data.username, data.password);
186 t.enable_mobility = true;
187
188 let neutral_sockaddr = match data.turn_server {
189 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
190 SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
191 };
192 let u = UdpSocket::bind(neutral_sockaddr).await?;
193
194 let params = ExportedParameters {
195 realm: data.realm,
196 nonce: data.nonce,
197 mobility_ticket: data.mobility_ticket,
198 permissions: vec![
199 (data.counterpart, None)
200 ]
201 };
202
203 let turnclient = t.restore_from_exported_parameters(u, ¶ms)?;
204
205 Ok(TurnTie {
206 turnclient,
207 counterpart: data.counterpart,
208 })
209}
210
211#[pin_project]
212pub struct TurnTie {
213 #[pin]
214 turnclient: TurnClient,
215 counterpart: SocketAddr,
216}
217
218impl Stream for TurnTie {
219 type Item = anyhow::Result<Bytes>;
220
221 fn poll_next(
222 self: std::pin::Pin<&mut Self>,
223 cx: &mut std::task::Context<'_>,
224 ) -> std::task::Poll<Option<Self::Item>> {
225 let mut this = self.project();
226 'main_loop: loop {
227 return match this.turnclient.as_mut().poll_next(cx) {
228 Poll::Pending => Poll::Pending,
229 Poll::Ready(None) => Poll::Ready(None),
230 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
231 Poll::Ready(Some(Ok(msg))) => match msg {
232 MessageFromTurnServer::RecvFrom(fromaddr, buf) => {
233 if fromaddr == *this.counterpart {
234 return Poll::Ready(Some(Ok(buf.into())));
235 }
236 continue 'main_loop;
237 }
238 _ => continue 'main_loop,
239 },
240 };
241 }
242 }
243}
244
245impl Sink<Bytes> for TurnTie {
246 type Error = anyhow::Error;
247
248 fn poll_ready(
249 self: std::pin::Pin<&mut Self>,
250 cx: &mut std::task::Context<'_>,
251 ) -> std::task::Poll<Result<(), Self::Error>> {
252 let this = self.project();
253 this.turnclient.poll_ready(cx)
254 }
255
256 fn start_send(self: std::pin::Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
257 let msg = MessageToTurnServer::SendTo(self.counterpart, item.into());
258 let this = self.project();
259 this.turnclient.start_send(msg)
260 }
261
262 fn poll_flush(
263 self: std::pin::Pin<&mut Self>,
264 cx: &mut std::task::Context<'_>,
265 ) -> std::task::Poll<Result<(), Self::Error>> {
266 let this = self.project();
267 this.turnclient.poll_flush(cx)
268 }
269
270 fn poll_close(
271 self: std::pin::Pin<&mut Self>,
272 cx: &mut std::task::Context<'_>,
273 ) -> std::task::Poll<Result<(), Self::Error>> {
274 let this = self.project();
275 this.turnclient.poll_close(cx)
276 }
277}