1#![deny(missing_docs, missing_debug_implementations)]
28
29use bincode::Config;
30use bytes::{BufMut, BytesMut};
31use serde::{Deserialize, Serialize};
32use std::fmt;
33use std::io::{self, Read};
34use std::marker::PhantomData;
35use tokio_codec::{Decoder, Encoder};
36
37pub struct BinCodec<T> {
39 config: Config,
40 _pd: PhantomData<T>,
41}
42
43impl<T> BinCodec<T> {
44 pub fn new() -> Self {
46 let config = bincode::config();
47 BinCodec::with_config(config)
48 }
49
50 pub fn with_config(config: Config) -> Self {
52 BinCodec {
53 config,
54 _pd: PhantomData,
55 }
56 }
57}
58
59impl<T> Decoder for BinCodec<T>
60where
61 for<'de> T: Deserialize<'de>,
62{
63 type Item = T;
64 type Error = bincode::Error;
65
66 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
67 if !buf.is_empty() {
68 let mut reader = Reader::new(&buf[..]);
69 let message = self.config.deserialize_from(&mut reader)?;
70 buf.split_to(reader.amount());
71 Ok(Some(message))
72 } else {
73 Ok(None)
74 }
75 }
76}
77
78impl<T> Encoder for BinCodec<T>
79where
80 T: Serialize,
81{
82 type Item = T;
83 type Error = bincode::Error;
84
85 fn encode(&mut self, item: T, buf: &mut BytesMut) -> Result<(), Self::Error> {
86 let size = self.config.serialized_size(&item)?;
87 buf.reserve(size as usize);
88 let message = self.config.serialize(&item)?;
89 buf.put(&message[..]);
90 Ok(())
91 }
92}
93
94impl<T> fmt::Debug for BinCodec<T> {
95 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
96 f.debug_struct("BinCodec").finish()
97 }
98}
99
100#[derive(Debug)]
101struct Reader<'buf> {
102 buf: &'buf [u8],
103 amount: usize,
104}
105
106impl<'buf> Reader<'buf> {
107 pub fn new(buf: &'buf [u8]) -> Self {
108 Reader { buf, amount: 0 }
109 }
110
111 pub fn amount(&self) -> usize {
112 self.amount
113 }
114}
115
116impl<'buf, 'a> Read for &'a mut Reader<'buf> {
117 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
118 let bytes_read = self.buf.read(buf)?;
119 self.amount += bytes_read;
120 Ok(bytes_read)
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use futures::{Future, Sink, Stream};
128 use serde_derive::{Deserialize, Serialize};
129 use std::net::SocketAddr;
130 use tokio::{
131 codec::Framed,
132 net::{TcpListener, TcpStream},
133 runtime::current_thread,
134 };
135
136 #[derive(Deserialize, Serialize, Debug, Clone, Eq, PartialEq)]
137 enum Mock {
138 One,
139 Two,
140 }
141
142 #[test]
143 fn it_works() {
144 let addr = SocketAddr::new("127.0.0.1".parse().unwrap(), 15151);
145 let echo = TcpListener::bind(&addr).unwrap();
146
147 let jh = std::thread::spawn(move || {
148 current_thread::run(
149 echo.incoming()
150 .map_err(bincode::Error::from)
151 .take(1)
152 .for_each(|stream| {
153 let (w, r) = Framed::new(stream, BinCodec::<Mock>::new()).split();
154 r.forward(w).map(|_| ())
155 })
156 .map_err(|_| ()),
157 )
158 });
159
160 let client = TcpStream::connect(&addr).wait().unwrap();
161 let client = Framed::new(client, BinCodec::<Mock>::new());
162
163 let client = client.send(Mock::One).wait().unwrap();
164
165 let (got, client) = match client.into_future().wait() {
166 Ok(x) => x,
167 Err((e, _)) => panic!(e),
168 };
169
170 assert_eq!(got, Some(Mock::One));
171
172 let client = client.send(Mock::Two).wait().unwrap();
173
174 let (got2, client) = match client.into_future().wait() {
175 Ok(x) => x,
176 Err((e, _)) => panic!(e),
177 };
178
179 assert_eq!(got2, Some(Mock::Two));
180
181 drop(client);
182 jh.join().unwrap();
183 }
184}