1#![forbid(unsafe_code)] use super::MockNetRuntime;
11use super::io::{LocalStream, stream_pair};
12use crate::util::mpsc_channel;
13use core::fmt;
14use tor_rtcompat::tls::{TlsAcceptorSettings, TlsConnector};
15use tor_rtcompat::{
16 CertifiedConn, NetStreamListener, NetStreamProvider, Runtime, StreamOps, TlsProvider,
17};
18use tor_rtcompat::{UdpProvider, UdpSocket};
19
20use async_trait::async_trait;
21use futures::FutureExt;
22use futures::channel::mpsc;
23use futures::io::{AsyncRead, AsyncWrite};
24use futures::lock::Mutex as AsyncMutex;
25use futures::sink::SinkExt;
26use futures::stream::{Stream, StreamExt};
27use std::collections::HashMap;
28use std::fmt::Formatter;
29use std::io::{self, Error as IoError, ErrorKind, Result as IoResult};
30use std::net::{IpAddr, SocketAddr};
31use std::pin::Pin;
32use std::sync::atomic::{AtomicU16, Ordering};
33use std::sync::{Arc, Mutex};
34use std::task::{Context, Poll};
35use thiserror::Error;
36use void::Void;
37
38type ConnSender = mpsc::Sender<(LocalStream, SocketAddr)>;
41type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>;
43
44#[derive(Default)]
51pub struct MockNetwork {
52 listening: Mutex<HashMap<SocketAddr, AddrBehavior>>,
54}
55
56#[derive(Clone)]
58struct ListenerEntry {
59 send: ConnSender,
62
63 tls_cert: Option<Vec<u8>>,
66}
67
68#[derive(Clone)]
70enum AddrBehavior {
71 Listener(ListenerEntry),
73 Timeout,
75}
76
77#[derive(Clone)]
115pub struct MockNetProvider {
116 inner: Arc<MockNetProviderInner>,
121}
122
123impl fmt::Debug for MockNetProvider {
124 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("MockNetProvider").finish_non_exhaustive()
126 }
127}
128
129struct MockNetProviderInner {
134 addrs: Vec<IpAddr>,
136 net: Arc<MockNetwork>,
138 next_port: AtomicU16,
143}
144
145pub struct MockNetListener {
149 addr: SocketAddr,
151 receiver: AsyncMutex<ConnReceiver>,
155}
156
157pub struct ProviderBuilder {
161 addrs: Vec<IpAddr>,
163 net: Arc<MockNetwork>,
165}
166
167impl Default for MockNetProvider {
168 fn default() -> Self {
169 Arc::new(MockNetwork::default()).builder().provider()
170 }
171}
172
173impl MockNetwork {
174 pub fn new() -> Arc<Self> {
176 Default::default()
177 }
178
179 pub fn builder(self: &Arc<Self>) -> ProviderBuilder {
192 ProviderBuilder {
193 addrs: vec![],
194 net: Arc::clone(self),
195 }
196 }
197
198 pub fn add_blackhole(&self, address: SocketAddr) -> IoResult<()> {
200 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
201 if listener_map.contains_key(&address) {
202 return Err(err(ErrorKind::AddrInUse));
203 }
204 listener_map.insert(address, AddrBehavior::Timeout);
205 Ok(())
206 }
207
208 async fn send_connection(
217 &self,
218 source_addr: SocketAddr,
219 target_addr: SocketAddr,
220 peer_stream: LocalStream,
221 ) -> IoResult<Option<Vec<u8>>> {
222 let entry = {
223 let listener_map = self.listening.lock().expect("Poisoned lock for listener");
224 listener_map.get(&target_addr).cloned()
225 };
226 match entry {
227 Some(AddrBehavior::Listener(mut entry)) => {
228 if entry.send.send((peer_stream, source_addr)).await.is_ok() {
229 return Ok(entry.tls_cert);
230 }
231 Err(err(ErrorKind::ConnectionRefused))
232 }
233 Some(AddrBehavior::Timeout) => futures::future::pending().await,
234 None => Err(err(ErrorKind::ConnectionRefused)),
235 }
236 }
237
238 fn add_listener(&self, addr: SocketAddr, tls_cert: Option<Vec<u8>>) -> IoResult<ConnReceiver> {
246 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
247 if listener_map.contains_key(&addr) {
248 return Err(err(ErrorKind::AddrInUse));
250 }
251
252 let (send, recv) = mpsc_channel(16);
253
254 let entry = ListenerEntry { send, tls_cert };
255
256 listener_map.insert(addr, AddrBehavior::Listener(entry));
257
258 Ok(recv)
259 }
260}
261
262impl ProviderBuilder {
263 pub fn add_address(&mut self, addr: IpAddr) -> &mut Self {
265 self.addrs.push(addr);
266 self
267 }
268 pub fn runtime<R: Runtime>(&self, runtime: R) -> super::MockNetRuntime<R> {
271 MockNetRuntime::new(runtime, self.provider())
272 }
273 pub fn provider(&self) -> MockNetProvider {
275 let inner = MockNetProviderInner {
276 addrs: self.addrs.clone(),
277 net: Arc::clone(&self.net),
278 next_port: AtomicU16::new(1),
279 };
280 MockNetProvider {
281 inner: Arc::new(inner),
282 }
283 }
284}
285
286impl NetStreamListener for MockNetListener {
287 type Stream = LocalStream;
288
289 type Incoming = Self;
290
291 fn local_addr(&self) -> IoResult<SocketAddr> {
292 Ok(self.addr)
293 }
294
295 fn incoming(self) -> Self {
296 self
297 }
298}
299
300impl Stream for MockNetListener {
301 type Item = IoResult<(LocalStream, SocketAddr)>;
302 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303 let mut recv = futures::ready!(self.receiver.lock().poll_unpin(cx));
304 match recv.poll_next_unpin(cx) {
305 Poll::Pending => Poll::Pending,
306 Poll::Ready(None) => Poll::Ready(None),
307 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
308 }
309 }
310}
311
312#[derive(Debug)]
314#[non_exhaustive]
315pub struct MockUdpSocket {
316 void: Void,
321}
322
323#[async_trait]
324impl UdpProvider for MockNetProvider {
325 type UdpSocket = MockUdpSocket;
326
327 async fn bind(&self, addr: &SocketAddr) -> IoResult<MockUdpSocket> {
328 let _ = addr; Err(io::ErrorKind::Unsupported.into())
330 }
331}
332
333#[allow(clippy::diverging_sub_expression)] #[async_trait]
335impl UdpSocket for MockUdpSocket {
336 async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
337 void::unreachable((self.void, buf).0)
341 }
342 async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
343 void::unreachable((self.void, buf, target).0)
344 }
345 fn local_addr(&self) -> IoResult<SocketAddr> {
346 void::unreachable(self.void)
347 }
348}
349
350impl MockNetProvider {
351 fn get_addr_in_family(&self, other: &IpAddr) -> Option<IpAddr> {
354 self.inner
355 .addrs
356 .iter()
357 .find(|a| a.is_ipv4() == other.is_ipv4())
358 .copied()
359 }
360
361 fn arbitrary_port(&self) -> u16 {
369 let next = self.inner.next_port.fetch_add(1, Ordering::Relaxed);
370 assert!(next != 0);
371 next
372 }
373
374 fn get_origin_addr_for(&self, addr: &SocketAddr) -> IoResult<SocketAddr> {
381 let my_addr = self
382 .get_addr_in_family(&addr.ip())
383 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?;
384 Ok(SocketAddr::new(my_addr, self.arbitrary_port()))
385 }
386
387 fn get_listener_addr(&self, spec: &SocketAddr) -> IoResult<SocketAddr> {
397 let ipaddr = {
398 let ip = spec.ip();
399 if ip.is_unspecified() {
400 self.get_addr_in_family(&ip)
401 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?
402 } else if self.inner.addrs.iter().any(|a| a == &ip) {
403 ip
404 } else {
405 return Err(err(ErrorKind::AddrNotAvailable));
406 }
407 };
408 let port = {
409 if spec.port() == 0 {
410 self.arbitrary_port()
411 } else {
412 spec.port()
413 }
414 };
415
416 Ok(SocketAddr::new(ipaddr, port))
417 }
418
419 pub fn listen_tls(&self, addr: &SocketAddr, tls_cert: Vec<u8>) -> IoResult<MockNetListener> {
425 let addr = self.get_listener_addr(addr)?;
426
427 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, Some(tls_cert))?);
428
429 Ok(MockNetListener { addr, receiver })
430 }
431}
432
433#[async_trait]
434impl NetStreamProvider for MockNetProvider {
435 type Stream = LocalStream;
436 type Listener = MockNetListener;
437
438 async fn connect(&self, addr: &SocketAddr) -> IoResult<LocalStream> {
439 let my_addr = self.get_origin_addr_for(addr)?;
440 let (mut mine, theirs) = stream_pair();
441
442 let cert = self
443 .inner
444 .net
445 .send_connection(my_addr, *addr, theirs)
446 .await?;
447
448 mine.tls_cert = cert;
449
450 Ok(mine)
451 }
452
453 async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
454 let addr = self.get_listener_addr(addr)?;
455
456 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, None)?);
457
458 Ok(MockNetListener { addr, receiver })
459 }
460}
461
462#[async_trait]
463impl TlsProvider<LocalStream> for MockNetProvider {
464 type Connector = MockTlsConnector;
465 type TlsStream = MockTlsStream;
466 type Acceptor = MockTlsAcceptor;
467 type TlsServerStream = MockTlsStream;
468
469 fn tls_connector(&self) -> MockTlsConnector {
470 MockTlsConnector {}
471 }
472 fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<MockTlsAcceptor> {
473 Ok(MockTlsAcceptor {
474 own_cert: settings.cert_der().to_vec(),
475 })
476 }
477
478 fn supports_keying_material_export(&self) -> bool {
479 false
480 }
481}
482
483#[derive(Clone)]
488#[non_exhaustive]
489pub struct MockTlsConnector;
490
491#[derive(Clone)]
496#[non_exhaustive]
497pub struct MockTlsAcceptor {
498 own_cert: Vec<u8>,
500}
501
502pub struct MockTlsStream {
511 peer_cert: Option<Vec<u8>>,
513 own_cert: Option<Vec<u8>>,
515 stream: LocalStream,
517}
518
519#[async_trait]
520impl TlsConnector<LocalStream> for MockTlsConnector {
521 type Conn = MockTlsStream;
522
523 async fn negotiate_unvalidated(
524 &self,
525 mut stream: LocalStream,
526 _sni_hostname: &str,
527 ) -> IoResult<MockTlsStream> {
528 let peer_cert = stream.tls_cert.take();
529
530 if peer_cert.is_none() {
531 return Err(std::io::Error::other("attempted to wrap non-TLS stream!"));
532 }
533
534 Ok(MockTlsStream {
535 peer_cert,
536 own_cert: None,
537 stream,
538 })
539 }
540}
541
542#[async_trait]
543impl TlsConnector<LocalStream> for MockTlsAcceptor {
544 type Conn = MockTlsStream;
545
546 async fn negotiate_unvalidated(
547 &self,
548 stream: LocalStream,
549 _sni_hostname: &str,
550 ) -> IoResult<MockTlsStream> {
551 Ok(MockTlsStream {
552 peer_cert: None,
553 own_cert: Some(self.own_cert.clone()),
554 stream,
555 })
556 }
557}
558
559impl CertifiedConn for MockTlsStream {
560 fn peer_certificate(&self) -> IoResult<Option<Vec<u8>>> {
561 Ok(self.peer_cert.clone())
562 }
563
564 fn own_certificate(&self) -> IoResult<Option<Vec<u8>>> {
565 Ok(self.own_cert.clone())
566 }
567 fn export_keying_material(
568 &self,
569 _len: usize,
570 _label: &[u8],
571 _context: Option<&[u8]>,
572 ) -> IoResult<Vec<u8>> {
573 Ok(Vec::new())
574 }
575}
576
577impl AsyncRead for MockTlsStream {
578 fn poll_read(
579 mut self: Pin<&mut Self>,
580 cx: &mut Context<'_>,
581 buf: &mut [u8],
582 ) -> Poll<IoResult<usize>> {
583 Pin::new(&mut self.stream).poll_read(cx, buf)
584 }
585}
586impl AsyncWrite for MockTlsStream {
587 fn poll_write(
588 mut self: Pin<&mut Self>,
589 cx: &mut Context<'_>,
590 buf: &[u8],
591 ) -> Poll<IoResult<usize>> {
592 Pin::new(&mut self.stream).poll_write(cx, buf)
593 }
594 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
595 Pin::new(&mut self.stream).poll_flush(cx)
596 }
597 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
598 Pin::new(&mut self.stream).poll_close(cx)
599 }
600}
601
602impl StreamOps for MockTlsStream {
603 fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
604 Err(std::io::Error::new(
605 std::io::ErrorKind::Unsupported,
606 "not supported on non-StreamOps stream!",
607 ))
608 }
609
610 fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
611 Box::new(tor_rtcompat::NoOpStreamOpsHandle::default())
612 }
613}
614
615#[derive(Clone, Error, Debug)]
617#[non_exhaustive]
618pub enum MockNetError {
619 #[error("Invalid operation on mock network")]
621 BadOp,
622}
623
624fn err(k: ErrorKind) -> IoError {
626 IoError::new(k, MockNetError::BadOp)
627}
628
629#[cfg(all(test, not(miri)))] mod test {
631 #![allow(clippy::bool_assert_comparison)]
633 #![allow(clippy::clone_on_copy)]
634 #![allow(clippy::dbg_macro)]
635 #![allow(clippy::mixed_attributes_style)]
636 #![allow(clippy::print_stderr)]
637 #![allow(clippy::print_stdout)]
638 #![allow(clippy::single_char_pattern)]
639 #![allow(clippy::unwrap_used)]
640 #![allow(clippy::unchecked_time_subtraction)]
641 #![allow(clippy::useless_vec)]
642 #![allow(clippy::needless_pass_by_value)]
643 use super::*;
645 use futures::io::{AsyncReadExt, AsyncWriteExt};
646 use tor_rtcompat::test_with_all_runtimes;
647
648 fn client_pair() -> (MockNetProvider, MockNetProvider) {
649 let net = MockNetwork::new();
650 let client1 = net
651 .builder()
652 .add_address("192.0.2.55".parse().unwrap())
653 .provider();
654 let client2 = net
655 .builder()
656 .add_address("198.51.100.7".parse().unwrap())
657 .provider();
658
659 (client1, client2)
660 }
661
662 #[test]
663 fn end_to_end() {
664 test_with_all_runtimes!(|_rt| async {
665 let (client1, client2) = client_pair();
666 let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
667 let address = lis.local_addr()?;
668
669 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
670 async {
671 let mut conn = client1.connect(&address).await?;
672 conn.write_all(b"This is totally a network.").await?;
673 conn.close().await?;
674
675 let a2 = "192.0.2.200:99".parse().unwrap();
677 let cant_connect = client1.connect(&a2).await;
678 assert!(cant_connect.is_err());
679 Ok(())
680 },
681 async {
682 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
683 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
684 let mut inp = Vec::new();
685 conn.read_to_end(&mut inp).await?;
686 assert_eq!(&inp[..], &b"This is totally a network."[..]);
687 Ok(())
688 }
689 );
690 r1?;
691 r2?;
692 IoResult::Ok(())
693 });
694 }
695
696 #[test]
697 fn pick_listener_addr() -> IoResult<()> {
698 let net = MockNetwork::new();
699 let ip4 = "192.0.2.55".parse().unwrap();
700 let ip6 = "2001:db8::7".parse().unwrap();
701 let client = net.builder().add_address(ip4).add_address(ip6).provider();
702
703 let a1 = client.get_listener_addr(&"0.0.0.0:99".parse().unwrap())?;
705 assert_eq!(a1.ip(), ip4);
706 assert_eq!(a1.port(), 99);
707 let a2 = client.get_listener_addr(&"192.0.2.55:100".parse().unwrap())?;
708 assert_eq!(a2.ip(), ip4);
709 assert_eq!(a2.port(), 100);
710 let a3 = client.get_listener_addr(&"192.0.2.55:0".parse().unwrap())?;
711 assert_eq!(a3.ip(), ip4);
712 assert!(a3.port() != 0);
713 let a4 = client.get_listener_addr(&"0.0.0.0:0".parse().unwrap())?;
714 assert_eq!(a4.ip(), ip4);
715 assert!(a4.port() != 0);
716 assert!(a4.port() != a3.port());
717 let a5 = client.get_listener_addr(&"[::]:99".parse().unwrap())?;
718 assert_eq!(a5.ip(), ip6);
719 assert_eq!(a5.port(), 99);
720 let a6 = client.get_listener_addr(&"[2001:db8::7]:100".parse().unwrap())?;
721 assert_eq!(a6.ip(), ip6);
722 assert_eq!(a6.port(), 100);
723
724 let e1 = client.get_listener_addr(&"192.0.2.56:0".parse().unwrap());
726 let e2 = client.get_listener_addr(&"[2001:db8::8]:0".parse().unwrap());
727 assert!(e1.is_err());
728 assert!(e2.is_err());
729
730 IoResult::Ok(())
731 }
732
733 #[test]
734 fn listener_stream() {
735 test_with_all_runtimes!(|_rt| async {
736 let (client1, client2) = client_pair();
737
738 let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
739 let address = lis.local_addr()?;
740 let mut incoming = lis.incoming();
741
742 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
743 async {
744 for _ in 0..3_u8 {
745 let mut c = client1.connect(&address).await?;
746 c.close().await?;
747 }
748 Ok(())
749 },
750 async {
751 for _ in 0..3_u8 {
752 let (mut c, a) = incoming.next().await.unwrap()?;
753 let mut v = Vec::new();
754 let _ = c.read_to_end(&mut v).await?;
755 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
756 }
757 Ok(())
758 }
759 );
760 r1?;
761 r2?;
762 IoResult::Ok(())
763 });
764 }
765
766 #[test]
767 fn tls_basics() {
768 let (client1, client2) = client_pair();
769 let cert = b"I am certified for something I assure you.";
770
771 test_with_all_runtimes!(|_rt| async {
772 let lis = client2
773 .listen_tls(&"0.0.0.0:0".parse().unwrap(), cert[..].into())
774 .unwrap();
775 let address = lis.local_addr().unwrap();
776
777 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
778 async {
779 let connector = client1.tls_connector();
780 let conn = client1.connect(&address).await?;
781 let mut conn = connector
782 .negotiate_unvalidated(conn, "zombo.example.com")
783 .await?;
784 assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
785 conn.write_all(b"This is totally encrypted.").await?;
786 let mut v = Vec::new();
787 conn.read_to_end(&mut v).await?;
788 conn.close().await?;
789 assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
790 Ok(())
791 },
792 async {
793 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
794 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
795 let mut inp = [0_u8; 26];
796 conn.read_exact(&mut inp[..]).await?;
797 assert_eq!(&inp[..], &b"This is totally encrypted."[..]);
798 conn.write_all(b"Yup, your secrets is safe").await?;
799 Ok(())
800 }
801 );
802 r1?;
803 r2?;
804 IoResult::Ok(())
805 });
806 }
807}