socket_server_mocker/
tcp_server.rs

1use std::io::{Read, Write};
2use std::net::{SocketAddr, TcpListener, TcpStream};
3use std::sync::mpsc::{Receiver, Sender};
4use std::thread;
5use std::time::Duration;
6
7use crate::server_mocker::MockerOptions;
8use crate::Instruction::{
9    self, ReceiveMessageWithMaxSize, SendMessage, SendMessageDependingOnLastReceivedMessage,
10};
11use crate::ServerMockerError::{
12    self, UnableToAcceptConnection, UnableToBindListener, UnableToGetLocalAddress,
13    UnableToReadTcpStream, UnableToSetReadTimeout, UnableToWriteTcpStream,
14};
15
16/// Options for the TCP server mocker
17#[derive(Debug, Clone)]
18pub struct TcpMocker {
19    /// Socket address on which the server will listen. Will be set to `127.0.0.1:0` by default.
20    pub socket_addr: SocketAddr,
21    /// Timeout for the server to wait for a message from the client.
22    pub net_timeout: Duration,
23    /// Timeout if no more instruction is available and [`Instruction::StopExchange`] hasn't been sent
24    pub rx_timeout: Duration,
25    /// Buffer size for TCP socket
26    pub reader_buffer_size: usize,
27}
28
29impl Default for TcpMocker {
30    fn default() -> Self {
31        Self {
32            socket_addr: SocketAddr::from(([127, 0, 0, 1], 0)),
33            net_timeout: Duration::from_millis(100),
34            rx_timeout: Duration::from_millis(100),
35            reader_buffer_size: 1024,
36        }
37    }
38}
39
40impl MockerOptions for TcpMocker {
41    fn socket_address(&self) -> SocketAddr {
42        self.socket_addr
43    }
44
45    fn net_timeout(&self) -> Duration {
46        self.net_timeout
47    }
48
49    fn run(
50        self,
51        instruction_rx: Receiver<Vec<Instruction>>,
52        message_tx: Sender<Vec<u8>>,
53        error_tx: Sender<ServerMockerError>,
54    ) -> Result<SocketAddr, ServerMockerError> {
55        let listener = TcpListener::bind(self.socket_addr)
56            .map_err(|e| UnableToBindListener(self.socket_addr, e))?;
57        let socket_addr = listener.local_addr().map_err(UnableToGetLocalAddress)?;
58
59        thread::spawn(move || match listener.accept() {
60            Ok((stream, _addr)) => {
61                TcpServerImpl {
62                    options: self,
63                    stream,
64                    instruction_rx,
65                    message_tx,
66                    error_tx,
67                }
68                .run();
69            }
70            Err(err) => {
71                error_tx
72                    .send(UnableToAcceptConnection(socket_addr, err))
73                    .unwrap();
74            }
75        });
76
77        Ok(socket_addr)
78    }
79}
80
81/// TCP server mocker thread implementation
82pub(crate) struct TcpServerImpl {
83    options: TcpMocker,
84    stream: TcpStream,
85    instruction_rx: Receiver<Vec<Instruction>>,
86    message_tx: Sender<Vec<u8>>,
87    error_tx: Sender<ServerMockerError>,
88}
89
90/// TCP server mocker thread implementation
91impl TcpServerImpl {
92    fn run(mut self) {
93        let timeout = Some(self.options.net_timeout);
94        if let Err(e) = self.stream.set_read_timeout(timeout) {
95            self.error_tx.send(UnableToSetReadTimeout(e)).unwrap();
96            return;
97        }
98        let mut last_received_message: Option<Vec<u8>> = None;
99
100        // Timeout: if no more instruction is available and StopExchange hasn't been sent
101        // Stop server if no more instruction is available and StopExchange hasn't been sent
102        while let Ok(instructions) = self.instruction_rx.recv_timeout(self.options.rx_timeout) {
103            for instruction in instructions {
104                match instruction {
105                    SendMessage(binary_message) => {
106                        if let Err(e) = self.send_packet(&binary_message) {
107                            self.error_tx.send(e).unwrap();
108                        }
109                    }
110                    SendMessageDependingOnLastReceivedMessage(sent_message_calculator) => {
111                        // Call the closure to get the message to send
112                        let message_to_send =
113                            sent_message_calculator(last_received_message.clone());
114                        // Send the message or skip if the closure returned None
115                        if let Some(message_to_send) = message_to_send {
116                            if let Err(e) = self.send_packet(&message_to_send) {
117                                self.error_tx.send(e).unwrap();
118                            }
119                        }
120                    }
121                    Instruction::ReceiveMessage => {
122                        match self.read_packet() {
123                            Ok(whole_received_packet) => {
124                                last_received_message = Some(whole_received_packet.clone());
125                                self.message_tx.send(whole_received_packet).unwrap();
126                            }
127                            Err(e) => self.error_tx.send(e).unwrap(),
128                        };
129                    }
130                    ReceiveMessageWithMaxSize(max_message_size) => {
131                        match self.read_packet() {
132                            Ok(mut whole_received_packet) => {
133                                whole_received_packet.truncate(max_message_size);
134                                last_received_message = Some(whole_received_packet.clone());
135                                self.message_tx.send(whole_received_packet).unwrap();
136                            }
137                            Err(e) => self.error_tx.send(e).unwrap(),
138                        };
139                    }
140                    Instruction::StopExchange => {
141                        return;
142                    }
143                }
144            }
145        }
146    }
147
148    /// Read a TCP packet from the client, using temporary buffer
149    fn read_packet(&mut self) -> Result<Vec<u8>, ServerMockerError> {
150        let mut whole_received_packet: Vec<u8> = Vec::new();
151        // FIXME: not much point in reading into a buffer and copying, perhaps need to consolidate
152        let mut buffer = vec![0; self.options.reader_buffer_size];
153
154        loop {
155            let bytes_read = self
156                .stream
157                .read(&mut buffer)
158                .map_err(UnableToReadTcpStream)?;
159            whole_received_packet.extend_from_slice(&buffer[..bytes_read]);
160            if bytes_read < self.options.reader_buffer_size {
161                break;
162            }
163        }
164        Ok(whole_received_packet)
165    }
166
167    fn send_packet(&mut self, packet: &[u8]) -> Result<(), ServerMockerError> {
168        self.stream
169            .write_all(packet)
170            .map_err(UnableToWriteTcpStream)
171    }
172}