1#![forbid(unsafe_code)]
13
14use async_trait::async_trait;
15use rustmod_core::encoding::{Reader, Writer};
16use rustmod_core::frame::tcp;
17pub use rustmod_core::UnitId;
18use rustmod_core::{DecodeError, EncodeError};
19use std::sync::Arc;
20use std::sync::atomic::{AtomicU16, Ordering};
21use thiserror::Error;
22use tokio::io::{AsyncReadExt, AsyncWriteExt};
23use tokio::net::{TcpStream, ToSocketAddrs};
24use tokio::sync::Mutex;
25use tracing::trace;
26
27pub mod server;
28pub mod sim;
29pub use server::{ModbusRtuOverTcpServer, ModbusService, ModbusTcpServer, ServiceError};
30pub use sim::{CoilBank, InMemoryModbusService, InMemoryPointModel, RegisterBank};
31#[cfg(feature = "rtu")]
32pub mod rtu;
33#[cfg(feature = "rtu")]
34pub use rtu::{ModbusRtuConfig, ModbusRtuTransport};
35#[cfg(feature = "rtu")]
36pub mod rtu_server;
37#[cfg(feature = "rtu")]
38pub use rtu_server::{ModbusRtuServer, ModbusRtuServerConfig};
39
40const MAX_TCP_PDU_LEN: usize = 253;
41
42#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum DataLinkError {
46 #[error("io error: {0}")]
48 Io(#[from] std::io::Error),
49 #[error("encode error: {0}")]
51 Encode(#[from] EncodeError),
52 #[error("decode error: {0}")]
54 Decode(#[from] DecodeError),
55 #[error("connection closed")]
57 ConnectionClosed,
58 #[error("request timed out")]
60 Timeout,
61 #[error("invalid response: {0}")]
63 InvalidResponse(&'static str),
64 #[error("transaction id mismatch: expected {expected}, got {got}")]
66 MismatchedTransactionId { expected: u16, got: u16 },
67 #[error("response buffer too small (needed {needed}, available {available})")]
69 ResponseBufferTooSmall { needed: usize, available: usize },
70}
71
72#[async_trait]
77pub trait DataLink: Send + Sync {
78 async fn exchange(
82 &self,
83 unit_id: UnitId,
84 request_pdu: &[u8],
85 response_pdu: &mut [u8],
86 ) -> Result<usize, DataLinkError>;
87
88 async fn reconnect(&self) -> Result<(), DataLinkError> {
94 Ok(())
95 }
96
97 fn is_connected(&self) -> bool {
101 true
102 }
103}
104
105#[derive(Debug)]
110pub struct ModbusTcpTransport {
111 stream: Arc<Mutex<TcpStream>>,
112 next_transaction_id: Arc<AtomicU16>,
113 peer_addr: std::net::SocketAddr,
114}
115
116impl ModbusTcpTransport {
117 pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self, DataLinkError> {
119 let stream = TcpStream::connect(addr).await?;
120 let peer_addr = stream.peer_addr()?;
121 Ok(Self {
122 stream: Arc::new(Mutex::new(stream)),
123 next_transaction_id: Arc::new(AtomicU16::new(1)),
124 peer_addr,
125 })
126 }
127
128 pub fn from_stream(stream: TcpStream) -> Self {
130 let peer_addr = stream
131 .peer_addr()
132 .expect("TcpStream must have a peer address");
133 Self {
134 stream: Arc::new(Mutex::new(stream)),
135 next_transaction_id: Arc::new(AtomicU16::new(1)),
136 peer_addr,
137 }
138 }
139
140 fn next_tid(&self) -> u16 {
141 self.next_transaction_id.fetch_add(1, Ordering::Relaxed)
142 }
143}
144
145async fn read_exact_or_connection_closed(
146 stream: &mut TcpStream,
147 buf: &mut [u8],
148) -> Result<(), DataLinkError> {
149 if let Err(err) = stream.read_exact(buf).await {
150 if err.kind() == std::io::ErrorKind::UnexpectedEof {
151 return Err(DataLinkError::ConnectionClosed);
152 }
153 return Err(DataLinkError::Io(err));
154 }
155 Ok(())
156}
157
158async fn drain_exact(stream: &mut TcpStream, mut len: usize) -> Result<(), DataLinkError> {
159 let mut scratch = [0u8; 256];
160 while len > 0 {
161 let chunk = len.min(scratch.len());
162 read_exact_or_connection_closed(stream, &mut scratch[..chunk]).await?;
163 len -= chunk;
164 }
165 Ok(())
166}
167
168#[async_trait]
169impl DataLink for ModbusTcpTransport {
170 async fn reconnect(&self) -> Result<(), DataLinkError> {
171 let new_stream = TcpStream::connect(self.peer_addr).await?;
172 let mut guard = self.stream.lock().await;
173 *guard = new_stream;
174 tracing::info!(peer = %self.peer_addr, "reconnected modbus tcp transport");
175 Ok(())
176 }
177
178 async fn exchange(
179 &self,
180 unit_id: UnitId,
181 request_pdu: &[u8],
182 response_pdu: &mut [u8],
183 ) -> Result<usize, DataLinkError> {
184 if request_pdu.is_empty() {
185 return Err(DataLinkError::InvalidResponse("empty request pdu"));
186 }
187
188 let transaction_id = self.next_tid();
189 let mut req_frame = vec![0u8; tcp::MBAP_HEADER_LEN + request_pdu.len()];
190 let mut writer = Writer::new(&mut req_frame);
191 tcp::encode_frame(&mut writer, transaction_id, unit_id, request_pdu)?;
192
193 let mut stream = self.stream.lock().await;
194 trace!(
195 transaction_id,
196 unit_id = unit_id.as_u8(),
197 pdu_len = request_pdu.len(),
198 "sending modbus tcp request"
199 );
200 stream.write_all(writer.as_written()).await?;
201
202 let mut mbap = [0u8; tcp::MBAP_HEADER_LEN];
203 read_exact_or_connection_closed(&mut stream, &mut mbap).await?;
204
205 let mut reader = Reader::new(&mbap);
206 let header = tcp::MbapHeader::decode(&mut reader)?;
207
208 let pdu_len = usize::from(header.length)
209 .checked_sub(1)
210 .ok_or(DataLinkError::InvalidResponse("invalid mbap length"))?;
211 if pdu_len == 0 {
212 return Err(DataLinkError::InvalidResponse("empty response pdu"));
213 }
214 let tid_mismatch = header.transaction_id != transaction_id;
215 let unit_mismatch = header.unit_id != unit_id;
216
217 if pdu_len > MAX_TCP_PDU_LEN {
218 drain_exact(&mut stream, pdu_len).await?;
219 if tid_mismatch {
220 return Err(DataLinkError::MismatchedTransactionId {
221 expected: transaction_id,
222 got: header.transaction_id,
223 });
224 }
225 if unit_mismatch {
226 return Err(DataLinkError::InvalidResponse("unit id mismatch"));
227 }
228 return Err(DataLinkError::InvalidResponse("response pdu too large"));
229 }
230
231 if pdu_len > response_pdu.len() {
232 drain_exact(&mut stream, pdu_len).await?;
233 if tid_mismatch {
234 return Err(DataLinkError::MismatchedTransactionId {
235 expected: transaction_id,
236 got: header.transaction_id,
237 });
238 }
239 if unit_mismatch {
240 return Err(DataLinkError::InvalidResponse("unit id mismatch"));
241 }
242 return Err(DataLinkError::ResponseBufferTooSmall {
243 needed: pdu_len,
244 available: response_pdu.len(),
245 });
246 }
247
248 read_exact_or_connection_closed(&mut stream, &mut response_pdu[..pdu_len]).await?;
249 if tid_mismatch {
250 return Err(DataLinkError::MismatchedTransactionId {
251 expected: transaction_id,
252 got: header.transaction_id,
253 });
254 }
255 if unit_mismatch {
256 return Err(DataLinkError::InvalidResponse("unit id mismatch"));
257 }
258 trace!(
259 transaction_id,
260 unit_id = unit_id.as_u8(),
261 pdu_len,
262 "received modbus tcp response"
263 );
264 Ok(pdu_len)
265 }
266}
267
268#[cfg(test)]
269const _: () = {
270 fn _assert_send_sync<T: Send + Sync>() {}
271 fn _assertions() {
272 _assert_send_sync::<ModbusTcpTransport>();
273 }
274};
275
276#[cfg(test)]
277mod tests {
278 use super::{DataLink, DataLinkError, ModbusTcpTransport};
279 use rustmod_core::encoding::Writer;
280 use rustmod_core::frame::tcp;
281 use rustmod_core::UnitId;
282 use tokio::io::{AsyncReadExt, AsyncWriteExt};
283 use tokio::net::TcpListener;
284
285 #[tokio::test]
286 async fn exchange_roundtrip_over_tcp() {
287 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
288 let addr = listener.local_addr().unwrap();
289
290 let server = tokio::spawn(async move {
291 let (mut socket, _) = listener.accept().await.unwrap();
292
293 let mut req = [0u8; 12];
294 socket.read_exact(&mut req).await.unwrap();
295 assert_eq!(&req[7..], &[0x03, 0x00, 0x6B, 0x00, 0x03]);
296
297 let mut frame = [0u8; 15];
298 let mut w = Writer::new(&mut frame);
299 tcp::encode_frame(
300 &mut w,
301 1,
302 UnitId::new(1),
303 &[0x03, 0x06, 0x02, 0x2B, 0x00, 0x00, 0x00, 0x64],
304 )
305 .unwrap();
306 socket.write_all(w.as_written()).await.unwrap();
307 });
308
309 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
310 let mut response = [0u8; 256];
311 let len = transport
312 .exchange(UnitId::new(1), &[0x03, 0x00, 0x6B, 0x00, 0x03], &mut response)
313 .await
314 .unwrap();
315
316 assert_eq!(
317 &response[..len],
318 &[0x03, 0x06, 0x02, 0x2B, 0x00, 0x00, 0x00, 0x64]
319 );
320
321 server.await.unwrap();
322 }
323
324 #[tokio::test]
325 async fn exchange_rejects_mismatched_transaction_id() {
326 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
327 let addr = listener.local_addr().unwrap();
328
329 let server = tokio::spawn(async move {
330 let (mut socket, _) = listener.accept().await.unwrap();
331
332 let mut req = [0u8; 12];
333 socket.read_exact(&mut req).await.unwrap();
334
335 let mut frame = [0u8; 9];
336 let mut w = Writer::new(&mut frame);
337 tcp::encode_frame(&mut w, 2, UnitId::new(1), &[0x83, 0x02]).unwrap();
338 socket.write_all(w.as_written()).await.unwrap();
339 });
340
341 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
342 let mut response = [0u8; 16];
343 let err = transport
344 .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
345 .await
346 .unwrap_err();
347
348 match err {
349 DataLinkError::MismatchedTransactionId { expected, got } => {
350 assert_eq!(expected, 1);
351 assert_eq!(got, 2);
352 }
353 other => panic!("unexpected error: {other:?}"),
354 }
355
356 server.await.unwrap();
357 }
358
359 #[tokio::test]
360 async fn exchange_drains_pdu_on_transaction_mismatch() {
361 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
362 let addr = listener.local_addr().unwrap();
363
364 let server = tokio::spawn(async move {
365 let (mut socket, _) = listener.accept().await.unwrap();
366
367 let mut req = [0u8; 12];
368 socket.read_exact(&mut req).await.unwrap();
369 let mut mismatch = [0u8; 9];
370 let mut mismatch_w = Writer::new(&mut mismatch);
371 tcp::encode_frame(&mut mismatch_w, 2, UnitId::new(1), &[0x83, 0x02]).unwrap();
372 socket.write_all(mismatch_w.as_written()).await.unwrap();
373
374 let mut req2 = [0u8; 12];
375 socket.read_exact(&mut req2).await.unwrap();
376 let mut ok = [0u8; 11];
377 let mut ok_w = Writer::new(&mut ok);
378 tcp::encode_frame(&mut ok_w, 2, UnitId::new(1), &[0x03, 0x02, 0x00, 0x2A]).unwrap();
379 socket.write_all(ok_w.as_written()).await.unwrap();
380 });
381
382 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
383 let mut response = [0u8; 16];
384 let err = transport
385 .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
386 .await
387 .unwrap_err();
388 assert!(matches!(err, DataLinkError::MismatchedTransactionId { .. }));
389
390 let len = transport
391 .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
392 .await
393 .unwrap();
394 assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
395
396 server.await.unwrap();
397 }
398
399 #[tokio::test]
400 async fn exchange_rejects_and_drains_oversized_response_pdu() {
401 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
402 let addr = listener.local_addr().unwrap();
403
404 let server = tokio::spawn(async move {
405 let (mut socket, _) = listener.accept().await.unwrap();
406
407 let mut req = [0u8; 12];
408 socket.read_exact(&mut req).await.unwrap();
409 let mut oversized = vec![0u8; tcp::MBAP_HEADER_LEN + 254];
410 oversized[0..2].copy_from_slice(&1u16.to_be_bytes());
411 oversized[2..4].copy_from_slice(&0u16.to_be_bytes());
412 oversized[4..6].copy_from_slice(&255u16.to_be_bytes());
413 oversized[6] = 1;
414 oversized[7] = 0x03;
415 socket.write_all(&oversized).await.unwrap();
416
417 let mut req2 = [0u8; 12];
418 socket.read_exact(&mut req2).await.unwrap();
419 let mut ok = [0u8; 11];
420 let mut ok_w = Writer::new(&mut ok);
421 tcp::encode_frame(&mut ok_w, 2, UnitId::new(1), &[0x03, 0x02, 0x00, 0x2A]).unwrap();
422 socket.write_all(ok_w.as_written()).await.unwrap();
423 });
424
425 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
426 let mut response = [0u8; 260];
427 let err = transport
428 .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
429 .await
430 .unwrap_err();
431 assert!(matches!(
432 err,
433 DataLinkError::InvalidResponse("response pdu too large")
434 ));
435
436 let len = transport
437 .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
438 .await
439 .unwrap();
440 assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
441
442 server.await.unwrap();
443 }
444}