1use super::allocate::Allocate;
2use super::stun_transaction::StunTransaction;
3use crate::attribute::Attribute;
4use crate::auth::AuthParams;
5use crate::channel_data::ChannelData;
6use crate::{AsyncReply, AsyncResult, Error, ErrorKind, Result};
7use fibers_timeout_queue::TimeoutQueue;
8use fibers_transport::Transport;
9use futures::{Async, Future, Poll};
10use rustun::channel::{Channel as StunChannel, RecvMessage};
11use rustun::message::{ErrorResponse, Indication, Request, Response};
12use rustun::transport::StunTransport;
13use std::collections::HashMap;
14use std::net::{IpAddr, SocketAddr};
15use std::time::Duration;
16use stun_codec::rfc5766::attributes::ChannelNumber;
17use stun_codec::{rfc5389, rfc5766};
18
19const PERMISSION_LIFETIME_SECONDS: u64 = 300;
20const CHANNEL_LIFETIME_SECONDS: u64 = PERMISSION_LIFETIME_SECONDS; #[derive(Debug)]
23pub struct ClientCore<S, C>
24where
25 S: StunTransport<Attribute, PeerAddr = ()>,
26 C: Transport<PeerAddr = (), SendItem = ChannelData, RecvItem = ChannelData>,
27{
28 stun_channel: StunChannel<Attribute, S>,
29 channel_data_transporter: C,
30 auth_params: AuthParams,
31 lifetime: Duration,
32 permissions: HashMap<IpAddr, Option<AsyncReply<()>>>,
33 channels: HashMap<SocketAddr, ChannelState>,
34 next_channel_number: ChannelNumber,
35 timeout_queue: TimeoutQueue<TimeoutEntry>,
36 refresh_transaction: StunTransaction,
37 create_permission_transaction: StunTransaction<(SocketAddr, Response<Attribute>)>,
38 channel_bind_transaction: StunTransaction<(SocketAddr, Response<Attribute>)>,
39 relay_addr: Option<SocketAddr>,
40}
41impl<S, C> ClientCore<S, C>
42where
43 S: StunTransport<Attribute, PeerAddr = ()> + 'static,
44 C: Transport<PeerAddr = (), SendItem = ChannelData, RecvItem = ChannelData>,
45{
46 pub fn allocate(
47 stun_transporter: S,
48 channel_data_transporter: C,
49 auth_params: AuthParams,
50 ) -> Allocate<S, C> {
51 Allocate::new(
52 StunChannel::new(stun_transporter),
53 channel_data_transporter,
54 auth_params,
55 )
56 }
57
58 pub fn new(
59 stun_channel: StunChannel<Attribute, S>,
60 channel_data_transporter: C,
61 auth_params: AuthParams,
62 lifetime: Duration,
63 relay_addr: Option<SocketAddr>,
64 ) -> Self {
65 let mut timeout_queue = TimeoutQueue::new();
66 timeout_queue.push(TimeoutEntry::Refresh, lifetime * 9 / 10);
67 ClientCore {
68 stun_channel,
69 channel_data_transporter,
70 auth_params,
71 lifetime,
72 permissions: HashMap::new(),
73 channels: HashMap::new(),
74 next_channel_number: ChannelNumber::min(),
75 timeout_queue,
76 refresh_transaction: StunTransaction::empty(),
77 create_permission_transaction: StunTransaction::empty(),
78 channel_bind_transaction: StunTransaction::empty(),
79 relay_addr,
80 }
81 }
82
83 pub fn stun_channel_ref(&self) -> &StunChannel<Attribute, S> {
84 &self.stun_channel
85 }
86
87 pub fn relay_addr(&self) -> Option<SocketAddr> {
88 self.relay_addr
89 }
90
91 fn start_refresh(&mut self) -> Result<()> {
92 let lifetime = track!(rfc5766::attributes::Lifetime::new(self.lifetime))?;
93
94 let mut request = Request::new(rfc5766::methods::REFRESH);
95 request.add_attribute(lifetime.into());
96 track!(self.auth_params.add_auth_attributes(&mut request))?;
97
98 self.refresh_transaction = StunTransaction::new(self.stun_channel.call((), request));
99 Ok(())
100 }
101
102 fn handle_refresh_response(&mut self, response: Response<Attribute>) -> Result<()> {
103 match response {
104 Err(response) => {
105 track!(self.handle_error_response(response))?;
106 track!(self.start_refresh())?;
107 }
108 Ok(response) => {
109 let mut lifetime = None;
110 for attr in response.attributes() {
111 match attr {
112 Attribute::Lifetime(a) => {
113 lifetime = Some(a.lifetime());
114 }
115 Attribute::MessageIntegrity(a) => {
116 track!(self.auth_params.validate(a))?;
117 }
118 _ => {}
119 }
120 }
121
122 self.lifetime = track_assert_some!(lifetime, ErrorKind::Other; response);
123 self.timeout_queue
124 .push(TimeoutEntry::Refresh, self.lifetime * 9 / 10);
125 }
126 }
127 Ok(())
128 }
129
130 fn handle_create_permission_response(
131 &mut self,
132 peer: SocketAddr,
133 response: Response<Attribute>,
134 ) -> Result<()> {
135 let reply = track_assert_some!(self.permissions.remove(&peer.ip()), ErrorKind::Other);
136 match response {
137 Err(response) => {
138 track!(self.handle_error_response(response))?;
139 if let Err(e) = track!(self.create_permission_inner(peer)) {
140 if let Some(reply) = reply {
141 reply.send(Err(e.clone()));
142 }
143 return Err(e);
144 }
145 self.permissions.insert(peer.ip(), reply);
146 }
147 Ok(response) => {
148 for attr in response.attributes() {
149 if let Attribute::MessageIntegrity(a) = attr {
150 track!(self.auth_params.validate(a))?;
151 }
152 }
153 if let Some(reply) = reply {
154 reply.send(Ok(()));
155 }
156 self.permissions.insert(peer.ip(), None);
157 self.timeout_queue.push(
158 TimeoutEntry::Permission { peer },
159 Duration::from_secs(PERMISSION_LIFETIME_SECONDS * 9 / 10),
160 );
161 }
162 }
163 Ok(())
164 }
165
166 fn handle_channel_bind_response(
167 &mut self,
168 peer: SocketAddr,
169 response: Response<Attribute>,
170 ) -> Result<()> {
171 let state = track_assert_some!(self.channels.remove(&peer), ErrorKind::Other);
172 match response {
173 Err(response) => {
174 track!(self.handle_error_response(response))?;
175 if let Err(e) = track!(self.channel_bind_inner(peer, state.channel_number())) {
176 if let ChannelState::Creating { reply, .. } = state {
177 reply.send(Err(e.clone()));
178 }
179 return Err(e);
180 }
181 self.channels.insert(peer, state);
182 }
183 Ok(response) => {
184 for attr in response.attributes() {
185 if let Attribute::MessageIntegrity(a) = attr {
186 track!(self.auth_params.validate(a))?;
187 }
188 }
189
190 let number = state.channel_number();
191 if let ChannelState::Creating { reply, .. } = state {
192 reply.send(Ok(()));
193 }
194 self.channels.insert(peer, ChannelState::Created { number });
195 self.timeout_queue.push(
196 TimeoutEntry::Channel { peer },
197 Duration::from_secs(CHANNEL_LIFETIME_SECONDS * 9 / 10),
198 );
199 }
200 }
201 Ok(())
202 }
203
204 fn handle_error_response(&mut self, response: ErrorResponse<Attribute>) -> Result<()> {
205 let error: &rfc5389::attributes::ErrorCode =
206 track_assert_some!(response.get_attribute(), ErrorKind::Other; response);
207 track_assert_eq!(
208 error.code(),
209 rfc5389::errors::StaleNonce::CODEPOINT,
210 ErrorKind::Other; response
211 );
212
213 let nonce: &rfc5389::attributes::Nonce =
214 track_assert_some!(response.get_attribute(), ErrorKind::Other; response);
215 self.auth_params.set_nonce(nonce.clone());
216
217 Ok(())
218 }
219
220 fn handle_timeout(&mut self, entry: TimeoutEntry) -> Result<()> {
221 match entry {
222 TimeoutEntry::Refresh => track!(self.start_refresh())?,
223 TimeoutEntry::Permission { peer } => {
224 if self.permissions.remove(&peer.ip()).is_some() {
225 track!(self.create_permission_inner(peer))?;
226 self.permissions.insert(peer.ip(), None);
227 self.timeout_queue.push(
228 TimeoutEntry::Permission { peer },
229 Duration::from_secs(PERMISSION_LIFETIME_SECONDS * 9 / 10),
230 );
231 }
232 }
233 TimeoutEntry::Channel { peer } => {
234 if let Some(state) = self.channels.remove(&peer) {
235 track!(self.channel_bind_inner(peer, state.channel_number()))?;
236 self.channels.insert(peer, state);
237 self.timeout_queue.push(
238 TimeoutEntry::Channel { peer },
239 Duration::from_secs(CHANNEL_LIFETIME_SECONDS * 9 / 10),
240 );
241 }
242 }
243 }
244 Ok(())
245 }
246
247 fn handle_stun_message(
248 &mut self,
249 message: RecvMessage<Attribute>,
250 ) -> Result<Option<(SocketAddr, Vec<u8>)>> {
251 match message {
252 RecvMessage::Invalid(message) => track_panic!(ErrorKind::Other; message),
253 RecvMessage::Request(request) => track_panic!(ErrorKind::Other; request),
254 RecvMessage::Indication(indication) => track!(self.handle_stun_indication(indication)),
255 }
256 }
257
258 fn handle_stun_indication(
259 &mut self,
260 indication: Indication<Attribute>,
261 ) -> Result<Option<(SocketAddr, Vec<u8>)>> {
262 match indication.method() {
263 rfc5766::methods::DATA => {
264 let data: &rfc5766::attributes::Data =
265 track_assert_some!(indication.get_attribute(), ErrorKind::Other; indication);
266 let peer: &rfc5766::attributes::XorPeerAddress =
267 track_assert_some!(indication.get_attribute(), ErrorKind::Other; indication);
268 track_assert!(
269 self.permissions.contains_key(&peer.address().ip()),
270 ErrorKind::Other; peer, indication
271 );
272 Ok(Some((peer.address(), Vec::from(data.data()))))
273 }
274 _ => {
275 track_panic!(ErrorKind::Other; indication);
276 }
277 }
278 }
279
280 fn handle_channel_data(&mut self, data: ChannelData) -> Result<(SocketAddr, Vec<u8>)> {
281 let peer = track_assert_some!(
283 self.channels
284 .iter()
285 .find(|x| x.1.channel_number() == data.channel_number())
286 .map(|x| *x.0),
287 ErrorKind::Other
288 );
289 Ok((peer, data.into_data()))
290 }
291
292 fn create_permission_inner(&mut self, peer: SocketAddr) -> Result<()> {
293 let mut request = Request::new(rfc5766::methods::CREATE_PERMISSION);
295 request.add_attribute(rfc5766::attributes::XorPeerAddress::new(peer).into());
296 track!(self.auth_params.add_auth_attributes(&mut request))?;
297
298 self.create_permission_transaction =
299 StunTransaction::with_peer(peer, self.stun_channel.call((), request));
300 Ok(())
301 }
302
303 fn channel_bind_inner(
304 &mut self,
305 peer: SocketAddr,
306 channel_number: ChannelNumber,
307 ) -> Result<()> {
308 track_assert!(!self.channels.contains_key(&peer), ErrorKind::InvalidInput; peer);
309
310 let mut request = Request::new(rfc5766::methods::CHANNEL_BIND);
311 request.add_attribute(rfc5766::attributes::XorPeerAddress::new(peer).into());
312 request.add_attribute(channel_number.into());
313 track!(self.auth_params.add_auth_attributes(&mut request))?;
314
315 self.channel_bind_transaction =
316 StunTransaction::with_peer(peer, self.stun_channel.call((), request));
317 Ok(())
318 }
319
320 fn next_channel_number(&mut self) -> ChannelNumber {
321 let curr = self.next_channel_number;
323 self.next_channel_number = curr.wrapping_increment();
324 curr
325 }
326
327 pub fn create_permission(&mut self, peer: SocketAddr) -> AsyncResult<()> {
328 let (result, reply) = AsyncResult::new();
329 match track!(self.create_permission_inner(peer)) {
330 Err(e) => {
331 reply.send(Err(e));
332 }
333 Ok(()) => {
334 self.permissions.insert(peer.ip(), Some(reply));
335 }
336 }
337 result
338 }
339
340 pub fn channel_bind(&mut self, peer: SocketAddr) -> AsyncResult<()> {
341 let (result, reply) = AsyncResult::new();
342 let channel_number = self.next_channel_number();
343 match track!(self.channel_bind_inner(peer, channel_number)) {
344 Err(e) => {
345 reply.send(Err(e));
346 }
347 Ok(()) => {
348 self.channels.insert(
349 peer,
350 ChannelState::Creating {
351 number: channel_number,
352 reply,
353 },
354 );
355 }
356 }
357 result
358 }
359
360 pub fn start_send(&mut self, peer: SocketAddr, data: Vec<u8>) -> Result<()> {
361 if let Some(state) = self.channels.get(&peer) {
362 let data = track!(ChannelData::new(state.channel_number(), data,))?;
363 track!(self.channel_data_transporter.start_send((), data))?;
364 } else if self.permissions.contains_key(&peer.ip()) {
365 track_assert!(self.permissions.contains_key(&peer.ip()), ErrorKind::Other; peer);
366 let mut indication = Indication::new(rfc5766::methods::SEND);
367 indication.add_attribute(rfc5766::attributes::XorPeerAddress::new(peer).into());
368 indication.add_attribute(track!(rfc5766::attributes::Data::new(data))?.into());
369 track!(self.stun_channel.cast((), indication))?;
370 } else {
371 track_panic!(ErrorKind::InvalidInput, "Unknown peer: {:?}", peer);
372 }
373 Ok(())
374 }
375
376 pub fn poll_send(&mut self) -> Poll<(), Error> {
377 let is_ready = track!(self.stun_channel.poll_send())?.is_ready()
378 && track!(self.channel_data_transporter.poll_send())?.is_ready();
379 if is_ready {
380 Ok(Async::Ready(()))
381 } else {
382 Ok(Async::NotReady)
383 }
384 }
385
386 pub fn poll_recv(&mut self) -> Poll<Option<(SocketAddr, Vec<u8>)>, Error> {
387 let mut did_something = true;
388 while did_something {
389 did_something = false;
390
391 while let Async::Ready(message) = track!(self.stun_channel.poll_recv())? {
392 did_something = true;
393 if let Some((_, message)) = message {
394 if let Some((peer, data)) = track!(self.handle_stun_message(message))? {
395 return Ok(Async::Ready(Some((peer, data))));
396 }
397 } else {
398 track_panic!(ErrorKind::Other, "Unexpected termination");
399 }
400 }
401 if let Async::Ready(data) = track!(self.channel_data_transporter.poll_recv())? {
402 if let Some((_, data)) = data {
403 let (peer, data) = track!(self.handle_channel_data(data))?;
404 return Ok(Async::Ready(Some((peer, data))));
405 } else {
406 track_panic!(ErrorKind::Other, "Unexpected termination");
407 }
408 }
409 while let Some(entry) = self.timeout_queue.pop() {
410 did_something = true;
411 track!(self.handle_timeout(entry))?;
412 }
413 if let Async::Ready(response) = track!(self.refresh_transaction.poll())? {
414 did_something = true;
415 track!(self.handle_refresh_response(response))?;
416 }
417 if let Async::Ready((peer, response)) =
418 track!(self.create_permission_transaction.poll())?
419 {
420 did_something = true;
421 track!(self.handle_create_permission_response(peer, response))?;
422 }
423 if let Async::Ready((peer, response)) = track!(self.channel_bind_transaction.poll())? {
424 did_something = true;
425 track!(self.handle_channel_bind_response(peer, response))?;
426 }
427 track!(self.channel_data_transporter.poll_send())?;
428 }
429 Ok(Async::NotReady)
430 }
431}
432
433#[derive(Debug)]
434enum TimeoutEntry {
435 Refresh,
436 Permission { peer: SocketAddr },
437 Channel { peer: SocketAddr },
438}
439
440#[derive(Debug)]
441enum ChannelState {
442 Creating {
443 number: ChannelNumber,
444 reply: AsyncReply<()>,
445 },
446 Created {
447 number: ChannelNumber,
448 },
449}
450impl ChannelState {
451 fn channel_number(&self) -> ChannelNumber {
452 match self {
453 ChannelState::Creating { number, .. } => *number,
454 ChannelState::Created { number } => *number,
455 }
456 }
457}