Skip to main content

rusty_modbus_server/
server.rs

1//! `ModbusServer` — async Modbus server with pluggable data store.
2
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use bytes::Bytes;
7use rusty_modbus_frame::frame::{Frame, FrameHeader};
8use rusty_modbus_tcp::config::TcpServerConfig;
9use rusty_modbus_tcp::listener::TcpServerListener;
10use rusty_modbus_tcp::transport::{TransportSink, TransportStream};
11use rusty_modbus_types::{ExceptionCode, MAX_PDU_SIZE, MbapHeader, UnitId};
12use tokio::sync::watch;
13use tracing::{debug, info, trace, warn};
14
15use crate::config::{DeviceIdentification, ServerConfig};
16use crate::error::ServerError;
17use crate::handler;
18use crate::store::DataStore;
19
20/// Async Modbus server, generic over the data store implementation.
21pub struct ModbusServer<S: DataStore> {
22    config: ServerConfig,
23    store: Arc<S>,
24    local_addr: SocketAddr,
25    shutdown_tx: watch::Sender<bool>,
26    accept_handle: Option<tokio::task::JoinHandle<()>>,
27}
28
29impl<S: DataStore + 'static> ModbusServer<S> {
30    /// Create and start a new Modbus server.
31    ///
32    /// Binds to the configured address and begins accepting connections immediately.
33    ///
34    /// # Errors
35    ///
36    /// Returns [`ServerError::Bind`] if the address cannot be bound.
37    #[tracing::instrument(level = "debug", skip(config, store), fields(addr = %config.listen_addr, unit_id = config.unit_id.0))]
38    pub async fn start(config: ServerConfig, store: Arc<S>) -> Result<Self, ServerError> {
39        let tcp_config = TcpServerConfig {
40            max_connections: config.max_connections,
41            ..config.tcp_config.clone()
42        };
43
44        let listener = TcpServerListener::bind(config.listen_addr, tcp_config)
45            .await
46            .map_err(|e| match e {
47                rusty_modbus_tcp::TransportError::Io(io) => ServerError::Bind(io),
48                other => ServerError::Transport(other),
49            })?;
50
51        let local_addr = listener.local_addr().map_err(|e| match e {
52            rusty_modbus_tcp::TransportError::Io(io) => ServerError::Bind(io),
53            other => ServerError::Transport(other),
54        })?;
55        info!(addr = %local_addr, unit_id = config.unit_id.0, "Modbus server listening");
56
57        let (shutdown_tx, shutdown_rx) = watch::channel(false);
58
59        let server_unit_id = config.unit_id;
60        let server_store = Arc::clone(&store);
61        let server_device_id = config.device_id.clone();
62
63        let accept_handle = tokio::spawn(async move {
64            accept_loop(
65                listener,
66                server_unit_id,
67                server_store,
68                server_device_id,
69                shutdown_rx,
70            )
71            .await;
72        });
73
74        Ok(Self {
75            config,
76            store,
77            local_addr,
78            shutdown_tx,
79            accept_handle: Some(accept_handle),
80        })
81    }
82
83    /// Graceful shutdown: stop accepting, wait for in-flight, close connections.
84    pub async fn stop(&self) {
85        info!(addr = %self.local_addr, "stopping Modbus server");
86        let _ = self.shutdown_tx.send(true);
87        // Give in-flight requests time to complete.
88        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
89    }
90
91    /// Get a reference to the data store.
92    #[must_use]
93    pub fn store(&self) -> &S {
94        self.store.as_ref()
95    }
96
97    /// Local address the server is bound to.
98    #[must_use]
99    pub fn local_addr(&self) -> SocketAddr {
100        self.local_addr
101    }
102}
103
104impl<S: DataStore> Drop for ModbusServer<S> {
105    fn drop(&mut self) {
106        let _ = self.shutdown_tx.send(true);
107        if let Some(h) = self.accept_handle.take() {
108            h.abort();
109        }
110    }
111}
112
113impl<S: DataStore> std::fmt::Debug for ModbusServer<S> {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("ModbusServer")
116            .field("addr", &self.local_addr)
117            .field("unit_id", &self.config.unit_id)
118            .finish_non_exhaustive()
119    }
120}
121
122async fn accept_loop<S: DataStore + 'static>(
123    listener: TcpServerListener,
124    unit_id: UnitId,
125    store: Arc<S>,
126    device_id: DeviceIdentification,
127    mut shutdown_rx: watch::Receiver<bool>,
128) {
129    loop {
130        tokio::select! {
131            result = listener.accept() => {
132                if let Ok((sink, stream, addr, guard)) = result {
133                    debug!(peer_addr = %addr, "accepted Modbus server connection");
134                    let conn_store = Arc::clone(&store);
135                    let conn_device_id = device_id.clone();
136                    tokio::spawn(async move {
137                        handle_connection(sink, stream, addr, unit_id, conn_store, conn_device_id).await;
138                        drop(guard);
139                    });
140                } else if let Err(error) = result {
141                    warn!(error = %error, "Modbus server accept failed");
142                }
143                // Accept error could be transient; continue.
144            }
145            _ = shutdown_rx.changed() => {
146                if *shutdown_rx.borrow() {
147                    debug!("Modbus server accept loop received shutdown");
148                    break;
149                }
150            }
151        }
152    }
153}
154
155async fn handle_connection<S: DataStore>(
156    mut sink: rusty_modbus_tcp::TcpSink,
157    mut stream: rusty_modbus_tcp::TcpRecvStream,
158    peer_addr: SocketAddr,
159    unit_id: UnitId,
160    store: Arc<S>,
161    device_id: DeviceIdentification,
162) {
163    while let Ok(frame) = stream.recv().await {
164        let request_unit_id = UnitId(frame.unit_id());
165        let pdu_len = frame.pdu.len();
166        trace!(
167            peer_addr = %peer_addr,
168            request_unit_id = request_unit_id.0,
169            pdu_len,
170            "received Modbus server request"
171        );
172
173        // Check unit ID: accept if it matches, is broadcast (0x00), or is TCP direct (0xFF).
174        if request_unit_id.0 != unit_id.0
175            && !request_unit_id.is_broadcast()
176            && !request_unit_id.is_tcp_device()
177        {
178            // Not for us — discard silently.
179            debug!(
180                peer_addr = %peer_addr,
181                request_unit_id = request_unit_id.0,
182                server_unit_id = unit_id.0,
183                "discarding request for different unit id"
184            );
185            continue;
186        }
187
188        let txn_id = match frame.header {
189            FrameHeader::Mbap(h) => h.transaction_id.get(),
190            FrameHeader::Rtu { .. } => 0,
191        };
192
193        // Process the request.
194        if let Some(response_pdu) =
195            handler::process_request(&frame.pdu, request_unit_id, store.as_ref(), &device_id).await
196        {
197            let Some(response_frame) = response_frame(txn_id, request_unit_id, response_pdu) else {
198                warn!(peer_addr = %peer_addr, txn_id, "dropping empty Modbus response PDU");
199                break;
200            };
201            if let Err(error) = sink.send(response_frame).await {
202                debug!(peer_addr = %peer_addr, txn_id, error = %error, "failed to send Modbus response");
203                break; // Connection lost.
204            }
205            trace!(peer_addr = %peer_addr, txn_id, "sent Modbus server response");
206        }
207        // If process_request returned None, it was a broadcast — no response.
208    }
209    debug!(peer_addr = %peer_addr, "Modbus server connection closed");
210}
211
212fn response_frame(txn_id: u16, unit_id: UnitId, response_pdu: Vec<u8>) -> Option<Frame> {
213    let pdu = bounded_response_pdu(response_pdu)?;
214    let pdu_len = u16::try_from(pdu.len()).expect("MAX_PDU_SIZE fits in u16");
215    let header = MbapHeader::new(txn_id, unit_id.0, pdu_len);
216    Some(Frame {
217        header: FrameHeader::Mbap(header),
218        pdu: Bytes::from(pdu),
219    })
220}
221
222fn bounded_response_pdu(response_pdu: Vec<u8>) -> Option<Vec<u8>> {
223    let fc = response_pdu.first().copied()?;
224    if response_pdu.len() <= MAX_PDU_SIZE {
225        return Some(response_pdu);
226    }
227
228    warn!(
229        function_code = fc,
230        pdu_len = response_pdu.len(),
231        max_pdu_size = MAX_PDU_SIZE,
232        "server response exceeded Modbus PDU limit"
233    );
234    Some(vec![fc | 0x80, ExceptionCode::ServerDeviceFailure.code()])
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn response_frame_preserves_valid_pdu() {
243        let frame = response_frame(0x1234, UnitId(7), vec![0x03, 0x02, 0xAA, 0xBB])
244            .expect("valid response should produce a frame");
245
246        match frame.header {
247            FrameHeader::Mbap(header) => {
248                assert_eq!(header.transaction_id.get(), 0x1234);
249                assert_eq!(header.unit_id, 7);
250                assert_eq!(header.pdu_length(), 4);
251            }
252            FrameHeader::Rtu { .. } => panic!("expected MBAP response"),
253        }
254        assert_eq!(frame.pdu.as_ref(), &[0x03, 0x02, 0xAA, 0xBB]);
255    }
256
257    #[test]
258    fn response_frame_turns_oversized_pdu_into_exception() {
259        let frame = response_frame(0xBEEF, UnitId(2), vec![0x03; MAX_PDU_SIZE + 1])
260            .expect("oversized response should become an exception frame");
261
262        match frame.header {
263            FrameHeader::Mbap(header) => {
264                assert_eq!(header.transaction_id.get(), 0xBEEF);
265                assert_eq!(header.unit_id, 2);
266                assert_eq!(header.pdu_length(), 2);
267            }
268            FrameHeader::Rtu { .. } => panic!("expected MBAP response"),
269        }
270        assert_eq!(
271            frame.pdu.as_ref(),
272            &[0x83, ExceptionCode::ServerDeviceFailure.code()]
273        );
274    }
275
276    #[test]
277    fn response_frame_drops_empty_pdu() {
278        assert!(response_frame(0, UnitId(1), Vec::new()).is_none());
279    }
280}