use futures::FutureExt;
use futures::{poll, SinkExt};
use log::{error, warn};
use std::future::Future;
use std::hash::Hash;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::thread;
use std::time::Duration;
use tokio::spawn;
use tokio::sync::mpsc;
use dynamic_pool::{DynamicPool, DynamicPoolItem, DynamicReset};
use futures::future::{join_all, try_join_all, JoinAll};
use tokio::task::{spawn_blocking, JoinError, JoinHandle, JoinSet};
use tokio::time::error::Elapsed;
use super::sender::ShutdownHandleShardSender;
use super::types::GracefulShutdownFuture;
use super::{shard, Commit, ServiceHandleMessage, ServiceHandleShardSender, ShardStats, TakenData};
use crate::{ServiceData, ShardError, ShardShutdownStats};
struct ServiceHandleShardSenderVec<Key, Data>(Vec<ServiceHandleShardSender<Key, Data>>);
impl<Key, Data> DynamicReset for ServiceHandleShardSenderVec<Key, Data> {
fn reset(&mut self) {}
}
impl<Key, Data> Deref for ServiceHandleShardSenderVec<Key, Data> {
type Target = Vec<ServiceHandleShardSender<Key, Data>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<Key, Data> DerefMut for ServiceHandleShardSenderVec<Key, Data> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub struct ServiceHandle<Key, Data> {
pool: DynamicPool<ServiceHandleShardSenderVec<Key, Data>>,
shards: DynamicPoolItem<ServiceHandleShardSenderVec<Key, Data>>,
}
impl<Key: Send + 'static, Data: ServiceData> ServiceHandle<Key, Data> {
pub(super) fn with_senders_and_pool_capacity(
senders: Vec<mpsc::Sender<ServiceHandleMessage<Key, Data>>>,
handle_pool_capacity: usize,
) -> Self {
let shards: Vec<_> = senders
.into_iter()
.map(ServiceHandleShardSender::from_sender)
.collect();
assert!(
!shards.is_empty(),
"Somehow, a ServiceHandle was tried to be constructed that holds 0 shards."
);
let pool = DynamicPool::new(handle_pool_capacity, handle_pool_capacity, move || {
ServiceHandleShardSenderVec(shards.clone())
});
let shards = pool.take();
ServiceHandle { pool, shards }
}
}
impl<Key: Send + Hash, Data: ServiceData> ServiceHandle<Key, Data> {
#[inline]
pub fn handle(&self) -> Self {
self.clone()
}
#[inline]
fn select_shard(&mut self, key: &Key) -> &mut ServiceHandleShardSender<Key, Data> {
let shard_len = self.shards.len();
let shard_idx = if shard_len == 1 {
0
} else {
let key_hash = fxhash::hash(&key);
key_hash % shard_len
};
unsafe { self.shards.get_unchecked_mut(shard_idx) }
}
#[inline]
pub fn execute<F, T>(
&mut self,
key: Key,
func: F,
) -> impl Future<Output = Result<T, ShardError>> + '_
where
F: FnOnce(&Data) -> T + Send + 'static,
T: Send + 'static,
{
self.select_shard(&key).execute(key, func)
}
#[inline]
pub fn execute_if_cached<F, T>(
&mut self,
key: Key,
func: F,
) -> impl Future<Output = Result<Option<T>, ShardError>> + '_
where
F: FnOnce(&Data) -> T + Send + 'static,
T: Send + 'static,
{
self.select_shard(&key).execute_if_cached(key, func)
}
#[inline]
pub fn take_data(
&mut self,
key: Key,
) -> impl Future<Output = Result<Option<TakenData<Key, Data>>, ShardError>> + '_ {
self.select_shard(&key).take_data(key)
}
pub async fn get_shard_stats(&mut self) -> Result<Vec<ShardStats>, ShardError> {
try_join_all(
self.shards
.iter_mut()
.map(|s| s.get_shard_stats())
.collect::<Vec<_>>(),
)
.await
}
}
impl<Key, Data> Clone for ServiceHandle<Key, Data> {
fn clone(&self) -> Self {
let shards = self.pool.take();
let pool = self.pool.clone();
Self { shards, pool }
}
}
pub struct MutableServiceHandle<Key, Data>(ServiceHandle<Key, Data>);
impl<Key: Send + Hash, Data: ServiceData> MutableServiceHandle<Key, Data> {
pub(super) fn from_service_handle(service_handle: ServiceHandle<Key, Data>) -> Self {
Self(service_handle)
}
pub fn into_immutable_handle(self) -> ServiceHandle<Key, Data> {
self.0
}
#[inline]
pub fn handle(&self) -> Self {
self.clone()
}
#[inline]
pub fn execute_mut<F, T>(
&mut self,
key: Key,
func: F,
) -> impl Future<Output = Result<T, ShardError>> + '_
where
F: FnOnce(&mut Data) -> Commit<T> + Send + 'static,
T: Send + 'static,
{
self.0.select_shard(&key).execute_mut(key, func)
}
}
impl<Key, Data> Clone for MutableServiceHandle<Key, Data> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<Key, Data> Deref for MutableServiceHandle<Key, Data> {
type Target = ServiceHandle<Key, Data>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<Key, Data> DerefMut for MutableServiceHandle<Key, Data> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
struct ShutdownHandleInner {
shard_join_handles: Vec<JoinHandle<()>>,
graceful_shutdown_future: GracefulShutdownFuture,
}
pub struct ShutdownHandle {
inner: Option<ShutdownHandleInner>,
}
impl ShutdownHandle {
pub(super) fn with_senders_and_join_handles<Key, Data>(
senders: Vec<mpsc::Sender<ServiceHandleMessage<Key, Data>>>,
shard_join_handles: Vec<JoinHandle<()>>,
) -> Self
where
Key: Send + 'static,
Data: ServiceData,
{
let shards: Vec<_> = senders
.into_iter()
.map(ShutdownHandleShardSender::from_sender)
.collect();
assert!(
!shards.is_empty(),
"Somehow, a ShutdownHandle tried to be constructed that holds 0 shards."
);
let graceful_shutdown_future =
Box::pin(try_join_all(shards.into_iter().map(|s| s.shutdown())));
ShutdownHandle {
inner: Some(ShutdownHandleInner {
shard_join_handles,
graceful_shutdown_future,
}),
}
}
}
impl ShutdownHandle {
pub fn gracefully_shutdown(mut self) -> GracefulShutdownHandle {
let inner = self
.inner
.take()
.expect("invariant: missing shutdown handle inner");
GracefulShutdownHandle {
inner: Some(GracefulShutdownHandleInner {
shard_join_handles: inner.shard_join_handles,
graceful_shutdown_join_handle: tokio::spawn(async move {
let result = inner.graceful_shutdown_future.await;
match result {
Ok(shutdown_stats_vec) => {
ShutdownResult::GracefullyShutdown(shutdown_stats_vec)
}
Err(_) => {
error!("one or more shards were already shutdown during the graceful shutdown. performing hard shutdown");
ShutdownResult::HardShutdown
}
}
}),
}),
}
}
}
impl Drop for ShutdownHandle {
fn drop(&mut self) {
let Some(inner) = self.inner.take() else {
return;
};
warn!("ShutdownHandle was dropped before a shutdown was initiated; hard shutting down immediately.");
for handle in inner.shard_join_handles {
handle.abort();
}
}
}
struct GracefulShutdownHandleInner {
shard_join_handles: Vec<JoinHandle<()>>,
graceful_shutdown_join_handle: JoinHandle<ShutdownResult>,
}
pub struct GracefulShutdownHandle {
inner: Option<GracefulShutdownHandleInner>,
}
impl GracefulShutdownHandle {
pub async fn join(&mut self) -> ShutdownResult {
match self.inner.as_mut() {
Some(inner) => {
let handle = &mut inner.graceful_shutdown_join_handle;
let result = handle.await.expect("shutdown task has panicked");
self.inner = None;
result
}
None => ShutdownResult::AlreadyShutdown,
}
}
pub async fn hard_shutdown_after(mut self, timeout: std::time::Duration) -> ShutdownResult {
match tokio::time::timeout(timeout, self.join()).await {
Ok(shutdown_result) => shutdown_result,
Err(_) => self.hard_shutdown_immediatey(),
}
}
pub fn hard_shutdown_immediatey(mut self) -> ShutdownResult {
let inner = self
.inner
.take()
.expect("invariant: missing graceful shutdown join handle");
match inner.graceful_shutdown_join_handle.now_or_never() {
Some(result) => result.expect("shutdown task panicked"),
None => {
for handle in inner.shard_join_handles {
handle.abort();
}
ShutdownResult::HardShutdown
}
}
}
}
impl Drop for GracefulShutdownHandle {
fn drop(&mut self) {
let Some(inner) = self.inner.take() else {
return;
};
warn!("GracefulShutdownHandle was dropped before shutdown was complete; hard shutting down immediately.");
inner.graceful_shutdown_join_handle.abort();
for handle in inner.shard_join_handles {
handle.abort();
}
}
}
#[derive(Debug)]
pub enum ShutdownResult {
GracefullyShutdown(Vec<ShardShutdownStats>),
HardShutdown,
AlreadyShutdown,
}