use crate::connection::WSConnection;
use crate::error::Error;
use crate::message::Message;
use crate::read::ReadStream;
use crate::request::parse_to_http_request;
use crate::write::{Writer, WriterKind};
use base64::prelude::BASE64_STANDARD;
use base64::prelude::*;
use bytes::BytesMut;
use rand::random;
use sha1::{Digest, Sha1};
use std::sync::Arc;
use tokio::io::{split, AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::mpsc::channel;
use tokio::sync::Mutex;
use tokio::time::{timeout, Duration};
use tokio_stream::wrappers::ReceiverStream;
const SEC_WEBSOCKETS_KEY: &str = "Sec-WebSocket-Key:";
const UUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
const SWITCHING_PROTOCOLS: &str = "101 Switching Protocols";
const HTTP_ACCEPT_RESPONSE: &str = "HTTP/1.1 101 Switching Protocols\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Accept: {}\r\n\
\r\n";
pub type Result = std::result::Result<WSConnection, Error>;
pub async fn accept_async(stream: TcpStream) -> Result {
let (reader, mut write_half) = split(stream);
let mut buf_reader = BufReader::new(reader);
let sec_websockets_accept = header_read(&mut buf_reader).await;
match sec_websockets_accept {
Some(accept_value) => {
let response = HTTP_ACCEPT_RESPONSE.replace("{}", &accept_value);
write_half
.write_all(response.as_bytes())
.await
.map_err(|source| Error::IOError { source })?;
write_half.flush().await?;
}
None => Err(Error::NoSecWebsocketKey)?,
}
second_stage_handshake(buf_reader, write_half, WriterKind::Server).await
}
async fn second_stage_handshake(
buf_reader: BufReader<ReadHalf<TcpStream>>,
write_half: WriteHalf<TcpStream>,
kind: WriterKind,
) -> Result {
let writer = Arc::new(Mutex::new(Writer::new(write_half, kind)));
let stream_writer = writer.clone();
let (read_tx, read_rx) = channel::<std::result::Result<Message, Error>>(20);
let mut read_stream = ReadStream::new(buf_reader, read_tx, stream_writer);
let connection_writer = writer.clone();
let receiver_stream = ReceiverStream::new(read_rx);
let ws_connection = WSConnection::new(connection_writer, receiver_stream);
tokio::spawn(async move {
if let Err(err) = read_stream.poll_messages().await {
let _ = read_stream.read_tx.send(Err(err)).await;
}
});
Ok(ws_connection)
}
pub async fn connect_async(addr: &str) -> Result {
let client_websocket_key = generate_websocket_key();
let (request, hostname) = parse_to_http_request(addr, &client_websocket_key)?;
let stream = TcpStream::connect(hostname).await?;
let (reader, mut write_half) = split(stream);
let mut buf_reader = BufReader::new(reader);
write_half.write_all(request.as_bytes()).await?;
let mut buffer: Vec<u8> = vec![0; 206];
let number_read = buf_reader.read(&mut buffer).await?;
buffer.truncate(number_read);
let response = String::from_utf8(buffer)?;
if !response.contains(SWITCHING_PROTOCOLS) {
return Err(Error::NoUpgrade);
}
let expected_accept_value = generate_websocket_accept_value(client_websocket_key);
if !response.contains(&expected_accept_value) {
return Err(Error::InvalidAcceptKey);
}
second_stage_handshake(buf_reader, write_half, WriterKind::Client).await
}
async fn header_read<T: AsyncReadExt + Unpin>(buf_reader: &mut T) -> Option<String> {
let mut websocket_header: Option<String> = None;
let mut websocket_accept: Option<String> = None;
let mut header_buf = BytesMut::with_capacity(1024 * 16); while header_buf.len() <= 1024 * 16 {
let mut tmp_buf = vec![0; 1024];
match timeout(Duration::from_secs(10), buf_reader.read(&mut tmp_buf)).await {
Ok(Ok(0)) | Err(_) => break, Ok(Ok(n)) => {
header_buf.extend_from_slice(&tmp_buf[..n]);
let s = String::from_utf8_lossy(&header_buf);
if let Some(start) = s.find(SEC_WEBSOCKETS_KEY) {
websocket_header = Some(s[start..].lines().next().unwrap().to_string());
break;
}
}
_ => {}
}
}
if let Some(header) = websocket_header {
if let Some(key) = parse_websocket_key(header) {
websocket_accept = Some(generate_websocket_accept_value(key));
}
}
websocket_accept
}
fn parse_websocket_key(header: String) -> Option<String> {
for line in header.lines() {
if line.starts_with(SEC_WEBSOCKETS_KEY) {
if let Some(stripped) = line.strip_prefix(SEC_WEBSOCKETS_KEY) {
return stripped.split_whitespace().next().map(ToOwned::to_owned);
}
}
}
None
}
fn generate_websocket_accept_value(key: String) -> String {
let mut sha1 = Sha1::new();
sha1.update(key.as_bytes());
sha1.update(UUID.as_bytes());
BASE64_STANDARD.encode(sha1.finalize())
}
fn generate_websocket_key() -> String {
let random_bytes: [u8; 16] = random();
BASE64_STANDARD.encode(random_bytes)
}