1use crate::DataLinkError;
2use rustmod_core::encoding::{Reader, Writer};
3use rustmod_core::frame::{rtu as rtu_frame, tcp};
4use rustmod_core::pdu::{DecodedRequest, ExceptionCode, ExceptionResponse};
5use rustmod_core::{DecodeError, UnitId};
6use std::future::Future;
7use std::sync::Arc;
8use thiserror::Error;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
11use tokio::sync::Semaphore;
12use tracing::{debug, warn};
13
14#[cfg(feature = "metrics")]
15use std::sync::atomic::{AtomicU64, Ordering};
16
17const DEFAULT_MAX_PDU_LEN: usize = 253;
18const DEFAULT_MAX_RTU_FRAME_LEN: usize = 256;
19
20#[derive(Debug, Error)]
24#[non_exhaustive]
25pub enum ServiceError {
26 #[error("modbus exception: {0:?}")]
28 Exception(ExceptionCode),
29 #[error("invalid request: {0}")]
31 InvalidRequest(&'static str),
32 #[error("internal error: {0}")]
34 Internal(&'static str),
35}
36
37pub trait ModbusService: Send + Sync + 'static {
43 fn handle(
48 &self,
49 unit_id: UnitId,
50 request: DecodedRequest<'_>,
51 response_pdu: &mut [u8],
52 ) -> Result<usize, ServiceError>;
53}
54
55impl<T> ModbusService for Arc<T>
56where
57 T: ModbusService + ?Sized,
58{
59 fn handle(
60 &self,
61 unit_id: UnitId,
62 request: DecodedRequest<'_>,
63 response_pdu: &mut [u8],
64 ) -> Result<usize, ServiceError> {
65 (**self).handle(unit_id, request, response_pdu)
66 }
67}
68
69#[cfg(feature = "metrics")]
71#[derive(Debug, Default)]
72pub struct ServerMetrics {
73 requests_total: AtomicU64,
74 responses_ok: AtomicU64,
75 exceptions_sent: AtomicU64,
76 decode_errors: AtomicU64,
77 internal_errors: AtomicU64,
78}
79
80#[cfg(feature = "metrics")]
82#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
83pub struct ServerMetricsSnapshot {
84 pub requests_total: u64,
85 pub responses_ok: u64,
86 pub exceptions_sent: u64,
87 pub decode_errors: u64,
88 pub internal_errors: u64,
89}
90
91#[cfg(feature = "metrics")]
92impl ServerMetrics {
93 fn snapshot(&self) -> ServerMetricsSnapshot {
94 ServerMetricsSnapshot {
95 requests_total: self.requests_total.load(Ordering::Relaxed),
96 responses_ok: self.responses_ok.load(Ordering::Relaxed),
97 exceptions_sent: self.exceptions_sent.load(Ordering::Relaxed),
98 decode_errors: self.decode_errors.load(Ordering::Relaxed),
99 internal_errors: self.internal_errors.load(Ordering::Relaxed),
100 }
101 }
102}
103
104const DEFAULT_MAX_CONNECTIONS: usize = 256;
105
106pub struct ModbusTcpServer<S> {
112 listener: TcpListener,
113 service: Arc<S>,
114 max_pdu_len: usize,
115 max_connections: usize,
116 #[cfg(feature = "metrics")]
117 metrics: Arc<ServerMetrics>,
118}
119
120impl<S: ModbusService> ModbusTcpServer<S> {
121 pub async fn bind<A: ToSocketAddrs>(addr: A, service: S) -> Result<Self, DataLinkError> {
123 let listener = TcpListener::bind(addr).await?;
124 Ok(Self::from_listener(listener, service))
125 }
126
127 #[must_use]
129 pub fn from_listener(listener: TcpListener, service: S) -> Self {
130 Self {
131 listener,
132 service: Arc::new(service),
133 max_pdu_len: DEFAULT_MAX_PDU_LEN,
134 max_connections: DEFAULT_MAX_CONNECTIONS,
135 #[cfg(feature = "metrics")]
136 metrics: Arc::new(ServerMetrics::default()),
137 }
138 }
139
140 pub fn local_addr(&self) -> Result<std::net::SocketAddr, DataLinkError> {
142 Ok(self.listener.local_addr()?)
143 }
144
145 #[must_use]
147 pub fn with_max_pdu_len(mut self, max_pdu_len: usize) -> Self {
148 self.max_pdu_len = max_pdu_len;
149 self
150 }
151
152 #[must_use]
154 pub fn with_max_connections(mut self, max_connections: usize) -> Self {
155 self.max_connections = max_connections;
156 self
157 }
158
159 #[cfg(feature = "metrics")]
161 pub fn metrics_handle(&self) -> Arc<ServerMetrics> {
162 Arc::clone(&self.metrics)
163 }
164
165 #[cfg(feature = "metrics")]
167 pub fn metrics_snapshot(&self) -> ServerMetricsSnapshot {
168 self.metrics.snapshot()
169 }
170
171 pub async fn run(self) -> Result<(), DataLinkError> {
173 let semaphore = Arc::new(Semaphore::new(self.max_connections));
174 loop {
175 let (socket, peer) = self.listener.accept().await?;
176 let service = Arc::clone(&self.service);
177 let max_pdu_len = self.max_pdu_len;
178 let permit = Arc::clone(&semaphore);
179 #[cfg(feature = "metrics")]
180 let metrics = Arc::clone(&self.metrics);
181
182 tokio::spawn(async move {
183 let _permit = permit.acquire().await;
184 if let Err(err) = handle_connection(
185 socket,
186 service,
187 max_pdu_len,
188 #[cfg(feature = "metrics")]
189 metrics,
190 )
191 .await
192 {
193 warn!(%peer, error = %err, "modbus tcp server connection ended with error");
194 }
195 });
196 }
197 }
198
199 pub async fn run_until(self, shutdown: impl Future<Output = ()> + Send) -> Result<(), DataLinkError> {
201 let semaphore = Arc::new(Semaphore::new(self.max_connections));
202 tokio::pin!(shutdown);
203 loop {
204 tokio::select! {
205 result = self.listener.accept() => {
206 let (socket, peer) = result?;
207 let service = Arc::clone(&self.service);
208 let max_pdu_len = self.max_pdu_len;
209 let permit = Arc::clone(&semaphore);
210 #[cfg(feature = "metrics")]
211 let metrics = Arc::clone(&self.metrics);
212
213 tokio::spawn(async move {
214 let _permit = permit.acquire().await;
215 if let Err(err) = handle_connection(
216 socket,
217 service,
218 max_pdu_len,
219 #[cfg(feature = "metrics")]
220 metrics,
221 )
222 .await
223 {
224 warn!(%peer, error = %err, "modbus tcp server connection ended with error");
225 }
226 });
227 }
228 () = &mut shutdown => {
229 return Ok(());
230 }
231 }
232 }
233 }
234}
235
236pub struct ModbusRtuOverTcpServer<S> {
241 listener: TcpListener,
242 service: Arc<S>,
243 max_pdu_len: usize,
244 max_frame_len: usize,
245 max_connections: usize,
246 #[cfg(feature = "metrics")]
247 metrics: Arc<ServerMetrics>,
248}
249
250impl<S: ModbusService> ModbusRtuOverTcpServer<S> {
251 pub async fn bind<A: ToSocketAddrs>(addr: A, service: S) -> Result<Self, DataLinkError> {
253 let listener = TcpListener::bind(addr).await?;
254 Ok(Self::from_listener(listener, service))
255 }
256
257 #[must_use]
259 pub fn from_listener(listener: TcpListener, service: S) -> Self {
260 Self {
261 listener,
262 service: Arc::new(service),
263 max_pdu_len: DEFAULT_MAX_PDU_LEN,
264 max_frame_len: DEFAULT_MAX_RTU_FRAME_LEN,
265 max_connections: DEFAULT_MAX_CONNECTIONS,
266 #[cfg(feature = "metrics")]
267 metrics: Arc::new(ServerMetrics::default()),
268 }
269 }
270
271 pub fn local_addr(&self) -> Result<std::net::SocketAddr, DataLinkError> {
273 Ok(self.listener.local_addr()?)
274 }
275
276 #[must_use]
278 pub fn with_max_pdu_len(mut self, max_pdu_len: usize) -> Self {
279 self.max_pdu_len = max_pdu_len;
280 self
281 }
282
283 #[must_use]
285 pub fn with_max_frame_len(mut self, max_frame_len: usize) -> Self {
286 self.max_frame_len = max_frame_len;
287 self
288 }
289
290 #[must_use]
292 pub fn with_max_connections(mut self, max_connections: usize) -> Self {
293 self.max_connections = max_connections;
294 self
295 }
296
297 #[cfg(feature = "metrics")]
299 pub fn metrics_handle(&self) -> Arc<ServerMetrics> {
300 Arc::clone(&self.metrics)
301 }
302
303 #[cfg(feature = "metrics")]
305 pub fn metrics_snapshot(&self) -> ServerMetricsSnapshot {
306 self.metrics.snapshot()
307 }
308
309 pub async fn run(self) -> Result<(), DataLinkError> {
311 let semaphore = Arc::new(Semaphore::new(self.max_connections));
312 loop {
313 let (socket, peer) = self.listener.accept().await?;
314 let service = Arc::clone(&self.service);
315 let max_pdu_len = self.max_pdu_len;
316 let max_frame_len = self.max_frame_len;
317 let permit = Arc::clone(&semaphore);
318 #[cfg(feature = "metrics")]
319 let metrics = Arc::clone(&self.metrics);
320
321 tokio::spawn(async move {
322 let _permit = permit.acquire().await;
323 if let Err(err) = handle_rtu_over_tcp_connection(
324 socket,
325 service,
326 max_pdu_len,
327 max_frame_len,
328 #[cfg(feature = "metrics")]
329 metrics,
330 )
331 .await
332 {
333 warn!(
334 %peer,
335 error = %err,
336 "modbus rtu-over-tcp server connection ended with error"
337 );
338 }
339 });
340 }
341 }
342
343 pub async fn run_until(self, shutdown: impl Future<Output = ()> + Send) -> Result<(), DataLinkError> {
345 let semaphore = Arc::new(Semaphore::new(self.max_connections));
346 tokio::pin!(shutdown);
347 loop {
348 tokio::select! {
349 result = self.listener.accept() => {
350 let (socket, peer) = result?;
351 let service = Arc::clone(&self.service);
352 let max_pdu_len = self.max_pdu_len;
353 let max_frame_len = self.max_frame_len;
354 let permit = Arc::clone(&semaphore);
355 #[cfg(feature = "metrics")]
356 let metrics = Arc::clone(&self.metrics);
357
358 tokio::spawn(async move {
359 let _permit = permit.acquire().await;
360 if let Err(err) = handle_rtu_over_tcp_connection(
361 socket,
362 service,
363 max_pdu_len,
364 max_frame_len,
365 #[cfg(feature = "metrics")]
366 metrics,
367 )
368 .await
369 {
370 warn!(
371 %peer,
372 error = %err,
373 "modbus rtu-over-tcp server connection ended with error"
374 );
375 }
376 });
377 }
378 () = &mut shutdown => {
379 return Ok(());
380 }
381 }
382 }
383 }
384}
385
386fn is_write_request(request: &DecodedRequest<'_>) -> bool {
387 matches!(
388 request,
389 DecodedRequest::WriteSingleCoil(_)
390 | DecodedRequest::WriteSingleRegister(_)
391 | DecodedRequest::WriteMultipleCoils(_)
392 | DecodedRequest::WriteMultipleRegisters(_)
393 | DecodedRequest::MaskWriteRegister(_)
394 | DecodedRequest::ReadWriteMultipleRegisters(_)
395 )
396}
397
398async fn handle_connection<S: ModbusService>(
399 mut socket: TcpStream,
400 service: Arc<S>,
401 max_pdu_len: usize,
402 #[cfg(feature = "metrics")] metrics: Arc<ServerMetrics>,
403) -> Result<(), DataLinkError> {
404 let mut request_pdu_buf = [0u8; 253];
405 let mut response_pdu = vec![0u8; max_pdu_len];
406
407 loop {
408 let mut mbap = [0u8; tcp::MBAP_HEADER_LEN];
409 if let Err(err) = socket.read_exact(&mut mbap).await {
410 if err.kind() == std::io::ErrorKind::UnexpectedEof {
411 return Ok(());
412 }
413 return Err(DataLinkError::Io(err));
414 }
415
416 let mut mbap_reader = Reader::new(&mbap);
417 let header = tcp::MbapHeader::decode(&mut mbap_reader)?;
418 let pdu_len = usize::from(header.length)
419 .checked_sub(1)
420 .ok_or(DataLinkError::InvalidResponse("invalid mbap length"))?;
421
422 if pdu_len == 0 || pdu_len > max_pdu_len {
423 return Err(DataLinkError::InvalidResponse("invalid request pdu length"));
424 }
425
426 socket.read_exact(&mut request_pdu_buf[..pdu_len]).await?;
427 let request_pdu = &request_pdu_buf[..pdu_len];
428
429 #[cfg(feature = "metrics")]
430 metrics.requests_total.fetch_add(1, Ordering::Relaxed);
431
432 let mut request_reader = Reader::new(request_pdu);
433 let decoded = match DecodedRequest::decode(&mut request_reader) {
434 Ok(req) if request_reader.is_empty() => req,
435 Ok(_) => {
436 #[cfg(feature = "metrics")]
437 {
438 metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
439 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
440 }
441 let function = request_pdu[0] & 0x7F;
442 send_exception(
443 &mut socket,
444 header.transaction_id,
445 header.unit_id,
446 function,
447 ExceptionCode::IllegalDataValue,
448 )
449 .await?;
450 continue;
451 }
452 Err(err) => {
453 #[cfg(feature = "metrics")]
454 {
455 metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
456 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
457 }
458 let function = request_pdu.first().copied().unwrap_or(0) & 0x7F;
459 send_exception(
460 &mut socket,
461 header.transaction_id,
462 header.unit_id,
463 function,
464 map_decode_error_to_exception(err),
465 )
466 .await?;
467 continue;
468 }
469 };
470
471 debug!(
472 correlation_id = header.transaction_id,
473 unit_id = header.unit_id.as_u8(),
474 function = decoded.function_code().as_u8(),
475 pdu_len,
476 "received modbus tcp request"
477 );
478
479 if header.unit_id == UnitId::BROADCAST {
481 if is_write_request(&decoded) {
482 let _ = service.handle(header.unit_id, decoded, &mut response_pdu);
484 continue;
485 } else {
486 send_exception(
488 &mut socket,
489 header.transaction_id,
490 header.unit_id,
491 decoded.function_code().as_u8(),
492 ExceptionCode::IllegalFunction,
493 )
494 .await?;
495 continue;
496 }
497 }
498
499 match service.handle(header.unit_id, decoded, &mut response_pdu) {
500 Ok(response_len) => {
501 if response_len == 0 || response_len > max_pdu_len {
502 #[cfg(feature = "metrics")]
503 {
504 metrics.internal_errors.fetch_add(1, Ordering::Relaxed);
505 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
506 }
507 send_exception(
508 &mut socket,
509 header.transaction_id,
510 header.unit_id,
511 decoded.function_code().as_u8(),
512 ExceptionCode::ServerDeviceFailure,
513 )
514 .await?;
515 continue;
516 }
517
518 #[cfg(feature = "metrics")]
519 metrics.responses_ok.fetch_add(1, Ordering::Relaxed);
520
521 send_pdu(
522 &mut socket,
523 header.transaction_id,
524 header.unit_id,
525 &response_pdu[..response_len],
526 )
527 .await?;
528 }
529 Err(ServiceError::Exception(code)) => {
530 #[cfg(feature = "metrics")]
531 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
532
533 send_exception(
534 &mut socket,
535 header.transaction_id,
536 header.unit_id,
537 decoded.function_code().as_u8(),
538 code,
539 )
540 .await?;
541 }
542 Err(ServiceError::InvalidRequest(_)) => {
543 #[cfg(feature = "metrics")]
544 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
545
546 send_exception(
547 &mut socket,
548 header.transaction_id,
549 header.unit_id,
550 decoded.function_code().as_u8(),
551 ExceptionCode::IllegalDataValue,
552 )
553 .await?;
554 }
555 Err(_) => {
556 #[cfg(feature = "metrics")]
557 {
558 metrics.internal_errors.fetch_add(1, Ordering::Relaxed);
559 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
560 }
561
562 send_exception(
563 &mut socket,
564 header.transaction_id,
565 header.unit_id,
566 decoded.function_code().as_u8(),
567 ExceptionCode::ServerDeviceFailure,
568 )
569 .await?;
570 }
571 }
572 }
573}
574
575fn decode_rtu_suffix_frame(buffer: &[u8]) -> Option<(usize, UnitId, &[u8])> {
576 if buffer.len() < 4 {
577 return None;
578 }
579 for start in 0..=buffer.len() - 4 {
580 if let Ok((unit_id, pdu)) = rtu_frame::decode_frame(&buffer[start..]) {
581 return Some((start, unit_id, pdu));
582 }
583 }
584 None
585}
586
587async fn handle_rtu_over_tcp_connection<S: ModbusService>(
588 mut socket: TcpStream,
589 service: Arc<S>,
590 max_pdu_len: usize,
591 max_frame_len: usize,
592 #[cfg(feature = "metrics")] metrics: Arc<ServerMetrics>,
593) -> Result<(), DataLinkError> {
594 if max_frame_len < 4 {
595 return Err(DataLinkError::InvalidResponse(
596 "rtu frame length must be at least 4 bytes",
597 ));
598 }
599
600 let mut frame = vec![0u8; max_frame_len];
601 let mut len = 0usize;
602 let mut response_pdu = vec![0u8; max_pdu_len];
603
604 loop {
605 if len == max_frame_len {
606 frame.copy_within(1..max_frame_len, 0);
608 len -= 1;
609 }
610
611 let n = socket.read(&mut frame[len..len + 1]).await?;
612 if n == 0 {
613 return Ok(());
614 }
615 len += n;
616
617 let Some((_, unit_id, request_pdu)) = decode_rtu_suffix_frame(&frame[..len]) else {
618 continue;
619 };
620 len = 0;
621
622 #[cfg(feature = "metrics")]
623 metrics.requests_total.fetch_add(1, Ordering::Relaxed);
624
625 if request_pdu.is_empty() || request_pdu.len() > max_pdu_len {
626 #[cfg(feature = "metrics")]
627 {
628 metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
629 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
630 }
631 send_rtu_exception(&mut socket, unit_id, 0, ExceptionCode::IllegalDataValue).await?;
632 continue;
633 }
634
635 let mut request_reader = Reader::new(request_pdu);
636 let decoded = match DecodedRequest::decode(&mut request_reader) {
637 Ok(req) if request_reader.is_empty() => req,
638 Ok(_) => {
639 #[cfg(feature = "metrics")]
640 {
641 metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
642 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
643 }
644 let function = request_pdu[0] & 0x7F;
645 send_rtu_exception(
646 &mut socket,
647 unit_id,
648 function,
649 ExceptionCode::IllegalDataValue,
650 )
651 .await?;
652 continue;
653 }
654 Err(err) => {
655 #[cfg(feature = "metrics")]
656 {
657 metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
658 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
659 }
660 let function = request_pdu.first().copied().unwrap_or(0) & 0x7F;
661 send_rtu_exception(
662 &mut socket,
663 unit_id,
664 function,
665 map_decode_error_to_exception(err),
666 )
667 .await?;
668 continue;
669 }
670 };
671
672 debug!(
673 unit_id = unit_id.as_u8(),
674 function = decoded.function_code().as_u8(),
675 pdu_len = request_pdu.len(),
676 "received modbus rtu-over-tcp request"
677 );
678
679 match service.handle(unit_id, decoded, &mut response_pdu) {
680 Ok(response_len) => {
681 if response_len == 0 || response_len > max_pdu_len {
682 #[cfg(feature = "metrics")]
683 {
684 metrics.internal_errors.fetch_add(1, Ordering::Relaxed);
685 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
686 }
687 send_rtu_exception(
688 &mut socket,
689 unit_id,
690 decoded.function_code().as_u8(),
691 ExceptionCode::ServerDeviceFailure,
692 )
693 .await?;
694 continue;
695 }
696
697 #[cfg(feature = "metrics")]
698 metrics.responses_ok.fetch_add(1, Ordering::Relaxed);
699
700 send_rtu_pdu(&mut socket, unit_id, &response_pdu[..response_len]).await?;
701 }
702 Err(ServiceError::Exception(code)) => {
703 #[cfg(feature = "metrics")]
704 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
705
706 send_rtu_exception(&mut socket, unit_id, decoded.function_code().as_u8(), code)
707 .await?;
708 }
709 Err(ServiceError::InvalidRequest(_)) => {
710 #[cfg(feature = "metrics")]
711 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
712
713 send_rtu_exception(
714 &mut socket,
715 unit_id,
716 decoded.function_code().as_u8(),
717 ExceptionCode::IllegalDataValue,
718 )
719 .await?;
720 }
721 Err(ServiceError::Internal(_)) => {
722 #[cfg(feature = "metrics")]
723 {
724 metrics.internal_errors.fetch_add(1, Ordering::Relaxed);
725 metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
726 }
727
728 send_rtu_exception(
729 &mut socket,
730 unit_id,
731 decoded.function_code().as_u8(),
732 ExceptionCode::ServerDeviceFailure,
733 )
734 .await?;
735 }
736 }
737 }
738}
739
740fn map_decode_error_to_exception(err: DecodeError) -> ExceptionCode {
741 match err {
742 DecodeError::InvalidFunctionCode => ExceptionCode::IllegalFunction,
743 DecodeError::InvalidLength | DecodeError::InvalidValue | DecodeError::UnexpectedEof => {
744 ExceptionCode::IllegalDataValue
745 }
746 DecodeError::InvalidCrc | DecodeError::Unsupported | DecodeError::Message(_) => {
747 ExceptionCode::ServerDeviceFailure
748 }
749 _ => ExceptionCode::ServerDeviceFailure,
750 }
751}
752
753async fn send_exception(
754 socket: &mut TcpStream,
755 transaction_id: u16,
756 unit_id: UnitId,
757 function_code: u8,
758 exception_code: ExceptionCode,
759) -> Result<(), DataLinkError> {
760 let mut pdu = [0u8; 2];
761 let mut pdu_writer = Writer::new(&mut pdu);
762 ExceptionResponse {
763 function_code,
764 exception_code,
765 }
766 .encode(&mut pdu_writer)
767 .map_err(DataLinkError::Encode)?;
768
769 send_pdu(socket, transaction_id, unit_id, pdu_writer.as_written()).await
770}
771
772async fn send_pdu(
773 socket: &mut TcpStream,
774 transaction_id: u16,
775 unit_id: UnitId,
776 pdu: &[u8],
777) -> Result<(), DataLinkError> {
778 let mut frame = vec![0u8; tcp::MBAP_HEADER_LEN + pdu.len()];
779 let mut frame_writer = Writer::new(&mut frame);
780 tcp::encode_frame(&mut frame_writer, transaction_id, unit_id, pdu)?;
781
782 debug!(
783 correlation_id = transaction_id,
784 unit_id = unit_id.as_u8(),
785 pdu_len = pdu.len(),
786 "sending modbus tcp server response"
787 );
788 socket.write_all(frame_writer.as_written()).await?;
789 Ok(())
790}
791
792async fn send_rtu_exception(
793 socket: &mut TcpStream,
794 unit_id: UnitId,
795 function_code: u8,
796 exception_code: ExceptionCode,
797) -> Result<(), DataLinkError> {
798 let mut pdu = [0u8; 2];
799 let mut pdu_writer = Writer::new(&mut pdu);
800 ExceptionResponse {
801 function_code,
802 exception_code,
803 }
804 .encode(&mut pdu_writer)
805 .map_err(DataLinkError::Encode)?;
806
807 send_rtu_pdu(socket, unit_id, pdu_writer.as_written()).await
808}
809
810async fn send_rtu_pdu(socket: &mut TcpStream, unit_id: UnitId, pdu: &[u8]) -> Result<(), DataLinkError> {
811 let mut frame = vec![0u8; pdu.len() + 3];
812 let mut writer = Writer::new(&mut frame);
813 rtu_frame::encode_frame(&mut writer, unit_id, pdu)?;
814 socket.write_all(writer.as_written()).await?;
815 Ok(())
816}
817
818#[cfg(test)]
819mod tests {
820 use super::{ModbusRtuOverTcpServer, ModbusService, ModbusTcpServer, ServiceError};
821 use crate::{DataLink, ModbusTcpTransport};
822 use rustmod_core::encoding::Writer;
823 use rustmod_core::frame::rtu as rtu_frame;
824 use rustmod_core::pdu::{DecodedRequest, ExceptionCode};
825 use rustmod_core::UnitId;
826 use tokio::io::{AsyncReadExt, AsyncWriteExt};
827 use tokio::net::TcpStream;
828
829 struct EchoReadService;
830
831 impl ModbusService for EchoReadService {
832 fn handle(
833 &self,
834 _unit_id: UnitId,
835 request: DecodedRequest<'_>,
836 response_pdu: &mut [u8],
837 ) -> Result<usize, ServiceError> {
838 match request {
839 DecodedRequest::ReadHoldingRegisters(_) => {
840 let bytes = [0x03u8, 0x02, 0x00, 0x2A];
841 response_pdu[..bytes.len()].copy_from_slice(&bytes);
842 Ok(bytes.len())
843 }
844 _ => Err(ServiceError::Exception(ExceptionCode::IllegalFunction)),
845 }
846 }
847 }
848
849 struct AlwaysExceptionService;
850
851 impl ModbusService for AlwaysExceptionService {
852 fn handle(
853 &self,
854 _unit_id: UnitId,
855 _request: DecodedRequest<'_>,
856 _response_pdu: &mut [u8],
857 ) -> Result<usize, ServiceError> {
858 Err(ServiceError::Exception(ExceptionCode::IllegalDataAddress))
859 }
860 }
861
862 #[tokio::test]
863 async fn tcp_server_handles_basic_read_request() {
864 let server = ModbusTcpServer::bind("127.0.0.1:0", EchoReadService)
865 .await
866 .unwrap();
867 let addr = server.local_addr().unwrap();
868 let task = tokio::spawn(server.run());
869
870 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
871 let mut response = [0u8; 32];
872 let len = transport
873 .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
874 .await
875 .unwrap();
876 assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
877
878 task.abort();
879 let _ = task.await;
880 }
881
882 #[tokio::test]
883 async fn tcp_server_sends_exception_response() {
884 let server = ModbusTcpServer::bind("127.0.0.1:0", AlwaysExceptionService)
885 .await
886 .unwrap();
887 let addr = server.local_addr().unwrap();
888 let task = tokio::spawn(server.run());
889
890 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
891 let mut response = [0u8; 32];
892 let len = transport
893 .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
894 .await
895 .unwrap();
896 assert_eq!(&response[..len], &[0x83, 0x02]);
897
898 task.abort();
899 let _ = task.await;
900 }
901
902 #[tokio::test]
903 async fn tcp_server_maps_decode_error_to_exception() {
904 let server = ModbusTcpServer::bind("127.0.0.1:0", EchoReadService)
905 .await
906 .unwrap();
907 let addr = server.local_addr().unwrap();
908 let task = tokio::spawn(server.run());
909
910 let transport = ModbusTcpTransport::connect(addr).await.unwrap();
911 let mut response = [0u8; 32];
912 let len = transport
913 .exchange(
914 UnitId::new(1),
915 &[0x10, 0x00, 0x00, 0x00, 0x02, 0x03, 0x12, 0x34, 0x56],
916 &mut response,
917 )
918 .await
919 .unwrap();
920 assert_eq!(&response[..len], &[0x90, 0x03]);
921
922 task.abort();
923 let _ = task.await;
924 }
925
926 #[tokio::test]
927 async fn rtu_over_tcp_server_handles_basic_read_request() {
928 let server = ModbusRtuOverTcpServer::bind("127.0.0.1:0", EchoReadService)
929 .await
930 .unwrap();
931 let addr = server.local_addr().unwrap();
932 let task = tokio::spawn(server.run());
933
934 let mut stream = TcpStream::connect(addr).await.unwrap();
935 let mut request = [0u8; 16];
936 let mut writer = Writer::new(&mut request);
937 rtu_frame::encode_frame(&mut writer, UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01]).unwrap();
938 stream.write_all(writer.as_written()).await.unwrap();
939
940 let mut response = [0u8; 7];
941 stream.read_exact(&mut response).await.unwrap();
942 let (unit_id, pdu) = rtu_frame::decode_frame(&response).unwrap();
943 assert_eq!(unit_id, UnitId::new(1));
944 assert_eq!(pdu, &[0x03, 0x02, 0x00, 0x2A]);
945
946 task.abort();
947 let _ = task.await;
948 }
949
950 #[tokio::test]
951 async fn rtu_over_tcp_server_maps_decode_error_to_exception() {
952 let server = ModbusRtuOverTcpServer::bind("127.0.0.1:0", EchoReadService)
953 .await
954 .unwrap();
955 let addr = server.local_addr().unwrap();
956 let task = tokio::spawn(server.run());
957
958 let mut stream = TcpStream::connect(addr).await.unwrap();
959 let mut request = [0u8; 32];
960 let mut writer = Writer::new(&mut request);
961 rtu_frame::encode_frame(
962 &mut writer,
963 UnitId::new(1),
964 &[0x10, 0x00, 0x00, 0x00, 0x02, 0x03, 0x12, 0x34, 0x56],
965 )
966 .unwrap();
967 stream.write_all(writer.as_written()).await.unwrap();
968
969 let mut response = [0u8; 5];
970 stream.read_exact(&mut response).await.unwrap();
971 let (unit_id, pdu) = rtu_frame::decode_frame(&response).unwrap();
972 assert_eq!(unit_id, UnitId::new(1));
973 assert_eq!(pdu, &[0x90, 0x03]);
974
975 task.abort();
976 let _ = task.await;
977 }
978}