1#![forbid(unsafe_code)]
36
37pub mod points;
38pub mod sync;
39
40pub use points::{CoilPoints, RegisterPoints};
41pub use sync::{SyncClientError, SyncModbusTcpClient};
42
43use rustmod_core::encoding::{Reader, Writer};
44use rustmod_core::pdu::{
45 CustomRequest, ExceptionResponse, ReadCoilsRequest, ReadDiscreteInputsRequest,
46 ReadHoldingRegistersRequest, ReadInputRegistersRequest, ReadWriteMultipleRegistersRequest,
47 Request, Response, MaskWriteRegisterRequest, WriteMultipleCoilsRequest,
48 WriteMultipleRegistersRequest, WriteSingleCoilRequest, WriteSingleRegisterRequest,
49};
50use rustmod_core::{DecodeError, EncodeError};
51pub use rustmod_datalink::UnitId;
52use rustmod_datalink::{DataLink, DataLinkError};
53use std::sync::Arc;
54use std::sync::atomic::{AtomicU64, Ordering};
55use std::time::Duration;
56use thiserror::Error;
57use tokio::sync::Mutex;
58use tokio::time::{Instant, sleep, timeout};
59use tracing::{debug, warn};
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63#[non_exhaustive]
64pub enum RetryPolicy {
65 Never,
67 ReadOnly,
69 All,
71}
72
73#[derive(Debug, Clone, Copy)]
78#[must_use]
79pub struct ClientConfig {
80 pub response_timeout: Duration,
82 pub retry_count: u8,
84 pub throttle_delay: Option<Duration>,
86 pub retry_policy: RetryPolicy,
88}
89
90impl Default for ClientConfig {
91 fn default() -> Self {
92 Self {
93 response_timeout: Duration::from_secs(5),
94 retry_count: 3,
95 throttle_delay: None,
96 retry_policy: RetryPolicy::ReadOnly,
97 }
98 }
99}
100
101impl ClientConfig {
102 #[must_use = "builder methods return a new value"]
104 pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
105 self.response_timeout = timeout;
106 self
107 }
108
109 #[must_use = "builder methods return a new value"]
111 pub fn with_retry_count(mut self, retry_count: u8) -> Self {
112 self.retry_count = retry_count;
113 self
114 }
115
116 #[must_use = "builder methods return a new value"]
118 pub fn with_throttle_delay(mut self, throttle_delay: Option<Duration>) -> Self {
119 self.throttle_delay = throttle_delay;
120 self
121 }
122
123 #[must_use = "builder methods return a new value"]
125 pub fn with_retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
126 self.retry_policy = retry_policy;
127 self
128 }
129}
130
131#[derive(Debug, Error)]
133#[non_exhaustive]
134pub enum ClientError {
135 #[error("datalink error: {0}")]
137 DataLink(#[from] DataLinkError),
138 #[error("encode error: {0}")]
140 Encode(#[from] EncodeError),
141 #[error("decode error: {0}")]
143 Decode(#[from] DecodeError),
144 #[error("request timed out")]
146 Timeout,
147 #[error("modbus exception: {0}")]
149 Exception(ExceptionResponse),
150 #[error("invalid response: {0}")]
152 InvalidResponse(InvalidResponseKind),
153}
154
155#[derive(Debug, Clone, PartialEq, Eq)]
157#[non_exhaustive]
158pub enum InvalidResponseKind {
159 TrailingBytes,
161 FunctionMismatch,
163 EchoMismatch,
165 PayloadLengthMismatch,
167 PayloadTruncated,
169 Other(&'static str),
171}
172
173impl std::fmt::Display for InvalidResponseKind {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 match self {
176 Self::TrailingBytes => f.write_str("trailing bytes in response"),
177 Self::FunctionMismatch => f.write_str("unexpected function response"),
178 Self::EchoMismatch => f.write_str("echo mismatch"),
179 Self::PayloadLengthMismatch => f.write_str("payload length mismatch"),
180 Self::PayloadTruncated => f.write_str("payload truncated"),
181 Self::Other(msg) => f.write_str(msg),
182 }
183 }
184}
185
186#[derive(Debug, Clone, PartialEq, Eq)]
188pub struct ReportServerIdResponse {
189 pub server_id: u8,
191 pub run_indicator_status: bool,
193 pub additional_data: Vec<u8>,
195}
196
197#[derive(Debug, Clone, PartialEq, Eq)]
199pub struct DeviceIdentificationObject {
200 pub object_id: u8,
202 pub value: Vec<u8>,
204}
205
206#[derive(Debug, Clone, PartialEq, Eq)]
208pub struct ReadDeviceIdentificationResponse {
209 pub read_device_id_code: u8,
211 pub conformity_level: u8,
213 pub more_follows: bool,
215 pub next_object_id: u8,
217 pub objects: Vec<DeviceIdentificationObject>,
219}
220
221#[cfg(feature = "metrics")]
223#[derive(Debug, Default)]
224pub struct ClientMetrics {
225 requests_total: AtomicU64,
226 successful_responses: AtomicU64,
227 retries_total: AtomicU64,
228 timeouts_total: AtomicU64,
229 transport_errors_total: AtomicU64,
230 exceptions_total: AtomicU64,
231 decode_errors_total: AtomicU64,
232}
233
234#[cfg(feature = "metrics")]
236#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
237pub struct ClientMetricsSnapshot {
238 pub requests_total: u64,
239 pub successful_responses: u64,
240 pub retries_total: u64,
241 pub timeouts_total: u64,
242 pub transport_errors_total: u64,
243 pub exceptions_total: u64,
244 pub decode_errors_total: u64,
245}
246
247#[cfg(feature = "metrics")]
248impl ClientMetrics {
249 fn snapshot(&self) -> ClientMetricsSnapshot {
250 ClientMetricsSnapshot {
251 requests_total: self.requests_total.load(Ordering::Relaxed),
252 successful_responses: self.successful_responses.load(Ordering::Relaxed),
253 retries_total: self.retries_total.load(Ordering::Relaxed),
254 timeouts_total: self.timeouts_total.load(Ordering::Relaxed),
255 transport_errors_total: self.transport_errors_total.load(Ordering::Relaxed),
256 exceptions_total: self.exceptions_total.load(Ordering::Relaxed),
257 decode_errors_total: self.decode_errors_total.load(Ordering::Relaxed),
258 }
259 }
260}
261
262pub struct ModbusClient<D: DataLink> {
267 datalink: Arc<D>,
268 config: ClientConfig,
269 last_request_at: Arc<Mutex<Option<Instant>>>,
270 request_counter: Arc<AtomicU64>,
271 #[cfg(feature = "metrics")]
272 metrics: Arc<ClientMetrics>,
273}
274
275impl<D: DataLink> Clone for ModbusClient<D> {
276 fn clone(&self) -> Self {
277 Self {
278 datalink: Arc::clone(&self.datalink),
279 config: self.config,
280 last_request_at: Arc::clone(&self.last_request_at),
281 request_counter: Arc::clone(&self.request_counter),
282 #[cfg(feature = "metrics")]
283 metrics: Arc::clone(&self.metrics),
284 }
285 }
286}
287
288impl<D: DataLink> ModbusClient<D> {
289 #[must_use]
291 pub fn new(datalink: D) -> Self {
292 Self::with_config(datalink, ClientConfig::default())
293 }
294
295 #[must_use]
297 pub fn with_config(datalink: D, config: ClientConfig) -> Self {
298 Self {
299 datalink: Arc::new(datalink),
300 config,
301 last_request_at: Arc::new(Mutex::new(None)),
302 request_counter: Arc::new(AtomicU64::new(1)),
303 #[cfg(feature = "metrics")]
304 metrics: Arc::new(ClientMetrics::default()),
305 }
306 }
307
308 pub fn config(&self) -> ClientConfig {
310 self.config
311 }
312
313 pub fn is_connected(&self) -> bool {
315 self.datalink.is_connected()
316 }
317
318 #[cfg(feature = "metrics")]
320 pub fn metrics_snapshot(&self) -> ClientMetricsSnapshot {
321 self.metrics.snapshot()
322 }
323
324 fn next_correlation_id(&self) -> u64 {
325 self.request_counter.fetch_add(1, Ordering::Relaxed)
326 }
327
328 async fn apply_throttle(&self) {
329 let Some(delay) = self.config.throttle_delay else {
330 return;
331 };
332
333 let mut last = self.last_request_at.lock().await;
334 if let Some(previous) = *last {
335 let elapsed = previous.elapsed();
336 if elapsed < delay {
337 sleep(delay - elapsed).await;
338 }
339 }
340 *last = Some(Instant::now());
341 }
342
343 fn is_retryable(err: &DataLinkError) -> bool {
344 matches!(
345 err,
346 DataLinkError::Io(_)
347 | DataLinkError::Timeout
348 | DataLinkError::ConnectionClosed
349 )
350 }
351
352 fn request_is_retry_eligible(&self, request: &Request<'_>) -> bool {
353 match self.config.retry_policy {
354 RetryPolicy::Never => false,
355 RetryPolicy::All => true,
356 RetryPolicy::ReadOnly => matches!(
357 request,
358 Request::ReadCoils(_)
359 | Request::ReadDiscreteInputs(_)
360 | Request::ReadHoldingRegisters(_)
361 | Request::ReadInputRegisters(_)
362 | Request::ReadExceptionStatus(_)
363 | Request::Diagnostics(_)
364 | Request::ReadFifoQueue(_)
365 | Request::Custom(CustomRequest { function_code: 0x11, .. })
366 | Request::Custom(CustomRequest { function_code: 0x2B, .. })
367 ),
368 }
369 }
370
371 async fn exchange_raw(
372 &self,
373 correlation_id: u64,
374 unit_id: UnitId,
375 request_pdu: &[u8],
376 response_buf: &mut [u8],
377 retry_eligible: bool,
378 ) -> Result<usize, ClientError> {
379 self.apply_throttle().await;
380
381 #[cfg(feature = "metrics")]
382 self.metrics.requests_total.fetch_add(1, Ordering::Relaxed);
383
384 let attempts = usize::from(self.config.retry_count) + 1;
385 let mut last_err: Option<ClientError> = None;
386
387 for attempt in 1..=attempts {
388 let result = timeout(
389 self.config.response_timeout,
390 self.datalink.exchange(unit_id, request_pdu, response_buf),
391 )
392 .await;
393
394 match result {
395 Ok(Ok(len)) => {
396 debug!(
397 correlation_id,
398 unit_id = unit_id.as_u8(),
399 attempt,
400 len,
401 "modbus request succeeded"
402 );
403 #[cfg(feature = "metrics")]
404 self.metrics
405 .successful_responses
406 .fetch_add(1, Ordering::Relaxed);
407 return Ok(len);
408 }
409 Ok(Err(err)) => {
410 #[cfg(feature = "metrics")]
411 self.metrics
412 .transport_errors_total
413 .fetch_add(1, Ordering::Relaxed);
414 if attempt < attempts && retry_eligible && Self::is_retryable(&err) {
415 warn!(
416 correlation_id,
417 unit_id = unit_id.as_u8(),
418 attempt,
419 error = %err,
420 "retrying modbus request after transport error"
421 );
422 if let Err(reconnect_err) = self.datalink.reconnect().await {
423 debug!(
424 correlation_id,
425 unit_id = unit_id.as_u8(),
426 error = %reconnect_err,
427 "reconnect attempt failed"
428 );
429 }
430 #[cfg(feature = "metrics")]
431 self.metrics.retries_total.fetch_add(1, Ordering::Relaxed);
432 last_err = Some(ClientError::DataLink(err));
433 continue;
434 }
435 return Err(ClientError::DataLink(err));
436 }
437 Err(_) => {
438 #[cfg(feature = "metrics")]
439 self.metrics.timeouts_total.fetch_add(1, Ordering::Relaxed);
440 if attempt < attempts && retry_eligible {
441 warn!(
442 correlation_id,
443 unit_id = unit_id.as_u8(),
444 attempt,
445 "retrying modbus request after timeout"
446 );
447 #[cfg(feature = "metrics")]
448 self.metrics.retries_total.fetch_add(1, Ordering::Relaxed);
449 last_err = Some(ClientError::Timeout);
450 continue;
451 }
452 return Err(ClientError::Timeout);
453 }
454 }
455 }
456
457 Err(last_err.unwrap_or(ClientError::InvalidResponse(
458 InvalidResponseKind::Other("retry loop exhausted"),
459 )))
460 }
461
462 async fn send_request<'a>(
463 &self,
464 unit_id: UnitId,
465 request: &Request<'_>,
466 response_storage: &'a mut [u8],
467 ) -> Result<Response<'a>, ClientError> {
468 let correlation_id = self.next_correlation_id();
469 let mut req_buf = [0u8; 260];
470 let mut writer = Writer::new(&mut req_buf);
471 request.encode(&mut writer)?;
472
473 debug!(
474 correlation_id,
475 unit_id = unit_id.as_u8(),
476 function = request.function_code().as_u8(),
477 pdu_len = writer.as_written().len(),
478 "dispatching modbus request"
479 );
480 let retry_eligible = self.request_is_retry_eligible(request);
481
482 let response_len = self
483 .exchange_raw(
484 correlation_id,
485 unit_id,
486 writer.as_written(),
487 response_storage,
488 retry_eligible,
489 )
490 .await?;
491
492 let mut reader = Reader::new(&response_storage[..response_len]);
493 let response = match Response::decode(&mut reader) {
494 Ok(resp) => resp,
495 Err(err) => {
496 #[cfg(feature = "metrics")]
497 self.metrics
498 .decode_errors_total
499 .fetch_add(1, Ordering::Relaxed);
500 return Err(ClientError::Decode(err));
501 }
502 };
503
504 if !reader.is_empty() {
505 #[cfg(feature = "metrics")]
506 self.metrics
507 .decode_errors_total
508 .fetch_add(1, Ordering::Relaxed);
509 return Err(ClientError::InvalidResponse(InvalidResponseKind::TrailingBytes));
510 }
511
512 if let Response::Exception(ex) = response {
513 #[cfg(feature = "metrics")]
514 self.metrics.exceptions_total.fetch_add(1, Ordering::Relaxed);
515 return Err(ClientError::Exception(ex));
516 }
517
518 Ok(response)
519 }
520
521 pub async fn read_coils(
523 &self,
524 unit_id: UnitId,
525 start: u16,
526 quantity: u16,
527 ) -> Result<Vec<bool>, ClientError> {
528 let request = Request::ReadCoils(ReadCoilsRequest {
529 start_address: start,
530 quantity,
531 });
532
533 let mut response_buf = [0u8; 260];
534 let response = self
535 .send_request(unit_id, &request, &mut response_buf)
536 .await?;
537
538 match response {
539 Response::ReadCoils(data) => {
540 let count = usize::from(quantity);
541 let expected_bytes = count.div_ceil(8);
542 if data.coil_status.len() != expected_bytes {
543 return Err(ClientError::InvalidResponse(InvalidResponseKind::PayloadLengthMismatch));
544 }
545 Ok((0..count).filter_map(|idx| data.coil(idx)).collect())
546 }
547 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
548 }
549 }
550
551 pub async fn custom_request(
553 &self,
554 unit_id: UnitId,
555 function_code: u8,
556 payload: &[u8],
557 ) -> Result<Vec<u8>, ClientError> {
558 let request = Request::Custom(CustomRequest {
559 function_code,
560 data: payload,
561 });
562
563 let mut response_buf = [0u8; 260];
564 let response = self
565 .send_request(unit_id, &request, &mut response_buf)
566 .await?;
567
568 match response {
569 Response::Custom(custom) if custom.function_code == function_code => {
570 Ok(custom.data.to_vec())
571 }
572 Response::Custom(_) => {
573 Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch))
574 }
575 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
576 }
577 }
578
579 pub async fn report_server_id(&self, unit_id: UnitId) -> Result<ReportServerIdResponse, ClientError> {
581 let payload = self.custom_request(unit_id, 0x11, &[]).await?;
582 let Some((&byte_count, data)) = payload.split_first() else {
583 return Err(ClientError::InvalidResponse(InvalidResponseKind::Other("report server id payload missing byte count")));
584 };
585 let byte_count = usize::from(byte_count);
586 if data.len() != byte_count || byte_count < 2 {
587 return Err(ClientError::InvalidResponse(InvalidResponseKind::Other("report server id payload length mismatch")));
588 }
589
590 Ok(ReportServerIdResponse {
591 server_id: data[0],
592 run_indicator_status: data[1] != 0,
593 additional_data: data[2..].to_vec(),
594 })
595 }
596
597 pub async fn read_device_identification(
599 &self,
600 unit_id: UnitId,
601 read_device_id_code: u8,
602 object_id: u8,
603 ) -> Result<ReadDeviceIdentificationResponse, ClientError> {
604 let payload = self
605 .custom_request(unit_id, 0x2B, &[0x0E, read_device_id_code, object_id])
606 .await?;
607
608 if payload.len() < 6 {
609 return Err(ClientError::InvalidResponse(InvalidResponseKind::Other("read device identification payload too short")));
610 }
611 if payload[0] != 0x0E {
612 return Err(ClientError::InvalidResponse(InvalidResponseKind::Other("read device identification MEI type mismatch")));
613 }
614
615 let object_count = usize::from(payload[5]);
616 let mut cursor = 6usize;
617 let mut objects = Vec::with_capacity(object_count);
618 for _ in 0..object_count {
619 if payload.len().saturating_sub(cursor) < 2 {
620 return Err(ClientError::InvalidResponse(InvalidResponseKind::Other("read device identification object header truncated")));
621 }
622 let id = payload[cursor];
623 let len = usize::from(payload[cursor + 1]);
624 cursor += 2;
625 let end = cursor
626 .checked_add(len)
627 .ok_or(ClientError::InvalidResponse(InvalidResponseKind::Other("read device identification object length overflow")))?;
628 if end > payload.len() {
629 return Err(ClientError::InvalidResponse(InvalidResponseKind::Other("read device identification object data truncated")));
630 }
631 objects.push(DeviceIdentificationObject {
632 object_id: id,
633 value: payload[cursor..end].to_vec(),
634 });
635 cursor = end;
636 }
637 if cursor != payload.len() {
638 return Err(ClientError::InvalidResponse(InvalidResponseKind::Other("read device identification trailing data")));
639 }
640
641 Ok(ReadDeviceIdentificationResponse {
642 read_device_id_code: payload[1],
643 conformity_level: payload[2],
644 more_follows: payload[3] != 0,
645 next_object_id: payload[4],
646 objects,
647 })
648 }
649
650 pub async fn read_discrete_inputs(
652 &self,
653 unit_id: UnitId,
654 start: u16,
655 quantity: u16,
656 ) -> Result<Vec<bool>, ClientError> {
657 let request = Request::ReadDiscreteInputs(ReadDiscreteInputsRequest {
658 start_address: start,
659 quantity,
660 });
661
662 let mut response_buf = [0u8; 260];
663 let response = self
664 .send_request(unit_id, &request, &mut response_buf)
665 .await?;
666
667 match response {
668 Response::ReadDiscreteInputs(data) => {
669 let count = usize::from(quantity);
670 let expected_bytes = count.div_ceil(8);
671 if data.input_status.len() != expected_bytes {
672 return Err(ClientError::InvalidResponse(InvalidResponseKind::PayloadLengthMismatch));
673 }
674 Ok((0..count).filter_map(|idx| data.coil(idx)).collect())
675 }
676 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
677 }
678 }
679
680 pub async fn read_holding_registers(
682 &self,
683 unit_id: UnitId,
684 start: u16,
685 quantity: u16,
686 ) -> Result<Vec<u16>, ClientError> {
687 let request = Request::ReadHoldingRegisters(ReadHoldingRegistersRequest {
688 start_address: start,
689 quantity,
690 });
691
692 let mut response_buf = [0u8; 260];
693 let response = self
694 .send_request(unit_id, &request, &mut response_buf)
695 .await?;
696
697 match response {
698 Response::ReadHoldingRegisters(data) => {
699 let count = usize::from(quantity);
700 if data.register_count() != count {
701 return Err(ClientError::InvalidResponse(InvalidResponseKind::PayloadLengthMismatch));
702 }
703 Ok((0..count).filter_map(|idx| data.register(idx)).collect())
704 }
705 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
706 }
707 }
708
709 pub async fn read_input_registers(
711 &self,
712 unit_id: UnitId,
713 start: u16,
714 quantity: u16,
715 ) -> Result<Vec<u16>, ClientError> {
716 let request = Request::ReadInputRegisters(ReadInputRegistersRequest {
717 start_address: start,
718 quantity,
719 });
720
721 let mut response_buf = [0u8; 260];
722 let response = self
723 .send_request(unit_id, &request, &mut response_buf)
724 .await?;
725
726 match response {
727 Response::ReadInputRegisters(data) => {
728 let count = usize::from(quantity);
729 if data.register_count() != count {
730 return Err(ClientError::InvalidResponse(InvalidResponseKind::PayloadLengthMismatch));
731 }
732 Ok((0..count).filter_map(|idx| data.register(idx)).collect())
733 }
734 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
735 }
736 }
737
738 pub async fn write_single_coil(
740 &self,
741 unit_id: UnitId,
742 address: u16,
743 value: bool,
744 ) -> Result<(), ClientError> {
745 let request = Request::WriteSingleCoil(WriteSingleCoilRequest { address, value });
746
747 let mut response_buf = [0u8; 260];
748 let response = self
749 .send_request(unit_id, &request, &mut response_buf)
750 .await?;
751
752 match response {
753 Response::WriteSingleCoil(resp) if resp.address == address && resp.value == value => Ok(()),
754 Response::WriteSingleCoil(_) => {
755 Err(ClientError::InvalidResponse(InvalidResponseKind::EchoMismatch))
756 }
757 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
758 }
759 }
760
761 pub async fn write_single_register(
763 &self,
764 unit_id: UnitId,
765 address: u16,
766 value: u16,
767 ) -> Result<(), ClientError> {
768 let request = Request::WriteSingleRegister(WriteSingleRegisterRequest { address, value });
769
770 let mut response_buf = [0u8; 260];
771 let response = self
772 .send_request(unit_id, &request, &mut response_buf)
773 .await?;
774
775 match response {
776 Response::WriteSingleRegister(resp) if resp.address == address && resp.value == value => {
777 Ok(())
778 }
779 Response::WriteSingleRegister(_) => {
780 Err(ClientError::InvalidResponse(InvalidResponseKind::EchoMismatch))
781 }
782 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
783 }
784 }
785
786 pub async fn mask_write_register(
788 &self,
789 unit_id: UnitId,
790 address: u16,
791 and_mask: u16,
792 or_mask: u16,
793 ) -> Result<(), ClientError> {
794 let request = Request::MaskWriteRegister(MaskWriteRegisterRequest {
795 address,
796 and_mask,
797 or_mask,
798 });
799
800 let mut response_buf = [0u8; 260];
801 let response = self
802 .send_request(unit_id, &request, &mut response_buf)
803 .await?;
804
805 match response {
806 Response::MaskWriteRegister(resp)
807 if resp.address == address && resp.and_mask == and_mask && resp.or_mask == or_mask =>
808 {
809 Ok(())
810 }
811 Response::MaskWriteRegister(_) => {
812 Err(ClientError::InvalidResponse(InvalidResponseKind::EchoMismatch))
813 }
814 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
815 }
816 }
817
818 pub async fn write_multiple_coils(
820 &self,
821 unit_id: UnitId,
822 start: u16,
823 values: &[bool],
824 ) -> Result<(), ClientError> {
825 let request_variant = WriteMultipleCoilsRequest {
826 start_address: start,
827 values,
828 };
829 let expected_qty = request_variant.quantity()?;
830
831 let request = Request::WriteMultipleCoils(request_variant);
832 let mut response_buf = [0u8; 260];
833 let response = self
834 .send_request(unit_id, &request, &mut response_buf)
835 .await?;
836
837 match response {
838 Response::WriteMultipleCoils(resp)
839 if resp.start_address == start && resp.quantity == expected_qty =>
840 {
841 Ok(())
842 }
843 Response::WriteMultipleCoils(_) => {
844 Err(ClientError::InvalidResponse(InvalidResponseKind::EchoMismatch))
845 }
846 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
847 }
848 }
849
850 pub async fn write_multiple_registers(
852 &self,
853 unit_id: UnitId,
854 start: u16,
855 values: &[u16],
856 ) -> Result<(), ClientError> {
857 let request_variant = WriteMultipleRegistersRequest {
858 start_address: start,
859 values,
860 };
861 let expected_qty = request_variant.quantity()?;
862
863 let request = Request::WriteMultipleRegisters(request_variant);
864 let mut response_buf = [0u8; 260];
865 let response = self
866 .send_request(unit_id, &request, &mut response_buf)
867 .await?;
868
869 match response {
870 Response::WriteMultipleRegisters(resp)
871 if resp.start_address == start && resp.quantity == expected_qty =>
872 {
873 Ok(())
874 }
875 Response::WriteMultipleRegisters(_) => {
876 Err(ClientError::InvalidResponse(InvalidResponseKind::EchoMismatch))
877 }
878 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
879 }
880 }
881
882 pub async fn read_write_multiple_registers(
884 &self,
885 unit_id: UnitId,
886 read_start: u16,
887 read_quantity: u16,
888 write_start: u16,
889 write_values: &[u16],
890 ) -> Result<Vec<u16>, ClientError> {
891 let request = Request::ReadWriteMultipleRegisters(ReadWriteMultipleRegistersRequest {
892 read_start_address: read_start,
893 read_quantity,
894 write_start_address: write_start,
895 values: write_values,
896 });
897
898 let mut response_buf = [0u8; 260];
899 let response = self
900 .send_request(unit_id, &request, &mut response_buf)
901 .await?;
902
903 match response {
904 Response::ReadWriteMultipleRegisters(data) => {
905 let count = usize::from(read_quantity);
906 if data.register_count() != count {
907 return Err(ClientError::InvalidResponse(InvalidResponseKind::PayloadLengthMismatch));
908 }
909 Ok((0..count).filter_map(|idx| data.register(idx)).collect())
910 }
911 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
912 }
913 }
914
915 pub async fn read_coils_raw(
918 &self,
919 unit_id: UnitId,
920 start: u16,
921 quantity: u16,
922 ) -> Result<(Vec<u8>, u16), ClientError> {
923 let request = Request::ReadCoils(ReadCoilsRequest {
924 start_address: start,
925 quantity,
926 });
927
928 let mut response_buf = [0u8; 260];
929 let response = self
930 .send_request(unit_id, &request, &mut response_buf)
931 .await?;
932
933 match response {
934 Response::ReadCoils(data) => {
935 let expected_bytes = usize::from(quantity).div_ceil(8);
936 if data.coil_status.len() != expected_bytes {
937 return Err(ClientError::InvalidResponse(InvalidResponseKind::PayloadLengthMismatch));
938 }
939 Ok((data.coil_status.to_vec(), quantity))
940 }
941 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
942 }
943 }
944
945 pub async fn read_discrete_inputs_raw(
947 &self,
948 unit_id: UnitId,
949 start: u16,
950 quantity: u16,
951 ) -> Result<(Vec<u8>, u16), ClientError> {
952 let request = Request::ReadDiscreteInputs(ReadDiscreteInputsRequest {
953 start_address: start,
954 quantity,
955 });
956
957 let mut response_buf = [0u8; 260];
958 let response = self
959 .send_request(unit_id, &request, &mut response_buf)
960 .await?;
961
962 match response {
963 Response::ReadDiscreteInputs(data) => {
964 let expected_bytes = usize::from(quantity).div_ceil(8);
965 if data.input_status.len() != expected_bytes {
966 return Err(ClientError::InvalidResponse(InvalidResponseKind::PayloadLengthMismatch));
967 }
968 Ok((data.input_status.to_vec(), quantity))
969 }
970 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
971 }
972 }
973
974 pub async fn read_exception_status(
976 &self,
977 unit_id: UnitId,
978 ) -> Result<u8, ClientError> {
979 use rustmod_core::pdu::ReadExceptionStatusRequest;
980 let request = Request::ReadExceptionStatus(ReadExceptionStatusRequest);
981
982 let mut response_buf = [0u8; 260];
983 let response = self
984 .send_request(unit_id, &request, &mut response_buf)
985 .await?;
986
987 match response {
988 Response::ReadExceptionStatus(data) => Ok(data.data),
989 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
990 }
991 }
992
993 pub async fn diagnostics(
995 &self,
996 unit_id: UnitId,
997 sub_function: u16,
998 data: u16,
999 ) -> Result<(u16, u16), ClientError> {
1000 use rustmod_core::pdu::DiagnosticsRequest;
1001 let request = Request::Diagnostics(DiagnosticsRequest { sub_function, data });
1002
1003 let mut response_buf = [0u8; 260];
1004 let response = self
1005 .send_request(unit_id, &request, &mut response_buf)
1006 .await?;
1007
1008 match response {
1009 Response::Diagnostics(resp) => Ok((resp.sub_function, resp.data)),
1010 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
1011 }
1012 }
1013
1014 pub async fn read_fifo_queue(
1016 &self,
1017 unit_id: UnitId,
1018 address: u16,
1019 ) -> Result<Vec<u16>, ClientError> {
1020 use rustmod_core::pdu::ReadFifoQueueRequest;
1021 let request = Request::ReadFifoQueue(ReadFifoQueueRequest {
1022 fifo_pointer_address: address,
1023 });
1024
1025 let mut response_buf = [0u8; 260];
1026 let response = self
1027 .send_request(unit_id, &request, &mut response_buf)
1028 .await?;
1029
1030 match response {
1031 Response::ReadFifoQueue(data) => {
1032 Ok((0..data.fifo_count())
1033 .filter_map(|idx| data.value(idx))
1034 .collect())
1035 }
1036 _ => Err(ClientError::InvalidResponse(InvalidResponseKind::FunctionMismatch)),
1037 }
1038 }
1039}
1040
1041#[cfg(test)]
1042const _: () = {
1043 fn _assert_send_sync<T: Send + Sync>() {}
1044 fn _assertions() {
1045 _assert_send_sync::<ModbusClient<rustmod_datalink::ModbusTcpTransport>>();
1046 }
1047};
1048
1049#[cfg(test)]
1050mod tests {
1051 use super::{ClientConfig, ClientError, InvalidResponseKind, ModbusClient, RetryPolicy, UnitId};
1052 use async_trait::async_trait;
1053 use rustmod_datalink::{DataLink, DataLinkError};
1054 use std::collections::VecDeque;
1055 use std::sync::Arc;
1056 use std::sync::atomic::{AtomicUsize, Ordering};
1057 use std::time::Duration;
1058 use tokio::sync::Mutex;
1059 use tokio::time::sleep;
1060
1061 type MockQueue = VecDeque<Result<Vec<u8>, DataLinkError>>;
1062
1063 #[derive(Clone, Default)]
1064 struct MockLink {
1065 responses: Arc<Mutex<MockQueue>>,
1066 calls: Arc<AtomicUsize>,
1067 }
1068
1069 impl MockLink {
1070 fn with_responses(responses: Vec<Result<Vec<u8>, DataLinkError>>) -> Self {
1071 Self {
1072 responses: Arc::new(Mutex::new(responses.into())),
1073 calls: Arc::new(AtomicUsize::new(0)),
1074 }
1075 }
1076
1077 fn call_count(&self) -> usize {
1078 self.calls.load(Ordering::Relaxed)
1079 }
1080 }
1081
1082 #[async_trait]
1083 impl DataLink for MockLink {
1084 async fn exchange(
1085 &self,
1086 _unit_id: UnitId,
1087 _request_pdu: &[u8],
1088 response_pdu: &mut [u8],
1089 ) -> Result<usize, DataLinkError> {
1090 self.calls.fetch_add(1, Ordering::Relaxed);
1091 let mut guard = self.responses.lock().await;
1092 let next = guard
1093 .pop_front()
1094 .ok_or(DataLinkError::InvalidResponse("no mock response"))?;
1095 let bytes = next?;
1096 if bytes.len() > response_pdu.len() {
1097 return Err(DataLinkError::ResponseBufferTooSmall {
1098 needed: bytes.len(),
1099 available: response_pdu.len(),
1100 });
1101 }
1102 response_pdu[..bytes.len()].copy_from_slice(&bytes);
1103 Ok(bytes.len())
1104 }
1105 }
1106
1107 #[derive(Clone, Default)]
1108 struct ConnectionClosedThenSlowLink {
1109 calls: Arc<AtomicUsize>,
1110 }
1111
1112 impl ConnectionClosedThenSlowLink {
1113 fn call_count(&self) -> usize {
1114 self.calls.load(Ordering::Relaxed)
1115 }
1116 }
1117
1118 #[async_trait]
1119 impl DataLink for ConnectionClosedThenSlowLink {
1120 async fn exchange(
1121 &self,
1122 _unit_id: UnitId,
1123 _request_pdu: &[u8],
1124 response_pdu: &mut [u8],
1125 ) -> Result<usize, DataLinkError> {
1126 let call = self.calls.fetch_add(1, Ordering::Relaxed);
1127 if call == 0 {
1128 return Err(DataLinkError::ConnectionClosed);
1129 }
1130
1131 sleep(Duration::from_millis(50)).await;
1132 response_pdu[..4].copy_from_slice(&[0x03, 0x02, 0x00, 0x2A]);
1133 Ok(4)
1134 }
1135 }
1136
1137 #[tokio::test]
1138 async fn read_holding_registers_success() {
1139 let link = MockLink::with_responses(vec![Ok(vec![
1140 0x03, 0x04, 0x12, 0x34, 0xAB, 0xCD,
1141 ])]);
1142 let client = ModbusClient::new(link);
1143
1144 let values = client.read_holding_registers(UnitId::new(1), 0, 2).await.unwrap();
1145 assert_eq!(values, vec![0x1234, 0xABCD]);
1146 }
1147
1148 #[tokio::test]
1149 async fn exception_is_mapped() {
1150 let link = MockLink::with_responses(vec![Ok(vec![0x83, 0x02])]);
1151 let client = ModbusClient::new(link);
1152
1153 let err = client.read_holding_registers(UnitId::new(1), 0, 1).await.unwrap_err();
1154 assert!(matches!(err, ClientError::Exception(_)));
1155 }
1156
1157 #[tokio::test]
1158 async fn custom_request_roundtrip() {
1159 let link = MockLink::with_responses(vec![Ok(vec![0x41, 0x12, 0x34])]);
1160 let client = ModbusClient::new(link);
1161
1162 let payload = client.custom_request(UnitId::new(1), 0x41, &[0xAA]).await.unwrap();
1163 assert_eq!(payload, vec![0x12, 0x34]);
1164 }
1165
1166 #[tokio::test]
1167 async fn report_server_id_parses_payload() {
1168 let link = MockLink::with_responses(vec![Ok(vec![0x11, 0x03, 0x2A, 0xFF, 0x10])]);
1169 let client = ModbusClient::new(link);
1170
1171 let report = client.report_server_id(UnitId::new(1)).await.unwrap();
1172 assert_eq!(report.server_id, 0x2A);
1173 assert!(report.run_indicator_status);
1174 assert_eq!(report.additional_data, vec![0x10]);
1175 }
1176
1177 #[tokio::test]
1178 async fn read_device_identification_parses_objects() {
1179 let link = MockLink::with_responses(vec![Ok(vec![
1180 0x2B, 0x0E, 0x01, 0x01, 0x00, 0x00, 0x02, 0x00, 0x07, b'r', b'u', b's', b't', b'-',
1181 b'm', b'o', 0x01, 0x03, b'0', b'.', b'1',
1182 ])]);
1183 let client = ModbusClient::new(link);
1184
1185 let response = client.read_device_identification(UnitId::new(1), 0x01, 0x00).await.unwrap();
1186 assert_eq!(response.read_device_id_code, 0x01);
1187 assert_eq!(response.conformity_level, 0x01);
1188 assert!(!response.more_follows);
1189 assert_eq!(response.next_object_id, 0x00);
1190 assert_eq!(response.objects.len(), 2);
1191 assert_eq!(response.objects[0].object_id, 0x00);
1192 assert_eq!(response.objects[0].value, b"rust-mo".to_vec());
1193 assert_eq!(response.objects[1].object_id, 0x01);
1194 assert_eq!(response.objects[1].value, b"0.1".to_vec());
1195 }
1196
1197 #[tokio::test]
1198 async fn read_device_identification_rejects_wrong_mei_type() {
1199 let link = MockLink::with_responses(vec![Ok(vec![
1200 0x2B, 0x0D, 0x01, 0x01, 0x00, 0x00, 0x00,
1201 ])]);
1202 let client = ModbusClient::new(link);
1203
1204 let err = client
1205 .read_device_identification(UnitId::new(1), 0x01, 0x00)
1206 .await
1207 .unwrap_err();
1208 assert!(matches!(
1209 err,
1210 ClientError::InvalidResponse(InvalidResponseKind::Other("read device identification MEI type mismatch"))
1211 ));
1212 }
1213
1214 #[tokio::test]
1215 async fn retries_after_connection_closed() {
1216 let link = MockLink::with_responses(vec![
1217 Err(DataLinkError::ConnectionClosed),
1218 Ok(vec![0x03, 0x02, 0x00, 0x2A]),
1219 ]);
1220 let link_for_assert = link.clone();
1221
1222 let client = ModbusClient::with_config(link, ClientConfig::default().with_retry_count(1));
1223
1224 let values = client.read_holding_registers(UnitId::new(1), 0, 1).await.unwrap();
1225 assert_eq!(values, vec![42]);
1226 assert_eq!(link_for_assert.call_count(), 2);
1227 }
1228
1229 #[tokio::test]
1230 async fn write_is_not_retried_by_default() {
1231 let link = MockLink::with_responses(vec![
1232 Err(DataLinkError::ConnectionClosed),
1233 Ok(vec![0x06, 0x00, 0x01, 0x00, 0x2A]),
1234 ]);
1235 let link_for_assert = link.clone();
1236
1237 let client = ModbusClient::with_config(link, ClientConfig::default().with_retry_count(1));
1238 let err = client.write_single_register(UnitId::new(1), 1, 42).await.unwrap_err();
1239
1240 assert!(matches!(
1241 err,
1242 ClientError::DataLink(DataLinkError::ConnectionClosed)
1243 ));
1244 assert_eq!(link_for_assert.call_count(), 1);
1245 }
1246
1247 #[tokio::test]
1248 async fn response_buffer_too_small_is_not_retried() {
1249 let link = MockLink::with_responses(vec![
1250 Err(DataLinkError::ResponseBufferTooSmall {
1251 needed: 300,
1252 available: 260,
1253 }),
1254 Ok(vec![0x03, 0x02, 0x00, 0x2A]),
1255 ]);
1256 let link_for_assert = link.clone();
1257
1258 let client = ModbusClient::with_config(link, ClientConfig::default().with_retry_count(1));
1259 let err = client.read_holding_registers(UnitId::new(1), 0, 1).await.unwrap_err();
1260
1261 assert!(matches!(
1262 err,
1263 ClientError::DataLink(DataLinkError::ResponseBufferTooSmall { .. })
1264 ));
1265 assert_eq!(link_for_assert.call_count(), 1);
1266 }
1267
1268 #[tokio::test]
1269 async fn write_can_retry_when_policy_is_all() {
1270 let link = MockLink::with_responses(vec![
1271 Err(DataLinkError::ConnectionClosed),
1272 Ok(vec![0x06, 0x00, 0x01, 0x00, 0x2A]),
1273 ]);
1274 let link_for_assert = link.clone();
1275
1276 let config = ClientConfig::default()
1277 .with_retry_count(1)
1278 .with_retry_policy(RetryPolicy::All);
1279 let client = ModbusClient::with_config(link, config);
1280 client.write_single_register(UnitId::new(1), 1, 42).await.unwrap();
1281
1282 assert_eq!(link_for_assert.call_count(), 2);
1283 }
1284
1285 #[tokio::test]
1286 async fn final_timeout_is_reported_over_previous_transport_error() {
1287 let link = ConnectionClosedThenSlowLink::default();
1288 let link_for_assert = link.clone();
1289
1290 let config = ClientConfig::default()
1291 .with_retry_count(1)
1292 .with_response_timeout(Duration::from_millis(10));
1293 let client = ModbusClient::with_config(link, config);
1294
1295 let err = client.read_holding_registers(UnitId::new(1), 0, 1).await.unwrap_err();
1296 assert!(matches!(err, ClientError::Timeout));
1297 assert_eq!(link_for_assert.call_count(), 2);
1298 }
1299
1300 #[tokio::test]
1301 async fn mask_write_register_success() {
1302 let link = MockLink::with_responses(vec![Ok(vec![0x16, 0x00, 0x04, 0xFF, 0x00, 0x00, 0x12])]);
1303 let client = ModbusClient::new(link);
1304 client
1305 .mask_write_register(UnitId::new(1), 0x0004, 0xFF00, 0x0012)
1306 .await
1307 .unwrap();
1308 }
1309
1310 #[tokio::test]
1311 async fn read_write_multiple_registers_success() {
1312 let link = MockLink::with_responses(vec![Ok(vec![0x17, 0x04, 0x12, 0x34, 0xAB, 0xCD])]);
1313 let client = ModbusClient::new(link);
1314
1315 let values = client
1316 .read_write_multiple_registers(UnitId::new(1), 0x0010, 2, 0x0020, &[0x0102, 0x0304])
1317 .await
1318 .unwrap();
1319 assert_eq!(values, vec![0x1234, 0xABCD]);
1320 }
1321
1322 #[tokio::test]
1323 async fn read_coils_rejects_truncated_payload() {
1324 let link = MockLink::with_responses(vec![Ok(vec![0x01, 0x01, 0b0000_1111])]);
1325 let client = ModbusClient::new(link);
1326 let err = client.read_coils(UnitId::new(1), 0, 9).await.unwrap_err();
1327 assert!(matches!(
1328 err,
1329 ClientError::InvalidResponse(InvalidResponseKind::PayloadLengthMismatch)
1330 ));
1331 }
1332
1333 #[tokio::test]
1334 async fn read_discrete_inputs_rejects_truncated_payload() {
1335 let link = MockLink::with_responses(vec![Ok(vec![0x02, 0x01, 0b0000_1111])]);
1336 let client = ModbusClient::new(link);
1337 let err = client.read_discrete_inputs(UnitId::new(1), 0, 9).await.unwrap_err();
1338 assert!(matches!(
1339 err,
1340 ClientError::InvalidResponse(InvalidResponseKind::PayloadLengthMismatch)
1341 ));
1342 }
1343
1344 #[cfg(feature = "metrics")]
1345 #[tokio::test]
1346 async fn metrics_count_success() {
1347 let link = MockLink::with_responses(vec![Ok(vec![0x03, 0x02, 0x00, 0x2A])]);
1348 let client = ModbusClient::new(link);
1349
1350 let _ = client.read_holding_registers(UnitId::new(1), 0, 1).await.unwrap();
1351 let metrics = client.metrics_snapshot();
1352
1353 assert_eq!(metrics.requests_total, 1);
1354 assert_eq!(metrics.successful_responses, 1);
1355 assert_eq!(metrics.exceptions_total, 0);
1356 }
1357}