1use crate::{sink::Router, BoxSink};
2use futures::{
3 channel::mpsc::{self, Receiver, Sender},
4 ready,
5 stream::BoxStream,
6 Future, Sink, SinkExt, Stream, StreamExt,
7};
8use log::{error, warn};
9use pin_project_lite::pin_project;
10use selium_protocol::{
11 error_codes::REPLIER_ALREADY_BOUND,
12 traits::{ShutdownSink, ShutdownStream},
13 ErrorPayload, Frame,
14};
15use selium_std::errors::{Result, SeliumError};
16use std::{
17 collections::HashMap,
18 pin::Pin,
19 task::{Context, Poll},
20};
21use tokio_stream::StreamMap;
22
23const SOCK_CHANNEL_SIZE: usize = 100;
24
25type BoxedBiStream = (
26 BoxSink<Frame, SeliumError>,
27 BoxStream<'static, Result<Frame>>,
28);
29
30pub enum Socket {
31 Client(
32 (
33 BoxSink<Frame, SeliumError>,
34 BoxStream<'static, Result<Frame>>,
35 ),
36 ),
37 Server(
38 (
39 BoxSink<Frame, SeliumError>,
40 BoxStream<'static, Result<Frame>>,
41 ),
42 ),
43}
44
45pin_project! {
46 #[project = TopicProj]
47 #[must_use = "futures do nothing unless you `.await` or poll them"]
48 pub struct Topic {
49 #[pin]
50 server: Option<BoxedBiStream>,
51 #[pin]
52 stream: StreamMap<usize, BoxStream<'static, Result<Frame>>>,
53 #[pin]
54 sink: Router<usize, BoxSink<Frame, SeliumError>>,
55 next_id: usize,
56 #[pin]
57 handle: Receiver<Socket>,
58 buffered_req: Option<Frame>,
59 buffered_rep: Option<Frame>,
60 buffered_err: Option<(Option<ErrorPayload>, BoxSink<Frame, SeliumError>)>,
61 }
62}
63
64impl Topic {
65 pub fn pair() -> (Self, Sender<Socket>) {
66 let (tx, rx) = mpsc::channel(SOCK_CHANNEL_SIZE);
67
68 (
69 Self {
70 server: None,
71 stream: StreamMap::new(),
72 sink: Router::new(),
73 next_id: 0,
74 handle: rx,
75 buffered_req: None,
76 buffered_rep: None,
77 buffered_err: None,
78 },
79 tx,
80 )
81 }
82}
83
84impl Future for Topic {
85 type Output = ();
86
87 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
88 let TopicProj {
89 mut server,
90 mut stream,
91 mut sink,
92 next_id,
93 mut handle,
94 buffered_req,
95 buffered_rep,
96 buffered_err,
97 } = self.project();
98
99 loop {
100 let mut server_pending = false;
101 let mut stream_pending = false;
102
103 if buffered_req.is_some() && server.is_some() {
106 let si = &mut server.as_mut().as_pin_mut().unwrap().0;
107 ready!(si.poll_ready_unpin(cx)).unwrap();
109 si.start_send_unpin(buffered_req.take().unwrap()).unwrap();
110 }
111
112 if let Some((maybe_err, mut si)) = buffered_err.take() {
115 if let Some(err) = maybe_err {
116 match si.poll_ready_unpin(cx) {
117 Poll::Ready(Ok(_)) => {
118 if si.start_send_unpin(Frame::Error(err)).is_ok() {
119 *buffered_err = Some((None, si));
120 }
121 }
122 Poll::Ready(Err(e)) => warn!("Could not poll replier sink: {e:?}"),
123 Poll::Pending => {
124 *buffered_err = Some((Some(err), si));
125 return Poll::Pending;
126 }
127 }
128 } else {
129 match si.poll_close_unpin(cx) {
130 Poll::Ready(Ok(_)) => (),
131 Poll::Ready(Err(e)) => warn!("Could not close replier sink: {e:?}"),
132 Poll::Pending => {
133 *buffered_err = Some((None, si));
134 return Poll::Pending;
135 }
136 }
137 }
138 }
139
140 match handle.as_mut().poll_next(cx) {
141 Poll::Ready(Some(sock)) => match sock {
142 Socket::Client((si, st)) => {
143 stream.as_mut().insert(*next_id, st);
144 sink.as_mut().insert(*next_id, si);
145
146 *next_id += 1;
147 }
148 Socket::Server((si, st)) => {
149 if server.is_some() {
150 let error_payload = ErrorPayload {
151 code: REPLIER_ALREADY_BOUND,
152 message: "A replier already exists for this topic".into(),
153 };
154 *buffered_err = Some((Some(error_payload), si));
155 } else {
156 let _ = server.insert((si, st));
157 }
158 }
159 },
160 Poll::Ready(None) => {
162 ready!(sink.as_mut().poll_flush(cx)).unwrap();
163 stream.iter_mut().for_each(|(_, s)| s.shutdown_stream());
164 sink.iter_mut().for_each(|(_, s)| s.shutdown_sink());
165
166 if server.is_some() {
167 server.as_mut().as_pin_mut().unwrap().1.shutdown_stream();
168 }
169
170 return Poll::Ready(());
171 }
172 Poll::Pending
174 if stream.is_empty()
175 && server.is_none()
176 && buffered_req.is_none()
177 && buffered_rep.is_none() =>
178 {
179 return Poll::Pending
180 }
181 Poll::Pending => (),
183 }
184
185 if server.is_some() {
186 let st = &mut server.as_mut().as_pin_mut().unwrap().1;
187
188 match st.poll_next_unpin(cx) {
189 Poll::Ready(Some(Ok(item))) => {
191 *buffered_rep = Some(item);
192 }
193 Poll::Ready(Some(Err(e))) => {
195 error!("Received invalid message from replier: {e:?}")
196 }
197 Poll::Ready(None) => {
199 let si = &mut server.as_mut().as_pin_mut().unwrap().0;
200 ready!(si.poll_flush_unpin(cx)).unwrap();
201 ready!(sink.as_mut().poll_flush(cx)).unwrap();
202 *server = None;
203 }
204 Poll::Pending => {
206 server_pending = true;
207 }
208 }
209 }
210
211 if buffered_rep.is_some() {
214 ready!(sink.as_mut().poll_ready(cx)).unwrap();
216
217 let r = sink.as_mut().start_send(buffered_rep.take().unwrap());
218
219 if let Some(e) = r.err() {
220 error!("Failed to send reply to requestor: {e:?}");
221 }
222 }
223
224 match stream.as_mut().poll_next(cx) {
225 Poll::Ready(Some((id, Ok(item)))) => {
227 let mut payload = item.unwrap_message();
228 payload
229 .headers
230 .get_or_insert(HashMap::new())
231 .insert("cid".into(), format!("{id}"));
232 *buffered_req = Some(Frame::Message(payload));
233 }
234 Poll::Ready(Some((_, Err(e)))) => {
236 error!("Received invalid message from requestor: {e:?}")
237 }
238 Poll::Ready(None) => {
240 ready!(sink.as_mut().poll_flush(cx)).unwrap();
242
243 if server.is_some() {
244 let si = &mut server.as_mut().as_pin_mut().unwrap().0;
245 ready!(si.poll_flush_unpin(cx)).unwrap();
246 }
247 }
248 Poll::Pending => {
250 stream_pending = true;
251 }
252 }
253
254 if server_pending && stream_pending {
255 ready!(sink.poll_flush(cx)).unwrap();
257
258 if server.is_some() {
259 let si = &mut server.as_mut().as_pin_mut().unwrap().0;
260 ready!(si.poll_flush_unpin(cx)).unwrap();
261 }
262
263 return Poll::Pending;
264 }
265 }
266 }
267}