1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
use openssl::error::ErrorStack; use openssl::ssl::{Ssl, SslContext, SslMethod, SslVersion}; use std::fmt::Debug; use std::io::{Read, Write}; use std::net::{ToSocketAddrs, TcpStream}; mod error; mod psk_providers; mod stream; #[cfg(test)] mod tests; pub use error::TunnelError; pub use psk_providers::{PskProvider, SimplePskProvider}; pub use stream::TunnelStream; pub fn connect_simple(addr: impl ToSocketAddrs, identity: &str, psk: &[u8]) -> Result<TunnelStream<TcpStream>, TunnelError> { let sock_addr = addr.to_socket_addrs() .map_err(|e| TunnelError::from(e, "Error when resolving the socket address!"))?.next(); let sock_addr = sock_addr.ok_or_else(|| TunnelError::new("No parseable addresses."))?; let stream = match std::net::TcpStream::connect(sock_addr) { Ok(s) => s, Err(e) => { eprintln!("Couldn't connect to server! {}", e); std::process::exit(4); } }; let stream = match client(stream, identity, psk) { Ok(s) => s, Err(e) => { eprintln!("Error while initialising the connection! {}", e); std::process::exit(5); } }; Ok(stream) } pub fn server<S>( stream: S, psk_provider: impl PskProvider + Send + Sync + 'static, ) -> Result<TunnelStream<S>, TunnelError> where S: Read + Write + Debug + 'static, { let mut ctx = SslContext::builder(SslMethod::tls()) .map_err(|e| TunnelError::from(e, "Error when building the SSL context!"))?; ctx.set_psk_server_callback(move |_ssl, identity, psk_buf| { let identity = identity.ok_or_else(|| ErrorStack::get())?; let identity = std::str::from_utf8(identity).map_err(|_| ErrorStack::get())?; let psk = psk_provider .get_psk(identity) .map_err(|_| ErrorStack::get())?; &mut psk_buf[..psk.len()].copy_from_slice(psk); Ok(psk.len()) }); ctx.set_min_proto_version(Some(SslVersion::TLS1_3)) .map_err(|e| { TunnelError::from(e, "Error setting the minimum protocol version to TLS 1.3") })?; ctx.set_cipher_list("PSK-CHACHA20-POLY1305").map_err(|e| { TunnelError::from( e, "Error setting the cipher suite list to PSK-CHACHA20-POLY1305", ) })?; let ssl = Ssl::new(&ctx.build()) .map_err(|e| TunnelError::from(e, "Error on starting a new TLS session"))?; let tls_stream = ssl .accept(stream) .map_err(|e| TunnelError::from(e, "Error on accepting a new TLS connection"))?; Ok(TunnelStream { tls_stream }) } pub fn client<S>( stream: S, identity: impl Into<String>, psk: impl Into<Vec<u8>>, ) -> Result<TunnelStream<S>, TunnelError> where S: Read + Write + Debug + 'static, { let mut ctx = SslContext::builder(SslMethod::tls()) .map_err(|e| TunnelError::from(e, "Error when building the SSL context!"))?; let identity = identity.into().into_bytes(); let psk = psk.into(); ctx.set_psk_client_callback(move |_ssl, _hint, identity_buf, psk_buf| { &mut identity_buf[..identity.len()].copy_from_slice(&identity); identity_buf[identity.len()] = b'\0'; &mut psk_buf[..psk.len()].copy_from_slice(&psk); Ok(psk.len()) }); ctx.set_min_proto_version(Some(SslVersion::TLS1_3)) .map_err(|e| { TunnelError::from(e, "Error setting the minimum protocol version to TLS 1.3") })?; ctx.set_cipher_list("PSK-CHACHA20-POLY1305").map_err(|e| { TunnelError::from( e, "Error setting the cipher suite list to PSK-CHACHA20-POLY1305", ) })?; let ssl = Ssl::new(&ctx.build()) .map_err(|e| TunnelError::from(e, "Error when starting a new TLS session"))?; let tls_stream = ssl .connect(stream) .map_err(|e| TunnelError::from(e, "Error when connecting to a TLS socket"))?; Ok(TunnelStream { tls_stream }) }