Skip to main content

thruster_core_async_await/
middleware.rs

1use std::boxed::Box;
2use futures::future::Future;
3use std::pin::Pin;
4#[cfg(feature = "thruster_error_handling")]
5use crate::errors::ThrusterError;
6
7#[cfg(not(feature = "thruster_error_handling"))]
8pub type MiddlewareResult<C> = C;
9#[cfg(not(feature = "thruster_error_handling"))]
10pub type MiddlewareReturnValue<T> = Pin<Box<dyn Future<Output=T> + Send>>;
11#[cfg(not(feature = "thruster_error_handling"))]
12pub type MiddlewareNext<C> = Box<dyn Fn(C) -> Pin<Box<dyn Future<Output=C> + Send>> + Send + Sync>;
13#[cfg(not(feature = "thruster_error_handling"))]
14type MiddlewareFn<C> = fn(C, MiddlewareNext<C>) -> Pin<Box<dyn Future<Output=C> + Send>>;
15
16#[cfg(feature = "thruster_error_handling")]
17pub type MiddlewareResult<C> = Result<C, ThrusterError<C>>;
18#[cfg(feature = "thruster_error_handling")]
19pub type MiddlewareReturnValue<C> = Pin<Box<dyn Future<Output=MiddlewareResult<C>> + Send>>;
20#[cfg(feature = "thruster_error_handling")]
21pub type MiddlewareNext<C> = Box<dyn Fn(C) -> Pin<Box<dyn Future<Output=MiddlewareResult<C>> + Send>> + Send + Sync>;
22#[cfg(feature = "thruster_error_handling")]
23type MiddlewareFn<C> = fn(C, MiddlewareNext<C>) -> Pin<Box<dyn Future<Output=MiddlewareResult<C>> + Send>>;
24
25pub struct Middleware<C: 'static> {
26  pub middleware: &'static [
27    MiddlewareFn<C>
28  ]
29}
30
31fn chained_run<C: 'static>(i: usize, j: usize, nodes: Vec<&'static Middleware<C>>) -> MiddlewareNext<C> {
32  Box::new(move |ctx| {
33    match nodes.get(i) {
34      Some(n) => {
35        match n.middleware.get(j) {
36          Some(m) => m(ctx, chained_run(i, j + 1, nodes.clone())),
37          None => chained_run(i + 1, 0, nodes.clone())(ctx),
38        }
39      },
40      None => panic!("Chain ran into end of cycle")
41    }
42  })
43}
44
45pub struct Chain<C: 'static> {
46  pub nodes: Vec<&'static Middleware<C>>,
47  built: MiddlewareNext<C>
48}
49
50impl<C: 'static> Chain<C> {
51  pub fn new(nodes: Vec<&'static Middleware<C>>) -> Chain<C> {
52    Chain {
53      nodes,
54      built: Box::new(|_| panic!("Tried to run an unbuilt chain!"))
55    }
56  }
57
58  fn chained_run(&self, i: usize, j: usize) -> MiddlewareNext<C> {
59    chained_run(i, j, self.nodes.clone())
60  }
61
62  fn build(&mut self) {
63    self.built = self.chained_run(0, 0);
64  }
65
66  fn run(&self, context: C) -> Pin<Box<dyn Future<Output=MiddlewareResult<C>> + Send>> {
67    (self.built)(context)
68  }
69}
70
71impl<C: 'static> Clone for Chain<C> {
72  fn clone(&self) -> Self {
73    let mut chain = Chain::new(self.nodes.clone());
74    chain.build();
75    chain
76  }
77}
78
79///
80/// The MiddlewareChain is used to wrap a series of middleware functions in such a way that the tail can
81/// be accessed and modified later on. This allows Thruster to properly compose pieces of middleware
82/// into a single long chain rather than relying on disperate parts.
83///
84pub struct MiddlewareChain<T: 'static> {
85  pub chain: Chain<T>,
86  pub assigned: bool
87}
88
89impl<T: 'static> MiddlewareChain<T> {
90  ///
91  /// Creates a new, blank (i.e. will panic if run,) MiddlewareChain
92  ///
93  pub fn new() -> Self {
94    MiddlewareChain {
95      chain: Chain::new(vec![]),
96      assigned: false
97    }
98  }
99
100  ///
101  /// Assign a runnable function to this middleware chain
102  ///
103  pub fn assign(&mut self, chain: Chain<T>) {
104    self.chain = chain;
105    self.assigned = true;
106  }
107
108  pub fn assign_legacy(&mut self, chain: Chain<T>) {
109    self.assign(chain);
110  }
111
112  ///
113  /// Run the middleware chain once
114  ///
115  #[cfg(not(feature = "thruster_error_handling"))]
116  pub fn run(&self, context: T) -> Pin<Box<dyn Future<Output=T> + Send>> {
117    self.chain.run(context)
118  }
119
120  #[cfg(feature = "thruster_error_handling")]
121  pub fn run(&self, context: T) -> Pin<Box<dyn Future<Output=MiddlewareResult<T>> + Send>> {
122    self.chain.run(context)
123  }
124
125  ///
126  /// Concatenate two middleware chains. This will make this chains tail point
127  /// to the next chain. That means that calling `next` in the final piece of
128  /// this chain will invoke the next chain rather than an "End of chain" panic
129  ///
130  pub fn chain(&mut self, mut chain: MiddlewareChain<T>) {
131    self.chain.nodes.append(&mut chain.chain.nodes);
132    self.assigned = self.assigned || chain.is_assigned();
133  }
134
135  ///
136  /// Tells if the chain has been assigned OR whether it is unassigned but has
137  /// an assigned tail. If is only chained but has no assigned runnable, then
138  /// this chain acts as a passthrough to the next one.
139  ///
140  pub fn is_assigned(&self) -> bool {
141    self.assigned
142  }
143}
144
145impl<T: 'static> Default for MiddlewareChain<T> {
146  fn default() -> Self {
147    Self::new()
148  }
149}
150
151impl<T: 'static> Clone for MiddlewareChain<T> {
152  fn clone(&self) -> Self {
153    MiddlewareChain {
154      chain: self.chain.clone(),
155      assigned: self.assigned,
156    }
157  }
158}