cfg_async! {
use crate::{
channel::Receiver,
commands,
connector::Connector,
encoder::AsyncEncoder,
messages::{Capability, Commands, MessageId},
rate_limit::{RateClass, RateLimit},
twitch::UserConfig,
util::{Notify, NotifyHandle},
writer::{AsyncWriter, MpscWriter},
AsyncDecoder, DecodeError, Encodable, FromIrcMessage, IrcMessage,
};
use super::{
channel::Channels,
timeout::{TimeoutState, RATE_LIMIT_WINDOW, TIMEOUT, WINDOW},
Capabilities, Channel, Error, Identity, Status, StepResult,
};
use futures_lite::{AsyncRead, AsyncWrite, AsyncWriteExt, Stream};
use std::{
collections::{HashSet, VecDeque},
pin::Pin,
task::{Context, Poll},
};
pub struct AsyncRunner {
pub identity: Identity,
channels: Channels,
activity_rx: Receiver<()>,
writer_rx: Receiver<Box<[u8]>>,
notify: Notify,
notify_handle: NotifyHandle,
timeout_state: TimeoutState,
decoder: AsyncDecoder<Box<dyn AsyncRead + Send + Sync + Unpin>>,
encoder: AsyncEncoder<Box<dyn AsyncWrite + Send + Sync + Unpin>>,
writer: AsyncWriter<MpscWriter>,
global_rate_limit: RateLimit,
missed_messages: VecDeque<Commands<'static>>,
}
impl std::fmt::Debug for AsyncRunner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncRunner { .. }").finish()
}
}
impl AsyncRunner {
pub async fn connect<C>(connector: C, user_config: &UserConfig) -> Result<Self, Error>
where
C: Connector,
for<'a> &'a C::Output: AsyncRead + AsyncWrite + Send + Sync + Unpin,
{
log::debug!("connecting");
let mut stream = { connector }.connect().await?;
log::debug!("connection established");
log::debug!("registering");
let mut buf = vec![];
commands::register(user_config).encode(&mut buf)?;
stream.write_all(&buf).await?;
log::debug!("registered");
let read = async_dup::Arc::new(stream);
let write = read.clone();
let read: Box<dyn AsyncRead + Send + Sync + Unpin> = Box::new(read);
let write: Box<dyn AsyncWrite + Send + Sync + Unpin> = Box::new(write);
let mut decoder = AsyncDecoder::new(read);
let mut encoder = AsyncEncoder::new(write);
log::debug!("waiting for the connection to be ready");
let mut missed_messages = VecDeque::new();
let identity = Self::wait_for_ready(
&mut decoder,
&mut encoder,
user_config,
&mut missed_messages,
)
.await?;
log::debug!("connection is ready: {:?}", identity);
let (writer_tx, writer_rx) = crate::channel::unbounded();
let (notify, notify_handle) = Notify::new();
let (activity_tx, activity_rx) = crate::channel::bounded(32);
let writer = AsyncWriter::new(MpscWriter::new(writer_tx), activity_tx);
let timeout_state = TimeoutState::Start;
let channels = Channels::default();
let global_rate_limit = RateLimit::from_class(RateClass::Regular);
Ok(Self {
identity,
channels,
activity_rx,
writer_rx,
notify,
notify_handle,
timeout_state,
decoder,
encoder,
writer,
global_rate_limit,
missed_messages,
})
}
pub fn is_on_channel(&self, channel: &str) -> bool {
self.channels.is_on(channel)
}
pub fn get_channel_mut(&mut self, channel: &str) -> Option<&mut Channel> {
self.channels.get_mut(channel)
}
pub fn writer(&self) -> AsyncWriter<MpscWriter> {
self.writer.clone()
}
pub fn quit_handle(&self) -> NotifyHandle {
self.notify_handle.clone()
}
pub async fn join(&mut self, channel: &str) -> Result<(), Error> {
if self.is_on_channel(channel) {
return Err(Error::AlreadyOnChannel {
channel: channel.to_string(),
});
}
log::debug!("joining '{}'", channel);
self.encoder.encode(commands::join(channel)).await?;
let channel = crate::commands::Channel::new(channel).to_string();
log::debug!("waiting for a response");
let mut queue = VecDeque::new();
let status = self
.wait_for(&mut queue, |msg, this| match msg {
Commands::Join(msg) => {
Ok(msg.channel() == channel && msg.name() == this.identity.username())
}
Commands::Notice(msg) if matches!(msg.msg_id(), Some(MessageId::MsgBanned)) => {
Err(Error::BannedFromChannel {
channel: msg.channel().to_string(),
})
}
_ => Ok(false),
})
.await?;
if let Some(status) = status {
match status {
Status::Quit | Status::Eof => return Err(Error::UnexpectedEof),
_ => unimplemented!(),
}
}
self.missed_messages.extend(queue);
log::debug!("joined '{}'", channel);
Ok(())
}
pub async fn part(&mut self, channel: &str) -> Result<(), Error> {
if !self.is_on_channel(channel) {
return Err(Error::NotOnChannel {
channel: channel.to_string(),
});
}
log::debug!("leaving '{}'", channel);
self.encoder.encode(commands::part(channel)).await?;
let channel = crate::commands::Channel::new(channel).to_string();
log::debug!("waiting for a response");
let mut queue = VecDeque::new();
let status = self
.wait_for(&mut queue, |msg, this| match msg {
Commands::Part(msg) => {
Ok(msg.channel() == channel && msg.name() == this.identity.username())
}
_ => Ok(false),
})
.await?;
if let Some(status) = status {
match status {
Status::Quit | Status::Eof => return Err(Error::UnexpectedEof),
_ => unimplemented!(),
}
}
log::debug!("left '{}'", channel);
self.missed_messages.extend(queue);
Ok(())
}
pub async fn next_message(&mut self) -> Result<Status<'static>, Error> {
use crate::util::{Either::*, FutExt as _};
loop {
match self.step().await? {
StepResult::Nothing => continue,
StepResult::Status(Status::Quit) => {
if let Left(_notified) = self.notify.wait().now_or_never().await {
self.writer_rx.close();
self.activity_rx.close();
while self.available_queued_messages() > 0 {
self.drain_queued_messages().await?;
futures_lite::future::yield_now().await;
}
self.encoder.encode(commands::raw("QUIT\r\n")).await?;
break Ok(Status::Quit);
}
}
StepResult::Status(status) => break Ok(status),
}
}
}
pub async fn step(&mut self) -> Result<StepResult<'static>, Error> {
use crate::util::*;
use crate::IntoOwned as _;
if let Some(msg) = self.missed_messages.pop_front() {
return Ok(StepResult::Status(Status::Message(msg)));
}
let select = self
.decoder
.read_message()
.either(self.activity_rx.recv())
.either(self.writer_rx.recv())
.either(self.notify.wait())
.either(super::timeout::next_delay())
.await;
match select {
Left(Left(Left(Left(msg)))) => {
let msg = match msg {
Err(DecodeError::Eof) => {
log::info!("got an EOF, exiting main loop");
return Ok(StepResult::Status(Status::Eof));
}
Err(err) => {
log::warn!("read an error: {}", err);
return Err(err.into());
}
Ok(msg) => msg,
};
self.timeout_state = TimeoutState::activity();
let all = Commands::from_irc(msg)
.expect("msg identity conversion should be upheld")
.into_owned();
self.check_messages(&all).await?;
return Ok(StepResult::Status(Status::Message(all)));
}
Left(Left(Left(Right(Some(_activity))))) => {
self.timeout_state = TimeoutState::activity();
}
Left(Left(Right(Some(write_data)))) => {
let msg = std::str::from_utf8(&*write_data).map_err(Error::InvalidUtf8)?;
let res = crate::irc::parse_one(msg)
.expect("encoder should produce valid IRC messages");
let msg = res.1;
if let crate::irc::IrcMessage::PRIVMSG = msg.get_command() {
if let Some(ch) = msg.nth_arg(0) {
if !self.channels.is_on(ch) {
self.channels.add(ch)
}
let ch = self.channels.get_mut(ch).unwrap();
if ch.rated_limited_at.map(|s| s.elapsed()) > Some(RATE_LIMIT_WINDOW) {
ch.reset_rate_limit();
}
ch.rate_limited.enqueue(write_data)
}
}
}
Left(Right(_notified)) => return Ok(StepResult::Status(Status::Quit)),
Right(_timeout) => {
log::info!("idle connection detected, sending a ping");
let ts = timestamp().to_string();
self.encoder.encode(commands::ping(&ts)).await?;
self.timeout_state = TimeoutState::waiting_for_pong();
}
_ => {
return Ok(StepResult::Status(Status::Eof));
}
}
match self.timeout_state {
TimeoutState::WaitingForPong(dt) => {
if dt.elapsed() > TIMEOUT {
log::warn!("PING timeout detected, exiting");
return Err(Error::TimedOut);
}
}
TimeoutState::Activity(dt) => {
if dt.elapsed() > WINDOW {
log::warn!("idle connectiond detected, sending a PING");
let ts = timestamp().to_string();
self.encoder.encode(crate::commands::ping(&ts)).await?;
self.timeout_state = TimeoutState::waiting_for_pong();
}
}
TimeoutState::Start => {}
}
log::trace!("draining messages");
self.drain_queued_messages().await?;
Ok(StepResult::Nothing)
}
async fn check_messages(&mut self, all: &Commands<'static>) -> Result<(), Error> {
use {Commands::*, TimeoutState::*};
log::trace!("< {}", all.raw().escape_debug());
match &all {
Ping(msg) => {
let token = msg.token();
log::debug!(
"got a ping from the server. responding with token '{}'",
token
);
self.encoder.encode(commands::pong(token)).await?;
self.timeout_state = TimeoutState::activity();
}
Pong(..) if matches!(self.timeout_state, WaitingForPong {..}) => {
self.timeout_state = TimeoutState::activity()
}
Join(msg) if msg.name() == self.identity.username() => {
log::debug!("starting tracking channel for '{}'", msg.channel());
self.channels.add(msg.channel());
}
Part(msg) if msg.name() == self.identity.username() => {
log::debug!("stopping tracking of channel '{}'", msg.channel());
self.channels.remove(msg.channel());
}
RoomState(msg) => {
if let Some(dur) = msg.is_slow_mode() {
if let Some(ch) = self.channels.get_mut(msg.channel()) {
ch.enable_slow_mode(dur)
}
}
}
Notice(msg) => {
let ch = self.channels.get_mut(msg.channel());
match (msg.msg_id(), ch) {
(Some(MessageId::SlowOn), Some(ch)) => ch.enable_slow_mode(30),
(Some(MessageId::SlowOff), Some(ch)) => ch.disable_slow_mode(),
(Some(MessageId::MsgRatelimit), Some(ch)) => ch.set_rate_limited(),
(Some(MessageId::MsgBanned), ..) => self.channels.remove(msg.channel()),
_ => {}
}
}
Reconnect(_) => return Err(Error::ShouldReconnect),
_ => {}
}
Ok(())
}
}
impl AsyncRunner {
async fn wait_for<F>(
&mut self,
missed: &mut VecDeque<Commands<'static>>,
func: F,
) -> Result<Option<Status<'static>>, Error>
where
F: Fn(&Commands<'static>, &Self) -> Result<bool, Error> + Send + Sync,
{
loop {
match self.step().await? {
StepResult::Status(Status::Message(msg)) => {
if func(&msg, self)? {
break Ok(None);
}
missed.push_back(msg);
}
StepResult::Status(d) => return Ok(Some(d)),
StepResult::Nothing => continue,
}
}
}
fn available_queued_messages(&self) -> usize {
self.channels
.map
.values()
.map(|s| s.rate_limited.queue.len())
.sum()
}
async fn drain_queued_messages(&mut self) -> std::io::Result<()> {
let enc = &mut self.encoder;
let limit = &mut self.global_rate_limit.get_available_tokens();
let start = *limit;
for channel in self.channels.map.values_mut() {
if channel.rated_limited_at.map(|s| s.elapsed()) > Some(RATE_LIMIT_WINDOW) {
channel.reset_rate_limit();
}
channel
.rate_limited
.drain_until_blocked(&channel.name, limit, enc)
.await?;
let left = std::cmp::max(start, *limit);
let right = std::cmp::min(start, *limit);
let diff = left - right;
if *limit == 0 {
log::warn!(target: "twitchchat::rate_limit", "global rate limit hit while draining '{}'", &channel.name);
break;
}
match self.global_rate_limit.consume(diff) {
Ok(rem) => *limit = rem,
Err(..) => {
log::warn!(target: "twitchchat::rate_limit", "global rate limit hit while draining '{}'", &channel.name);
break;
}
}
}
Ok(())
}
async fn wait_for_ready<R, W>(
decoder: &mut AsyncDecoder<R>,
encoder: &mut AsyncEncoder<W>,
user_config: &UserConfig,
missed_messages: &mut VecDeque<Commands<'static>>,
) -> Result<Identity, Error>
where
R: AsyncRead + Send + Sync + Unpin,
W: AsyncWrite + Send + Sync + Unpin,
{
use crate::IntoOwned as _;
let is_anonymous = user_config.is_anonymous();
let mut looking_for: HashSet<_> = user_config.capabilities.iter().collect();
let mut caps = Capabilities::default();
let mut our_name = None;
use crate::twitch::Capability as TwitchCap;
let will_be_getting_global_user_state_hopefully =
user_config.capabilities.contains(&TwitchCap::Tags) &&
user_config.capabilities.contains(&TwitchCap::Commands);
let identity = loop {
let msg: IrcMessage<'_> = decoder.read_message().await?;
use Commands::*;
let commands = Commands::from_irc(msg)?;
missed_messages.push_back(commands.clone().into_owned());
match commands {
Ready(msg) => {
our_name.replace(msg.username().to_string());
if is_anonymous {
break Identity::Anonymous { caps };
}
if looking_for.is_empty() && !will_be_getting_global_user_state_hopefully {
break Identity::Basic {
name: our_name.take().unwrap(),
caps,
};
}
}
Cap(msg) => match msg.capability() {
Capability::Acknowledged(name) => {
use crate::twitch::Capability as Cap;
let cap = match Cap::maybe_from_str(name) {
Some(cap) => cap,
None => {
caps.unknown.insert(name.to_string());
continue;
}
};
*match cap {
Cap::Tags => &mut caps.tags,
Cap::Membership => &mut caps.membership,
Cap::Commands => &mut caps.commands,
} = true;
looking_for.remove(&cap);
}
Capability::NotAcknowledged(name) => {
return Err(Error::InvalidCap {
cap: name.to_string(),
})
}
},
GlobalUserState(msg) => {
let id = match msg.user_id {
Some(id) => id.parse().unwrap(),
None => {
break Identity::Basic {
name: our_name.take().unwrap(),
caps,
};
}
};
break Identity::Full {
name: our_name.unwrap(),
user_id: id,
display_name: msg.display_name.map(|s| s.to_string()),
color: msg.color,
caps,
};
}
Ping(msg) => encoder.encode(commands::pong(msg.token())).await?,
_ => {
if our_name.is_some() && !will_be_getting_global_user_state_hopefully && looking_for.is_empty() {
break Identity::Basic {
name: our_name.take().unwrap(),
caps,
};
}
}
};
};
Ok(identity)
}
}
impl Stream for AsyncRunner {
type Item = Commands<'static>;
fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use std::future::Future;
let fut = self.get_mut().next_message();
futures_lite::pin!(fut);
match futures_lite::ready!(fut.poll(ctx)) {
Ok(status) => match status {
Status::Message(msg) => Poll::Ready(Some(msg)),
Status::Quit | Status::Eof => Poll::Ready(None),
},
Err(..) => Poll::Ready(None),
}
}
}
}