#![doc = include_str!("../README.md")]
#![feature(impl_trait_in_assoc_type)]
#[derive(Clone)]
pub struct ConcurrencyLimiterService<S> {
inner: S,
status: std::sync::Arc<ConcurrencyLimiterServiceSharedStatus>,
}
struct ConcurrencyLimiterServiceSharedStatus {
limit: u64,
current: std::sync::atomic::AtomicU64,
}
#[derive(Debug)]
struct ConcurrencyLimitError {
limit: u64,
current: u64,
}
impl std::fmt::Display for ConcurrencyLimitError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "concurrency limited ({}/{})", self.current, self.limit)
}
}
impl std::error::Error for ConcurrencyLimitError {}
#[cfg(feature = "volo-grpc")]
impl Into<volo_grpc::Status> for ConcurrencyLimitError {
fn into(self) -> volo_grpc::Status {
volo_grpc::Status::resource_exhausted(self.to_string())
}
}
#[volo::service]
impl<Cx, Req, S> volo::Service<Cx, Req> for ConcurrencyLimiterService<S>
where
Req: std::fmt::Debug + Send + 'static,
S: Send + 'static + volo::Service<Cx, Req> + Sync,
Cx: Send + 'static,
ConcurrencyLimitError: Into<S::Error>,
{
async fn call<'cx, 's>(&'s self, cx: &'cx mut Cx, req: Req) -> Result<S::Response, S::Error>
where
's: 'cx,
{
let curr = self
.status
.current
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if curr > self.status.limit {
self.status
.current
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
return Err(ConcurrencyLimitError {
limit: self.status.limit,
current: curr,
}
.into());
}
let res = self.inner.call(cx, req).await;
self.status
.current
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
res
}
}
pub struct ConcurrencyLimiterServiceLayer {
limit: u64,
}
impl ConcurrencyLimiterServiceLayer {
pub fn new(limit: u64) -> Self {
Self { limit }
}
}
impl<S> volo::Layer<S> for ConcurrencyLimiterServiceLayer {
type Service = ConcurrencyLimiterService<S>;
fn layer(self, inner: S) -> Self::Service {
ConcurrencyLimiterService {
inner,
status: std::sync::Arc::new(ConcurrencyLimiterServiceSharedStatus {
limit: self.limit,
current: std::sync::atomic::AtomicU64::new(0),
}),
}
}
}