1#![forbid(unsafe_code)] use super::MockNetRuntime;
11use super::io::{LocalStream, stream_pair};
12use crate::util::mpsc_channel;
13use core::fmt;
14use std::borrow::Cow;
15use tor_rtcompat::tls::{TlsAcceptorSettings, TlsConnector};
16use tor_rtcompat::{
17 CertifiedConn, NetStreamListener, NetStreamProvider, Runtime, StreamOps, TlsProvider,
18};
19use tor_rtcompat::{UdpProvider, UdpSocket};
20
21use async_trait::async_trait;
22use futures::FutureExt;
23use futures::channel::mpsc;
24use futures::io::{AsyncRead, AsyncWrite};
25use futures::lock::Mutex as AsyncMutex;
26use futures::sink::SinkExt;
27use futures::stream::{Stream, StreamExt};
28use std::collections::HashMap;
29use std::fmt::Formatter;
30use std::io::{self, Error as IoError, ErrorKind, Result as IoResult};
31use std::net::{IpAddr, SocketAddr};
32use std::pin::Pin;
33use std::sync::atomic::{AtomicU16, Ordering};
34use std::sync::{Arc, Mutex};
35use std::task::{Context, Poll};
36use thiserror::Error;
37use void::Void;
38
39type ConnSender = mpsc::Sender<(LocalStream, SocketAddr)>;
42type ConnReceiver = mpsc::Receiver<(LocalStream, SocketAddr)>;
44
45#[derive(Default)]
52pub struct MockNetwork {
53 listening: Mutex<HashMap<SocketAddr, AddrBehavior>>,
55}
56
57#[derive(Clone)]
59struct ListenerEntry {
60 send: ConnSender,
63
64 tls_cert: Option<Vec<u8>>,
67}
68
69#[derive(Clone)]
71enum AddrBehavior {
72 Listener(ListenerEntry),
74 Timeout,
76}
77
78#[derive(Clone)]
116pub struct MockNetProvider {
117 inner: Arc<MockNetProviderInner>,
122}
123
124impl fmt::Debug for MockNetProvider {
125 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("MockNetProvider").finish_non_exhaustive()
127 }
128}
129
130struct MockNetProviderInner {
135 addrs: Vec<IpAddr>,
137 net: Arc<MockNetwork>,
139 next_port: AtomicU16,
144}
145
146pub struct MockNetListener {
150 addr: SocketAddr,
152 receiver: AsyncMutex<ConnReceiver>,
156}
157
158pub struct ProviderBuilder {
162 addrs: Vec<IpAddr>,
164 net: Arc<MockNetwork>,
166}
167
168impl Default for MockNetProvider {
169 fn default() -> Self {
170 Arc::new(MockNetwork::default()).builder().provider()
171 }
172}
173
174impl MockNetwork {
175 pub fn new() -> Arc<Self> {
177 Default::default()
178 }
179
180 pub fn builder(self: &Arc<Self>) -> ProviderBuilder {
193 ProviderBuilder {
194 addrs: vec![],
195 net: Arc::clone(self),
196 }
197 }
198
199 pub fn add_blackhole(&self, address: SocketAddr) -> IoResult<()> {
201 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
202 if listener_map.contains_key(&address) {
203 return Err(err(ErrorKind::AddrInUse));
204 }
205 listener_map.insert(address, AddrBehavior::Timeout);
206 Ok(())
207 }
208
209 async fn send_connection(
218 &self,
219 source_addr: SocketAddr,
220 target_addr: SocketAddr,
221 peer_stream: LocalStream,
222 ) -> IoResult<Option<Vec<u8>>> {
223 let entry = {
224 let listener_map = self.listening.lock().expect("Poisoned lock for listener");
225 listener_map.get(&target_addr).cloned()
226 };
227 match entry {
228 Some(AddrBehavior::Listener(mut entry)) => {
229 if entry.send.send((peer_stream, source_addr)).await.is_ok() {
230 return Ok(entry.tls_cert);
231 }
232 Err(err(ErrorKind::ConnectionRefused))
233 }
234 Some(AddrBehavior::Timeout) => futures::future::pending().await,
235 None => Err(err(ErrorKind::ConnectionRefused)),
236 }
237 }
238
239 fn add_listener(&self, addr: SocketAddr, tls_cert: Option<Vec<u8>>) -> IoResult<ConnReceiver> {
247 let mut listener_map = self.listening.lock().expect("Poisoned lock for listener");
248 if listener_map.contains_key(&addr) {
249 return Err(err(ErrorKind::AddrInUse));
251 }
252
253 let (send, recv) = mpsc_channel(16);
254
255 let entry = ListenerEntry { send, tls_cert };
256
257 listener_map.insert(addr, AddrBehavior::Listener(entry));
258
259 Ok(recv)
260 }
261}
262
263impl ProviderBuilder {
264 pub fn add_address(&mut self, addr: IpAddr) -> &mut Self {
266 self.addrs.push(addr);
267 self
268 }
269 pub fn runtime<R: Runtime>(&self, runtime: R) -> super::MockNetRuntime<R> {
272 MockNetRuntime::new(runtime, self.provider())
273 }
274 pub fn provider(&self) -> MockNetProvider {
276 let inner = MockNetProviderInner {
277 addrs: self.addrs.clone(),
278 net: Arc::clone(&self.net),
279 next_port: AtomicU16::new(1),
280 };
281 MockNetProvider {
282 inner: Arc::new(inner),
283 }
284 }
285}
286
287impl NetStreamListener for MockNetListener {
288 type Stream = LocalStream;
289
290 type Incoming = Self;
291
292 fn local_addr(&self) -> IoResult<SocketAddr> {
293 Ok(self.addr)
294 }
295
296 fn incoming(self) -> Self {
297 self
298 }
299}
300
301impl Stream for MockNetListener {
302 type Item = IoResult<(LocalStream, SocketAddr)>;
303 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
304 let mut recv = futures::ready!(self.receiver.lock().poll_unpin(cx));
305 match recv.poll_next_unpin(cx) {
306 Poll::Pending => Poll::Pending,
307 Poll::Ready(None) => Poll::Ready(None),
308 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
309 }
310 }
311}
312
313#[derive(Debug)]
315#[non_exhaustive]
316pub struct MockUdpSocket {
317 void: Void,
322}
323
324#[async_trait]
325impl UdpProvider for MockNetProvider {
326 type UdpSocket = MockUdpSocket;
327
328 async fn bind(&self, addr: &SocketAddr) -> IoResult<MockUdpSocket> {
329 let _ = addr; Err(io::ErrorKind::Unsupported.into())
331 }
332}
333
334#[allow(clippy::diverging_sub_expression)] #[async_trait]
336impl UdpSocket for MockUdpSocket {
337 async fn recv(&self, buf: &mut [u8]) -> IoResult<(usize, SocketAddr)> {
338 void::unreachable((self.void, buf).0)
342 }
343 async fn send(&self, buf: &[u8], target: &SocketAddr) -> IoResult<usize> {
344 void::unreachable((self.void, buf, target).0)
345 }
346 fn local_addr(&self) -> IoResult<SocketAddr> {
347 void::unreachable(self.void)
348 }
349}
350
351impl MockNetProvider {
352 fn get_addr_in_family(&self, other: &IpAddr) -> Option<IpAddr> {
355 self.inner
356 .addrs
357 .iter()
358 .find(|a| a.is_ipv4() == other.is_ipv4())
359 .copied()
360 }
361
362 fn arbitrary_port(&self) -> u16 {
370 let next = self.inner.next_port.fetch_add(1, Ordering::Relaxed);
371 assert!(next != 0);
372 next
373 }
374
375 fn get_origin_addr_for(&self, addr: &SocketAddr) -> IoResult<SocketAddr> {
382 let my_addr = self
383 .get_addr_in_family(&addr.ip())
384 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?;
385 Ok(SocketAddr::new(my_addr, self.arbitrary_port()))
386 }
387
388 fn get_listener_addr(&self, spec: &SocketAddr) -> IoResult<SocketAddr> {
398 let ipaddr = {
399 let ip = spec.ip();
400 if ip.is_unspecified() {
401 self.get_addr_in_family(&ip)
402 .ok_or_else(|| err(ErrorKind::AddrNotAvailable))?
403 } else if self.inner.addrs.iter().any(|a| a == &ip) {
404 ip
405 } else {
406 return Err(err(ErrorKind::AddrNotAvailable));
407 }
408 };
409 let port = {
410 if spec.port() == 0 {
411 self.arbitrary_port()
412 } else {
413 spec.port()
414 }
415 };
416
417 Ok(SocketAddr::new(ipaddr, port))
418 }
419
420 pub fn listen_tls(&self, addr: &SocketAddr, tls_cert: Vec<u8>) -> IoResult<MockNetListener> {
426 let addr = self.get_listener_addr(addr)?;
427
428 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, Some(tls_cert))?);
429
430 Ok(MockNetListener { addr, receiver })
431 }
432}
433
434#[async_trait]
435impl NetStreamProvider for MockNetProvider {
436 type Stream = LocalStream;
437 type Listener = MockNetListener;
438 type ListenOptions = tor_rtcompat::TcpListenOptions;
439
440 async fn connect(&self, addr: &SocketAddr) -> IoResult<LocalStream> {
441 let my_addr = self.get_origin_addr_for(addr)?;
442 let (mut mine, theirs) = stream_pair();
443
444 let cert = self
445 .inner
446 .net
447 .send_connection(my_addr, *addr, theirs)
448 .await?;
449
450 mine.tls_cert = cert;
451
452 Ok(mine)
453 }
454
455 async fn listen(
456 &self,
457 addr: &SocketAddr,
458 _options: &Self::ListenOptions,
459 ) -> IoResult<Self::Listener> {
460 let addr = self.get_listener_addr(addr)?;
461
462 let receiver = AsyncMutex::new(self.inner.net.add_listener(addr, None)?);
463
464 Ok(MockNetListener { addr, receiver })
465 }
466}
467
468#[async_trait]
469impl TlsProvider<LocalStream> for MockNetProvider {
470 type Connector = MockTlsConnector;
471 type TlsStream = MockTlsStream;
472 type Acceptor = MockTlsAcceptor;
473 type TlsServerStream = MockTlsStream;
474
475 fn tls_connector(&self) -> MockTlsConnector {
476 MockTlsConnector {}
477 }
478 fn tls_acceptor(&self, settings: TlsAcceptorSettings) -> IoResult<MockTlsAcceptor> {
479 Ok(MockTlsAcceptor {
480 own_cert: settings.cert_der().to_vec(),
481 })
482 }
483
484 fn supports_keying_material_export(&self) -> bool {
485 false
486 }
487}
488
489#[derive(Clone)]
494#[non_exhaustive]
495pub struct MockTlsConnector;
496
497#[derive(Clone)]
502#[non_exhaustive]
503pub struct MockTlsAcceptor {
504 own_cert: Vec<u8>,
506}
507
508pub struct MockTlsStream {
517 peer_cert: Option<Vec<u8>>,
519 own_cert: Option<Vec<u8>>,
521 stream: LocalStream,
523}
524
525#[async_trait]
526impl TlsConnector<LocalStream> for MockTlsConnector {
527 type Conn = MockTlsStream;
528
529 async fn negotiate_unvalidated(
530 &self,
531 mut stream: LocalStream,
532 _sni_hostname: &str,
533 ) -> IoResult<MockTlsStream> {
534 let peer_cert = stream.tls_cert.take();
535
536 if peer_cert.is_none() {
537 return Err(std::io::Error::other("attempted to wrap non-TLS stream!"));
538 }
539
540 Ok(MockTlsStream {
541 peer_cert,
542 own_cert: None,
543 stream,
544 })
545 }
546}
547
548#[async_trait]
549impl TlsConnector<LocalStream> for MockTlsAcceptor {
550 type Conn = MockTlsStream;
551
552 async fn negotiate_unvalidated(
553 &self,
554 stream: LocalStream,
555 _sni_hostname: &str,
556 ) -> IoResult<MockTlsStream> {
557 Ok(MockTlsStream {
558 peer_cert: None,
559 own_cert: Some(self.own_cert.clone()),
560 stream,
561 })
562 }
563}
564
565impl CertifiedConn for MockTlsStream {
566 fn peer_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
567 Ok(self.peer_cert.clone().map(Cow::from))
568 }
569
570 fn own_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
571 Ok(self.own_cert.clone().map(Cow::from))
572 }
573 fn export_keying_material(
574 &self,
575 _len: usize,
576 _label: &[u8],
577 _context: Option<&[u8]>,
578 ) -> IoResult<Vec<u8>> {
579 Ok(Vec::new())
580 }
581}
582
583impl AsyncRead for MockTlsStream {
584 fn poll_read(
585 mut self: Pin<&mut Self>,
586 cx: &mut Context<'_>,
587 buf: &mut [u8],
588 ) -> Poll<IoResult<usize>> {
589 Pin::new(&mut self.stream).poll_read(cx, buf)
590 }
591}
592impl AsyncWrite for MockTlsStream {
593 fn poll_write(
594 mut self: Pin<&mut Self>,
595 cx: &mut Context<'_>,
596 buf: &[u8],
597 ) -> Poll<IoResult<usize>> {
598 Pin::new(&mut self.stream).poll_write(cx, buf)
599 }
600 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
601 Pin::new(&mut self.stream).poll_flush(cx)
602 }
603 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
604 Pin::new(&mut self.stream).poll_close(cx)
605 }
606}
607
608impl StreamOps for MockTlsStream {
609 fn set_tcp_notsent_lowat(&self, _notsent_lowat: u32) -> IoResult<()> {
610 Err(std::io::Error::new(
611 std::io::ErrorKind::Unsupported,
612 "not supported on non-StreamOps stream!",
613 ))
614 }
615
616 fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
617 Box::new(tor_rtcompat::NoOpStreamOpsHandle::default())
618 }
619}
620
621#[derive(Clone, Error, Debug)]
623#[non_exhaustive]
624pub enum MockNetError {
625 #[error("Invalid operation on mock network")]
627 BadOp,
628}
629
630fn err(k: ErrorKind) -> IoError {
632 IoError::new(k, MockNetError::BadOp)
633}
634
635#[cfg(all(test, not(miri)))] mod test {
637 #![allow(clippy::bool_assert_comparison)]
639 #![allow(clippy::clone_on_copy)]
640 #![allow(clippy::dbg_macro)]
641 #![allow(clippy::mixed_attributes_style)]
642 #![allow(clippy::print_stderr)]
643 #![allow(clippy::print_stdout)]
644 #![allow(clippy::single_char_pattern)]
645 #![allow(clippy::unwrap_used)]
646 #![allow(clippy::unchecked_time_subtraction)]
647 #![allow(clippy::useless_vec)]
648 #![allow(clippy::needless_pass_by_value)]
649 use super::*;
651 use futures::io::{AsyncReadExt, AsyncWriteExt};
652 use tor_rtcompat::test_with_all_runtimes;
653
654 fn client_pair() -> (MockNetProvider, MockNetProvider) {
655 let net = MockNetwork::new();
656 let client1 = net
657 .builder()
658 .add_address("192.0.2.55".parse().unwrap())
659 .provider();
660 let client2 = net
661 .builder()
662 .add_address("198.51.100.7".parse().unwrap())
663 .provider();
664
665 (client1, client2)
666 }
667
668 #[test]
669 fn end_to_end() {
670 test_with_all_runtimes!(|_rt| async {
671 let (client1, client2) = client_pair();
672 let listen_options = Default::default();
673 let lis = client2
674 .listen(&"0.0.0.0:99".parse().unwrap(), &listen_options)
675 .await?;
676 let address = lis.local_addr()?;
677
678 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
679 async {
680 let mut conn = client1.connect(&address).await?;
681 conn.write_all(b"This is totally a network.").await?;
682 conn.close().await?;
683
684 let a2 = "192.0.2.200:99".parse().unwrap();
686 let cant_connect = client1.connect(&a2).await;
687 assert!(cant_connect.is_err());
688 Ok(())
689 },
690 async {
691 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
692 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
693 let mut inp = Vec::new();
694 conn.read_to_end(&mut inp).await?;
695 assert_eq!(&inp[..], &b"This is totally a network."[..]);
696 Ok(())
697 }
698 );
699 r1?;
700 r2?;
701 IoResult::Ok(())
702 });
703 }
704
705 #[test]
706 fn pick_listener_addr() -> IoResult<()> {
707 let net = MockNetwork::new();
708 let ip4 = "192.0.2.55".parse().unwrap();
709 let ip6 = "2001:db8::7".parse().unwrap();
710 let client = net.builder().add_address(ip4).add_address(ip6).provider();
711
712 let a1 = client.get_listener_addr(&"0.0.0.0:99".parse().unwrap())?;
714 assert_eq!(a1.ip(), ip4);
715 assert_eq!(a1.port(), 99);
716 let a2 = client.get_listener_addr(&"192.0.2.55:100".parse().unwrap())?;
717 assert_eq!(a2.ip(), ip4);
718 assert_eq!(a2.port(), 100);
719 let a3 = client.get_listener_addr(&"192.0.2.55:0".parse().unwrap())?;
720 assert_eq!(a3.ip(), ip4);
721 assert!(a3.port() != 0);
722 let a4 = client.get_listener_addr(&"0.0.0.0:0".parse().unwrap())?;
723 assert_eq!(a4.ip(), ip4);
724 assert!(a4.port() != 0);
725 assert!(a4.port() != a3.port());
726 let a5 = client.get_listener_addr(&"[::]:99".parse().unwrap())?;
727 assert_eq!(a5.ip(), ip6);
728 assert_eq!(a5.port(), 99);
729 let a6 = client.get_listener_addr(&"[2001:db8::7]:100".parse().unwrap())?;
730 assert_eq!(a6.ip(), ip6);
731 assert_eq!(a6.port(), 100);
732
733 let e1 = client.get_listener_addr(&"192.0.2.56:0".parse().unwrap());
735 let e2 = client.get_listener_addr(&"[2001:db8::8]:0".parse().unwrap());
736 assert!(e1.is_err());
737 assert!(e2.is_err());
738
739 IoResult::Ok(())
740 }
741
742 #[test]
743 fn listener_stream() {
744 test_with_all_runtimes!(|_rt| async {
745 let (client1, client2) = client_pair();
746
747 let listen_options = Default::default();
748 let lis = client2
749 .listen(&"0.0.0.0:99".parse().unwrap(), &listen_options)
750 .await?;
751 let address = lis.local_addr()?;
752 let mut incoming = lis.incoming();
753
754 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
755 async {
756 for _ in 0..3_u8 {
757 let mut c = client1.connect(&address).await?;
758 c.close().await?;
759 }
760 Ok(())
761 },
762 async {
763 for _ in 0..3_u8 {
764 let (mut c, a) = incoming.next().await.unwrap()?;
765 let mut v = Vec::new();
766 let _ = c.read_to_end(&mut v).await?;
767 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
768 }
769 Ok(())
770 }
771 );
772 r1?;
773 r2?;
774 IoResult::Ok(())
775 });
776 }
777
778 #[test]
779 fn tls_basics() {
780 let (client1, client2) = client_pair();
781 let cert = b"I am certified for something I assure you.";
782
783 test_with_all_runtimes!(|_rt| async {
784 let lis = client2
785 .listen_tls(&"0.0.0.0:0".parse().unwrap(), cert[..].into())
786 .unwrap();
787 let address = lis.local_addr().unwrap();
788
789 let (r1, r2): (IoResult<()>, IoResult<()>) = futures::join!(
790 async {
791 let connector = client1.tls_connector();
792 let conn = client1.connect(&address).await?;
793 let mut conn = connector
794 .negotiate_unvalidated(conn, "zombo.example.com")
795 .await?;
796 assert_eq!(&conn.peer_certificate()?.unwrap()[..], &cert[..]);
797 conn.write_all(b"This is totally encrypted.").await?;
798 let mut v = Vec::new();
799 conn.read_to_end(&mut v).await?;
800 conn.close().await?;
801 assert_eq!(v[..], b"Yup, your secrets is safe"[..]);
802 Ok(())
803 },
804 async {
805 let (mut conn, a) = lis.incoming().next().await.expect("closed?")?;
806 assert_eq!(a.ip(), "192.0.2.55".parse::<IpAddr>().unwrap());
807 let mut inp = [0_u8; 26];
808 conn.read_exact(&mut inp[..]).await?;
809 assert_eq!(&inp[..], &b"This is totally encrypted."[..]);
810 conn.write_all(b"Yup, your secrets is safe").await?;
811 Ok(())
812 }
813 );
814 r1?;
815 r2?;
816 IoResult::Ok(())
817 });
818 }
819}