use std::{
io::{Error as IoError, ErrorKind},
str::FromStr,
sync::Arc,
time::{Duration, Instant}
};
use futures::{
future::Future,
Sink,
stream::{SplitStream, Stream},
sync::mpsc::{self, UnboundedSender}
};
use native_tls::TlsConnector;
use parking_lot::Mutex;
use tokio::net::TcpStream as TokioTcpStream;
use tokio::timer::Interval;
use tokio_dns::TcpStream;
use tokio_tls::TlsStream;
use tokio_tungstenite::{
tungstenite::{
Error as TungsteniteError,
handshake::client::Request,
protocol::{Message as WebsocketMessage, WebSocketConfig},
},
WebSocketStream
};
use tokio_tungstenite::stream::Stream as TungsteniteStream;
use url::Url;
use spectacles_model::{
gateway::{
GatewayEvent,
HeartbeatPacket,
HelloPacket,
IdentifyPacket,
IdentifyProperties,
Opcodes,
ReadyPacket,
ReceivePacket,
ResumeSessionPacket,
SendablePacket,
},
presence::{ClientActivity, ClientPresence, Status}
};
use crate::{
constants::{GATEWAY_URL, GATEWAY_VERSION},
errors::{Error, Result}
};
pub type ShardSplitStream = SplitStream<WebSocketStream<TungsteniteStream<TokioTcpStream, TlsStream<TokioTcpStream>>>>;
#[derive(Clone)]
pub struct Shard {
pub token: String,
pub info: [usize; 2],
pub presence: ClientPresence,
pub session_id: Option<String>,
pub interval: Option<u64>,
pub sender: Arc<Mutex<UnboundedSender<WebsocketMessage>>>,
pub stream: Arc<Mutex<Option<ShardSplitStream>>>,
current_state: Arc<Mutex<String>>,
pub heartbeat: Arc<Mutex<Heartbeat>>,
}
pub enum ShardAction {
NoneAction,
Autoreconnect,
Reconnect,
Identify,
Resume
}
#[derive(Debug, Copy, Clone)]
pub struct Heartbeat {
pub acknowledged: bool,
pub seq: u64,
}
impl Heartbeat {
fn new() -> Heartbeat {
Self {
acknowledged: false,
seq: 0
}
}
}
impl Shard {
pub fn new(token: String, info: [usize; 2]) -> impl Future<Item = Shard, Error = Error> {
Shard::begin_connection(GATEWAY_URL, info[0])
.map(move |(sender, stream)| {
Shard {
token,
session_id: None,
presence: ClientPresence {
status: String::from("online"),
..Default::default()
},
info,
interval: None,
sender: Arc::new(Mutex::new(sender)),
current_state: Arc::new(Mutex::new(String::from("handshake"))),
stream: Arc::new(Mutex::new(Some(stream))),
heartbeat: Arc::new(Mutex::new(Heartbeat::new()))
}
})
}
pub fn fulfill_gateway<'a>(&mut self, packet: ReceivePacket<'a>) -> Result<ShardAction> {
let info = self.info.clone();
let current_state = self.current_state.lock().clone();
match packet.op {
Opcodes::Dispatch => {
if let Some(GatewayEvent::READY) = packet.t {
let ready: ReadyPacket = serde_json::from_str(packet.d.get()).unwrap();
*self.current_state.lock() = "connected".to_string();
self.session_id = Some(ready.session_id.clone());
trace!("[Shard {}] Received ready, set session ID as {}", &info[0], ready.session_id)
};
Ok(ShardAction::NoneAction)
}
Opcodes::Hello => {
if self.current_state.lock().clone() == "resume".to_string() {
return Ok(ShardAction::NoneAction)
};
let hello: HelloPacket = serde_json::from_str(packet.d.get()).unwrap();
if hello.heartbeat_interval > 0 {
self.interval = Some(hello.heartbeat_interval);
}
if current_state == "handshake".to_string() {
let dur = Duration::from_millis(hello.heartbeat_interval);
tokio::spawn(Shard::begin_interval(self.clone(), dur));
return Ok(ShardAction::Identify);
}
Ok(ShardAction::Autoreconnect)
},
Opcodes::HeartbeatAck => {
let mut hb = self.heartbeat.lock().clone();
hb.acknowledged = true;
Ok(ShardAction::NoneAction)
},
Opcodes::Reconnect => Ok(ShardAction::Reconnect),
Opcodes::InvalidSession => {
let invalid: bool = serde_json::from_str(packet.d.get()).unwrap();
if !invalid {
Ok(ShardAction::Identify)
} else { Ok(ShardAction::Resume) }
},
_ => Ok(ShardAction::NoneAction)
}
}
pub fn identify(&mut self) -> Result<()> {
let token = self.token.clone();
let shard = self.info.clone();
let presence = self.presence.clone();
self.send_payload(IdentifyPacket {
large_threshold: 250,
token,
shard,
compress: false,
presence: Some(presence),
version: GATEWAY_VERSION,
properties: IdentifyProperties {
os: std::env::consts::OS.to_string(),
browser: String::from("spectacles-rs"),
device: String::from("spectacles-rs")
}
})
}
pub fn autoreconnect(&mut self) -> Box<Future<Item = (), Error = Error> + Send>{
if self.session_id.is_some() && self.heartbeat.lock().seq > 0 {
Box::new(self.resume())
} else {
Box::new(self.reconnect())
}
}
pub fn reconnect(&mut self) -> impl Future<Item = (), Error = Error> + Send {
debug!("[Shard {}] Attempting to reconnect to gateway.", &self.info[0]);
self.reset_values().expect("[Shard] Failed to reset this shard for autoreconnecting.");
self.dial_gateway()
}
pub fn resume(&mut self) -> impl Future<Item = (), Error = Error> + Send {
debug!("[Shard {}] Attempting to resume gateway connection.", &self.info[0]);
let seq = self.heartbeat.lock().seq;
let token = self.token.clone();
let state = self.current_state.clone();
let session = self.session_id.clone();
let sender = self.sender.clone();
self.dial_gateway().then(move |result|{
if result.is_err() { return result };
*state.lock() = "resuming".to_string();
let payload = ResumeSessionPacket {
session_id: session.unwrap(),
seq,
token
};
send(&sender, WebsocketMessage::text(payload.to_json()?))
})
}
pub fn resolve_packet<'a>(&self, mess: &'a WebsocketMessage) -> Result<ReceivePacket<'a>> {
match mess {
WebsocketMessage::Binary(v) => serde_json::from_slice(v),
WebsocketMessage::Text(v) => serde_json::from_str(v),
_ => unreachable!("Invalid type detected."),
}.map_err(Error::from)
}
pub fn send_payload<T: SendablePacket>(&self, payload: T) -> Result<()> {
let json = payload.to_json()?;
send(&self.sender, WebsocketMessage::text(json))
}
pub fn change_status(&mut self, status: Status) -> Result<()> {
self.presence.status = status.to_string();
let oldpresence = self.presence.clone();
self.change_presence(oldpresence)
}
pub fn change_activity(&mut self, activity: ClientActivity) -> Result<()> {
self.presence.game = Some(activity);
let oldpresence = self.presence.clone();
self.change_presence(oldpresence)
}
pub fn change_presence(&mut self, presence: ClientPresence) -> Result<()> {
debug!("[Shard {}] Sending a presence change payload. {:?}", self.info[0], presence.clone());
self.send_payload(presence.clone())?;
self.presence = presence;
Ok(())
}
fn reset_values(&mut self) -> Result<()> {
self.session_id = None;
*self.current_state.lock() = "disconnected".to_string();
let mut hb = self.heartbeat.lock();
hb.acknowledged = true;
hb.seq = 0;
Ok(())
}
fn heartbeat(&mut self) -> Result<()> {
debug!("[Shard {}] Sending heartbeat.", self.info[0]);
let seq = self.heartbeat.lock().seq;
self.send_payload(HeartbeatPacket { seq })
}
fn dial_gateway(&mut self) -> impl Future<Item = (), Error = Error> + Send {
let info = self.info.clone();
*self.current_state.lock() = String::from("connected");
let state = self.current_state.clone();
let orig_sender = self.sender.clone();
let orig_stream = self.stream.clone();
let heartbeat = self.heartbeat.clone();
Shard::begin_connection(GATEWAY_URL, info[0])
.map(move |(sender, stream)| {
*orig_sender.lock() = sender;
*heartbeat.lock() = Heartbeat::new();
*state.lock() = String::from("handshake");
*orig_stream.lock() = Some(stream);
})
}
fn begin_interval(mut shard: Shard, duration: Duration) -> impl Future<Item = (), Error = ()> {
let info = shard.info.clone();
Interval::new(Instant::now(), duration)
.map_err(move |err| {
warn!("[Shard {}] Failed to begin heartbeat interval. {:?}", info[0], err);
})
.for_each(move |_| {
if let Err(r) = shard.heartbeat() {
warn!("[Shard {}] Failed to perform heartbeat. {:?}", info[0], r);
return Err(());
}
Ok(())
})
}
fn begin_connection(ws: &str, shard_id: usize) -> impl Future<Item = (UnboundedSender<WebsocketMessage>, ShardSplitStream), Error = Error> {
let url = Url::from_str(ws).expect("Invalid Websocket URL has been provided.");
let req = Request::from(url);
let (host, port) = Shard::get_addr_info(&req);
let tlsconn = TlsConnector::new().unwrap();
let tlsconn = tokio_tls::TlsConnector::from(tlsconn);
let socket = TcpStream::connect((host.as_ref(), port));
let handshake = socket.and_then(move |socket| {
debug!("[Shard {}] Beginning handshake with gateway.", shard_id);
tlsconn.connect(host.as_ref(), socket)
.map(|s| TungsteniteStream::Tls(s))
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
});
let stream = handshake.and_then(|mut stream| {
tokio_tungstenite::stream::NoDelay::set_nodelay(&mut stream, true)
.map(move |()| stream)
});
let stream = stream.and_then(move |stream| {
tokio_tungstenite::client_async_with_config(req, stream, Some(WebSocketConfig {
max_message_size: Some(usize::max_value()),
max_frame_size: Some(usize::max_value()),
..Default::default()
})).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
});
stream.map(move |(wstream, _)| {
let (tx, rx) = mpsc::unbounded();
let (sink, stream) = wstream.split();
tokio::spawn(rx.map_err(|err| {
error!("Failed to select sink. {:?}", err);
TungsteniteError::Io(IoError::new(ErrorKind::Other, "Error whilst attempting to select sink."))
}).forward(sink).map(|_| ()).map_err(|_| ()));
(tx, stream)
}).from_err()
}
fn get_addr_info(req: &Request) -> (String, u16) {
let host = req.url.host_str().expect("Could Not parse the Websocket Host.");
let port = req.url.port_or_known_default().expect("Could not parse the websocket port.");
(host.to_string(), port)
}
}
fn send(sender: &Arc<Mutex<UnboundedSender<WebsocketMessage>>>, mess: WebsocketMessage) -> Result<()> {
sender.lock().start_send(mess)
.map(|_| ())
.map_err(From::from)
}