#[cfg(feature = "driver-core")]
use crate::driver::Config;
use crate::{
error::{JoinError, JoinResult},
id::{ChannelId, GuildId, UserId},
shards::Sharder,
Call,
ConnectionInfo,
};
#[cfg(feature = "serenity")]
use async_trait::async_trait;
use dashmap::DashMap;
#[cfg(feature = "serenity")]
use futures::channel::mpsc::UnboundedSender as Sender;
use parking_lot::RwLock as PRwLock;
#[cfg(feature = "serenity")]
use serenity::{
client::bridge::voice::VoiceGatewayManager,
gateway::InterMessage,
model::{
id::{GuildId as SerenityGuild, UserId as SerenityUser},
voice::VoiceState,
},
};
use std::sync::Arc;
#[cfg(not(feature = "tokio-02-marker"))]
use tokio::sync::Mutex;
#[cfg(feature = "tokio-02-marker")]
use tokio_compat::sync::Mutex;
#[cfg(feature = "twilight")]
use twilight_gateway::Cluster;
#[cfg(feature = "twilight")]
use twilight_model::gateway::event::Event as TwilightEvent;
#[derive(Clone, Copy, Debug, Default)]
struct ClientData {
shard_count: u64,
initialised: bool,
user_id: UserId,
}
#[derive(Debug)]
pub struct Songbird {
client_data: PRwLock<ClientData>,
calls: DashMap<GuildId, Arc<Mutex<Call>>>,
sharder: Sharder,
#[cfg(feature = "driver-core")]
driver_config: PRwLock<Option<Config>>,
}
impl Songbird {
#[cfg(feature = "serenity")]
pub fn serenity() -> Arc<Self> {
Arc::new(Self {
client_data: Default::default(),
calls: Default::default(),
sharder: Sharder::Serenity(Default::default()),
#[cfg(feature = "driver-core")]
driver_config: Default::default(),
})
}
#[cfg(feature = "twilight")]
pub fn twilight<U>(cluster: Cluster, shard_count: u64, user_id: U) -> Arc<Self>
where
U: Into<UserId>,
{
Arc::new(Self {
client_data: PRwLock::new(ClientData {
shard_count,
initialised: true,
user_id: user_id.into(),
}),
calls: Default::default(),
sharder: Sharder::Twilight(cluster),
#[cfg(feature = "driver-core")]
driver_config: Default::default(),
})
}
pub fn initialise_client_data<U: Into<UserId>>(&self, shard_count: u64, user_id: U) {
let mut client_data = self.client_data.write();
if client_data.initialised {
return;
}
client_data.shard_count = shard_count;
client_data.user_id = user_id.into();
client_data.initialised = true;
}
pub fn get<G: Into<GuildId>>(&self, guild_id: G) -> Option<Arc<Mutex<Call>>> {
self.calls
.get(&guild_id.into())
.map(|mapref| Arc::clone(&mapref))
}
pub fn get_or_insert(&self, guild_id: GuildId) -> Arc<Mutex<Call>> {
self.get(guild_id).unwrap_or_else(|| {
self.calls
.entry(guild_id)
.or_insert_with(|| {
let info = self.manager_info();
let shard = shard_id(guild_id.0, info.shard_count);
let shard_handle = self
.sharder
.get_shard(shard)
.expect("Failed to get shard handle: shard_count incorrect?");
#[cfg(feature = "driver-core")]
let call = Call::from_driver_config(
guild_id,
shard_handle,
info.user_id,
self.driver_config.read().clone().unwrap_or_default(),
);
#[cfg(not(feature = "driver-core"))]
let call = Call::new(guild_id, shard_handle, info.user_id);
Arc::new(Mutex::new(call))
})
.clone()
})
}
fn manager_info(&self) -> ClientData {
let client_data = self.client_data.write();
*client_data
}
#[cfg(feature = "driver-core")]
#[inline]
pub async fn join<C, G>(&self, guild_id: G, channel_id: C) -> (Arc<Mutex<Call>>, JoinResult<()>)
where
C: Into<ChannelId>,
G: Into<GuildId>,
{
self._join(guild_id.into(), channel_id.into()).await
}
#[cfg(feature = "driver-core")]
async fn _join(
&self,
guild_id: GuildId,
channel_id: ChannelId,
) -> (Arc<Mutex<Call>>, JoinResult<()>) {
let call = self.get_or_insert(guild_id);
let stage_1 = {
let mut handler = call.lock().await;
handler.join(channel_id).await
};
let result = match stage_1 {
Ok(chan) => chan
.await
.map_err(|_| JoinError::Dropped)
.and_then(|x| x.map_err(JoinError::from)),
Err(e) => Err(e),
};
(call, result)
}
#[inline]
pub async fn join_gateway<C, G>(
&self,
guild_id: G,
channel_id: C,
) -> (Arc<Mutex<Call>>, JoinResult<ConnectionInfo>)
where
C: Into<ChannelId>,
G: Into<GuildId>,
{
self._join_gateway(guild_id.into(), channel_id.into()).await
}
async fn _join_gateway(
&self,
guild_id: GuildId,
channel_id: ChannelId,
) -> (Arc<Mutex<Call>>, JoinResult<ConnectionInfo>) {
let call = self.get_or_insert(guild_id);
let stage_1 = {
let mut handler = call.lock().await;
handler.join_gateway(channel_id).await
};
let result = match stage_1 {
Ok(chan) => chan.await.map_err(|_| JoinError::Dropped),
Err(e) => Err(e),
};
(call, result)
}
#[inline]
pub async fn leave<G: Into<GuildId>>(&self, guild_id: G) -> JoinResult<()> {
self._leave(guild_id.into()).await
}
async fn _leave(&self, guild_id: GuildId) -> JoinResult<()> {
if let Some(call) = self.get(guild_id) {
let mut handler = call.lock().await;
handler.leave().await
} else {
Err(JoinError::NoCall)
}
}
#[inline]
pub async fn remove<G: Into<GuildId>>(&self, guild_id: G) -> JoinResult<()> {
self._remove(guild_id.into()).await
}
async fn _remove(&self, guild_id: GuildId) -> JoinResult<()> {
self.leave(guild_id).await?;
self.calls.remove(&guild_id);
Ok(())
}
}
#[cfg(feature = "twilight")]
impl Songbird {
pub async fn process(&self, event: &TwilightEvent) {
match event {
TwilightEvent::VoiceServerUpdate(v) => {
let call = v.guild_id.map(GuildId::from).and_then(|id| self.get(id));
if let Some(call) = call {
let mut handler = call.lock().await;
if let Some(endpoint) = &v.endpoint {
handler.update_server(endpoint.clone(), v.token.clone());
}
}
},
TwilightEvent::VoiceStateUpdate(v) => {
if v.0.user_id.0 != self.client_data.read().user_id.0 {
return;
}
let call = v.0.guild_id.map(GuildId::from).and_then(|id| self.get(id));
if let Some(call) = call {
let mut handler = call.lock().await;
handler.update_state(v.0.session_id.clone());
}
},
_ => {},
}
}
}
#[cfg(feature = "serenity")]
#[async_trait]
impl VoiceGatewayManager for Songbird {
async fn initialise(&self, shard_count: u64, user_id: SerenityUser) {
self.initialise_client_data(shard_count, user_id);
}
async fn register_shard(&self, shard_id: u64, sender: Sender<InterMessage>) {
self.sharder.register_shard_handle(shard_id, sender);
}
async fn deregister_shard(&self, shard_id: u64) {
self.sharder.deregister_shard_handle(shard_id);
}
async fn server_update(&self, guild_id: SerenityGuild, endpoint: &Option<String>, token: &str) {
if let Some(call) = self.get(guild_id) {
let mut handler = call.lock().await;
if let Some(endpoint) = endpoint {
handler.update_server(endpoint.clone(), token.to_string());
}
}
}
async fn state_update(&self, guild_id: SerenityGuild, voice_state: &VoiceState) {
if voice_state.user_id.0 != self.client_data.read().user_id.0 {
return;
}
if let Some(call) = self.get(guild_id) {
let mut handler = call.lock().await;
handler.update_state(voice_state.session_id.clone());
}
}
}
#[cfg(feature = "driver-core")]
impl Songbird {
pub fn set_config(&self, new_config: Config) {
let mut config = self.driver_config.write();
*config = Some(new_config);
}
}
#[inline]
fn shard_id(guild_id: u64, shard_count: u64) -> u64 {
(guild_id >> 22) % shard_count
}