1use rama::{Context, Layer, Service};
16use tansu_sans_io::Frame;
17use tokio::sync::{mpsc, oneshot};
18use tokio_util::sync::CancellationToken;
19use tracing::debug;
20
21use crate::Error;
22
23#[derive(Clone, Debug, Default)]
25pub struct ChannelFrameLayer {
26 cancellation: CancellationToken,
27}
28
29impl ChannelFrameLayer {
30 pub fn new(cancellation: CancellationToken) -> Self {
31 Self { cancellation }
32 }
33}
34
35impl<S> Layer<S> for ChannelFrameLayer {
36 type Service = ChannelFrameService<S>;
37
38 fn layer(&self, inner: S) -> Self::Service {
39 Self::Service {
40 inner,
41 cancellation: self.cancellation.clone(),
42 }
43 }
44}
45
46#[derive(Clone, Debug, Default)]
48pub struct ChannelFrameService<S> {
49 inner: S,
50 cancellation: CancellationToken,
51}
52
53pub type FrameReceiver = mpsc::Receiver<(Frame, oneshot::Sender<Frame>)>;
55
56impl<S, State> Service<State, FrameReceiver> for ChannelFrameService<S>
57where
58 S: Service<State, Frame, Response = Frame>,
59 State: Clone + Send + Sync + 'static,
60 S::Error: From<Error>,
61{
62 type Response = ();
63 type Error = S::Error;
64
65 async fn serve(
66 &self,
67 ctx: Context<State>,
68 mut req: FrameReceiver,
69 ) -> Result<Self::Response, Self::Error> {
70 loop {
71 tokio::select! {
72 Some((frame, tx)) = req.recv() => {
73 debug!(?frame, ?tx);
74
75 self.inner
76 .serve(ctx.clone(), frame)
77 .await
78 .and_then(|response| {
79 tx.send(response)
80 .map_err(|unsent| Error::UnableToSend(Box::new(unsent)))
81 .map_err(Into::into)
82 })?
83 }
84
85 cancelled = self.cancellation.cancelled() => {
86 debug!(?cancelled);
87 break;
88 }
89 }
90 }
91
92 Ok(())
93 }
94}
95
96pub type FrameSender = mpsc::Sender<(Frame, oneshot::Sender<Frame>)>;
98
99#[derive(Clone, Debug)]
101pub struct FrameChannelService {
102 tx: FrameSender,
103}
104
105impl FrameChannelService {
106 pub fn new(tx: FrameSender) -> Self {
107 Self { tx }
108 }
109}
110
111impl<State> Service<State, Frame> for FrameChannelService
112where
113 State: Send + Sync + 'static,
114{
115 type Response = Frame;
116
117 type Error = Error;
118
119 async fn serve(&self, _ctx: Context<State>, req: Frame) -> Result<Self::Response, Self::Error> {
120 let (resp_tx, resp_rx) = oneshot::channel();
121
122 self.tx
123 .send((req, resp_tx))
124 .await
125 .map_err(|send_error| Error::UnableToSend(Box::new(send_error.0.0)))?;
126
127 resp_rx.await.map_err(Error::OneshotRecv)
128 }
129}
130
131pub fn bounded_channel(buffer: usize) -> (FrameSender, FrameReceiver) {
133 mpsc::channel::<(Frame, oneshot::Sender<Frame>)>(buffer)
134}