statsd_mock/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::net::{SocketAddr, UdpSocket};
4use std::sync::{
5    atomic::{AtomicBool, Ordering},
6    mpsc::channel,
7    Arc,
8};
9use std::thread;
10use std::time::Duration;
11
12// Mock StatsD Server
13pub struct StatsDServer {
14    local_addr: SocketAddr,
15    sock: UdpSocket,
16}
17
18impl Default for StatsDServer {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl StatsDServer {
25    pub fn new() -> Self {
26        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
27        let sock = UdpSocket::bind(addr).unwrap();
28
29        sock.set_read_timeout(Some(Duration::from_millis(100)))
30            .unwrap();
31        let local_addr = sock.local_addr().unwrap();
32
33        Self { local_addr, sock }
34    }
35
36    /// Return the mock server address: `127.0.0.1:<random port>`
37    pub fn addr(&self) -> String {
38        self.local_addr.clone().to_string()
39    }
40
41    /// Run the given test function while receiving several packets.
42    /// Return a vector of the packets.
43    ///
44    /// ```
45    /// use statsd::Client;
46    ///
47    /// // Start the mock server
48    /// let mock = statsd_mock::start();
49    ///
50    /// let client = Client::new(&mock.addr(), "duyet").unwrap();
51    /// let response = mock.run_while_receiving_all(|| {
52    ///     client.incr("some.counter");
53    ///     client.count("some.counter", 123.0);
54    /// });
55    /// assert_eq!(
56    ///     response,
57    ///     vec!["duyet.some.counter:1|c", "duyet.some.counter:123|c"]
58    /// );
59    /// ```
60    pub fn run_while_receiving_all<F>(self, func: F) -> Vec<String>
61    where
62        F: Fn(),
63    {
64        let (serv_tx, serv_rx) = channel();
65        let func_ran = Arc::new(AtomicBool::new(false));
66        let bg_func_ran = Arc::clone(&func_ran);
67
68        let bg = thread::spawn(move || loop {
69            let mut buf = [0; 1500];
70            if let Ok((len, _)) = self.sock.recv_from(&mut buf) {
71                let bytes = Vec::from(&buf[0..len]);
72                serv_tx.send(bytes).unwrap();
73            }
74            // go through the loop least once (do...while)
75            if bg_func_ran.load(Ordering::SeqCst) {
76                break;
77            }
78        });
79
80        func();
81
82        std::thread::sleep(Duration::from_millis(200));
83        func_ran.store(true, Ordering::SeqCst);
84        bg.join().expect("background thread should join");
85
86        serv_rx
87            .into_iter()
88            .map(|bytes| String::from_utf8(bytes).unwrap())
89            .collect()
90    }
91
92    /// Run the given test function while receiving several packets.
93    /// Return a vector of the packets.
94    pub fn capture_all<F>(self, func: F) -> Vec<String>
95    where
96        F: Fn(),
97    {
98        self.run_while_receiving_all(func)
99    }
100
101    /// Run the given test function while receiving several packets.
102    /// Return the concatenation of the packets.
103    ///
104    /// ```
105    /// use statsd::Client;
106    ///
107    /// // Start the mock server
108    /// let mock = statsd_mock::start();
109    ///
110    /// let client = Client::new(&mock.addr(), "duyet").unwrap();
111    /// let response = mock.run_while_receiving(|| {
112    ///     client.count("some.counter", 123.0);
113    /// });
114    /// assert_eq!(response, "duyet.some.counter:123|c");
115    /// ```
116    pub fn run_while_receiving<F>(self, func: F) -> String
117    where
118        F: Fn(),
119    {
120        itertools::Itertools::intersperse(
121            self.run_while_receiving_all(func).into_iter(),
122            String::from("\n"),
123        )
124        .fold(String::new(), |acc, b| acc + &b)
125    }
126
127    /// Run the given test function while receiving several packets.
128    /// Return the concatenation of the packets.
129    pub fn capture<F>(self, func: F) -> String
130    where
131        F: Fn(),
132    {
133        self.run_while_receiving(func)
134    }
135}
136
137pub fn start() -> StatsDServer {
138    StatsDServer::default()
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use statsd::client::Client;
145
146    #[test]
147    fn test_get_addr() {
148        let mock = start();
149
150        assert_eq!(mock.addr().contains("127.0.0.1:"), true);
151    }
152
153    #[test]
154    fn test_capture_incr() {
155        let mock = start();
156
157        let client = Client::new(&mock.addr(), "duyet").unwrap();
158        let response = mock.capture(|| client.incr("some.counter"));
159
160        assert_eq!(response, "duyet.some.counter:1|c");
161    }
162
163    #[test]
164    fn test_capture_decr() {
165        let mock = start();
166
167        let client = Client::new(&mock.addr(), "duyet").unwrap();
168        let response = mock.capture(|| client.decr("some.counter"));
169
170        assert_eq!(response, "duyet.some.counter:-1|c");
171    }
172
173    #[test]
174    fn test_capture_count() {
175        let mock = start();
176
177        let client = Client::new(&mock.addr(), "duyet").unwrap();
178        let response = mock.capture(|| {
179            client.count("some.counter", 123.0);
180        });
181
182        assert_eq!(response, "duyet.some.counter:123|c");
183    }
184
185    #[test]
186    fn test_capture_all() {
187        let mock = start();
188
189        let client = Client::new(&mock.addr(), "duyet").unwrap();
190        let response = mock.capture_all(|| {
191            client.incr("some.counter");
192            client.incr("some.counter2");
193            client.count("some.counter3", 123.0);
194        });
195
196        assert_eq!(
197            response,
198            vec![
199                "duyet.some.counter:1|c",
200                "duyet.some.counter2:1|c",
201                "duyet.some.counter3:123|c"
202            ]
203        );
204    }
205}