Skip to main content

rustmod_datalink/
lib.rs

1//! Async Modbus transport abstraction layer.
2//!
3//! This crate provides the [`DataLink`] trait that abstracts over TCP and RTU
4//! transports, along with concrete implementations:
5//!
6//! - [`ModbusTcpTransport`] — Modbus TCP client transport
7//! - [`ModbusTcpServer`] / [`ModbusRtuOverTcpServer`] — server implementations
8//! - [`InMemoryModbusService`] — in-memory simulator for testing
9//!
10//! Enable the `rtu` feature for serial RTU support via `tokio-serial`.
11
12#![forbid(unsafe_code)]
13
14use async_trait::async_trait;
15use rustmod_core::encoding::{Reader, Writer};
16use rustmod_core::frame::tcp;
17pub use rustmod_core::UnitId;
18use rustmod_core::{DecodeError, EncodeError};
19use std::sync::Arc;
20use std::sync::atomic::{AtomicU16, Ordering};
21use thiserror::Error;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::{TcpStream, ToSocketAddrs};
24use tokio::sync::Mutex;
25use tracing::trace;
26
27pub mod server;
28pub mod sim;
29pub use server::{ModbusRtuOverTcpServer, ModbusService, ModbusTcpServer, ServiceError};
30pub use sim::{CoilBank, InMemoryModbusService, InMemoryPointModel, RegisterBank};
31#[cfg(feature = "rtu")]
32pub mod rtu;
33#[cfg(feature = "rtu")]
34pub use rtu::{ModbusRtuConfig, ModbusRtuTransport};
35#[cfg(feature = "rtu")]
36pub mod rtu_server;
37#[cfg(feature = "rtu")]
38pub use rtu_server::{ModbusRtuServer, ModbusRtuServerConfig};
39
40const MAX_TCP_PDU_LEN: usize = 253;
41
42/// Errors that can occur during a transport-level operation.
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum DataLinkError {
46    /// Underlying I/O error (TCP socket, serial port, etc.).
47    #[error("io error: {0}")]
48    Io(#[from] std::io::Error),
49    /// Failed to encode a request frame.
50    #[error("encode error: {0}")]
51    Encode(#[from] EncodeError),
52    /// Failed to decode a response frame.
53    #[error("decode error: {0}")]
54    Decode(#[from] DecodeError),
55    /// The remote peer closed the connection.
56    #[error("connection closed")]
57    ConnectionClosed,
58    /// The request timed out waiting for a response.
59    #[error("request timed out")]
60    Timeout,
61    /// The response was structurally invalid.
62    #[error("invalid response: {0}")]
63    InvalidResponse(&'static str),
64    /// The response transaction ID did not match the request.
65    #[error("transaction id mismatch: expected {expected}, got {got}")]
66    MismatchedTransactionId { expected: u16, got: u16 },
67    /// The caller-provided response buffer was too small for the response PDU.
68    #[error("response buffer too small (needed {needed}, available {available})")]
69    ResponseBufferTooSmall { needed: usize, available: usize },
70}
71
72/// Async transport abstraction for Modbus request/response exchanges.
73///
74/// Implementations handle framing (TCP MBAP or RTU CRC) and wire I/O.
75/// The client crate uses this trait to remain transport-agnostic.
76#[async_trait]
77pub trait DataLink: Send + Sync {
78    /// Send a request PDU to a unit and write the response PDU into `response_pdu`.
79    ///
80    /// Returns the number of response bytes written to `response_pdu`.
81    async fn exchange(
82        &self,
83        unit_id: UnitId,
84        request_pdu: &[u8],
85        response_pdu: &mut [u8],
86    ) -> Result<usize, DataLinkError>;
87
88    /// Attempt to re-establish the underlying connection.
89    ///
90    /// Called by the client before retrying after a transport error.
91    /// The default implementation is a no-op (suitable for transports
92    /// that do not support reconnection or for mock implementations).
93    async fn reconnect(&self) -> Result<(), DataLinkError> {
94        Ok(())
95    }
96
97    /// Check if the transport is connected.
98    ///
99    /// The default implementation always returns `true`.
100    fn is_connected(&self) -> bool {
101        true
102    }
103}
104
105/// Modbus TCP client transport implementing the [`DataLink`] trait.
106///
107/// Uses MBAP framing with auto-incrementing transaction IDs. The transport
108/// is internally mutex-protected, so it can be shared behind an `Arc`.
109#[derive(Debug)]
110pub struct ModbusTcpTransport {
111    stream: Arc<Mutex<TcpStream>>,
112    next_transaction_id: Arc<AtomicU16>,
113    peer_addr: std::net::SocketAddr,
114}
115
116impl ModbusTcpTransport {
117    /// Connect to a Modbus TCP device (e.g. `"192.168.1.10:502"`).
118    pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self, DataLinkError> {
119        let stream = TcpStream::connect(addr).await?;
120        let peer_addr = stream.peer_addr()?;
121        Ok(Self {
122            stream: Arc::new(Mutex::new(stream)),
123            next_transaction_id: Arc::new(AtomicU16::new(1)),
124            peer_addr,
125        })
126    }
127
128    /// Wrap an existing [`TcpStream`] as a Modbus TCP transport.
129    pub fn from_stream(stream: TcpStream) -> Self {
130        let peer_addr = stream
131            .peer_addr()
132            .expect("TcpStream must have a peer address");
133        Self {
134            stream: Arc::new(Mutex::new(stream)),
135            next_transaction_id: Arc::new(AtomicU16::new(1)),
136            peer_addr,
137        }
138    }
139
140    fn next_tid(&self) -> u16 {
141        self.next_transaction_id.fetch_add(1, Ordering::Relaxed)
142    }
143}
144
145async fn read_exact_or_connection_closed(
146    stream: &mut TcpStream,
147    buf: &mut [u8],
148) -> Result<(), DataLinkError> {
149    if let Err(err) = stream.read_exact(buf).await {
150        if err.kind() == std::io::ErrorKind::UnexpectedEof {
151            return Err(DataLinkError::ConnectionClosed);
152        }
153        return Err(DataLinkError::Io(err));
154    }
155    Ok(())
156}
157
158async fn drain_exact(stream: &mut TcpStream, mut len: usize) -> Result<(), DataLinkError> {
159    let mut scratch = [0u8; 256];
160    while len > 0 {
161        let chunk = len.min(scratch.len());
162        read_exact_or_connection_closed(stream, &mut scratch[..chunk]).await?;
163        len -= chunk;
164    }
165    Ok(())
166}
167
168#[async_trait]
169impl DataLink for ModbusTcpTransport {
170    async fn reconnect(&self) -> Result<(), DataLinkError> {
171        let new_stream = TcpStream::connect(self.peer_addr).await?;
172        let mut guard = self.stream.lock().await;
173        *guard = new_stream;
174        tracing::info!(peer = %self.peer_addr, "reconnected modbus tcp transport");
175        Ok(())
176    }
177
178    async fn exchange(
179        &self,
180        unit_id: UnitId,
181        request_pdu: &[u8],
182        response_pdu: &mut [u8],
183    ) -> Result<usize, DataLinkError> {
184        if request_pdu.is_empty() {
185            return Err(DataLinkError::InvalidResponse("empty request pdu"));
186        }
187
188        let transaction_id = self.next_tid();
189        let mut req_frame = vec![0u8; tcp::MBAP_HEADER_LEN + request_pdu.len()];
190        let mut writer = Writer::new(&mut req_frame);
191        tcp::encode_frame(&mut writer, transaction_id, unit_id, request_pdu)?;
192
193        let mut stream = self.stream.lock().await;
194        trace!(
195            transaction_id,
196            unit_id = unit_id.as_u8(),
197            pdu_len = request_pdu.len(),
198            "sending modbus tcp request"
199        );
200        stream.write_all(writer.as_written()).await?;
201
202        let mut mbap = [0u8; tcp::MBAP_HEADER_LEN];
203        read_exact_or_connection_closed(&mut stream, &mut mbap).await?;
204
205        let mut reader = Reader::new(&mbap);
206        let header = tcp::MbapHeader::decode(&mut reader)?;
207
208        let pdu_len = usize::from(header.length)
209            .checked_sub(1)
210            .ok_or(DataLinkError::InvalidResponse("invalid mbap length"))?;
211        if pdu_len == 0 {
212            return Err(DataLinkError::InvalidResponse("empty response pdu"));
213        }
214        let tid_mismatch = header.transaction_id != transaction_id;
215        let unit_mismatch = header.unit_id != unit_id;
216
217        if pdu_len > MAX_TCP_PDU_LEN {
218            drain_exact(&mut stream, pdu_len).await?;
219            if tid_mismatch {
220                return Err(DataLinkError::MismatchedTransactionId {
221                    expected: transaction_id,
222                    got: header.transaction_id,
223                });
224            }
225            if unit_mismatch {
226                return Err(DataLinkError::InvalidResponse("unit id mismatch"));
227            }
228            return Err(DataLinkError::InvalidResponse("response pdu too large"));
229        }
230
231        if pdu_len > response_pdu.len() {
232            drain_exact(&mut stream, pdu_len).await?;
233            if tid_mismatch {
234                return Err(DataLinkError::MismatchedTransactionId {
235                    expected: transaction_id,
236                    got: header.transaction_id,
237                });
238            }
239            if unit_mismatch {
240                return Err(DataLinkError::InvalidResponse("unit id mismatch"));
241            }
242            return Err(DataLinkError::ResponseBufferTooSmall {
243                needed: pdu_len,
244                available: response_pdu.len(),
245            });
246        }
247
248        read_exact_or_connection_closed(&mut stream, &mut response_pdu[..pdu_len]).await?;
249        if tid_mismatch {
250            return Err(DataLinkError::MismatchedTransactionId {
251                expected: transaction_id,
252                got: header.transaction_id,
253            });
254        }
255        if unit_mismatch {
256            return Err(DataLinkError::InvalidResponse("unit id mismatch"));
257        }
258        trace!(
259            transaction_id,
260            unit_id = unit_id.as_u8(),
261            pdu_len,
262            "received modbus tcp response"
263        );
264        Ok(pdu_len)
265    }
266}
267
268#[cfg(test)]
269const _: () = {
270    fn _assert_send_sync<T: Send + Sync>() {}
271    fn _assertions() {
272        _assert_send_sync::<ModbusTcpTransport>();
273    }
274};
275
276#[cfg(test)]
277mod tests {
278    use super::{DataLink, DataLinkError, ModbusTcpTransport};
279    use rustmod_core::encoding::Writer;
280    use rustmod_core::frame::tcp;
281    use rustmod_core::UnitId;
282    use tokio::io::{AsyncReadExt, AsyncWriteExt};
283    use tokio::net::TcpListener;
284
285    #[tokio::test]
286    async fn exchange_roundtrip_over_tcp() {
287        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
288        let addr = listener.local_addr().unwrap();
289
290        let server = tokio::spawn(async move {
291            let (mut socket, _) = listener.accept().await.unwrap();
292
293            let mut req = [0u8; 12];
294            socket.read_exact(&mut req).await.unwrap();
295            assert_eq!(&req[7..], &[0x03, 0x00, 0x6B, 0x00, 0x03]);
296
297            let mut frame = [0u8; 15];
298            let mut w = Writer::new(&mut frame);
299            tcp::encode_frame(
300                &mut w,
301                1,
302                UnitId::new(1),
303                &[0x03, 0x06, 0x02, 0x2B, 0x00, 0x00, 0x00, 0x64],
304            )
305            .unwrap();
306            socket.write_all(w.as_written()).await.unwrap();
307        });
308
309        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
310        let mut response = [0u8; 256];
311        let len = transport
312            .exchange(UnitId::new(1), &[0x03, 0x00, 0x6B, 0x00, 0x03], &mut response)
313            .await
314            .unwrap();
315
316        assert_eq!(
317            &response[..len],
318            &[0x03, 0x06, 0x02, 0x2B, 0x00, 0x00, 0x00, 0x64]
319        );
320
321        server.await.unwrap();
322    }
323
324    #[tokio::test]
325    async fn exchange_rejects_mismatched_transaction_id() {
326        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
327        let addr = listener.local_addr().unwrap();
328
329        let server = tokio::spawn(async move {
330            let (mut socket, _) = listener.accept().await.unwrap();
331
332            let mut req = [0u8; 12];
333            socket.read_exact(&mut req).await.unwrap();
334
335            let mut frame = [0u8; 9];
336            let mut w = Writer::new(&mut frame);
337            tcp::encode_frame(&mut w, 2, UnitId::new(1), &[0x83, 0x02]).unwrap();
338            socket.write_all(w.as_written()).await.unwrap();
339        });
340
341        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
342        let mut response = [0u8; 16];
343        let err = transport
344            .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
345            .await
346            .unwrap_err();
347
348        match err {
349            DataLinkError::MismatchedTransactionId { expected, got } => {
350                assert_eq!(expected, 1);
351                assert_eq!(got, 2);
352            }
353            other => panic!("unexpected error: {other:?}"),
354        }
355
356        server.await.unwrap();
357    }
358
359    #[tokio::test]
360    async fn exchange_drains_pdu_on_transaction_mismatch() {
361        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
362        let addr = listener.local_addr().unwrap();
363
364        let server = tokio::spawn(async move {
365            let (mut socket, _) = listener.accept().await.unwrap();
366
367            let mut req = [0u8; 12];
368            socket.read_exact(&mut req).await.unwrap();
369            let mut mismatch = [0u8; 9];
370            let mut mismatch_w = Writer::new(&mut mismatch);
371            tcp::encode_frame(&mut mismatch_w, 2, UnitId::new(1), &[0x83, 0x02]).unwrap();
372            socket.write_all(mismatch_w.as_written()).await.unwrap();
373
374            let mut req2 = [0u8; 12];
375            socket.read_exact(&mut req2).await.unwrap();
376            let mut ok = [0u8; 11];
377            let mut ok_w = Writer::new(&mut ok);
378            tcp::encode_frame(&mut ok_w, 2, UnitId::new(1), &[0x03, 0x02, 0x00, 0x2A]).unwrap();
379            socket.write_all(ok_w.as_written()).await.unwrap();
380        });
381
382        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
383        let mut response = [0u8; 16];
384        let err = transport
385            .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
386            .await
387            .unwrap_err();
388        assert!(matches!(err, DataLinkError::MismatchedTransactionId { .. }));
389
390        let len = transport
391            .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
392            .await
393            .unwrap();
394        assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
395
396        server.await.unwrap();
397    }
398
399    #[tokio::test]
400    async fn exchange_rejects_and_drains_oversized_response_pdu() {
401        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
402        let addr = listener.local_addr().unwrap();
403
404        let server = tokio::spawn(async move {
405            let (mut socket, _) = listener.accept().await.unwrap();
406
407            let mut req = [0u8; 12];
408            socket.read_exact(&mut req).await.unwrap();
409            let mut oversized = vec![0u8; tcp::MBAP_HEADER_LEN + 254];
410            oversized[0..2].copy_from_slice(&1u16.to_be_bytes());
411            oversized[2..4].copy_from_slice(&0u16.to_be_bytes());
412            oversized[4..6].copy_from_slice(&255u16.to_be_bytes());
413            oversized[6] = 1;
414            oversized[7] = 0x03;
415            socket.write_all(&oversized).await.unwrap();
416
417            let mut req2 = [0u8; 12];
418            socket.read_exact(&mut req2).await.unwrap();
419            let mut ok = [0u8; 11];
420            let mut ok_w = Writer::new(&mut ok);
421            tcp::encode_frame(&mut ok_w, 2, UnitId::new(1), &[0x03, 0x02, 0x00, 0x2A]).unwrap();
422            socket.write_all(ok_w.as_written()).await.unwrap();
423        });
424
425        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
426        let mut response = [0u8; 260];
427        let err = transport
428            .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
429            .await
430            .unwrap_err();
431        assert!(matches!(
432            err,
433            DataLinkError::InvalidResponse("response pdu too large")
434        ));
435
436        let len = transport
437            .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
438            .await
439            .unwrap();
440        assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
441
442        server.await.unwrap();
443    }
444}