use std::{
boxed::Box,
future::Future,
sync::Arc,
time::Duration,
pin::Pin,
task::{Context as FutContext, Poll},
};
use tokio::{
sync::mpsc::{
unbounded_channel,
UnboundedReceiver as Receiver,
UnboundedSender as Sender,
},
time::{Delay, delay_for},
};
use futures::{
future::BoxFuture,
stream::{Stream, StreamExt},
};
use crate::{
client::bridge::gateway::ShardMessenger,
model::channel::Reaction,
model::id::UserId,
};
macro_rules! impl_reaction_collector {
($($name:ident;)*) => {
$(
impl<'a> $name<'a> {
pub fn filter_limit(mut self, limit: u32) -> Self {
self.filter.as_mut().unwrap().filter_limit = Some(limit);
self
}
pub fn collect_limit(mut self, limit: u32) -> Self {
self.filter.as_mut().unwrap().collect_limit = Some(limit);
self
}
pub fn filter<F: Fn(&Arc<Reaction>) -> bool + 'static + Send + Sync>(mut self, function: F) -> Self {
self.filter.as_mut().unwrap().filter = Some(Arc::new(function));
self
}
pub fn author_id(mut self, author_id: impl Into<u64>) -> Self {
self.filter.as_mut().unwrap().author_id = Some(author_id.into());
self
}
pub fn message_id(mut self, message_id: impl Into<u64>) -> Self {
self.filter.as_mut().unwrap().message_id = Some(message_id.into());
self
}
pub fn guild_id(mut self, guild_id: impl Into<u64>) -> Self {
self.filter.as_mut().unwrap().guild_id = Some(guild_id.into());
self
}
pub fn channel_id(mut self, channel_id: impl Into<u64>) -> Self {
self.filter.as_mut().unwrap().channel_id = Some(channel_id.into());
self
}
pub fn added(mut self, is_accepted: bool) -> Self {
self.filter.as_mut().unwrap().accept_added = is_accepted;
self
}
pub fn removed(mut self, is_accepted: bool) -> Self {
self.filter.as_mut().unwrap().accept_removed = is_accepted;
self
}
pub fn timeout(mut self, duration: Duration) -> Self {
self.timeout = Some(delay_for(duration));
self
}
}
)*
}
}
#[derive(Debug)]
pub enum ReactionAction {
Added(Arc<Reaction>),
Removed(Arc<Reaction>),
}
impl ReactionAction {
pub fn as_inner_ref(&self) -> &Arc<Reaction> {
match self {
Self::Added(inner) => inner,
Self::Removed(inner) => inner,
}
}
pub fn is_added(&self) -> bool {
if let Self::Added(_) = &self {
true
} else {
false
}
}
pub fn is_removed(&self) -> bool {
if let Self::Removed(_) = &self {
true
} else {
false
}
}
}
#[derive(Clone, Debug)]
pub struct ReactionFilter {
filtered: u32,
collected: u32,
options: FilterOptions,
sender: Sender<Arc<ReactionAction>>,
}
impl ReactionFilter {
fn new(options: FilterOptions) -> (Self, Receiver<Arc<ReactionAction>>) {
let (sender, receiver) = unbounded_channel();
let filter = Self {
filtered: 0,
collected: 0,
sender,
options,
};
(filter, receiver)
}
pub(crate) fn send_reaction(&mut self, reaction: &Arc<ReactionAction>) -> bool {
if self.is_passing_constraints(&reaction) {
self.collected += 1;
if self.sender.send(Arc::clone(reaction)).is_err() {
return false;
}
}
self.filtered += 1;
self.is_within_limits()
}
fn is_passing_constraints(&self, reaction: &Arc<ReactionAction>) -> bool {
let reaction = match **reaction {
ReactionAction::Added(ref reaction) => if self.options.accept_added {
reaction
} else {
return false;
},
ReactionAction::Removed(ref reaction) => if self.options.accept_removed {
reaction
} else {
return false;
},
};
self.options.guild_id.map_or(true, |id| { Some(id) == reaction.guild_id.map(|g| g.0) })
&& self.options.message_id.map_or(true, |id| { id == reaction.message_id.0 })
&& self.options.channel_id.map_or(true, |id| { id == reaction.channel_id.0 })
&& self.options.author_id.map_or(true, |id| { id == reaction.user_id.unwrap_or(UserId(0)).0 })
&& self.options.filter.as_ref().map_or(true, |f| f(&reaction))
}
fn is_within_limits(&self) -> bool {
self.options.filter_limit.map_or(true, |limit| { self.filtered < limit })
&& self.options.collect_limit.map_or(true, |limit| { self.collected < limit })
}
}
#[derive(Clone)]
struct FilterOptions {
filter_limit: Option<u32>,
collect_limit: Option<u32>,
filter: Option<Arc<dyn Fn(&Arc<Reaction>) -> bool + 'static + Send + Sync>>,
channel_id: Option<u64>,
guild_id: Option<u64>,
author_id: Option<u64>,
message_id: Option<u64>,
accept_added: bool,
accept_removed: bool,
}
impl Default for FilterOptions {
fn default() -> Self {
Self {
filter_limit: None,
collect_limit: None,
filter: None,
channel_id: None,
guild_id: None,
author_id: None,
message_id: None,
accept_added: true,
accept_removed: false,
}
}
}
impl_reaction_collector! {
CollectReaction;
ReactionCollectorBuilder;
}
pub struct ReactionCollectorBuilder<'a> {
filter: Option<FilterOptions>,
shard: Option<ShardMessenger>,
timeout: Option<Delay>,
fut: Option<BoxFuture<'a, ReactionCollector>>,
}
impl<'a> ReactionCollectorBuilder<'a> {
pub fn new(shard_messenger: impl AsRef<ShardMessenger>) -> Self {
Self {
filter: Some(FilterOptions::default()),
shard: Some(shard_messenger.as_ref().clone()),
timeout: None,
fut: None,
}
}
}
impl<'a> Future for ReactionCollectorBuilder<'a> {
type Output = ReactionCollector;
fn poll(mut self: Pin<&mut Self>, ctx: &mut FutContext<'_>) -> Poll<Self::Output> {
if self.fut.is_none() {
let shard_messenger = self.shard.take().unwrap();
let (filter, receiver) = ReactionFilter::new(self.filter.take().unwrap());
let timeout = self.timeout.take();
self.fut = Some(Box::pin(async move {
shard_messenger.set_reaction_filter(filter);
ReactionCollector {
receiver: Box::pin(receiver),
timeout: timeout.map(Box::pin),
}
}))
}
self.fut.as_mut().unwrap().as_mut().poll(ctx)
}
}
pub struct CollectReaction<'a> {
filter: Option<FilterOptions>,
shard: Option<ShardMessenger>,
timeout: Option<Delay>,
fut: Option<BoxFuture<'a, Option<Arc<ReactionAction>>>>,
}
impl<'a> CollectReaction<'a> {
pub fn new(shard_messenger: impl AsRef<ShardMessenger>) -> Self {
Self {
filter: Some(FilterOptions::default()),
shard: Some(shard_messenger.as_ref().clone()),
timeout: None,
fut: None,
}
}
}
impl<'a> Future for CollectReaction<'a> {
type Output = Option<Arc<ReactionAction>>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut FutContext<'_>) -> Poll<Self::Output> {
if self.fut.is_none() {
let shard_messenger = self.shard.take().unwrap();
let (filter, receiver) = ReactionFilter::new(self.filter.take().unwrap());
let timeout = self.timeout.take();
self.fut = Some(Box::pin(async move {
shard_messenger.set_reaction_filter(filter);
ReactionCollector {
receiver: Box::pin(receiver),
timeout: timeout.map(Box::pin),
}.next().await
}))
}
self.fut.as_mut().unwrap().as_mut().poll(ctx)
}
}
impl std::fmt::Debug for FilterOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReactionFilter")
.field("collect_limit", &self.collect_limit)
.field("filter", &"Option<Arc<dyn Fn(&Arc<Reaction>) -> bool + 'static + Send + Sync>>")
.field("channel_id", &self.channel_id)
.field("guild_id", &self.guild_id)
.field("author_id", &self.author_id)
.finish()
}
}
pub struct ReactionCollector {
receiver: Pin<Box<Receiver<Arc<ReactionAction>>>>,
timeout: Option<Pin<Box<Delay>>>,
}
impl ReactionCollector {
pub fn stop(mut self) {
self.receiver.close();
}
}
impl Stream for ReactionCollector {
type Item = Arc<ReactionAction>;
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut FutContext<'_>) -> Poll<Option<Self::Item>> {
if let Some(ref mut timeout) = self.timeout {
match timeout.as_mut().poll(ctx) {
Poll::Ready(_) => {
return Poll::Ready(None);
},
Poll::Pending => (),
}
}
self.receiver.as_mut().poll_next(ctx)
}
}
impl Drop for ReactionCollector {
fn drop(&mut self) {
self.receiver.close();
}
}