1#![forbid(unsafe_code)]
4
5pub mod points;
6pub mod sync;
7
8pub use points::{CoilPoints, RegisterPoints};
9pub use sync::{SyncClientError, SyncModbusTcpClient};
10
11use rustmod_core::encoding::{Reader, Writer};
12use rustmod_core::pdu::{
13 CustomRequest, ExceptionResponse, ReadCoilsRequest, ReadDiscreteInputsRequest,
14 ReadHoldingRegistersRequest, ReadInputRegistersRequest, ReadWriteMultipleRegistersRequest,
15 Request, Response, MaskWriteRegisterRequest, WriteMultipleCoilsRequest,
16 WriteMultipleRegistersRequest, WriteSingleCoilRequest, WriteSingleRegisterRequest,
17};
18use rustmod_core::{DecodeError, EncodeError};
19use rustmod_datalink::{DataLink, DataLinkError};
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::Duration;
22use thiserror::Error;
23use tokio::sync::Mutex;
24use tokio::time::{Instant, sleep, timeout};
25use tracing::{debug, warn};
26
27#[cfg(feature = "metrics")]
28use std::sync::Arc;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum RetryPolicy {
32 Never,
33 ReadOnly,
34 All,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub struct ClientConfig {
39 pub response_timeout: Duration,
40 pub retry_count: u8,
41 pub throttle_delay: Option<Duration>,
42 pub retry_policy: RetryPolicy,
43}
44
45impl Default for ClientConfig {
46 fn default() -> Self {
47 Self {
48 response_timeout: Duration::from_secs(5),
49 retry_count: 3,
50 throttle_delay: None,
51 retry_policy: RetryPolicy::ReadOnly,
52 }
53 }
54}
55
56impl ClientConfig {
57 pub fn with_response_timeout(mut self, timeout: Duration) -> Self {
58 self.response_timeout = timeout;
59 self
60 }
61
62 pub fn with_retry_count(mut self, retry_count: u8) -> Self {
63 self.retry_count = retry_count;
64 self
65 }
66
67 pub fn with_throttle_delay(mut self, throttle_delay: Option<Duration>) -> Self {
68 self.throttle_delay = throttle_delay;
69 self
70 }
71
72 pub fn with_retry_policy(mut self, retry_policy: RetryPolicy) -> Self {
73 self.retry_policy = retry_policy;
74 self
75 }
76}
77
78#[derive(Debug, Error)]
79pub enum ClientError {
80 #[error("datalink error: {0}")]
81 DataLink(#[from] DataLinkError),
82 #[error("encode error: {0}")]
83 Encode(#[from] EncodeError),
84 #[error("decode error: {0}")]
85 Decode(#[from] DecodeError),
86 #[error("request timed out")]
87 Timeout,
88 #[error("modbus exception: {0:?}")]
89 Exception(ExceptionResponse),
90 #[error("invalid response: {0}")]
91 InvalidResponse(&'static str),
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
95pub struct ReportServerIdResponse {
96 pub server_id: u8,
97 pub run_indicator_status: bool,
98 pub additional_data: Vec<u8>,
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
102pub struct DeviceIdentificationObject {
103 pub object_id: u8,
104 pub value: Vec<u8>,
105}
106
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub struct ReadDeviceIdentificationResponse {
109 pub read_device_id_code: u8,
110 pub conformity_level: u8,
111 pub more_follows: bool,
112 pub next_object_id: u8,
113 pub objects: Vec<DeviceIdentificationObject>,
114}
115
116#[cfg(feature = "metrics")]
117#[derive(Debug, Default)]
118pub struct ClientMetrics {
119 requests_total: AtomicU64,
120 successful_responses: AtomicU64,
121 retries_total: AtomicU64,
122 timeouts_total: AtomicU64,
123 transport_errors_total: AtomicU64,
124 exceptions_total: AtomicU64,
125 decode_errors_total: AtomicU64,
126}
127
128#[cfg(feature = "metrics")]
129#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
130pub struct ClientMetricsSnapshot {
131 pub requests_total: u64,
132 pub successful_responses: u64,
133 pub retries_total: u64,
134 pub timeouts_total: u64,
135 pub transport_errors_total: u64,
136 pub exceptions_total: u64,
137 pub decode_errors_total: u64,
138}
139
140#[cfg(feature = "metrics")]
141impl ClientMetrics {
142 fn snapshot(&self) -> ClientMetricsSnapshot {
143 ClientMetricsSnapshot {
144 requests_total: self.requests_total.load(Ordering::Relaxed),
145 successful_responses: self.successful_responses.load(Ordering::Relaxed),
146 retries_total: self.retries_total.load(Ordering::Relaxed),
147 timeouts_total: self.timeouts_total.load(Ordering::Relaxed),
148 transport_errors_total: self.transport_errors_total.load(Ordering::Relaxed),
149 exceptions_total: self.exceptions_total.load(Ordering::Relaxed),
150 decode_errors_total: self.decode_errors_total.load(Ordering::Relaxed),
151 }
152 }
153}
154
155pub struct ModbusClient<D: DataLink> {
156 datalink: D,
157 config: ClientConfig,
158 last_request_at: Mutex<Option<Instant>>,
159 request_counter: AtomicU64,
160 #[cfg(feature = "metrics")]
161 metrics: Arc<ClientMetrics>,
162}
163
164impl<D: DataLink> ModbusClient<D> {
165 pub fn new(datalink: D) -> Self {
166 Self::with_config(datalink, ClientConfig::default())
167 }
168
169 pub fn with_config(datalink: D, config: ClientConfig) -> Self {
170 Self {
171 datalink,
172 config,
173 last_request_at: Mutex::new(None),
174 request_counter: AtomicU64::new(1),
175 #[cfg(feature = "metrics")]
176 metrics: Arc::new(ClientMetrics::default()),
177 }
178 }
179
180 pub fn config(&self) -> ClientConfig {
181 self.config
182 }
183
184 #[cfg(feature = "metrics")]
185 pub fn metrics_snapshot(&self) -> ClientMetricsSnapshot {
186 self.metrics.snapshot()
187 }
188
189 fn next_correlation_id(&self) -> u64 {
190 self.request_counter.fetch_add(1, Ordering::Relaxed)
191 }
192
193 async fn apply_throttle(&self) {
194 let Some(delay) = self.config.throttle_delay else {
195 return;
196 };
197
198 let mut last = self.last_request_at.lock().await;
199 if let Some(previous) = *last {
200 let elapsed = previous.elapsed();
201 if elapsed < delay {
202 sleep(delay - elapsed).await;
203 }
204 }
205 *last = Some(Instant::now());
206 }
207
208 fn is_retryable(err: &DataLinkError) -> bool {
209 matches!(
210 err,
211 DataLinkError::Io(_)
212 | DataLinkError::Timeout
213 | DataLinkError::ConnectionClosed
214 )
215 }
216
217 fn request_is_retry_eligible(&self, request: &Request<'_>) -> bool {
218 match self.config.retry_policy {
219 RetryPolicy::Never => false,
220 RetryPolicy::All => true,
221 RetryPolicy::ReadOnly => matches!(
222 request,
223 Request::ReadCoils(_)
224 | Request::ReadDiscreteInputs(_)
225 | Request::ReadHoldingRegisters(_)
226 | Request::ReadInputRegisters(_)
227 ),
228 }
229 }
230
231 async fn exchange_raw(
232 &self,
233 correlation_id: u64,
234 unit_id: u8,
235 request_pdu: &[u8],
236 response_buf: &mut [u8],
237 retry_eligible: bool,
238 ) -> Result<usize, ClientError> {
239 self.apply_throttle().await;
240
241 #[cfg(feature = "metrics")]
242 self.metrics.requests_total.fetch_add(1, Ordering::Relaxed);
243
244 let attempts = usize::from(self.config.retry_count) + 1;
245 let mut last_err: Option<ClientError> = None;
246
247 for attempt in 1..=attempts {
248 let result = timeout(
249 self.config.response_timeout,
250 self.datalink.exchange(unit_id, request_pdu, response_buf),
251 )
252 .await;
253
254 match result {
255 Ok(Ok(len)) => {
256 debug!(
257 correlation_id,
258 unit_id,
259 attempt,
260 len,
261 "modbus request succeeded"
262 );
263 #[cfg(feature = "metrics")]
264 self.metrics
265 .successful_responses
266 .fetch_add(1, Ordering::Relaxed);
267 return Ok(len);
268 }
269 Ok(Err(err)) => {
270 #[cfg(feature = "metrics")]
271 self.metrics
272 .transport_errors_total
273 .fetch_add(1, Ordering::Relaxed);
274 if attempt < attempts && retry_eligible && Self::is_retryable(&err) {
275 warn!(
276 correlation_id,
277 unit_id,
278 attempt,
279 error = %err,
280 "retrying modbus request after transport error"
281 );
282 #[cfg(feature = "metrics")]
283 self.metrics.retries_total.fetch_add(1, Ordering::Relaxed);
284 last_err = Some(ClientError::DataLink(err));
285 continue;
286 }
287 return Err(ClientError::DataLink(err));
288 }
289 Err(_) => {
290 #[cfg(feature = "metrics")]
291 self.metrics.timeouts_total.fetch_add(1, Ordering::Relaxed);
292 if attempt < attempts && retry_eligible {
293 warn!(
294 correlation_id,
295 unit_id,
296 attempt,
297 "retrying modbus request after timeout"
298 );
299 #[cfg(feature = "metrics")]
300 self.metrics.retries_total.fetch_add(1, Ordering::Relaxed);
301 last_err = Some(ClientError::Timeout);
302 continue;
303 }
304 return Err(ClientError::Timeout);
305 }
306 }
307 }
308
309 Err(last_err.unwrap_or(ClientError::InvalidResponse(
310 "retry loop exhausted",
311 )))
312 }
313
314 async fn send_request<'a>(
315 &self,
316 unit_id: u8,
317 request: &Request<'_>,
318 response_storage: &'a mut [u8],
319 ) -> Result<Response<'a>, ClientError> {
320 let correlation_id = self.next_correlation_id();
321 let mut req_buf = [0u8; 260];
322 let mut writer = Writer::new(&mut req_buf);
323 request.encode(&mut writer)?;
324
325 debug!(
326 correlation_id,
327 unit_id,
328 function = request.function_code().as_u8(),
329 pdu_len = writer.as_written().len(),
330 "dispatching modbus request"
331 );
332 let retry_eligible = self.request_is_retry_eligible(request);
333
334 let response_len = self
335 .exchange_raw(
336 correlation_id,
337 unit_id,
338 writer.as_written(),
339 response_storage,
340 retry_eligible,
341 )
342 .await?;
343
344 let mut reader = Reader::new(&response_storage[..response_len]);
345 let response = match Response::decode(&mut reader) {
346 Ok(resp) => resp,
347 Err(err) => {
348 #[cfg(feature = "metrics")]
349 self.metrics
350 .decode_errors_total
351 .fetch_add(1, Ordering::Relaxed);
352 return Err(ClientError::Decode(err));
353 }
354 };
355
356 if !reader.is_empty() {
357 #[cfg(feature = "metrics")]
358 self.metrics
359 .decode_errors_total
360 .fetch_add(1, Ordering::Relaxed);
361 return Err(ClientError::InvalidResponse("trailing bytes in response"));
362 }
363
364 if let Response::Exception(ex) = response {
365 #[cfg(feature = "metrics")]
366 self.metrics.exceptions_total.fetch_add(1, Ordering::Relaxed);
367 return Err(ClientError::Exception(ex));
368 }
369
370 Ok(response)
371 }
372
373 pub async fn read_coils(
374 &self,
375 unit_id: u8,
376 start: u16,
377 quantity: u16,
378 ) -> Result<Vec<bool>, ClientError> {
379 let request = Request::ReadCoils(ReadCoilsRequest {
380 start_address: start,
381 quantity,
382 });
383
384 let mut response_buf = [0u8; 260];
385 let response = self
386 .send_request(unit_id, &request, &mut response_buf)
387 .await?;
388
389 match response {
390 Response::ReadCoils(data) => {
391 let count = usize::from(quantity);
392 if data.coil_status.len() * 8 < count {
393 return Err(ClientError::InvalidResponse(
394 "coil payload shorter than requested",
395 ));
396 }
397 Ok((0..count).filter_map(|idx| data.coil(idx)).collect())
398 }
399 _ => Err(ClientError::InvalidResponse("unexpected function response")),
400 }
401 }
402
403 pub async fn custom_request(
404 &self,
405 unit_id: u8,
406 function_code: u8,
407 payload: &[u8],
408 ) -> Result<Vec<u8>, ClientError> {
409 let request = Request::Custom(CustomRequest {
410 function_code,
411 data: payload,
412 });
413
414 let mut response_buf = [0u8; 260];
415 let response = self
416 .send_request(unit_id, &request, &mut response_buf)
417 .await?;
418
419 match response {
420 Response::Custom(custom) if custom.function_code == function_code => {
421 Ok(custom.data.to_vec())
422 }
423 Response::Custom(_) => {
424 Err(ClientError::InvalidResponse("custom response function mismatch"))
425 }
426 _ => Err(ClientError::InvalidResponse("unexpected function response")),
427 }
428 }
429
430 pub async fn report_server_id(&self, unit_id: u8) -> Result<ReportServerIdResponse, ClientError> {
431 let payload = self.custom_request(unit_id, 0x11, &[]).await?;
432 let Some((&byte_count, data)) = payload.split_first() else {
433 return Err(ClientError::InvalidResponse(
434 "report server id payload missing byte count",
435 ));
436 };
437 let byte_count = usize::from(byte_count);
438 if data.len() != byte_count || byte_count < 2 {
439 return Err(ClientError::InvalidResponse(
440 "report server id payload length mismatch",
441 ));
442 }
443
444 Ok(ReportServerIdResponse {
445 server_id: data[0],
446 run_indicator_status: data[1] != 0,
447 additional_data: data[2..].to_vec(),
448 })
449 }
450
451 pub async fn read_device_identification(
452 &self,
453 unit_id: u8,
454 read_device_id_code: u8,
455 object_id: u8,
456 ) -> Result<ReadDeviceIdentificationResponse, ClientError> {
457 let payload = self
458 .custom_request(unit_id, 0x2B, &[0x0E, read_device_id_code, object_id])
459 .await?;
460
461 if payload.len() < 6 {
462 return Err(ClientError::InvalidResponse(
463 "read device identification payload too short",
464 ));
465 }
466 if payload[0] != 0x0E {
467 return Err(ClientError::InvalidResponse(
468 "read device identification MEI type mismatch",
469 ));
470 }
471
472 let object_count = usize::from(payload[5]);
473 let mut cursor = 6usize;
474 let mut objects = Vec::with_capacity(object_count);
475 for _ in 0..object_count {
476 if payload.len().saturating_sub(cursor) < 2 {
477 return Err(ClientError::InvalidResponse(
478 "read device identification object header truncated",
479 ));
480 }
481 let id = payload[cursor];
482 let len = usize::from(payload[cursor + 1]);
483 cursor += 2;
484 let end = cursor
485 .checked_add(len)
486 .ok_or(ClientError::InvalidResponse(
487 "read device identification object length overflow",
488 ))?;
489 if end > payload.len() {
490 return Err(ClientError::InvalidResponse(
491 "read device identification object data truncated",
492 ));
493 }
494 objects.push(DeviceIdentificationObject {
495 object_id: id,
496 value: payload[cursor..end].to_vec(),
497 });
498 cursor = end;
499 }
500 if cursor != payload.len() {
501 return Err(ClientError::InvalidResponse(
502 "read device identification trailing data",
503 ));
504 }
505
506 Ok(ReadDeviceIdentificationResponse {
507 read_device_id_code: payload[1],
508 conformity_level: payload[2],
509 more_follows: payload[3] != 0,
510 next_object_id: payload[4],
511 objects,
512 })
513 }
514
515 pub async fn read_discrete_inputs(
516 &self,
517 unit_id: u8,
518 start: u16,
519 quantity: u16,
520 ) -> Result<Vec<bool>, ClientError> {
521 let request = Request::ReadDiscreteInputs(ReadDiscreteInputsRequest {
522 start_address: start,
523 quantity,
524 });
525
526 let mut response_buf = [0u8; 260];
527 let response = self
528 .send_request(unit_id, &request, &mut response_buf)
529 .await?;
530
531 match response {
532 Response::ReadDiscreteInputs(data) => {
533 let count = usize::from(quantity);
534 if data.input_status.len() * 8 < count {
535 return Err(ClientError::InvalidResponse(
536 "discrete input payload shorter than requested",
537 ));
538 }
539 Ok((0..count).filter_map(|idx| data.coil(idx)).collect())
540 }
541 _ => Err(ClientError::InvalidResponse("unexpected function response")),
542 }
543 }
544
545 pub async fn read_holding_registers(
546 &self,
547 unit_id: u8,
548 start: u16,
549 quantity: u16,
550 ) -> Result<Vec<u16>, ClientError> {
551 let request = Request::ReadHoldingRegisters(ReadHoldingRegistersRequest {
552 start_address: start,
553 quantity,
554 });
555
556 let mut response_buf = [0u8; 260];
557 let response = self
558 .send_request(unit_id, &request, &mut response_buf)
559 .await?;
560
561 match response {
562 Response::ReadHoldingRegisters(data) => {
563 let count = usize::from(quantity);
564 if data.register_count() < count {
565 return Err(ClientError::InvalidResponse(
566 "register payload shorter than requested",
567 ));
568 }
569 Ok((0..count).filter_map(|idx| data.register(idx)).collect())
570 }
571 _ => Err(ClientError::InvalidResponse("unexpected function response")),
572 }
573 }
574
575 pub async fn read_input_registers(
576 &self,
577 unit_id: u8,
578 start: u16,
579 quantity: u16,
580 ) -> Result<Vec<u16>, ClientError> {
581 let request = Request::ReadInputRegisters(ReadInputRegistersRequest {
582 start_address: start,
583 quantity,
584 });
585
586 let mut response_buf = [0u8; 260];
587 let response = self
588 .send_request(unit_id, &request, &mut response_buf)
589 .await?;
590
591 match response {
592 Response::ReadInputRegisters(data) => {
593 let count = usize::from(quantity);
594 if data.register_count() < count {
595 return Err(ClientError::InvalidResponse(
596 "register payload shorter than requested",
597 ));
598 }
599 Ok((0..count).filter_map(|idx| data.register(idx)).collect())
600 }
601 _ => Err(ClientError::InvalidResponse("unexpected function response")),
602 }
603 }
604
605 pub async fn write_single_coil(
606 &self,
607 unit_id: u8,
608 address: u16,
609 value: bool,
610 ) -> Result<(), ClientError> {
611 let request = Request::WriteSingleCoil(WriteSingleCoilRequest { address, value });
612
613 let mut response_buf = [0u8; 260];
614 let response = self
615 .send_request(unit_id, &request, &mut response_buf)
616 .await?;
617
618 match response {
619 Response::WriteSingleCoil(resp) if resp.address == address && resp.value == value => Ok(()),
620 Response::WriteSingleCoil(_) => {
621 Err(ClientError::InvalidResponse("write single coil echo mismatch"))
622 }
623 _ => Err(ClientError::InvalidResponse("unexpected function response")),
624 }
625 }
626
627 pub async fn write_single_register(
628 &self,
629 unit_id: u8,
630 address: u16,
631 value: u16,
632 ) -> Result<(), ClientError> {
633 let request = Request::WriteSingleRegister(WriteSingleRegisterRequest { address, value });
634
635 let mut response_buf = [0u8; 260];
636 let response = self
637 .send_request(unit_id, &request, &mut response_buf)
638 .await?;
639
640 match response {
641 Response::WriteSingleRegister(resp) if resp.address == address && resp.value == value => {
642 Ok(())
643 }
644 Response::WriteSingleRegister(_) => {
645 Err(ClientError::InvalidResponse("write single register echo mismatch"))
646 }
647 _ => Err(ClientError::InvalidResponse("unexpected function response")),
648 }
649 }
650
651 pub async fn mask_write_register(
652 &self,
653 unit_id: u8,
654 address: u16,
655 and_mask: u16,
656 or_mask: u16,
657 ) -> Result<(), ClientError> {
658 let request = Request::MaskWriteRegister(MaskWriteRegisterRequest {
659 address,
660 and_mask,
661 or_mask,
662 });
663
664 let mut response_buf = [0u8; 260];
665 let response = self
666 .send_request(unit_id, &request, &mut response_buf)
667 .await?;
668
669 match response {
670 Response::MaskWriteRegister(resp)
671 if resp.address == address && resp.and_mask == and_mask && resp.or_mask == or_mask =>
672 {
673 Ok(())
674 }
675 Response::MaskWriteRegister(_) => {
676 Err(ClientError::InvalidResponse("mask write register echo mismatch"))
677 }
678 _ => Err(ClientError::InvalidResponse("unexpected function response")),
679 }
680 }
681
682 pub async fn write_multiple_coils(
683 &self,
684 unit_id: u8,
685 start: u16,
686 values: &[bool],
687 ) -> Result<(), ClientError> {
688 let request_variant = WriteMultipleCoilsRequest {
689 start_address: start,
690 values,
691 };
692 let expected_qty = request_variant.quantity()?;
693
694 let request = Request::WriteMultipleCoils(request_variant);
695 let mut response_buf = [0u8; 260];
696 let response = self
697 .send_request(unit_id, &request, &mut response_buf)
698 .await?;
699
700 match response {
701 Response::WriteMultipleCoils(resp)
702 if resp.start_address == start && resp.quantity == expected_qty =>
703 {
704 Ok(())
705 }
706 Response::WriteMultipleCoils(_) => {
707 Err(ClientError::InvalidResponse("write multiple coils echo mismatch"))
708 }
709 _ => Err(ClientError::InvalidResponse("unexpected function response")),
710 }
711 }
712
713 pub async fn write_multiple_registers(
714 &self,
715 unit_id: u8,
716 start: u16,
717 values: &[u16],
718 ) -> Result<(), ClientError> {
719 let request_variant = WriteMultipleRegistersRequest {
720 start_address: start,
721 values,
722 };
723 let expected_qty = request_variant.quantity()?;
724
725 let request = Request::WriteMultipleRegisters(request_variant);
726 let mut response_buf = [0u8; 260];
727 let response = self
728 .send_request(unit_id, &request, &mut response_buf)
729 .await?;
730
731 match response {
732 Response::WriteMultipleRegisters(resp)
733 if resp.start_address == start && resp.quantity == expected_qty =>
734 {
735 Ok(())
736 }
737 Response::WriteMultipleRegisters(_) => {
738 Err(ClientError::InvalidResponse(
739 "write multiple registers echo mismatch",
740 ))
741 }
742 _ => Err(ClientError::InvalidResponse("unexpected function response")),
743 }
744 }
745
746 pub async fn read_write_multiple_registers(
747 &self,
748 unit_id: u8,
749 read_start: u16,
750 read_quantity: u16,
751 write_start: u16,
752 write_values: &[u16],
753 ) -> Result<Vec<u16>, ClientError> {
754 let request = Request::ReadWriteMultipleRegisters(ReadWriteMultipleRegistersRequest {
755 read_start_address: read_start,
756 read_quantity,
757 write_start_address: write_start,
758 values: write_values,
759 });
760
761 let mut response_buf = [0u8; 260];
762 let response = self
763 .send_request(unit_id, &request, &mut response_buf)
764 .await?;
765
766 match response {
767 Response::ReadWriteMultipleRegisters(data) => {
768 let count = usize::from(read_quantity);
769 if data.register_count() < count {
770 return Err(ClientError::InvalidResponse(
771 "read-write register payload shorter than requested",
772 ));
773 }
774 Ok((0..count).filter_map(|idx| data.register(idx)).collect())
775 }
776 _ => Err(ClientError::InvalidResponse("unexpected function response")),
777 }
778 }
779}
780
781#[cfg(test)]
782mod tests {
783 use super::{ClientConfig, ClientError, ModbusClient, RetryPolicy};
784 use async_trait::async_trait;
785 use rustmod_datalink::{DataLink, DataLinkError};
786 use std::collections::VecDeque;
787 use std::sync::Arc;
788 use std::sync::atomic::{AtomicUsize, Ordering};
789 use std::time::Duration;
790 use tokio::sync::Mutex;
791 use tokio::time::sleep;
792
793 type MockQueue = VecDeque<Result<Vec<u8>, DataLinkError>>;
794
795 #[derive(Clone, Default)]
796 struct MockLink {
797 responses: Arc<Mutex<MockQueue>>,
798 calls: Arc<AtomicUsize>,
799 }
800
801 impl MockLink {
802 fn with_responses(responses: Vec<Result<Vec<u8>, DataLinkError>>) -> Self {
803 Self {
804 responses: Arc::new(Mutex::new(responses.into())),
805 calls: Arc::new(AtomicUsize::new(0)),
806 }
807 }
808
809 fn call_count(&self) -> usize {
810 self.calls.load(Ordering::Relaxed)
811 }
812 }
813
814 #[async_trait]
815 impl DataLink for MockLink {
816 async fn exchange(
817 &self,
818 _unit_id: u8,
819 _request_pdu: &[u8],
820 response_pdu: &mut [u8],
821 ) -> Result<usize, DataLinkError> {
822 self.calls.fetch_add(1, Ordering::Relaxed);
823 let mut guard = self.responses.lock().await;
824 let next = guard
825 .pop_front()
826 .ok_or(DataLinkError::InvalidResponse("no mock response"))?;
827 let bytes = next?;
828 if bytes.len() > response_pdu.len() {
829 return Err(DataLinkError::ResponseBufferTooSmall {
830 needed: bytes.len(),
831 available: response_pdu.len(),
832 });
833 }
834 response_pdu[..bytes.len()].copy_from_slice(&bytes);
835 Ok(bytes.len())
836 }
837 }
838
839 #[derive(Clone, Default)]
840 struct ConnectionClosedThenSlowLink {
841 calls: Arc<AtomicUsize>,
842 }
843
844 impl ConnectionClosedThenSlowLink {
845 fn call_count(&self) -> usize {
846 self.calls.load(Ordering::Relaxed)
847 }
848 }
849
850 #[async_trait]
851 impl DataLink for ConnectionClosedThenSlowLink {
852 async fn exchange(
853 &self,
854 _unit_id: u8,
855 _request_pdu: &[u8],
856 response_pdu: &mut [u8],
857 ) -> Result<usize, DataLinkError> {
858 let call = self.calls.fetch_add(1, Ordering::Relaxed);
859 if call == 0 {
860 return Err(DataLinkError::ConnectionClosed);
861 }
862
863 sleep(Duration::from_millis(50)).await;
864 response_pdu[..4].copy_from_slice(&[0x03, 0x02, 0x00, 0x2A]);
865 Ok(4)
866 }
867 }
868
869 #[tokio::test]
870 async fn read_holding_registers_success() {
871 let link = MockLink::with_responses(vec![Ok(vec![
872 0x03, 0x04, 0x12, 0x34, 0xAB, 0xCD,
873 ])]);
874 let client = ModbusClient::new(link);
875
876 let values = client.read_holding_registers(1, 0, 2).await.unwrap();
877 assert_eq!(values, vec![0x1234, 0xABCD]);
878 }
879
880 #[tokio::test]
881 async fn exception_is_mapped() {
882 let link = MockLink::with_responses(vec![Ok(vec![0x83, 0x02])]);
883 let client = ModbusClient::new(link);
884
885 let err = client.read_holding_registers(1, 0, 1).await.unwrap_err();
886 assert!(matches!(err, ClientError::Exception(_)));
887 }
888
889 #[tokio::test]
890 async fn custom_request_roundtrip() {
891 let link = MockLink::with_responses(vec![Ok(vec![0x41, 0x12, 0x34])]);
892 let client = ModbusClient::new(link);
893
894 let payload = client.custom_request(1, 0x41, &[0xAA]).await.unwrap();
895 assert_eq!(payload, vec![0x12, 0x34]);
896 }
897
898 #[tokio::test]
899 async fn report_server_id_parses_payload() {
900 let link = MockLink::with_responses(vec![Ok(vec![0x11, 0x03, 0x2A, 0xFF, 0x10])]);
901 let client = ModbusClient::new(link);
902
903 let report = client.report_server_id(1).await.unwrap();
904 assert_eq!(report.server_id, 0x2A);
905 assert!(report.run_indicator_status);
906 assert_eq!(report.additional_data, vec![0x10]);
907 }
908
909 #[tokio::test]
910 async fn read_device_identification_parses_objects() {
911 let link = MockLink::with_responses(vec![Ok(vec![
912 0x2B, 0x0E, 0x01, 0x01, 0x00, 0x00, 0x02, 0x00, 0x07, b'r', b'u', b's', b't', b'-',
913 b'm', b'o', 0x01, 0x03, b'0', b'.', b'1',
914 ])]);
915 let client = ModbusClient::new(link);
916
917 let response = client.read_device_identification(1, 0x01, 0x00).await.unwrap();
918 assert_eq!(response.read_device_id_code, 0x01);
919 assert_eq!(response.conformity_level, 0x01);
920 assert!(!response.more_follows);
921 assert_eq!(response.next_object_id, 0x00);
922 assert_eq!(response.objects.len(), 2);
923 assert_eq!(response.objects[0].object_id, 0x00);
924 assert_eq!(response.objects[0].value, b"rust-mo".to_vec());
925 assert_eq!(response.objects[1].object_id, 0x01);
926 assert_eq!(response.objects[1].value, b"0.1".to_vec());
927 }
928
929 #[tokio::test]
930 async fn read_device_identification_rejects_wrong_mei_type() {
931 let link = MockLink::with_responses(vec![Ok(vec![
932 0x2B, 0x0D, 0x01, 0x01, 0x00, 0x00, 0x00,
933 ])]);
934 let client = ModbusClient::new(link);
935
936 let err = client
937 .read_device_identification(1, 0x01, 0x00)
938 .await
939 .unwrap_err();
940 assert!(matches!(
941 err,
942 ClientError::InvalidResponse("read device identification MEI type mismatch")
943 ));
944 }
945
946 #[tokio::test]
947 async fn retries_after_connection_closed() {
948 let link = MockLink::with_responses(vec![
949 Err(DataLinkError::ConnectionClosed),
950 Ok(vec![0x03, 0x02, 0x00, 0x2A]),
951 ]);
952 let link_for_assert = link.clone();
953
954 let client = ModbusClient::with_config(link, ClientConfig::default().with_retry_count(1));
955
956 let values = client.read_holding_registers(1, 0, 1).await.unwrap();
957 assert_eq!(values, vec![42]);
958 assert_eq!(link_for_assert.call_count(), 2);
959 }
960
961 #[tokio::test]
962 async fn write_is_not_retried_by_default() {
963 let link = MockLink::with_responses(vec![
964 Err(DataLinkError::ConnectionClosed),
965 Ok(vec![0x06, 0x00, 0x01, 0x00, 0x2A]),
966 ]);
967 let link_for_assert = link.clone();
968
969 let client = ModbusClient::with_config(link, ClientConfig::default().with_retry_count(1));
970 let err = client.write_single_register(1, 1, 42).await.unwrap_err();
971
972 assert!(matches!(
973 err,
974 ClientError::DataLink(DataLinkError::ConnectionClosed)
975 ));
976 assert_eq!(link_for_assert.call_count(), 1);
977 }
978
979 #[tokio::test]
980 async fn response_buffer_too_small_is_not_retried() {
981 let link = MockLink::with_responses(vec![
982 Err(DataLinkError::ResponseBufferTooSmall {
983 needed: 300,
984 available: 260,
985 }),
986 Ok(vec![0x03, 0x02, 0x00, 0x2A]),
987 ]);
988 let link_for_assert = link.clone();
989
990 let client = ModbusClient::with_config(link, ClientConfig::default().with_retry_count(1));
991 let err = client.read_holding_registers(1, 0, 1).await.unwrap_err();
992
993 assert!(matches!(
994 err,
995 ClientError::DataLink(DataLinkError::ResponseBufferTooSmall { .. })
996 ));
997 assert_eq!(link_for_assert.call_count(), 1);
998 }
999
1000 #[tokio::test]
1001 async fn write_can_retry_when_policy_is_all() {
1002 let link = MockLink::with_responses(vec![
1003 Err(DataLinkError::ConnectionClosed),
1004 Ok(vec![0x06, 0x00, 0x01, 0x00, 0x2A]),
1005 ]);
1006 let link_for_assert = link.clone();
1007
1008 let config = ClientConfig::default()
1009 .with_retry_count(1)
1010 .with_retry_policy(RetryPolicy::All);
1011 let client = ModbusClient::with_config(link, config);
1012 client.write_single_register(1, 1, 42).await.unwrap();
1013
1014 assert_eq!(link_for_assert.call_count(), 2);
1015 }
1016
1017 #[tokio::test]
1018 async fn final_timeout_is_reported_over_previous_transport_error() {
1019 let link = ConnectionClosedThenSlowLink::default();
1020 let link_for_assert = link.clone();
1021
1022 let config = ClientConfig::default()
1023 .with_retry_count(1)
1024 .with_response_timeout(Duration::from_millis(10));
1025 let client = ModbusClient::with_config(link, config);
1026
1027 let err = client.read_holding_registers(1, 0, 1).await.unwrap_err();
1028 assert!(matches!(err, ClientError::Timeout));
1029 assert_eq!(link_for_assert.call_count(), 2);
1030 }
1031
1032 #[tokio::test]
1033 async fn mask_write_register_success() {
1034 let link = MockLink::with_responses(vec![Ok(vec![0x16, 0x00, 0x04, 0xFF, 0x00, 0x00, 0x12])]);
1035 let client = ModbusClient::new(link);
1036 client
1037 .mask_write_register(1, 0x0004, 0xFF00, 0x0012)
1038 .await
1039 .unwrap();
1040 }
1041
1042 #[tokio::test]
1043 async fn read_write_multiple_registers_success() {
1044 let link = MockLink::with_responses(vec![Ok(vec![0x17, 0x04, 0x12, 0x34, 0xAB, 0xCD])]);
1045 let client = ModbusClient::new(link);
1046
1047 let values = client
1048 .read_write_multiple_registers(1, 0x0010, 2, 0x0020, &[0x0102, 0x0304])
1049 .await
1050 .unwrap();
1051 assert_eq!(values, vec![0x1234, 0xABCD]);
1052 }
1053
1054 #[tokio::test]
1055 async fn read_coils_rejects_truncated_payload() {
1056 let link = MockLink::with_responses(vec![Ok(vec![0x01, 0x01, 0b0000_1111])]);
1057 let client = ModbusClient::new(link);
1058 let err = client.read_coils(1, 0, 9).await.unwrap_err();
1059 assert!(matches!(
1060 err,
1061 ClientError::InvalidResponse("coil payload shorter than requested")
1062 ));
1063 }
1064
1065 #[tokio::test]
1066 async fn read_discrete_inputs_rejects_truncated_payload() {
1067 let link = MockLink::with_responses(vec![Ok(vec![0x02, 0x01, 0b0000_1111])]);
1068 let client = ModbusClient::new(link);
1069 let err = client.read_discrete_inputs(1, 0, 9).await.unwrap_err();
1070 assert!(matches!(
1071 err,
1072 ClientError::InvalidResponse("discrete input payload shorter than requested")
1073 ));
1074 }
1075
1076 #[cfg(feature = "metrics")]
1077 #[tokio::test]
1078 async fn metrics_count_success() {
1079 let link = MockLink::with_responses(vec![Ok(vec![0x03, 0x02, 0x00, 0x2A])]);
1080 let client = ModbusClient::new(link);
1081
1082 let _ = client.read_holding_registers(1, 0, 1).await.unwrap();
1083 let metrics = client.metrics_snapshot();
1084
1085 assert_eq!(metrics.requests_total, 1);
1086 assert_eq!(metrics.successful_responses, 1);
1087 assert_eq!(metrics.exceptions_total, 0);
1088 }
1089}