rusturn/client/
core.rs

1use super::allocate::Allocate;
2use super::stun_transaction::StunTransaction;
3use crate::attribute::Attribute;
4use crate::auth::AuthParams;
5use crate::channel_data::ChannelData;
6use crate::{AsyncReply, AsyncResult, Error, ErrorKind, Result};
7use fibers_timeout_queue::TimeoutQueue;
8use fibers_transport::Transport;
9use futures::{Async, Future, Poll};
10use rustun::channel::{Channel as StunChannel, RecvMessage};
11use rustun::message::{ErrorResponse, Indication, Request, Response};
12use rustun::transport::StunTransport;
13use std::collections::HashMap;
14use std::net::{IpAddr, SocketAddr};
15use std::time::Duration;
16use stun_codec::rfc5766::attributes::ChannelNumber;
17use stun_codec::{rfc5389, rfc5766};
18
19const PERMISSION_LIFETIME_SECONDS: u64 = 300;
20const CHANNEL_LIFETIME_SECONDS: u64 = PERMISSION_LIFETIME_SECONDS; // FIXME: Use `600` (and refresh permissions)
21
22#[derive(Debug)]
23pub struct ClientCore<S, C>
24where
25    S: StunTransport<Attribute, PeerAddr = ()>,
26    C: Transport<PeerAddr = (), SendItem = ChannelData, RecvItem = ChannelData>,
27{
28    stun_channel: StunChannel<Attribute, S>,
29    channel_data_transporter: C,
30    auth_params: AuthParams,
31    lifetime: Duration,
32    permissions: HashMap<IpAddr, Option<AsyncReply<()>>>,
33    channels: HashMap<SocketAddr, ChannelState>,
34    next_channel_number: ChannelNumber,
35    timeout_queue: TimeoutQueue<TimeoutEntry>,
36    refresh_transaction: StunTransaction,
37    create_permission_transaction: StunTransaction<(SocketAddr, Response<Attribute>)>,
38    channel_bind_transaction: StunTransaction<(SocketAddr, Response<Attribute>)>,
39    relay_addr: Option<SocketAddr>,
40}
41impl<S, C> ClientCore<S, C>
42where
43    S: StunTransport<Attribute, PeerAddr = ()> + 'static,
44    C: Transport<PeerAddr = (), SendItem = ChannelData, RecvItem = ChannelData>,
45{
46    pub fn allocate(
47        stun_transporter: S,
48        channel_data_transporter: C,
49        auth_params: AuthParams,
50    ) -> Allocate<S, C> {
51        Allocate::new(
52            StunChannel::new(stun_transporter),
53            channel_data_transporter,
54            auth_params,
55        )
56    }
57
58    pub fn new(
59        stun_channel: StunChannel<Attribute, S>,
60        channel_data_transporter: C,
61        auth_params: AuthParams,
62        lifetime: Duration,
63        relay_addr: Option<SocketAddr>,
64    ) -> Self {
65        let mut timeout_queue = TimeoutQueue::new();
66        timeout_queue.push(TimeoutEntry::Refresh, lifetime * 9 / 10);
67        ClientCore {
68            stun_channel,
69            channel_data_transporter,
70            auth_params,
71            lifetime,
72            permissions: HashMap::new(),
73            channels: HashMap::new(),
74            next_channel_number: ChannelNumber::min(),
75            timeout_queue,
76            refresh_transaction: StunTransaction::empty(),
77            create_permission_transaction: StunTransaction::empty(),
78            channel_bind_transaction: StunTransaction::empty(),
79            relay_addr,
80        }
81    }
82
83    pub fn stun_channel_ref(&self) -> &StunChannel<Attribute, S> {
84        &self.stun_channel
85    }
86
87    pub fn relay_addr(&self) -> Option<SocketAddr> {
88        self.relay_addr
89    }
90
91    fn start_refresh(&mut self) -> Result<()> {
92        let lifetime = track!(rfc5766::attributes::Lifetime::new(self.lifetime))?;
93
94        let mut request = Request::new(rfc5766::methods::REFRESH);
95        request.add_attribute(lifetime.into());
96        track!(self.auth_params.add_auth_attributes(&mut request))?;
97
98        self.refresh_transaction = StunTransaction::new(self.stun_channel.call((), request));
99        Ok(())
100    }
101
102    fn handle_refresh_response(&mut self, response: Response<Attribute>) -> Result<()> {
103        match response {
104            Err(response) => {
105                track!(self.handle_error_response(response))?;
106                track!(self.start_refresh())?;
107            }
108            Ok(response) => {
109                let mut lifetime = None;
110                for attr in response.attributes() {
111                    match attr {
112                        Attribute::Lifetime(a) => {
113                            lifetime = Some(a.lifetime());
114                        }
115                        Attribute::MessageIntegrity(a) => {
116                            track!(self.auth_params.validate(a))?;
117                        }
118                        _ => {}
119                    }
120                }
121
122                self.lifetime = track_assert_some!(lifetime, ErrorKind::Other; response);
123                self.timeout_queue
124                    .push(TimeoutEntry::Refresh, self.lifetime * 9 / 10);
125            }
126        }
127        Ok(())
128    }
129
130    fn handle_create_permission_response(
131        &mut self,
132        peer: SocketAddr,
133        response: Response<Attribute>,
134    ) -> Result<()> {
135        let reply = track_assert_some!(self.permissions.remove(&peer.ip()), ErrorKind::Other);
136        match response {
137            Err(response) => {
138                track!(self.handle_error_response(response))?;
139                if let Err(e) = track!(self.create_permission_inner(peer)) {
140                    if let Some(reply) = reply {
141                        reply.send(Err(e.clone()));
142                    }
143                    return Err(e);
144                }
145                self.permissions.insert(peer.ip(), reply);
146            }
147            Ok(response) => {
148                for attr in response.attributes() {
149                    if let Attribute::MessageIntegrity(a) = attr {
150                        track!(self.auth_params.validate(a))?;
151                    }
152                }
153                if let Some(reply) = reply {
154                    reply.send(Ok(()));
155                }
156                self.permissions.insert(peer.ip(), None);
157                self.timeout_queue.push(
158                    TimeoutEntry::Permission { peer },
159                    Duration::from_secs(PERMISSION_LIFETIME_SECONDS * 9 / 10),
160                );
161            }
162        }
163        Ok(())
164    }
165
166    fn handle_channel_bind_response(
167        &mut self,
168        peer: SocketAddr,
169        response: Response<Attribute>,
170    ) -> Result<()> {
171        let state = track_assert_some!(self.channels.remove(&peer), ErrorKind::Other);
172        match response {
173            Err(response) => {
174                track!(self.handle_error_response(response))?;
175                if let Err(e) = track!(self.channel_bind_inner(peer, state.channel_number())) {
176                    if let ChannelState::Creating { reply, .. } = state {
177                        reply.send(Err(e.clone()));
178                    }
179                    return Err(e);
180                }
181                self.channels.insert(peer, state);
182            }
183            Ok(response) => {
184                for attr in response.attributes() {
185                    if let Attribute::MessageIntegrity(a) = attr {
186                        track!(self.auth_params.validate(a))?;
187                    }
188                }
189
190                let number = state.channel_number();
191                if let ChannelState::Creating { reply, .. } = state {
192                    reply.send(Ok(()));
193                }
194                self.channels.insert(peer, ChannelState::Created { number });
195                self.timeout_queue.push(
196                    TimeoutEntry::Channel { peer },
197                    Duration::from_secs(CHANNEL_LIFETIME_SECONDS * 9 / 10),
198                );
199            }
200        }
201        Ok(())
202    }
203
204    fn handle_error_response(&mut self, response: ErrorResponse<Attribute>) -> Result<()> {
205        let error: &rfc5389::attributes::ErrorCode =
206            track_assert_some!(response.get_attribute(), ErrorKind::Other; response);
207        track_assert_eq!(
208            error.code(),
209            rfc5389::errors::StaleNonce::CODEPOINT,
210            ErrorKind::Other; response
211        );
212
213        let nonce: &rfc5389::attributes::Nonce =
214            track_assert_some!(response.get_attribute(), ErrorKind::Other; response);
215        self.auth_params.set_nonce(nonce.clone());
216
217        Ok(())
218    }
219
220    fn handle_timeout(&mut self, entry: TimeoutEntry) -> Result<()> {
221        match entry {
222            TimeoutEntry::Refresh => track!(self.start_refresh())?,
223            TimeoutEntry::Permission { peer } => {
224                if self.permissions.remove(&peer.ip()).is_some() {
225                    track!(self.create_permission_inner(peer))?;
226                    self.permissions.insert(peer.ip(), None);
227                    self.timeout_queue.push(
228                        TimeoutEntry::Permission { peer },
229                        Duration::from_secs(PERMISSION_LIFETIME_SECONDS * 9 / 10),
230                    );
231                }
232            }
233            TimeoutEntry::Channel { peer } => {
234                if let Some(state) = self.channels.remove(&peer) {
235                    track!(self.channel_bind_inner(peer, state.channel_number()))?;
236                    self.channels.insert(peer, state);
237                    self.timeout_queue.push(
238                        TimeoutEntry::Channel { peer },
239                        Duration::from_secs(CHANNEL_LIFETIME_SECONDS * 9 / 10),
240                    );
241                }
242            }
243        }
244        Ok(())
245    }
246
247    fn handle_stun_message(
248        &mut self,
249        message: RecvMessage<Attribute>,
250    ) -> Result<Option<(SocketAddr, Vec<u8>)>> {
251        match message {
252            RecvMessage::Invalid(message) => track_panic!(ErrorKind::Other; message),
253            RecvMessage::Request(request) => track_panic!(ErrorKind::Other; request),
254            RecvMessage::Indication(indication) => track!(self.handle_stun_indication(indication)),
255        }
256    }
257
258    fn handle_stun_indication(
259        &mut self,
260        indication: Indication<Attribute>,
261    ) -> Result<Option<(SocketAddr, Vec<u8>)>> {
262        match indication.method() {
263            rfc5766::methods::DATA => {
264                let data: &rfc5766::attributes::Data =
265                    track_assert_some!(indication.get_attribute(), ErrorKind::Other; indication);
266                let peer: &rfc5766::attributes::XorPeerAddress =
267                    track_assert_some!(indication.get_attribute(), ErrorKind::Other; indication);
268                track_assert!(
269                    self.permissions.contains_key(&peer.address().ip()),
270                    ErrorKind::Other; peer,  indication
271                );
272                Ok(Some((peer.address(), Vec::from(data.data()))))
273            }
274            _ => {
275                track_panic!(ErrorKind::Other; indication);
276            }
277        }
278    }
279
280    fn handle_channel_data(&mut self, data: ChannelData) -> Result<(SocketAddr, Vec<u8>)> {
281        // FIXME: optimize
282        let peer = track_assert_some!(
283            self.channels
284                .iter()
285                .find(|x| x.1.channel_number() == data.channel_number())
286                .map(|x| *x.0),
287            ErrorKind::Other
288        );
289        Ok((peer, data.into_data()))
290    }
291
292    fn create_permission_inner(&mut self, peer: SocketAddr) -> Result<()> {
293        // If a permission already exists on the TURN server, it will just be refreshed.
294        let mut request = Request::new(rfc5766::methods::CREATE_PERMISSION);
295        request.add_attribute(rfc5766::attributes::XorPeerAddress::new(peer).into());
296        track!(self.auth_params.add_auth_attributes(&mut request))?;
297
298        self.create_permission_transaction =
299            StunTransaction::with_peer(peer, self.stun_channel.call((), request));
300        Ok(())
301    }
302
303    fn channel_bind_inner(
304        &mut self,
305        peer: SocketAddr,
306        channel_number: ChannelNumber,
307    ) -> Result<()> {
308        track_assert!(!self.channels.contains_key(&peer), ErrorKind::InvalidInput; peer);
309
310        let mut request = Request::new(rfc5766::methods::CHANNEL_BIND);
311        request.add_attribute(rfc5766::attributes::XorPeerAddress::new(peer).into());
312        request.add_attribute(channel_number.into());
313        track!(self.auth_params.add_auth_attributes(&mut request))?;
314
315        self.channel_bind_transaction =
316            StunTransaction::with_peer(peer, self.stun_channel.call((), request));
317        Ok(())
318    }
319
320    fn next_channel_number(&mut self) -> ChannelNumber {
321        // FIXME: collision check
322        let curr = self.next_channel_number;
323        self.next_channel_number = curr.wrapping_increment();
324        curr
325    }
326
327    pub fn create_permission(&mut self, peer: SocketAddr) -> AsyncResult<()> {
328        let (result, reply) = AsyncResult::new();
329        match track!(self.create_permission_inner(peer)) {
330            Err(e) => {
331                reply.send(Err(e));
332            }
333            Ok(()) => {
334                self.permissions.insert(peer.ip(), Some(reply));
335            }
336        }
337        result
338    }
339
340    pub fn channel_bind(&mut self, peer: SocketAddr) -> AsyncResult<()> {
341        let (result, reply) = AsyncResult::new();
342        let channel_number = self.next_channel_number();
343        match track!(self.channel_bind_inner(peer, channel_number)) {
344            Err(e) => {
345                reply.send(Err(e));
346            }
347            Ok(()) => {
348                self.channels.insert(
349                    peer,
350                    ChannelState::Creating {
351                        number: channel_number,
352                        reply,
353                    },
354                );
355            }
356        }
357        result
358    }
359
360    pub fn start_send(&mut self, peer: SocketAddr, data: Vec<u8>) -> Result<()> {
361        if let Some(state) = self.channels.get(&peer) {
362            let data = track!(ChannelData::new(state.channel_number(), data,))?;
363            track!(self.channel_data_transporter.start_send((), data))?;
364        } else if self.permissions.contains_key(&peer.ip()) {
365            track_assert!(self.permissions.contains_key(&peer.ip()), ErrorKind::Other; peer);
366            let mut indication = Indication::new(rfc5766::methods::SEND);
367            indication.add_attribute(rfc5766::attributes::XorPeerAddress::new(peer).into());
368            indication.add_attribute(track!(rfc5766::attributes::Data::new(data))?.into());
369            track!(self.stun_channel.cast((), indication))?;
370        } else {
371            track_panic!(ErrorKind::InvalidInput, "Unknown peer: {:?}", peer);
372        }
373        Ok(())
374    }
375
376    pub fn poll_send(&mut self) -> Poll<(), Error> {
377        let is_ready = track!(self.stun_channel.poll_send())?.is_ready()
378            && track!(self.channel_data_transporter.poll_send())?.is_ready();
379        if is_ready {
380            Ok(Async::Ready(()))
381        } else {
382            Ok(Async::NotReady)
383        }
384    }
385
386    pub fn poll_recv(&mut self) -> Poll<Option<(SocketAddr, Vec<u8>)>, Error> {
387        let mut did_something = true;
388        while did_something {
389            did_something = false;
390
391            while let Async::Ready(message) = track!(self.stun_channel.poll_recv())? {
392                did_something = true;
393                if let Some((_, message)) = message {
394                    if let Some((peer, data)) = track!(self.handle_stun_message(message))? {
395                        return Ok(Async::Ready(Some((peer, data))));
396                    }
397                } else {
398                    track_panic!(ErrorKind::Other, "Unexpected termination");
399                }
400            }
401            if let Async::Ready(data) = track!(self.channel_data_transporter.poll_recv())? {
402                if let Some((_, data)) = data {
403                    let (peer, data) = track!(self.handle_channel_data(data))?;
404                    return Ok(Async::Ready(Some((peer, data))));
405                } else {
406                    track_panic!(ErrorKind::Other, "Unexpected termination");
407                }
408            }
409            while let Some(entry) = self.timeout_queue.pop() {
410                did_something = true;
411                track!(self.handle_timeout(entry))?;
412            }
413            if let Async::Ready(response) = track!(self.refresh_transaction.poll())? {
414                did_something = true;
415                track!(self.handle_refresh_response(response))?;
416            }
417            if let Async::Ready((peer, response)) =
418                track!(self.create_permission_transaction.poll())?
419            {
420                did_something = true;
421                track!(self.handle_create_permission_response(peer, response))?;
422            }
423            if let Async::Ready((peer, response)) = track!(self.channel_bind_transaction.poll())? {
424                did_something = true;
425                track!(self.handle_channel_bind_response(peer, response))?;
426            }
427            track!(self.channel_data_transporter.poll_send())?;
428        }
429        Ok(Async::NotReady)
430    }
431}
432
433#[derive(Debug)]
434enum TimeoutEntry {
435    Refresh,
436    Permission { peer: SocketAddr },
437    Channel { peer: SocketAddr },
438}
439
440#[derive(Debug)]
441enum ChannelState {
442    Creating {
443        number: ChannelNumber,
444        reply: AsyncReply<()>,
445    },
446    Created {
447        number: ChannelNumber,
448    },
449}
450impl ChannelState {
451    fn channel_number(&self) -> ChannelNumber {
452        match self {
453            ChannelState::Creating { number, .. } => *number,
454            ChannelState::Created { number } => *number,
455        }
456    }
457}