1#![forbid(unsafe_code)]
4
5use async_trait::async_trait;
6use rustmod_core::encoding::{Reader, Writer};
7use rustmod_core::frame::tcp;
8use rustmod_core::{DecodeError, EncodeError};
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU16, Ordering};
11use thiserror::Error;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::{TcpStream, ToSocketAddrs};
14use tokio::sync::Mutex;
15use tracing::trace;
16
17pub mod server;
18pub mod sim;
19pub use server::{ModbusRtuOverTcpServer, ModbusService, ModbusTcpServer, ServiceError};
20pub use sim::{CoilBank, InMemoryModbusService, InMemoryPointModel, RegisterBank};
21#[cfg(feature = "rtu")]
22pub mod rtu;
23#[cfg(feature = "rtu")]
24pub use rtu::{ModbusRtuConfig, ModbusRtuTransport};
25#[cfg(feature = "rtu")]
26pub mod rtu_server;
27#[cfg(feature = "rtu")]
28pub use rtu_server::{ModbusRtuServer, ModbusRtuServerConfig};
29
30const MAX_TCP_PDU_LEN: usize = 253;
31
32#[derive(Debug, Error)]
33pub enum DataLinkError {
34 #[error("io error: {0}")]
35 Io(#[from] std::io::Error),
36 #[error("encode error: {0}")]
37 Encode(#[from] EncodeError),
38 #[error("decode error: {0}")]
39 Decode(#[from] DecodeError),
40 #[error("connection closed")]
41 ConnectionClosed,
42 #[error("request timed out")]
43 Timeout,
44 #[error("invalid response: {0}")]
45 InvalidResponse(&'static str),
46 #[error("transaction id mismatch: expected {expected}, got {got}")]
47 MismatchedTransactionId { expected: u16, got: u16 },
48 #[error("response buffer too small (needed {needed}, available {available})")]
49 ResponseBufferTooSmall { needed: usize, available: usize },
50}
51
52#[async_trait]
53pub trait DataLink: Send + Sync {
54 async fn exchange(
58 &self,
59 unit_id: u8,
60 request_pdu: &[u8],
61 response_pdu: &mut [u8],
62 ) -> Result<usize, DataLinkError>;
63}
64
65#[derive(Debug)]
66pub struct ModbusTcpTransport {
67 stream: Arc<Mutex<TcpStream>>,
68 next_transaction_id: Arc<AtomicU16>,
69}
70
71impl ModbusTcpTransport {
72 pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self, DataLinkError> {
73 let stream = TcpStream::connect(addr).await?;
74 Ok(Self::from_stream(stream))
75 }
76
77 pub fn from_stream(stream: TcpStream) -> Self {
78 Self {
79 stream: Arc::new(Mutex::new(stream)),
80 next_transaction_id: Arc::new(AtomicU16::new(1)),
81 }
82 }
83
84 fn next_tid(&self) -> u16 {
85 self.next_transaction_id.fetch_add(1, Ordering::Relaxed)
86 }
87}
88
89async fn read_exact_or_connection_closed(
90 stream: &mut TcpStream,
91 buf: &mut [u8],
92) -> Result<(), DataLinkError> {
93 if let Err(err) = stream.read_exact(buf).await {
94 if err.kind() == std::io::ErrorKind::UnexpectedEof {
95 return Err(DataLinkError::ConnectionClosed);
96 }
97 return Err(DataLinkError::Io(err));
98 }
99 Ok(())
100}
101
102async fn drain_exact(stream: &mut TcpStream, mut len: usize) -> Result<(), DataLinkError> {
103 let mut scratch = [0u8; 256];
104 while len > 0 {
105 let chunk = len.min(scratch.len());
106 read_exact_or_connection_closed(stream, &mut scratch[..chunk]).await?;
107 len -= chunk;
108 }
109 Ok(())
110}
111
112#[async_trait]
113impl DataLink for ModbusTcpTransport {
114 async fn exchange(
115 &self,
116 unit_id: u8,
117 request_pdu: &[u8],
118 response_pdu: &mut [u8],
119 ) -> Result<usize, DataLinkError> {
120 if request_pdu.is_empty() {
121 return Err(DataLinkError::InvalidResponse("empty request pdu"));
122 }
123
124 let transaction_id = self.next_tid();
125 let mut req_frame = vec![0u8; tcp::MBAP_HEADER_LEN + request_pdu.len()];
126 let mut writer = Writer::new(&mut req_frame);
127 tcp::encode_frame(&mut writer, transaction_id, unit_id, request_pdu)?;
128
129 let mut stream = self.stream.lock().await;
130 trace!(
131 transaction_id,
132 unit_id,
133 pdu_len = request_pdu.len(),
134 "sending modbus tcp request"
135 );
136 stream.write_all(writer.as_written()).await?;
137
138 let mut mbap = [0u8; tcp::MBAP_HEADER_LEN];
139 read_exact_or_connection_closed(&mut stream, &mut mbap).await?;
140
141 let mut reader = Reader::new(&mbap);
142 let header = tcp::MbapHeader::decode(&mut reader)?;
143
144 let pdu_len = usize::from(header.length)
145 .checked_sub(1)
146 .ok_or(DataLinkError::InvalidResponse("invalid mbap length"))?;
147 if pdu_len == 0 {
148 return Err(DataLinkError::InvalidResponse("empty response pdu"));
149 }
150 let tid_mismatch = header.transaction_id != transaction_id;
151 let unit_mismatch = header.unit_id != unit_id;
152
153 if pdu_len > MAX_TCP_PDU_LEN {
154 drain_exact(&mut stream, pdu_len).await?;
155 if tid_mismatch {
156 return Err(DataLinkError::MismatchedTransactionId {
157 expected: transaction_id,
158 got: header.transaction_id,
159 });
160 }
161 if unit_mismatch {
162 return Err(DataLinkError::InvalidResponse("unit id mismatch"));
163 }
164 return Err(DataLinkError::InvalidResponse("response pdu too large"));
165 }
166
167 if pdu_len > response_pdu.len() {
168 drain_exact(&mut stream, pdu_len).await?;
169 if tid_mismatch {
170 return Err(DataLinkError::MismatchedTransactionId {
171 expected: transaction_id,
172 got: header.transaction_id,
173 });
174 }
175 if unit_mismatch {
176 return Err(DataLinkError::InvalidResponse("unit id mismatch"));
177 }
178 return Err(DataLinkError::ResponseBufferTooSmall {
179 needed: pdu_len,
180 available: response_pdu.len(),
181 });
182 }
183
184 read_exact_or_connection_closed(&mut stream, &mut response_pdu[..pdu_len]).await?;
185 if tid_mismatch {
186 return Err(DataLinkError::MismatchedTransactionId {
187 expected: transaction_id,
188 got: header.transaction_id,
189 });
190 }
191 if unit_mismatch {
192 return Err(DataLinkError::InvalidResponse("unit id mismatch"));
193 }
194 trace!(
195 transaction_id,
196 unit_id,
197 pdu_len,
198 "received modbus tcp response"
199 );
200 Ok(pdu_len)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::{DataLink, DataLinkError, ModbusTcpTransport};
207 use rustmod_core::encoding::Writer;
208 use rustmod_core::frame::tcp;
209 use tokio::io::{AsyncReadExt, AsyncWriteExt};
210 use tokio::net::TcpListener;
211
212 #[tokio::test]
213 async fn exchange_roundtrip_over_tcp() {
214 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
215 let addr = listener.local_addr().unwrap();
216
217 let server = tokio::spawn(async move {
218 let (mut socket, _) = listener.accept().await.unwrap();
219
220 let mut req = [0u8; 12];
221 socket.read_exact(&mut req).await.unwrap();
222 assert_eq!(&req[7..], &[0x03, 0x00, 0x6B, 0x00, 0x03]);
223
224 let mut frame = [0u8; 15];
225 let mut w = Writer::new(&mut frame);
226 tcp::encode_frame(
227 &mut w,
228 1,
229 1,
230 &[0x03, 0x06, 0x02, 0x2B, 0x00, 0x00, 0x00, 0x64],
231 )
232 .unwrap();
233 socket.write_all(w.as_written()).await.unwrap();
234 });
235
236 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
237 let mut response = [0u8; 256];
238 let len = transport
239 .exchange(1, &[0x03, 0x00, 0x6B, 0x00, 0x03], &mut response)
240 .await
241 .unwrap();
242
243 assert_eq!(
244 &response[..len],
245 &[0x03, 0x06, 0x02, 0x2B, 0x00, 0x00, 0x00, 0x64]
246 );
247
248 server.await.unwrap();
249 }
250
251 #[tokio::test]
252 async fn exchange_rejects_mismatched_transaction_id() {
253 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
254 let addr = listener.local_addr().unwrap();
255
256 let server = tokio::spawn(async move {
257 let (mut socket, _) = listener.accept().await.unwrap();
258
259 let mut req = [0u8; 12];
260 socket.read_exact(&mut req).await.unwrap();
261
262 let mut frame = [0u8; 9];
263 let mut w = Writer::new(&mut frame);
264 tcp::encode_frame(&mut w, 2, 1, &[0x83, 0x02]).unwrap();
265 socket.write_all(w.as_written()).await.unwrap();
266 });
267
268 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
269 let mut response = [0u8; 16];
270 let err = transport
271 .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
272 .await
273 .unwrap_err();
274
275 match err {
276 DataLinkError::MismatchedTransactionId { expected, got } => {
277 assert_eq!(expected, 1);
278 assert_eq!(got, 2);
279 }
280 other => panic!("unexpected error: {other:?}"),
281 }
282
283 server.await.unwrap();
284 }
285
286 #[tokio::test]
287 async fn exchange_drains_pdu_on_transaction_mismatch() {
288 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
289 let addr = listener.local_addr().unwrap();
290
291 let server = tokio::spawn(async move {
292 let (mut socket, _) = listener.accept().await.unwrap();
293
294 let mut req = [0u8; 12];
295 socket.read_exact(&mut req).await.unwrap();
296 let mut mismatch = [0u8; 9];
297 let mut mismatch_w = Writer::new(&mut mismatch);
298 tcp::encode_frame(&mut mismatch_w, 2, 1, &[0x83, 0x02]).unwrap();
299 socket.write_all(mismatch_w.as_written()).await.unwrap();
300
301 let mut req2 = [0u8; 12];
302 socket.read_exact(&mut req2).await.unwrap();
303 let mut ok = [0u8; 11];
304 let mut ok_w = Writer::new(&mut ok);
305 tcp::encode_frame(&mut ok_w, 2, 1, &[0x03, 0x02, 0x00, 0x2A]).unwrap();
306 socket.write_all(ok_w.as_written()).await.unwrap();
307 });
308
309 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
310 let mut response = [0u8; 16];
311 let err = transport
312 .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
313 .await
314 .unwrap_err();
315 assert!(matches!(err, DataLinkError::MismatchedTransactionId { .. }));
316
317 let len = transport
318 .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
319 .await
320 .unwrap();
321 assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
322
323 server.await.unwrap();
324 }
325
326 #[tokio::test]
327 async fn exchange_rejects_and_drains_oversized_response_pdu() {
328 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
329 let addr = listener.local_addr().unwrap();
330
331 let server = tokio::spawn(async move {
332 let (mut socket, _) = listener.accept().await.unwrap();
333
334 let mut req = [0u8; 12];
335 socket.read_exact(&mut req).await.unwrap();
336 let mut oversized = vec![0u8; tcp::MBAP_HEADER_LEN + 254];
337 oversized[0..2].copy_from_slice(&1u16.to_be_bytes());
338 oversized[2..4].copy_from_slice(&0u16.to_be_bytes());
339 oversized[4..6].copy_from_slice(&255u16.to_be_bytes());
340 oversized[6] = 1;
341 oversized[7] = 0x03;
342 socket.write_all(&oversized).await.unwrap();
343
344 let mut req2 = [0u8; 12];
345 socket.read_exact(&mut req2).await.unwrap();
346 let mut ok = [0u8; 11];
347 let mut ok_w = Writer::new(&mut ok);
348 tcp::encode_frame(&mut ok_w, 2, 1, &[0x03, 0x02, 0x00, 0x2A]).unwrap();
349 socket.write_all(ok_w.as_written()).await.unwrap();
350 });
351
352 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
353 let mut response = [0u8; 260];
354 let err = transport
355 .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
356 .await
357 .unwrap_err();
358 assert!(matches!(
359 err,
360 DataLinkError::InvalidResponse("response pdu too large")
361 ));
362
363 let len = transport
364 .exchange(1, &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
365 .await
366 .unwrap();
367 assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
368
369 server.await.unwrap();
370 }
371}