Skip to main content

rustmod_datalink/
lib.rs

1//! Async Modbus transport abstraction layer.
2
3#![forbid(unsafe_code)]
4
5use async_trait::async_trait;
6use rustmod_core::encoding::{Reader, Writer};
7use rustmod_core::frame::tcp;
8use rustmod_core::{DecodeError, EncodeError};
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU16, Ordering};
11use thiserror::Error;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::{TcpStream, ToSocketAddrs};
14use tokio::sync::Mutex;
15use tracing::trace;
16
17pub mod server;
18pub mod sim;
19pub use server::{ModbusRtuOverTcpServer, ModbusService, ModbusTcpServer, ServiceError};
20pub use sim::{CoilBank, InMemoryModbusService, InMemoryPointModel, RegisterBank};
21#[cfg(feature = "rtu")]
22pub mod rtu;
23#[cfg(feature = "rtu")]
24pub use rtu::{ModbusRtuConfig, ModbusRtuTransport};
25#[cfg(feature = "rtu")]
26pub mod rtu_server;
27#[cfg(feature = "rtu")]
28pub use rtu_server::{ModbusRtuServer, ModbusRtuServerConfig};
29
30const MAX_TCP_PDU_LEN: usize = 253;
31
32#[derive(Debug, Error)]
33pub enum DataLinkError {
34    #[error("io error: {0}")]
35    Io(#[from] std::io::Error),
36    #[error("encode error: {0}")]
37    Encode(#[from] EncodeError),
38    #[error("decode error: {0}")]
39    Decode(#[from] DecodeError),
40    #[error("connection closed")]
41    ConnectionClosed,
42    #[error("request timed out")]
43    Timeout,
44    #[error("invalid response: {0}")]
45    InvalidResponse(&'static str),
46    #[error("transaction id mismatch: expected {expected}, got {got}")]
47    MismatchedTransactionId { expected: u16, got: u16 },
48    #[error("response buffer too small (needed {needed}, available {available})")]
49    ResponseBufferTooSmall { needed: usize, available: usize },
50}
51
52#[async_trait]
53pub trait DataLink: Send + Sync {
54    /// Send a request PDU to a unit and write the response PDU into `response_pdu`.
55    ///
56    /// Returns the number of response bytes written to `response_pdu`.
57    async fn exchange(
58        &self,
59        unit_id: u8,
60        request_pdu: &[u8],
61        response_pdu: &mut [u8],
62    ) -> Result<usize, DataLinkError>;
63}
64
65#[derive(Debug)]
66pub struct ModbusTcpTransport {
67    stream: Arc<Mutex<TcpStream>>,
68    next_transaction_id: Arc<AtomicU16>,
69}
70
71impl ModbusTcpTransport {
72    pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self, DataLinkError> {
73        let stream = TcpStream::connect(addr).await?;
74        Ok(Self::from_stream(stream))
75    }
76
77    pub fn from_stream(stream: TcpStream) -> Self {
78        Self {
79            stream: Arc::new(Mutex::new(stream)),
80            next_transaction_id: Arc::new(AtomicU16::new(1)),
81        }
82    }
83
84    fn next_tid(&self) -> u16 {
85        self.next_transaction_id.fetch_add(1, Ordering::Relaxed)
86    }
87}
88
89async fn read_exact_or_connection_closed(
90    stream: &mut TcpStream,
91    buf: &mut [u8],
92) -> Result<(), DataLinkError> {
93    if let Err(err) = stream.read_exact(buf).await {
94        if err.kind() == std::io::ErrorKind::UnexpectedEof {
95            return Err(DataLinkError::ConnectionClosed);
96        }
97        return Err(DataLinkError::Io(err));
98    }
99    Ok(())
100}
101
102async fn drain_exact(stream: &mut TcpStream, mut len: usize) -> Result<(), DataLinkError> {
103    let mut scratch = [0u8; 256];
104    while len > 0 {
105        let chunk = len.min(scratch.len());
106        read_exact_or_connection_closed(stream, &mut scratch[..chunk]).await?;
107        len -= chunk;
108    }
109    Ok(())
110}
111
112#[async_trait]
113impl DataLink for ModbusTcpTransport {
114    async fn exchange(
115        &self,
116        unit_id: u8,
117        request_pdu: &[u8],
118        response_pdu: &mut [u8],
119    ) -> Result<usize, DataLinkError> {
120        if request_pdu.is_empty() {
121            return Err(DataLinkError::InvalidResponse("empty request pdu"));
122        }
123
124        let transaction_id = self.next_tid();
125        let mut req_frame = vec![0u8; tcp::MBAP_HEADER_LEN + request_pdu.len()];
126        let mut writer = Writer::new(&mut req_frame);
127        tcp::encode_frame(&mut writer, transaction_id, unit_id, request_pdu)?;
128
129        let mut stream = self.stream.lock().await;
130        trace!(
131            transaction_id,
132            unit_id,
133            pdu_len = request_pdu.len(),
134            "sending modbus tcp request"
135        );
136        stream.write_all(writer.as_written()).await?;
137
138        let mut mbap = [0u8; tcp::MBAP_HEADER_LEN];
139        read_exact_or_connection_closed(&mut stream, &mut mbap).await?;
140
141        let mut reader = Reader::new(&mbap);
142        let header = tcp::MbapHeader::decode(&mut reader)?;
143
144        let pdu_len = usize::from(header.length)
145            .checked_sub(1)
146            .ok_or(DataLinkError::InvalidResponse("invalid mbap length"))?;
147        if pdu_len == 0 {
148            return Err(DataLinkError::InvalidResponse("empty response pdu"));
149        }
150        let tid_mismatch = header.transaction_id != transaction_id;
151        let unit_mismatch = header.unit_id != unit_id;
152
153        if pdu_len > MAX_TCP_PDU_LEN {
154            drain_exact(&mut stream, pdu_len).await?;
155            if tid_mismatch {
156                return Err(DataLinkError::MismatchedTransactionId {
157                    expected: transaction_id,
158                    got: header.transaction_id,
159                });
160            }
161            if unit_mismatch {
162                return Err(DataLinkError::InvalidResponse("unit id mismatch"));
163            }
164            return Err(DataLinkError::InvalidResponse("response pdu too large"));
165        }
166
167        if pdu_len > response_pdu.len() {
168            drain_exact(&mut stream, pdu_len).await?;
169            if tid_mismatch {
170                return Err(DataLinkError::MismatchedTransactionId {
171                    expected: transaction_id,
172                    got: header.transaction_id,
173                });
174            }
175            if unit_mismatch {
176                return Err(DataLinkError::InvalidResponse("unit id mismatch"));
177            }
178            return Err(DataLinkError::ResponseBufferTooSmall {
179                needed: pdu_len,
180                available: response_pdu.len(),
181            });
182        }
183
184        read_exact_or_connection_closed(&mut stream, &mut response_pdu[..pdu_len]).await?;
185        if tid_mismatch {
186            return Err(DataLinkError::MismatchedTransactionId {
187                expected: transaction_id,
188                got: header.transaction_id,
189            });
190        }
191        if unit_mismatch {
192            return Err(DataLinkError::InvalidResponse("unit id mismatch"));
193        }
194        trace!(
195            transaction_id,
196            unit_id,
197            pdu_len,
198            "received modbus tcp response"
199        );
200        Ok(pdu_len)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::{DataLink, DataLinkError, ModbusTcpTransport};
207    use rustmod_core::encoding::Writer;
208    use rustmod_core::frame::tcp;
209    use tokio::io::{AsyncReadExt, AsyncWriteExt};
210    use tokio::net::TcpListener;
211
212    #[tokio::test]
213    async fn exchange_roundtrip_over_tcp() {
214        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
215        let addr = listener.local_addr().unwrap();
216
217        let server = tokio::spawn(async move {
218            let (mut socket, _) = listener.accept().await.unwrap();
219
220            let mut req = [0u8; 12];
221            socket.read_exact(&mut req).await.unwrap();
222            assert_eq!(&req[7..], &[0x03, 0x00, 0x6B, 0x00, 0x03]);
223
224            let mut frame = [0u8; 15];
225            let mut w = Writer::new(&mut frame);
226            tcp::encode_frame(
227                &mut w,
228                1,
229                1,
230                &[0x03, 0x06, 0x02, 0x2B, 0x00, 0x00, 0x00, 0x64],
231            )
232            .unwrap();
233            socket.write_all(w.as_written()).await.unwrap();
234        });
235
236        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
237        let mut response = [0u8; 256];
238        let len = transport
239            .exchange(1, &[0x03, 0x00, 0x6B, 0x00, 0x03], &mut response)
240            .await
241            .unwrap();
242
243        assert_eq!(
244            &response[..len],
245            &[0x03, 0x06, 0x02, 0x2B, 0x00, 0x00, 0x00, 0x64]
246        );
247
248        server.await.unwrap();
249    }
250
251    #[tokio::test]
252    async fn exchange_rejects_mismatched_transaction_id() {
253        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
254        let addr = listener.local_addr().unwrap();
255
256        let server = tokio::spawn(async move {
257            let (mut socket, _) = listener.accept().await.unwrap();
258
259            let mut req = [0u8; 12];
260            socket.read_exact(&mut req).await.unwrap();
261
262            let mut frame = [0u8; 9];
263            let mut w = Writer::new(&mut frame);
264            tcp::encode_frame(&mut w, 2, 1, &[0x83, 0x02]).unwrap();
265            socket.write_all(w.as_written()).await.unwrap();
266        });
267
268        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
269        let mut response = [0u8; 16];
270        let err = transport
271            .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
272            .await
273            .unwrap_err();
274
275        match err {
276            DataLinkError::MismatchedTransactionId { expected, got } => {
277                assert_eq!(expected, 1);
278                assert_eq!(got, 2);
279            }
280            other => panic!("unexpected error: {other:?}"),
281        }
282
283        server.await.unwrap();
284    }
285
286    #[tokio::test]
287    async fn exchange_drains_pdu_on_transaction_mismatch() {
288        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
289        let addr = listener.local_addr().unwrap();
290
291        let server = tokio::spawn(async move {
292            let (mut socket, _) = listener.accept().await.unwrap();
293
294            let mut req = [0u8; 12];
295            socket.read_exact(&mut req).await.unwrap();
296            let mut mismatch = [0u8; 9];
297            let mut mismatch_w = Writer::new(&mut mismatch);
298            tcp::encode_frame(&mut mismatch_w, 2, 1, &[0x83, 0x02]).unwrap();
299            socket.write_all(mismatch_w.as_written()).await.unwrap();
300
301            let mut req2 = [0u8; 12];
302            socket.read_exact(&mut req2).await.unwrap();
303            let mut ok = [0u8; 11];
304            let mut ok_w = Writer::new(&mut ok);
305            tcp::encode_frame(&mut ok_w, 2, 1, &[0x03, 0x02, 0x00, 0x2A]).unwrap();
306            socket.write_all(ok_w.as_written()).await.unwrap();
307        });
308
309        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
310        let mut response = [0u8; 16];
311        let err = transport
312            .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
313            .await
314            .unwrap_err();
315        assert!(matches!(err, DataLinkError::MismatchedTransactionId { .. }));
316
317        let len = transport
318            .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
319            .await
320            .unwrap();
321        assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
322
323        server.await.unwrap();
324    }
325
326    #[tokio::test]
327    async fn exchange_rejects_and_drains_oversized_response_pdu() {
328        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
329        let addr = listener.local_addr().unwrap();
330
331        let server = tokio::spawn(async move {
332            let (mut socket, _) = listener.accept().await.unwrap();
333
334            let mut req = [0u8; 12];
335            socket.read_exact(&mut req).await.unwrap();
336            let mut oversized = vec![0u8; tcp::MBAP_HEADER_LEN + 254];
337            oversized[0..2].copy_from_slice(&1u16.to_be_bytes());
338            oversized[2..4].copy_from_slice(&0u16.to_be_bytes());
339            oversized[4..6].copy_from_slice(&255u16.to_be_bytes());
340            oversized[6] = 1;
341            oversized[7] = 0x03;
342            socket.write_all(&oversized).await.unwrap();
343
344            let mut req2 = [0u8; 12];
345            socket.read_exact(&mut req2).await.unwrap();
346            let mut ok = [0u8; 11];
347            let mut ok_w = Writer::new(&mut ok);
348            tcp::encode_frame(&mut ok_w, 2, 1, &[0x03, 0x02, 0x00, 0x2A]).unwrap();
349            socket.write_all(ok_w.as_written()).await.unwrap();
350        });
351
352        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
353        let mut response = [0u8; 260];
354        let err = transport
355            .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
356            .await
357            .unwrap_err();
358        assert!(matches!(
359            err,
360            DataLinkError::InvalidResponse("response pdu too large")
361        ));
362
363        let len = transport
364            .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
365            .await
366            .unwrap();
367        assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
368
369        server.await.unwrap();
370    }
371}