1use async_trait::async_trait;
5use core::{fmt::Debug, marker::PhantomData};
6use futures::{prelude::*, AsyncRead, AsyncWrite};
7use libp2p::{
8 core::{
9 upgrade::{read_one, write_one},
10 ProtocolName,
11 },
12 request_response::RequestResponseCodec,
13};
14use serde::{de::DeserializeOwned, Serialize};
15use std::io::{Error as IoError, ErrorKind as IoErrorKind, Result as IOResult};
17
18pub trait MessageEvent: Serialize + DeserializeOwned + Debug + Send + Clone + Sync + 'static {}
20impl<T: Serialize + DeserializeOwned + Debug + Send + Clone + Sync + 'static> MessageEvent for T {}
21
22#[derive(Debug, Clone)]
24pub struct MessageProtocol();
25
26impl ProtocolName for MessageProtocol {
27 fn protocol_name(&self) -> &[u8] {
28 b"/stronghold-communication/1.0.0"
29 }
30}
31
32#[derive(Clone)]
34pub struct MessageCodec<Req, Res> {
35 p: PhantomData<Req>,
36 q: PhantomData<Res>,
37}
38
39impl<Req, Res> Default for MessageCodec<Req, Res> {
40 fn default() -> Self {
41 MessageCodec {
42 p: PhantomData,
43 q: PhantomData,
44 }
45 }
46}
47
48#[async_trait]
50impl<Req, Res> RequestResponseCodec for MessageCodec<Req, Res>
51where
52 Req: MessageEvent,
53 Res: MessageEvent,
54{
55 type Protocol = MessageProtocol;
56 type Request = Req;
57 type Response = Res;
58
59 async fn read_request<R>(&mut self, _: &MessageProtocol, io: &mut R) -> IOResult<Self::Request>
61 where
62 R: AsyncRead + Unpin + Send,
63 {
64 read_one(io, usize::MAX)
65 .map(|req| match req {
66 Ok(bytes) => {
67 serde_json::from_slice(bytes.as_slice()).map_err(|e| IoError::new(IoErrorKind::InvalidData, e))
68 }
69 Err(e) => Err(IoError::new(IoErrorKind::InvalidData, e)),
70 })
71 .await
72 }
73
74 async fn read_response<R>(&mut self, _: &MessageProtocol, io: &mut R) -> IOResult<Self::Response>
76 where
77 R: AsyncRead + Unpin + Send,
78 {
79 read_one(io, usize::MAX)
80 .map(|res| match res {
81 Ok(bytes) => {
82 serde_json::from_slice(bytes.as_slice()).map_err(|e| IoError::new(IoErrorKind::InvalidData, e))
83 }
84 Err(e) => Err(IoError::new(IoErrorKind::InvalidData, e)),
85 })
86 .await
87 }
88
89 async fn write_request<R>(&mut self, _: &MessageProtocol, io: &mut R, req: Self::Request) -> IOResult<()>
91 where
92 R: AsyncWrite + Unpin + Send,
93 {
94 let buf = serde_json::to_vec(&req).map_err(|e| IoError::new(IoErrorKind::InvalidData, e))?;
95 write_one(io, buf).await
96 }
97
98 async fn write_response<R>(&mut self, _: &MessageProtocol, io: &mut R, res: Self::Response) -> IOResult<()>
100 where
101 R: AsyncWrite + Unpin + Send,
102 {
103 let buf = serde_json::to_vec(&res).map_err(|e| IoError::new(IoErrorKind::InvalidData, e))?;
104 write_one(io, buf).await
105 }
106}
107
108#[cfg(test)]
109mod test {
110
111 use super::*;
112 use async_std::{
113 io,
114 net::{Shutdown, SocketAddr, TcpListener, TcpStream},
115 task,
116 task::JoinHandle,
117 };
118 use stronghold_utils::test_utils;
119
120 fn spawn_listener() -> (SocketAddr, JoinHandle<()>) {
121 let listener = task::block_on(async {
122 TcpListener::bind("127.0.0.1:0")
123 .await
124 .expect("Failed to bind tcp listener.")
125 });
126 let addr = listener.local_addr().expect("Faulty local address");
127 let handle = task::spawn(async move {
128 let mut incoming = listener.incoming();
129 let stream = incoming
130 .next()
131 .await
132 .expect("Incoming connection is none.")
133 .expect("Tcp stream is none.");
134 let (reader, writer) = &mut (&stream, &stream);
135 io::copy(reader, writer)
136 .await
137 .expect("Failed to copy reader into writer.");
138 });
139 (addr, handle)
140 }
141
142 #[test]
143 fn send_request() {
144 let mut test_vector = Vec::new();
145 for _ in 0..20 {
146 test_vector.push(test_utils::fresh::non_empty_bytestring());
147 }
148
149 let (addr, listener_handle) = spawn_listener();
150
151 let writer_handle = task::spawn(async move {
152 let protocol = MessageProtocol();
153 let mut codec = MessageCodec::<Vec<u8>, Vec<u8>>::default();
154 let mut socket = TcpStream::connect(addr).await.expect("Failed to connect tcp stream.");
155 for bytes in test_vector.iter() {
156 codec
157 .write_request(&protocol, &mut socket, bytes.clone())
158 .await
159 .expect("Failed to write request.");
160 }
161 for bytes in test_vector.iter() {
162 let received = codec
163 .read_request(&protocol, &mut socket)
164 .await
165 .expect("Failed to read request.");
166 assert_eq!(bytes, &received);
167 }
168 socket.shutdown(Shutdown::Both).expect("Failed to shutdown socket.");
169 });
170 task::block_on(async {
171 future::join(listener_handle, writer_handle).await;
172 });
173 }
174
175 #[test]
176 fn send_response() {
177 let mut test_vector = Vec::new();
178 for _ in 0..20 {
179 test_vector.push(test_utils::fresh::non_empty_bytestring());
180 }
181
182 let (addr, listener_handle) = spawn_listener();
183
184 let writer_handle = task::spawn(async move {
185 let protocol = MessageProtocol();
186 let mut codec = MessageCodec::<Vec<u8>, Vec<u8>>::default();
187 let mut socket = TcpStream::connect(addr).await.expect("Failed to connect tcp stream.");
188 for bytes in test_vector.iter() {
189 codec
190 .write_response(&protocol, &mut socket, bytes.clone())
191 .await
192 .expect("Failed to write response.");
193 }
194 for bytes in test_vector.iter() {
195 let received = codec
196 .read_response(&protocol, &mut socket)
197 .await
198 .expect("Failed to read response.");
199 assert_eq!(bytes, &received);
200 }
201 socket.shutdown(Shutdown::Both).expect("Failed to shutdown socket.");
202 });
203 task::block_on(async {
204 future::join(listener_handle, writer_handle).await;
205 });
206 }
207
208 #[test]
209 #[should_panic(expected = "All requests are corrupted.")]
210 fn corrupt_request() {
211 let mut test_vector = Vec::new();
212 for _ in 0..20 {
213 test_vector.push(test_utils::fresh::non_empty_bytestring());
214 }
215
216 let (addr, listener_handle) = spawn_listener();
217
218 let writer_handle = task::spawn(async move {
219 let protocol = MessageProtocol();
220 let mut codec = MessageCodec::<Vec<u8>, Vec<u8>>::default();
221 let mut socket = TcpStream::connect(addr).await.expect("Failed to connect tcp stream.");
222 for bytes in test_vector.clone().iter_mut() {
223 test_utils::corrupt(bytes);
224 codec
225 .write_request(&protocol, &mut socket, bytes.clone())
226 .await
227 .expect("Failed to write request.");
228 }
229 let mut results = Vec::new();
230 for bytes in test_vector.iter() {
231 let received = codec
232 .read_request(&protocol, &mut socket)
233 .await
234 .expect("Failed to read request.");
235 results.push(bytes == &received)
236 }
237 socket.shutdown(Shutdown::Both).expect("Failed to shutdown socket.");
238 results.into_iter().any(|res| res)
239 });
240 task::block_on(async {
241 let (_, results) = future::join(listener_handle, writer_handle).await;
242 assert!(results, "All requests are corrupted.")
243 });
244 }
245
246 #[test]
247 #[should_panic(expected = "All responses are corrupted.")]
248 fn corrupt_response() {
249 let mut test_vector = Vec::new();
250 for _ in 0..20 {
251 test_vector.push(test_utils::fresh::non_empty_bytestring());
252 }
253
254 let (addr, listener_handle) = spawn_listener();
255
256 let writer_handle = task::spawn(async move {
257 let protocol = MessageProtocol();
258 let mut codec = MessageCodec::<Vec<u8>, Vec<u8>>::default();
259 let mut socket = TcpStream::connect(addr).await.expect("Failed to connect tcp stream.");
260 for bytes in test_vector.clone().iter_mut() {
261 test_utils::corrupt(bytes);
262 codec
263 .write_response(&protocol, &mut socket, bytes.clone())
264 .await
265 .expect("Failed to write response.");
266 }
267 let mut results = Vec::new();
268 for bytes in test_vector.iter() {
269 let received = codec
270 .read_response(&protocol, &mut socket)
271 .await
272 .expect("Failed to read response.");
273 results.push(bytes == &received)
274 }
275 socket.shutdown(Shutdown::Both).expect("Failed to shutdown socket.");
276 results.into_iter().any(|res| res)
277 });
278 task::block_on(async {
279 let (_, results) = future::join(listener_handle, writer_handle).await;
280 assert!(results, "All responses are corrupted.")
281 });
282 }
283}