tcp_test/
lib.rs

1/*!
2Programmatically test TCP programs using real TCP streams.
3
4# Example
5
6Everything can be done using the [`channel()`] function:
7
8```
9use tcp_test::{channel, read_assert};
10use std::io::{Read, Write};
11
12#[test]
13fn first_test() {
14    let sent = b"Hello, reader";
15
16    let (mut reader, mut writer) = channel();
17
18    writer.write_all(sent).unwrap();
19
20    let mut read = Vec::new();
21    reader.read_to_end(&mut read).unwrap();
22
23    assert_eq!(read, sent);
24}
25
26#[test]
27fn second_test() {
28    let sent = b"Interesting story";
29
30    let (mut reader, mut writer) = channel();
31
32    writer.write_all(sent).unwrap();
33
34    read_assert!(reader, sent.len(), sent);
35}
36
37#[test]
38fn third_test() {
39    let sent = b"...";
40
41    let (mut reader, mut writer) = channel();
42
43    writer.write_all(sent).unwrap();
44
45    read_assert!(reader, sent.len(), sent);
46}
47```
48
49[`channel()`]: fn.channel.html
50*/
51
52//todo: Don't use ToSocketAddrs, create a new trait instead
53
54extern crate lazy_static;
55
56use lazy_static::lazy_static;
57
58use std::net::*;
59use std::sync::{mpsc, Arc, Mutex, Once};
60use std::thread::Builder;
61
62lazy_static! {
63    /// `127.0.0.1:31398`
64    static ref DEFAULT_ADDRESS: SocketAddr =
65        SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 31398));
66}
67
68static mut CHANNEL: Option<Arc<Mutex<(mpsc::Sender<()>, mpsc::Receiver<(TcpStream, TcpStream)>)>>> =
69    None;
70static INIT: Once = Once::new();
71
72fn init(address: impl ToSocketAddrs) {
73    INIT.call_once(move || {
74        let address = resolve(address);
75
76        // channel for blocking
77        let (ex_send, receiver) = mpsc::channel();
78
79        // channel for sending the streams
80        let (sender, ex_recv) = mpsc::channel();
81
82        unsafe {
83            CHANNEL = Some(Arc::new(Mutex::new((ex_send, ex_recv))));
84        };
85
86        let listener = TcpListener::bind(address)
87            .expect(concat!("TcpListener::bind() at init(), line ", line!()));
88
89        Builder::new()
90            .name(String::from("tcp-test background thread"))
91            .spawn(move || loop {
92                receiver
93                    .recv()
94                    .expect(concat!("Receiver::recv() at init(), line ", line!()));
95
96                let local = TcpStream::connect(address)
97                    .expect(concat!("TcpStream::connect() at init(), line ", line!()));
98                let remote = listener
99                    .accept()
100                    .expect(concat!("TcpListener::accept() at init(), line ", line!()))
101                    .0;
102
103                sender
104                    .send((local, remote))
105                    .expect(concat!("Sender::send() at init(), line ", line!()));
106            })
107            .expect(concat!("Builder::spawn() at init(), line ", line!()));
108    });
109}
110
111/// Returns two TCP streams pointing at each other.
112///
113/// The internal TCP listener is bound to `127.0.0.1:31398`.
114///
115/// # Example
116///
117/// ```
118/// use tcp_test::channel;
119/// use std::io::{Read, Write};
120///
121/// #[test]
122/// fn test() {
123///     let data = b"Hello world!";
124///     let (mut local, mut remote) = channel();
125///
126///     let local_addr = local.local_addr().unwrap();
127///     let peer_addr = remote.peer_addr().unwrap();
128///
129///     assert_eq!(local_addr, peer_addr);
130///     assert_eq!(local.peer_addr().unwrap(), "127.0.0.1:31398".parse().unwrap()); // default address
131///
132///     local.write_all(data).unwrap();
133///
134///     let mut buf = [0; 12];
135///     remote.read_exact(&mut buf).unwrap();
136///
137///     assert_eq!(&buf, data);
138/// }
139/// ```
140///
141/// Also see the [module level example](index.html#example).
142///
143/// [`listen()`]: fn.listen.html
144#[inline]
145pub fn channel() -> (TcpStream, TcpStream) {
146    channel_on(*DEFAULT_ADDRESS)
147}
148
149/// Returns two TCP streams pointing at each other.
150///
151/// The internal TCP listener is bound to `address`.
152/// Only one listener is used throughout the entire program,
153/// so the address should match in all calls to this function,
154/// otherwise it is not specified which address is finally used.
155///
156/// # Example
157///
158/// ```
159/// use tcp_test::channel_on;
160/// use std::io::{Read, Write};
161///
162/// #[test]
163/// fn test() {
164///     let data = b"Hello world!";
165///     let (mut local, mut remote) = channel_on("127.0.0.1:31399");
166///
167///     assert_eq!(local.peer_addr().unwrap(), "127.0.0.1:31399".parse().unwrap());
168///     assert_eq!(remote.local_addr().unwrap(), "127.0.0.1:31399".parse().unwrap());
169///
170///     local.write_all(data).unwrap();
171///
172///     let mut buf = [0; 12];
173///     remote.read_exact(&mut buf).unwrap();
174///
175///     assert_eq!(&buf, data);
176/// }
177/// ```
178///
179/// [`listen_on()`]: fn.listen_on.html
180#[inline]
181pub fn channel_on(address: impl ToSocketAddrs) -> (TcpStream, TcpStream) {
182    init(address);
183
184    let lock = unsafe { CHANNEL.clone().unwrap() };
185
186    let guard = lock
187        .lock()
188        .expect(concat!("Mutex::lock() at channel_on(), line ", line!()));
189
190    guard
191        .0
192        .send(())
193        .expect(concat!("Sender::send() at channel_on(), line ", line!()));
194
195    guard
196        .1
197        .recv()
198        .expect(concat!("Receiver::recv() at channel_on(), line ", line!()))
199}
200
201/// Get the first socket address.
202#[inline]
203fn resolve(address: impl ToSocketAddrs) -> SocketAddr {
204    address
205        .to_socket_addrs()
206        .expect(concat!(
207            "<impl ToSocketAddrs>::to_socket_addrs() at resolve(), line ",
208            line!()
209        ))
210        .next()
211        .expect(concat!(
212            "ToSocketAddrs::Iter::next() at resolve(), line ",
213            line!()
214        ))
215}
216
217/// Convenience macro for reading and comparing a specific amount of bytes.
218///
219/// Reads a `$n` number of bytes from `$resource` and then compares that buffer with `$expected`.
220/// Panics if the buffers are not equal.
221///
222/// # Example
223///
224/// ```
225/// use tcp_test::read_assert;
226/// use std::io::{self, Read};
227///
228/// struct Placeholder;
229///
230/// impl Read for Placeholder {
231///     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
232///         buf[0] = 1;
233///         buf[1] = 2;
234///         buf[2] = 3;
235///
236///         Ok(3)
237///     }
238/// }
239///
240/// read_assert!(Placeholder {}, 3, [1, 2, 3]);
241/// ```
242#[macro_export]
243macro_rules! read_assert {
244    ($resource:expr, $n:expr, $expected:expr) => {{
245        match &$expected {
246            expected => {
247                use std::io::Read;
248
249                let mut buf = [0; $n];
250                $resource
251                    .read_exact(&mut buf)
252                    .expect("failed to read in read_assert!");
253
254                assert_eq!(
255                    &buf[..],
256                    &expected[..],
257                    "read_assert! buffers are not equal"
258                );
259            }
260        };
261    }};
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn resolve_ok() {
270        assert_eq!(resolve("127.0.0.1:80"), "127.0.0.1:80".parse().unwrap());
271        assert_eq!(resolve("[::1]:80"), "[::1]:80".parse().unwrap());
272
273        let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 80));
274        let addrs = [addr; 3];
275        assert_eq!(resolve(addrs.as_ref()), addr);
276    }
277
278    #[test]
279    #[should_panic]
280    fn resolve_err() {
281        let addrs: [SocketAddr; 0] = [];
282        resolve(addrs.as_ref());
283    }
284
285    use std::io::{self, Read};
286
287    struct Placeholder;
288
289    impl Read for Placeholder {
290        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
291            Ok(buf.len())
292        }
293    }
294
295    #[test]
296    fn read_assert_ok() {
297        read_assert!(Placeholder {}, 9, [0; 9]);
298    }
299
300    #[test]
301    #[should_panic]
302    fn read_assert_panic() {
303        read_assert!(Placeholder {}, 1, [0xff]);
304    }
305
306    macro_rules! test {
307        () => {
308            let (local, remote) = channel();
309
310            let local_addr = remote.local_addr().unwrap();
311            let peer_addr = local.peer_addr().unwrap();
312            assert_eq!(local_addr, peer_addr);
313        };
314    }
315
316    #[test]
317    fn channel_0() {
318        test!();
319    }
320
321    #[test]
322    fn channel_1() {
323        test!();
324    }
325
326    #[test]
327    fn channel_2() {
328        test!();
329    }
330
331    #[test]
332    fn channel_3() {
333        test!();
334    }
335
336    #[test]
337    fn channel_4() {
338        test!();
339    }
340}