tower_resilience_executor/
service.rs1use crate::Executor;
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::sync::oneshot;
9use tower_service::Service;
10
11#[derive(Clone)]
28pub struct ExecutorService<S, E> {
29 inner: S,
30 executor: E,
31}
32
33impl<S, E> ExecutorService<S, E> {
34 pub fn new(service: S, executor: E) -> Self {
36 Self {
37 inner: service,
38 executor,
39 }
40 }
41
42 pub fn get_ref(&self) -> &S {
44 &self.inner
45 }
46
47 pub fn get_mut(&mut self) -> &mut S {
49 &mut self.inner
50 }
51
52 pub fn into_inner(self) -> S {
54 self.inner
55 }
56}
57
58impl<S, E, Req> Service<Req> for ExecutorService<S, E>
59where
60 S: Service<Req> + Clone + Send + 'static,
61 S::Future: Send,
62 S::Response: Send + 'static,
63 S::Error: Send + 'static,
64 E: Executor,
65 Req: Send + 'static,
66{
67 type Response = S::Response;
68 type Error = ExecutorError<S::Error>;
69 type Future = ExecutorFuture<S::Response, S::Error>;
70
71 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
72 self.inner.poll_ready(cx).map_err(ExecutorError::Service)
74 }
75
76 fn call(&mut self, req: Req) -> Self::Future {
77 let mut service = self.inner.clone();
79 let (tx, rx) = oneshot::channel();
80
81 let _handle = self.executor.spawn(async move {
83 let result = service.call(req).await;
85
86 let _ = tx.send(result.map_err(ExecutorError::Service));
90 });
91
92 ExecutorFuture { rx }
93 }
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
98pub enum ExecutorError<E> {
99 TaskCancelled,
101 Service(E),
103}
104
105impl<E: std::fmt::Display> std::fmt::Display for ExecutorError<E> {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 match self {
108 Self::TaskCancelled => write!(f, "executor task was cancelled"),
109 Self::Service(e) => write!(f, "service error: {}", e),
110 }
111 }
112}
113
114impl<E: std::error::Error + 'static> std::error::Error for ExecutorError<E> {
115 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
116 match self {
117 Self::Service(e) => Some(e),
118 _ => None,
119 }
120 }
121}
122
123pin_project! {
124 pub struct ExecutorFuture<T, E> {
126 #[pin]
127 rx: oneshot::Receiver<Result<T, ExecutorError<E>>>,
128 }
129}
130
131impl<T, E> Future for ExecutorFuture<T, E> {
132 type Output = Result<T, ExecutorError<E>>;
133
134 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 let this = self.project();
136 match this.rx.poll(cx) {
137 Poll::Ready(Ok(result)) => Poll::Ready(result),
138 Poll::Ready(Err(_)) => Poll::Ready(Err(ExecutorError::TaskCancelled)),
139 Poll::Pending => Poll::Pending,
140 }
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn test_error_display() {
150 let err: ExecutorError<std::io::Error> = ExecutorError::TaskCancelled;
151 assert_eq!(err.to_string(), "executor task was cancelled");
152 }
153
154 #[test]
155 fn test_error_eq() {
156 let err1: ExecutorError<&str> = ExecutorError::TaskCancelled;
157 let err2: ExecutorError<&str> = ExecutorError::TaskCancelled;
158 assert_eq!(err1, err2);
159
160 let err3: ExecutorError<&str> = ExecutorError::Service("test");
161 let err4: ExecutorError<&str> = ExecutorError::Service("test");
162 assert_eq!(err3, err4);
163 }
164}