use xitca_http::util::middleware::catch_unwind::{self, CatchUnwindError};
use crate::{
error::{Error, ThreadJoinError},
service::{ready::ReadyService, Service},
WebContext,
};
pub struct CatchUnwind;
impl<Arg> Service<Arg> for CatchUnwind
where
catch_unwind::CatchUnwind: Service<Arg>,
{
type Response = CatchUnwindService<<catch_unwind::CatchUnwind as Service<Arg>>::Response>;
type Error = <catch_unwind::CatchUnwind as Service<Arg>>::Error;
async fn call(&self, arg: Arg) -> Result<Self::Response, Self::Error> {
catch_unwind::CatchUnwind.call(arg).await.map(CatchUnwindService)
}
}
pub struct CatchUnwindService<S>(S);
impl<'r, C, B, S> Service<WebContext<'r, C, B>> for CatchUnwindService<S>
where
S: Service<WebContext<'r, C, B>>,
S::Error: Into<Error<C>>,
{
type Response = S::Response;
type Error = Error<C>;
#[inline]
async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
self.0.call(ctx).await.map_err(Into::into)
}
}
impl<C, E> From<CatchUnwindError<E>> for Error<C>
where
E: Into<Error<C>>,
{
fn from(e: CatchUnwindError<E>) -> Self {
match e {
CatchUnwindError::First(e) => Error::from(ThreadJoinError::new(e)),
CatchUnwindError::Second(e) => e.into(),
}
}
}
impl<S> ReadyService for CatchUnwindService<S>
where
S: ReadyService,
{
type Ready = S::Ready;
#[inline]
async fn ready(&self) -> Self::Ready {
self.0.ready().await
}
}
#[cfg(test)]
mod test {
use xitca_unsafe_collection::futures::NowOrPanic;
use crate::{
handler::handler_service,
http::{Request, StatusCode},
App,
};
use super::*;
#[test]
fn catch_panic() {
async fn handler() -> &'static str {
panic!("");
}
let res = App::new()
.with_state("996")
.at("/", handler_service(handler))
.enclosed(CatchUnwind)
.finish()
.call(())
.now_or_panic()
.unwrap()
.call(Request::default())
.now_or_panic()
.unwrap();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
}