1#![forbid(unsafe_code)] use super::MockNetRuntime;
11use super::io::{LocalStream, stream_pair};
12use crate::util::mpsc_channel;
13use core::fmt;
14use std::borrow::Cow;
15use tor_rtcompat::tls::{TlsAcceptorSettings, TlsConnector};
16use tor_rtcompat::{
17 CertifiedConn, NetStreamListener, NetStreamProvider, Runtime, StreamOps, TlsProvider,
18};
19use tor_rtcompat::{UdpProvider, UdpSocket};
20
21use async_trait::async_trait;
22use futures::FutureExt;
23use futures::channel::mpsc;
24use futures::io::{AsyncRead, AsyncWrite};
25use futures::lock::Mutex as AsyncMutex;
26use futures::sink::SinkExt;
27use futures::stream::{Stream, StreamExt};
28use std::collections::HashMap;
29use std::fmt::Formatter;
30use std::io::{self, Error as IoError, ErrorKind, Result as IoResult};
31use std::net::{IpAddr, SocketAddr};
32use std::pin::Pin;
33use std::sync::atomic::{AtomicU16, Ordering};
34use std::sync::{Arc, Mutex};
35use std::task::{Context, Poll};
36use thiserror::Error;
37use void::Void;
38
39type ConnSender = mpsc::Sender<(LocalStream, SocketAddr)>;
42type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>;
44
45#[derive(Default)]
52pub struct MockNetwork {
53 listening: Mutex<HashMap<SocketAddr, AddrBehavior>>,
55}
56
57#[derive(Clone)]
59struct ListenerEntry {
60 send: ConnSender,
63
64 tls_cert: Option<Vec<u8>>,
67}
68
69#[derive(Clone)]
71enum AddrBehavior {
72 Listener(ListenerEntry),
74 Timeout,
76}
77
78#[derive(Clone)]
116pub struct MockNetProvider {
117 inner: Arc<MockNetProviderInner>,
122}
123
124impl fmt::Debug for MockNetProvider {
125 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("MockNetProvider").finish_non_exhaustive()
127 }
128}
129
130struct MockNetProviderInner {
135 addrs: Vec<IpAddr>,
137 net: Arc<MockNetwork>,
139 next_port: AtomicU16,
144}
145
146pub struct MockNetListener {
150 addr: SocketAddr,
152 receiver: AsyncMutex<ConnReceiver>,
156}
157
158pub struct ProviderBuilder {
162 addrs: Vec<IpAddr>,
164 net: Arc<MockNetwork>,
166}
167
168impl Default for MockNetProvider {
169 fn default() -> Self {
170 Arc::new(MockNetwork::default()).builder().provider()
171 }
172}
173
174impl MockNetwork {
175 pub fn new() -> Arc<Self> {
177 Default::default()
178 }
179
180 pub fn builder(self: &Arc<Self>) -> ProviderBuilder {
193 ProviderBuilder {
194 addrs: vec![],
195 net: Arc::clone(self),
196 }
197 }
198
199 pub fn add_blackhole(&self, address: SocketAddr) -> IoResult<()> {
201 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
202 if listener_map.contains_key(&address) {
203 return Err(err(ErrorKind::AddrInUse));
204 }
205 listener_map.insert(address, AddrBehavior::Timeout);
206 Ok(())
207 }
208
209 async fn send_connection(
218 &self,
219 source_addr: SocketAddr,
220 target_addr: SocketAddr,
221 peer_stream: LocalStream,
222 ) -> IoResult<Option<Vec<u8>>> {
223 let entry = {
224 let listener_map = self.listening.lock().expect("Poisoned lock for listener");
225 listener_map.get(&target_addr).cloned()
226 };
227 match entry {
228 Some(AddrBehavior::Listener(mut entry)) => {
229 if entry.send.send((peer_stream, source_addr)).await.is_ok() {
230 return Ok(entry.tls_cert);
231 }
232 Err(err(ErrorKind::ConnectionRefused))
233 }
234 Some(AddrBehavior::Timeout) => futures::future::pending().await,
235 None => Err(err(ErrorKind::ConnectionRefused)),
236 }
237 }
238
239 fn add_listener(&self, addr: SocketAddr, tls_cert: Option<Vec<u8>>) -> IoResult<ConnReceiver> {
247 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
248 if listener_map.contains_key(&addr) {
249 return Err(err(ErrorKind::AddrInUse));
251 }
252
253 let (send, recv) = mpsc_channel(16);
254
255 let entry = ListenerEntry { send, tls_cert };
256
257 listener_map.insert(addr, AddrBehavior::Listener(entry));
258
259 Ok(recv)
260 }
261}
262
263impl ProviderBuilder {
264 pub fn add_address(&mut self, addr: IpAddr) -> &mut Self {
266 self.addrs.push(addr);
267 self
268 }
269 pub fn runtime<R: Runtime>(&self, runtime: R) -> super::MockNetRuntime<R> {
272 MockNetRuntime::new(runtime, self.provider())
273 }
274 pub fn provider(&self) -> MockNetProvider {
276 let inner = MockNetProviderInner {
277 addrs: self.addrs.clone(),
278 net: Arc::clone(&self.net),
279 next_port: AtomicU16::new(1),
280 };
281 MockNetProvider {
282 inner: Arc::new(inner),
283 }
284 }
285}
286
287impl NetStreamListener for MockNetListener {
288 type Stream = LocalStream;
289
290 type Incoming = Self;
291
292 fn local_addr(&self) -> IoResult<SocketAddr> {
293 Ok(self.addr)
294 }
295
296 fn incoming(self) -> Self {
297 self
298 }
299}
300
301impl Stream for MockNetListener {
302 type Item = IoResult<(LocalStream, SocketAddr)>;
303 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
304 let mut recv = futures::ready!(self.receiver.lock().poll_unpin(cx));
305 match recv.poll_next_unpin(cx) {
306 Poll::Pending => Poll::Pending,
307 Poll::Ready(None) => Poll::Ready(None),
308 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
309 }
310 }
311}
312
313#[derive(Debug)]
315#[non_exhaustive]
316pub struct MockUdpSocket {
317 void: Void,
322}
323
324#[async_trait]
325impl UdpProvider for MockNetProvider {
326 type UdpSocket = MockUdpSocket;
327
328 async fn bind(&self, addr: &SocketAddr) -> IoResult<MockUdpSocket> {
329 let _ = addr; Err(io::ErrorKind::Unsupported.into())
331 }
332}
333
334#[allow(clippy::diverging_sub_expression)] #[async_trait]
336impl UdpSocket for MockUdpSocket {
337 async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
338 void::unreachable((self.void, buf).0)
342 }
343 async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
344 void::unreachable((self.void, buf, target).0)
345 }
346 fn local_addr(&self) -> IoResult<SocketAddr> {
347 void::unreachable(self.void)
348 }
349}
350
351impl MockNetProvider {
352 fn get_addr_in_family(&self, other: &IpAddr) -> Option<IpAddr> {
355 self.inner
356 .addrs
357 .iter()
358 .find(|a| a.is_ipv4() == other.is_ipv4())
359 .copied()
360 }
361
362 fn arbitrary_port(&self) -> u16 {
370 let next = self.inner.next_port.fetch_add(1, Ordering::Relaxed);
371 assert!(next != 0);
372 next
373 }
374
375 fn get_origin_addr_for(&self, addr: &SocketAddr) -> IoResult<SocketAddr> {
382 let my_addr = self
383 .get_addr_in_family(&addr.ip())
384 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?;
385 Ok(SocketAddr::new(my_addr, self.arbitrary_port()))
386 }
387
388 fn get_listener_addr(&self, spec: &SocketAddr) -> IoResult<SocketAddr> {
398 let ipaddr = {
399 let ip = spec.ip();
400 if ip.is_unspecified() {
401 self.get_addr_in_family(&ip)
402 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?
403 } else if self.inner.addrs.iter().any(|a| a == &ip) {
404 ip
405 } else {
406 return Err(err(ErrorKind::AddrNotAvailable));
407 }
408 };
409 let port = {
410 if spec.port() == 0 {
411 self.arbitrary_port()
412 } else {
413 spec.port()
414 }
415 };
416
417 Ok(SocketAddr::new(ipaddr, port))
418 }
419
420 pub fn listen_tls(&self, addr: &SocketAddr, tls_cert: Vec<u8>) -> IoResult<MockNetListener> {
426 let addr = self.get_listener_addr(addr)?;
427
428 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, Some(tls_cert))?);
429
430 Ok(MockNetListener { addr, receiver })
431 }
432}
433
434#[async_trait]
435impl NetStreamProvider for MockNetProvider {
436 type Stream = LocalStream;
437 type Listener = MockNetListener;
438
439 async fn connect(&self, addr: &SocketAddr) -> IoResult<LocalStream> {
440 let my_addr = self.get_origin_addr_for(addr)?;
441 let (mut mine, theirs) = stream_pair();
442
443 let cert = self
444 .inner
445 .net
446 .send_connection(my_addr, *addr, theirs)
447 .await?;
448
449 mine.tls_cert = cert;
450
451 Ok(mine)
452 }
453
454 async fn listen(&self, addr: &SocketAddr) -> IoResult<Self::Listener> {
455 let addr = self.get_listener_addr(addr)?;
456
457 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, None)?);
458
459 Ok(MockNetListener { addr, receiver })
460 }
461}
462
463#[async_trait]
464impl TlsProvider<LocalStream> for MockNetProvider {
465 type Connector = MockTlsConnector;
466 type TlsStream = MockTlsStream;
467 type Acceptor = MockTlsAcceptor;
468 type TlsServerStream = MockTlsStream;
469
470 fn tls_connector(&self) -> MockTlsConnector {
471 MockTlsConnector {}
472 }
473 fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<MockTlsAcceptor> {
474 Ok(MockTlsAcceptor {
475 own_cert: settings.cert_der().to_vec(),
476 })
477 }
478
479 fn supports_keying_material_export(&self) -> bool {
480 false
481 }
482}
483
484#[derive(Clone)]
489#[non_exhaustive]
490pub struct MockTlsConnector;
491
492#[derive(Clone)]
497#[non_exhaustive]
498pub struct MockTlsAcceptor {
499 own_cert: Vec<u8>,
501}
502
503pub struct MockTlsStream {
512 peer_cert: Option<Vec<u8>>,
514 own_cert: Option<Vec<u8>>,
516 stream: LocalStream,
518}
519
520#[async_trait]
521impl TlsConnector<LocalStream> for MockTlsConnector {
522 type Conn = MockTlsStream;
523
524 async fn negotiate_unvalidated(
525 &self,
526 mut stream: LocalStream,
527 _sni_hostname: &str,
528 ) -> IoResult<MockTlsStream> {
529 let peer_cert = stream.tls_cert.take();
530
531 if peer_cert.is_none() {
532 return Err(std::io::Error::other("attempted to wrap non-TLS stream!"));
533 }
534
535 Ok(MockTlsStream {
536 peer_cert,
537 own_cert: None,
538 stream,
539 })
540 }
541}
542
543#[async_trait]
544impl TlsConnector<LocalStream> for MockTlsAcceptor {
545 type Conn = MockTlsStream;
546
547 async fn negotiate_unvalidated(
548 &self,
549 stream: LocalStream,
550 _sni_hostname: &str,
551 ) -> IoResult<MockTlsStream> {
552 Ok(MockTlsStream {
553 peer_cert: None,
554 own_cert: Some(self.own_cert.clone()),
555 stream,
556 })
557 }
558}
559
560impl CertifiedConn for MockTlsStream {
561 fn peer_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
562 Ok(self.peer_cert.clone().map(Cow::from))
563 }
564
565 fn own_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
566 Ok(self.own_cert.clone().map(Cow::from))
567 }
568 fn export_keying_material(
569 &self,
570 _len: usize,
571 _label: &[u8],
572 _context: Option<&[u8]>,
573 ) -> IoResult<Vec<u8>> {
574 Ok(Vec::new())
575 }
576}
577
578impl AsyncRead for MockTlsStream {
579 fn poll_read(
580 mut self: Pin<&mut Self>,
581 cx: &mut Context<'_>,
582 buf: &mut [u8],
583 ) -> Poll<IoResult<usize>> {
584 Pin::new(&mut self.stream).poll_read(cx, buf)
585 }
586}
587impl AsyncWrite for MockTlsStream {
588 fn poll_write(
589 mut self: Pin<&mut Self>,
590 cx: &mut Context<'_>,
591 buf: &[u8],
592 ) -> Poll<IoResult<usize>> {
593 Pin::new(&mut self.stream).poll_write(cx, buf)
594 }
595 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
596 Pin::new(&mut self.stream).poll_flush(cx)
597 }
598 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
599 Pin::new(&mut self.stream).poll_close(cx)
600 }
601}
602
603impl StreamOps for MockTlsStream {
604 fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
605 Err(std::io::Error::new(
606 std::io::ErrorKind::Unsupported,
607 "not supported on non-StreamOps stream!",
608 ))
609 }
610
611 fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
612 Box::new(tor_rtcompat::NoOpStreamOpsHandle::default())
613 }
614}
615
616#[derive(Clone, Error, Debug)]
618#[non_exhaustive]
619pub enum MockNetError {
620 #[error("Invalid operation on mock network")]
622 BadOp,
623}
624
625fn err(k: ErrorKind) -> IoError {
627 IoError::new(k, MockNetError::BadOp)
628}
629
630#[cfg(all(test, not(miri)))] mod test {
632 #![allow(clippy::bool_assert_comparison)]
634 #![allow(clippy::clone_on_copy)]
635 #![allow(clippy::dbg_macro)]
636 #![allow(clippy::mixed_attributes_style)]
637 #![allow(clippy::print_stderr)]
638 #![allow(clippy::print_stdout)]
639 #![allow(clippy::single_char_pattern)]
640 #![allow(clippy::unwrap_used)]
641 #![allow(clippy::unchecked_time_subtraction)]
642 #![allow(clippy::useless_vec)]
643 #![allow(clippy::needless_pass_by_value)]
644 use super::*;
646 use futures::io::{AsyncReadExt, AsyncWriteExt};
647 use tor_rtcompat::test_with_all_runtimes;
648
649 fn client_pair() -> (MockNetProvider, MockNetProvider) {
650 let net = MockNetwork::new();
651 let client1 = net
652 .builder()
653 .add_address("192.0.2.55".parse().unwrap())
654 .provider();
655 let client2 = net
656 .builder()
657 .add_address("198.51.100.7".parse().unwrap())
658 .provider();
659
660 (client1, client2)
661 }
662
663 #[test]
664 fn end_to_end() {
665 test_with_all_runtimes!(|_rt| async {
666 let (client1, client2) = client_pair();
667 let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
668 let address = lis.local_addr()?;
669
670 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
671 async {
672 let mut conn = client1.connect(&address).await?;
673 conn.write_all(b"This is totally a network.").await?;
674 conn.close().await?;
675
676 let a2 = "192.0.2.200:99".parse().unwrap();
678 let cant_connect = client1.connect(&a2).await;
679 assert!(cant_connect.is_err());
680 Ok(())
681 },
682 async {
683 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
684 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
685 let mut inp = Vec::new();
686 conn.read_to_end(&mut inp).await?;
687 assert_eq!(&inp[..], &b"This is totally a network."[..]);
688 Ok(())
689 }
690 );
691 r1?;
692 r2?;
693 IoResult::Ok(())
694 });
695 }
696
697 #[test]
698 fn pick_listener_addr() -> IoResult<()> {
699 let net = MockNetwork::new();
700 let ip4 = "192.0.2.55".parse().unwrap();
701 let ip6 = "2001:db8::7".parse().unwrap();
702 let client = net.builder().add_address(ip4).add_address(ip6).provider();
703
704 let a1 = client.get_listener_addr(&"0.0.0.0:99".parse().unwrap())?;
706 assert_eq!(a1.ip(), ip4);
707 assert_eq!(a1.port(), 99);
708 let a2 = client.get_listener_addr(&"192.0.2.55:100".parse().unwrap())?;
709 assert_eq!(a2.ip(), ip4);
710 assert_eq!(a2.port(), 100);
711 let a3 = client.get_listener_addr(&"192.0.2.55:0".parse().unwrap())?;
712 assert_eq!(a3.ip(), ip4);
713 assert!(a3.port() != 0);
714 let a4 = client.get_listener_addr(&"0.0.0.0:0".parse().unwrap())?;
715 assert_eq!(a4.ip(), ip4);
716 assert!(a4.port() != 0);
717 assert!(a4.port() != a3.port());
718 let a5 = client.get_listener_addr(&"[::]:99".parse().unwrap())?;
719 assert_eq!(a5.ip(), ip6);
720 assert_eq!(a5.port(), 99);
721 let a6 = client.get_listener_addr(&"[2001:db8::7]:100".parse().unwrap())?;
722 assert_eq!(a6.ip(), ip6);
723 assert_eq!(a6.port(), 100);
724
725 let e1 = client.get_listener_addr(&"192.0.2.56:0".parse().unwrap());
727 let e2 = client.get_listener_addr(&"[2001:db8::8]:0".parse().unwrap());
728 assert!(e1.is_err());
729 assert!(e2.is_err());
730
731 IoResult::Ok(())
732 }
733
734 #[test]
735 fn listener_stream() {
736 test_with_all_runtimes!(|_rt| async {
737 let (client1, client2) = client_pair();
738
739 let lis = client2.listen(&"0.0.0.0:99".parse().unwrap()).await?;
740 let address = lis.local_addr()?;
741 let mut incoming = lis.incoming();
742
743 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
744 async {
745 for _ in 0..3_u8 {
746 let mut c = client1.connect(&address).await?;
747 c.close().await?;
748 }
749 Ok(())
750 },
751 async {
752 for _ in 0..3_u8 {
753 let (mut c, a) = incoming.next().await.unwrap()?;
754 let mut v = Vec::new();
755 let _ = c.read_to_end(&mut v).await?;
756 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
757 }
758 Ok(())
759 }
760 );
761 r1?;
762 r2?;
763 IoResult::Ok(())
764 });
765 }
766
767 #[test]
768 fn tls_basics() {
769 let (client1, client2) = client_pair();
770 let cert = b"I am certified for something I assure you.";
771
772 test_with_all_runtimes!(|_rt| async {
773 let lis = client2
774 .listen_tls(&"0.0.0.0:0".parse().unwrap(), cert[..].into())
775 .unwrap();
776 let address = lis.local_addr().unwrap();
777
778 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
779 async {
780 let connector = client1.tls_connector();
781 let conn = client1.connect(&address).await?;
782 let mut conn = connector
783 .negotiate_unvalidated(conn, "zombo.example.com")
784 .await?;
785 assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
786 conn.write_all(b"This is totally encrypted.").await?;
787 let mut v = Vec::new();
788 conn.read_to_end(&mut v).await?;
789 conn.close().await?;
790 assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
791 Ok(())
792 },
793 async {
794 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
795 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
796 let mut inp = [0_u8; 26];
797 conn.read_exact(&mut inp[..]).await?;
798 assert_eq!(&inp[..], &b"This is totally encrypted."[..]);
799 conn.write_all(b"Yup, your secrets is safe").await?;
800 Ok(())
801 }
802 );
803 r1?;
804 r2?;
805 IoResult::Ok(())
806 });
807 }
808}