communication/behaviour/
protocol.rs

1// Copyright 2020-2021 IOTA Stiftung
2// SPDX-License-Identifier: Apache-2.0
3
4use async_trait::async_trait;
5use core::{fmt::Debug, marker::PhantomData};
6use futures::{prelude::*, AsyncRead, AsyncWrite};
7use libp2p::{
8    core::{
9        upgrade::{read_one, write_one},
10        ProtocolName,
11    },
12    request_response::RequestResponseCodec,
13};
14use serde::{de::DeserializeOwned, Serialize};
15// TODO: support no_std
16use std::io::{Error as IoError, ErrorKind as IoErrorKind, Result as IOResult};
17
18/// Trait for the generic Request and Response types
19pub trait MessageEvent: Serialize + DeserializeOwned + Debug + Send + Clone + Sync + 'static {}
20impl<T: Serialize + DeserializeOwned + Debug + Send + Clone + Sync + 'static> MessageEvent for T {}
21
22/// Custom protocol that extends libp2ps RequestResponseProtocol
23#[derive(Debug, Clone)]
24pub struct MessageProtocol();
25
26impl ProtocolName for MessageProtocol {
27    fn protocol_name(&self) -> &[u8] {
28        b"/stronghold-communication/1.0.0"
29    }
30}
31
32/// Describes how messages are read from and written to the io Socket by implementing the RequestResponseCodec
33#[derive(Clone)]
34pub struct MessageCodec<Req, Res> {
35    p: PhantomData<Req>,
36    q: PhantomData<Res>,
37}
38
39impl<Req, Res> Default for MessageCodec<Req, Res> {
40    fn default() -> Self {
41        MessageCodec {
42            p: PhantomData,
43            q: PhantomData,
44        }
45    }
46}
47
48/// Read and write requests and responses, and parse them into the generic structs Req and Res.
49#[async_trait]
50impl<Req, Res> RequestResponseCodec for MessageCodec<Req, Res>
51where
52    Req: MessageEvent,
53    Res: MessageEvent,
54{
55    type Protocol = MessageProtocol;
56    type Request = Req;
57    type Response = Res;
58
59    // read requests from remote peers and parse them into the request struct
60    async fn read_request<R>(&mut self, _: &MessageProtocol, io: &mut R) -> IOResult<Self::Request>
61    where
62        R: AsyncRead + Unpin + Send,
63    {
64        read_one(io, usize::MAX)
65            .map(|req| match req {
66                Ok(bytes) => {
67                    serde_json::from_slice(bytes.as_slice()).map_err(|e| IoError::new(IoErrorKind::InvalidData, e))
68                }
69                Err(e) => Err(IoError::new(IoErrorKind::InvalidData, e)),
70            })
71            .await
72    }
73
74    // read responses from remote peers and parse them into the request struct
75    async fn read_response<R>(&mut self, _: &MessageProtocol, io: &mut R) -> IOResult<Self::Response>
76    where
77        R: AsyncRead + Unpin + Send,
78    {
79        read_one(io, usize::MAX)
80            .map(|res| match res {
81                Ok(bytes) => {
82                    serde_json::from_slice(bytes.as_slice()).map_err(|e| IoError::new(IoErrorKind::InvalidData, e))
83                }
84                Err(e) => Err(IoError::new(IoErrorKind::InvalidData, e)),
85            })
86            .await
87    }
88
89    // deserialize request and write to the io socket
90    async fn write_request<R>(&mut self, _: &MessageProtocol, io: &mut R, req: Self::Request) -> IOResult<()>
91    where
92        R: AsyncWrite + Unpin + Send,
93    {
94        let buf = serde_json::to_vec(&req).map_err(|e| IoError::new(IoErrorKind::InvalidData, e))?;
95        write_one(io, buf).await
96    }
97
98    //  deserialize response and write to the io socket
99    async fn write_response<R>(&mut self, _: &MessageProtocol, io: &mut R, res: Self::Response) -> IOResult<()>
100    where
101        R: AsyncWrite + Unpin + Send,
102    {
103        let buf = serde_json::to_vec(&res).map_err(|e| IoError::new(IoErrorKind::InvalidData, e))?;
104        write_one(io, buf).await
105    }
106}
107
108#[cfg(test)]
109mod test {
110
111    use super::*;
112    use async_std::{
113        io,
114        net::{Shutdown, SocketAddr, TcpListener, TcpStream},
115        task,
116        task::JoinHandle,
117    };
118    use stronghold_utils::test_utils;
119
120    fn spawn_listener() -> (SocketAddr, JoinHandle<()>) {
121        let listener = task::block_on(async {
122            TcpListener::bind("127.0.0.1:0")
123                .await
124                .expect("Failed to bind tcp listener.")
125        });
126        let addr = listener.local_addr().expect("Faulty local address");
127        let handle = task::spawn(async move {
128            let mut incoming = listener.incoming();
129            let stream = incoming
130                .next()
131                .await
132                .expect("Incoming connection is none.")
133                .expect("Tcp stream is none.");
134            let (reader, writer) = &mut (&stream, &stream);
135            io::copy(reader, writer)
136                .await
137                .expect("Failed to copy reader into writer.");
138        });
139        (addr, handle)
140    }
141
142    #[test]
143    fn send_request() {
144        let mut test_vector = Vec::new();
145        for _ in 0..20 {
146            test_vector.push(test_utils::fresh::non_empty_bytestring());
147        }
148
149        let (addr, listener_handle) = spawn_listener();
150
151        let writer_handle = task::spawn(async move {
152            let protocol = MessageProtocol();
153            let mut codec = MessageCodec::<Vec<u8>, Vec<u8>>::default();
154            let mut socket = TcpStream::connect(addr).await.expect("Failed to connect tcp stream.");
155            for bytes in test_vector.iter() {
156                codec
157                    .write_request(&protocol, &mut socket, bytes.clone())
158                    .await
159                    .expect("Failed to write request.");
160            }
161            for bytes in test_vector.iter() {
162                let received = codec
163                    .read_request(&protocol, &mut socket)
164                    .await
165                    .expect("Failed to read request.");
166                assert_eq!(bytes, &received);
167            }
168            socket.shutdown(Shutdown::Both).expect("Failed to shutdown socket.");
169        });
170        task::block_on(async {
171            future::join(listener_handle, writer_handle).await;
172        });
173    }
174
175    #[test]
176    fn send_response() {
177        let mut test_vector = Vec::new();
178        for _ in 0..20 {
179            test_vector.push(test_utils::fresh::non_empty_bytestring());
180        }
181
182        let (addr, listener_handle) = spawn_listener();
183
184        let writer_handle = task::spawn(async move {
185            let protocol = MessageProtocol();
186            let mut codec = MessageCodec::<Vec<u8>, Vec<u8>>::default();
187            let mut socket = TcpStream::connect(addr).await.expect("Failed to connect tcp stream.");
188            for bytes in test_vector.iter() {
189                codec
190                    .write_response(&protocol, &mut socket, bytes.clone())
191                    .await
192                    .expect("Failed to write response.");
193            }
194            for bytes in test_vector.iter() {
195                let received = codec
196                    .read_response(&protocol, &mut socket)
197                    .await
198                    .expect("Failed to read response.");
199                assert_eq!(bytes, &received);
200            }
201            socket.shutdown(Shutdown::Both).expect("Failed to shutdown socket.");
202        });
203        task::block_on(async {
204            future::join(listener_handle, writer_handle).await;
205        });
206    }
207
208    #[test]
209    #[should_panic(expected = "All requests are corrupted.")]
210    fn corrupt_request() {
211        let mut test_vector = Vec::new();
212        for _ in 0..20 {
213            test_vector.push(test_utils::fresh::non_empty_bytestring());
214        }
215
216        let (addr, listener_handle) = spawn_listener();
217
218        let writer_handle = task::spawn(async move {
219            let protocol = MessageProtocol();
220            let mut codec = MessageCodec::<Vec<u8>, Vec<u8>>::default();
221            let mut socket = TcpStream::connect(addr).await.expect("Failed to connect tcp stream.");
222            for bytes in test_vector.clone().iter_mut() {
223                test_utils::corrupt(bytes);
224                codec
225                    .write_request(&protocol, &mut socket, bytes.clone())
226                    .await
227                    .expect("Failed to write request.");
228            }
229            let mut results = Vec::new();
230            for bytes in test_vector.iter() {
231                let received = codec
232                    .read_request(&protocol, &mut socket)
233                    .await
234                    .expect("Failed to read request.");
235                results.push(bytes == &received)
236            }
237            socket.shutdown(Shutdown::Both).expect("Failed to shutdown socket.");
238            results.into_iter().any(|res| res)
239        });
240        task::block_on(async {
241            let (_, results) = future::join(listener_handle, writer_handle).await;
242            assert!(results, "All requests are corrupted.")
243        });
244    }
245
246    #[test]
247    #[should_panic(expected = "All responses are corrupted.")]
248    fn corrupt_response() {
249        let mut test_vector = Vec::new();
250        for _ in 0..20 {
251            test_vector.push(test_utils::fresh::non_empty_bytestring());
252        }
253
254        let (addr, listener_handle) = spawn_listener();
255
256        let writer_handle = task::spawn(async move {
257            let protocol = MessageProtocol();
258            let mut codec = MessageCodec::<Vec<u8>, Vec<u8>>::default();
259            let mut socket = TcpStream::connect(addr).await.expect("Failed to connect tcp stream.");
260            for bytes in test_vector.clone().iter_mut() {
261                test_utils::corrupt(bytes);
262                codec
263                    .write_response(&protocol, &mut socket, bytes.clone())
264                    .await
265                    .expect("Failed to write response.");
266            }
267            let mut results = Vec::new();
268            for bytes in test_vector.iter() {
269                let received = codec
270                    .read_response(&protocol, &mut socket)
271                    .await
272                    .expect("Failed to read response.");
273                results.push(bytes == &received)
274            }
275            socket.shutdown(Shutdown::Both).expect("Failed to shutdown socket.");
276            results.into_iter().any(|res| res)
277        });
278        task::block_on(async {
279            let (_, results) = future::join(listener_handle, writer_handle).await;
280            assert!(results, "All responses are corrupted.")
281        });
282    }
283}