Skip to main content

stackforge_core/layer/modbus/
mod.rs

1//! Modbus protocol layer implementation.
2//!
3//! Implements Modbus/TCP (MBAP), Modbus RTU, and Modbus ASCII frame formats.
4//!
5//! ## Modbus/TCP (MBAP) Header Format
6//!
7//! ```text
8//! Offset  Size  Field
9//! 0       2     Transaction ID (big-endian)
10//! 2       2     Protocol ID (0x0000 for Modbus)
11//! 4       2     Length (number of following bytes, including Unit ID)
12//! 6       1     Unit ID (slave address)
13//! 7       1     Function Code
14//! 8..N    var   Data (function-code dependent)
15//! ```
16//!
17//! ## Modbus RTU Frame Format
18//!
19//! ```text
20//! Offset  Size  Field
21//! 0       1     Slave Address
22//! 1       1     Function Code
23//! 2..N-2  var   Data
24//! N-2     2     CRC-16 (little-endian)
25//! ```
26//!
27//! ## Modbus ASCII Frame Format
28//!
29//! ```text
30//! ':'  + hex(SlaveAddr + FuncCode + Data + LRC) + CR + LF
31//! ```
32
33pub mod builder;
34pub mod crc;
35
36pub use builder::ModbusBuilder;
37pub use crc::{modbus_crc16, modbus_lrc, verify_crc16, verify_lrc};
38
39use crate::layer::field::{FieldError, FieldValue};
40use crate::layer::{Layer, LayerIndex, LayerKind};
41
42/// Modbus/TCP default port.
43pub const MODBUS_TCP_PORT: u16 = 502;
44
45/// MBAP header length (transId + protoId + length + unitId = 7 bytes).
46pub const MODBUS_MBAP_HEADER_LEN: usize = 7;
47
48/// Minimum Modbus/TCP message: MBAP header (7) + function code (1) = 8 bytes.
49pub const MODBUS_MIN_HEADER_LEN: usize = 8;
50
51/// Modbus function code constants.
52pub mod func_code {
53    pub const READ_COILS: u8 = 0x01;
54    pub const READ_DISCRETE_INPUTS: u8 = 0x02;
55    pub const READ_HOLDING_REGISTERS: u8 = 0x03;
56    pub const READ_INPUT_REGISTERS: u8 = 0x04;
57    pub const WRITE_SINGLE_COIL: u8 = 0x05;
58    pub const WRITE_SINGLE_REGISTER: u8 = 0x06;
59    pub const READ_EXCEPTION_STATUS: u8 = 0x07;
60    pub const DIAGNOSTICS: u8 = 0x08;
61    pub const GET_COMM_EVENT_COUNTER: u8 = 0x0B;
62    pub const GET_COMM_EVENT_LOG: u8 = 0x0C;
63    pub const WRITE_MULTIPLE_COILS: u8 = 0x0F;
64    pub const WRITE_MULTIPLE_REGISTERS: u8 = 0x10;
65    pub const REPORT_SLAVE_ID: u8 = 0x11;
66    pub const READ_FILE_RECORD: u8 = 0x14;
67    pub const WRITE_FILE_RECORD: u8 = 0x15;
68    pub const MASK_WRITE_REGISTER: u8 = 0x16;
69    pub const READ_WRITE_MULTIPLE_REGISTERS: u8 = 0x17;
70    pub const READ_FIFO_QUEUE: u8 = 0x18;
71    pub const ENCAP_INTERFACE_TRANSPORT: u8 = 0x2B;
72}
73
74/// Modbus exception code constants.
75pub mod except_code {
76    pub const ILLEGAL_FUNCTION: u8 = 0x01;
77    pub const ILLEGAL_DATA_ADDRESS: u8 = 0x02;
78    pub const ILLEGAL_DATA_VALUE: u8 = 0x03;
79    pub const SERVER_DEVICE_FAILURE: u8 = 0x04;
80    pub const ACKNOWLEDGE: u8 = 0x05;
81    pub const SERVER_DEVICE_BUSY: u8 = 0x06;
82    pub const MEMORY_PARITY_ERROR: u8 = 0x08;
83    pub const GATEWAY_PATH_UNAVAILABLE: u8 = 0x0A;
84    pub const GATEWAY_TARGET_FAILED: u8 = 0x0B;
85}
86
87/// Field names exported for Python/generic access.
88pub static MODBUS_FIELD_NAMES: &[&str] = &[
89    "trans_id",
90    "proto_id",
91    "length",
92    "unit_id",
93    "func_code",
94    "except_code",
95    "start_addr",
96    "quantity",
97    "byte_count",
98    "output_value",
99    "register_val",
100    "coil_status",
101    "sub_func",
102    "ref_addr",
103    "and_mask",
104    "or_mask",
105    "data",
106];
107
108/// Modbus frame type.
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum ModbusFrameType {
111    /// Modbus/TCP (MBAP header)
112    Tcp,
113    /// Modbus RTU (serial, binary with CRC-16)
114    Rtu,
115    /// Modbus ASCII (serial, ':' + hex + LRC + CRLF)
116    Ascii,
117}
118
119/// Return a human-readable name for a Modbus function code.
120#[must_use]
121pub fn func_code_name(fc: u8) -> &'static str {
122    match fc & 0x7F {
123        func_code::READ_COILS => "Read Coils",
124        func_code::READ_DISCRETE_INPUTS => "Read Discrete Inputs",
125        func_code::READ_HOLDING_REGISTERS => "Read Holding Registers",
126        func_code::READ_INPUT_REGISTERS => "Read Input Registers",
127        func_code::WRITE_SINGLE_COIL => "Write Single Coil",
128        func_code::WRITE_SINGLE_REGISTER => "Write Single Register",
129        func_code::READ_EXCEPTION_STATUS => "Read Exception Status",
130        func_code::DIAGNOSTICS => "Diagnostics",
131        func_code::GET_COMM_EVENT_COUNTER => "Get Comm Event Counter",
132        func_code::GET_COMM_EVENT_LOG => "Get Comm Event Log",
133        func_code::WRITE_MULTIPLE_COILS => "Write Multiple Coils",
134        func_code::WRITE_MULTIPLE_REGISTERS => "Write Multiple Registers",
135        func_code::REPORT_SLAVE_ID => "Report Slave ID",
136        func_code::READ_FILE_RECORD => "Read File Record",
137        func_code::WRITE_FILE_RECORD => "Write File Record",
138        func_code::MASK_WRITE_REGISTER => "Mask Write Register",
139        func_code::READ_WRITE_MULTIPLE_REGISTERS => "Read/Write Multiple Registers",
140        func_code::READ_FIFO_QUEUE => "Read FIFO Queue",
141        func_code::ENCAP_INTERFACE_TRANSPORT => "Encapsulated Interface Transport",
142        _ => "Unknown",
143    }
144}
145
146/// Return a human-readable name for a Modbus exception code.
147#[must_use]
148pub fn except_code_name(ec: u8) -> &'static str {
149    match ec {
150        except_code::ILLEGAL_FUNCTION => "Illegal Function",
151        except_code::ILLEGAL_DATA_ADDRESS => "Illegal Data Address",
152        except_code::ILLEGAL_DATA_VALUE => "Illegal Data Value",
153        except_code::SERVER_DEVICE_FAILURE => "Server Device Failure",
154        except_code::ACKNOWLEDGE => "Acknowledge",
155        except_code::SERVER_DEVICE_BUSY => "Server Device Busy",
156        except_code::MEMORY_PARITY_ERROR => "Memory Parity Error",
157        except_code::GATEWAY_PATH_UNAVAILABLE => "Gateway Path Unavailable",
158        except_code::GATEWAY_TARGET_FAILED => "Gateway Target Device Failed",
159        _ => "Unknown",
160    }
161}
162
163/// Check if a TCP payload looks like a Modbus/TCP (MBAP) message.
164///
165/// Validates:
166/// 1. At least 8 bytes (MBAP header + function code)
167/// 2. Protocol ID at offset 2 is 0x0000
168/// 3. Length field at offset 4 is sensible (>= 2, <= remaining)
169#[must_use]
170pub fn is_modbus_tcp_payload(buf: &[u8]) -> bool {
171    if buf.len() < MODBUS_MIN_HEADER_LEN {
172        return false;
173    }
174    // Protocol ID must be 0x0000 for Modbus
175    let proto_id = u16::from_be_bytes([buf[2], buf[3]]);
176    if proto_id != 0x0000 {
177        return false;
178    }
179    // Length field: number of bytes after the first 6 bytes (unitId + funcCode + data)
180    let length = u16::from_be_bytes([buf[4], buf[5]]) as usize;
181    // Must be at least 2 (unitId + funcCode) and not exceed remaining data
182    if length < 2 {
183        return false;
184    }
185    // Sanity: length should not claim more data than we have
186    if 6 + length > buf.len() + 256 {
187        // Allow some slack for truncated captures
188        return false;
189    }
190    true
191}
192
193// ============================================================================
194// ModbusLayer -- zero-copy view
195// ============================================================================
196
197/// Modbus layer -- a zero-copy view into a packet buffer.
198///
199/// By default assumes Modbus/TCP (MBAP) framing since that is what appears
200/// on the wire over TCP port 502.
201#[derive(Debug, Clone)]
202pub struct ModbusLayer {
203    pub index: LayerIndex,
204    pub frame_type: ModbusFrameType,
205}
206
207impl ModbusLayer {
208    /// Create a new Modbus layer from a layer index (defaults to TCP framing).
209    #[must_use]
210    pub fn new(index: LayerIndex) -> Self {
211        Self {
212            index,
213            frame_type: ModbusFrameType::Tcp,
214        }
215    }
216
217    /// Create a Modbus layer with explicit frame type.
218    #[must_use]
219    pub fn with_frame_type(index: LayerIndex, frame_type: ModbusFrameType) -> Self {
220        Self { index, frame_type }
221    }
222
223    /// Return a slice of the buffer corresponding to this layer.
224    fn slice<'a>(&self, buf: &'a [u8]) -> &'a [u8] {
225        self.index.slice(buf)
226    }
227
228    // ========================================================================
229    // MBAP Header Accessors (Modbus/TCP)
230    // ========================================================================
231
232    /// Get the Transaction ID (MBAP bytes 0-1).
233    pub fn trans_id(&self, buf: &[u8]) -> Result<u16, FieldError> {
234        let s = self.slice(buf);
235        if s.len() < 2 {
236            return Err(FieldError::BufferTooShort {
237                offset: self.index.start,
238                need: 2,
239                have: s.len(),
240            });
241        }
242        Ok(u16::from_be_bytes([s[0], s[1]]))
243    }
244
245    /// Get the Protocol ID (MBAP bytes 2-3; should be 0x0000).
246    pub fn proto_id(&self, buf: &[u8]) -> Result<u16, FieldError> {
247        let s = self.slice(buf);
248        if s.len() < 4 {
249            return Err(FieldError::BufferTooShort {
250                offset: self.index.start + 2,
251                need: 2,
252                have: s.len().saturating_sub(2),
253            });
254        }
255        Ok(u16::from_be_bytes([s[2], s[3]]))
256    }
257
258    /// Get the Length field (MBAP bytes 4-5).
259    pub fn length(&self, buf: &[u8]) -> Result<u16, FieldError> {
260        let s = self.slice(buf);
261        if s.len() < 6 {
262            return Err(FieldError::BufferTooShort {
263                offset: self.index.start + 4,
264                need: 2,
265                have: s.len().saturating_sub(4),
266            });
267        }
268        Ok(u16::from_be_bytes([s[4], s[5]]))
269    }
270
271    /// Get the Unit ID (MBAP byte 6).
272    pub fn unit_id(&self, buf: &[u8]) -> Result<u8, FieldError> {
273        let s = self.slice(buf);
274        if s.len() < 7 {
275            return Err(FieldError::BufferTooShort {
276                offset: self.index.start + 6,
277                need: 1,
278                have: s.len().saturating_sub(6),
279            });
280        }
281        Ok(s[6])
282    }
283
284    /// Get the Function Code (MBAP byte 7).
285    pub fn func_code(&self, buf: &[u8]) -> Result<u8, FieldError> {
286        let s = self.slice(buf);
287        if s.len() < 8 {
288            return Err(FieldError::BufferTooShort {
289                offset: self.index.start + 7,
290                need: 1,
291                have: s.len().saturating_sub(7),
292            });
293        }
294        Ok(s[7])
295    }
296
297    /// Check if this is an exception response (function code has bit 7 set).
298    #[must_use]
299    pub fn is_error(&self, buf: &[u8]) -> bool {
300        self.func_code(buf)
301            .map(|fc| fc & 0x80 != 0)
302            .unwrap_or(false)
303    }
304
305    /// Get the exception code (byte 8, only valid for error responses).
306    pub fn except_code(&self, buf: &[u8]) -> Result<u8, FieldError> {
307        let s = self.slice(buf);
308        if s.len() < 9 {
309            return Err(FieldError::BufferTooShort {
310                offset: self.index.start + 8,
311                need: 1,
312                have: s.len().saturating_sub(8),
313            });
314        }
315        Ok(s[8])
316    }
317
318    // ========================================================================
319    // PDU Data Accessors
320    // ========================================================================
321
322    /// Get the Start Address (bytes 8-9 for request PDUs like 0x01-0x06).
323    pub fn start_addr(&self, buf: &[u8]) -> Result<u16, FieldError> {
324        let s = self.slice(buf);
325        if s.len() < 10 {
326            return Err(FieldError::BufferTooShort {
327                offset: self.index.start + 8,
328                need: 2,
329                have: s.len().saturating_sub(8),
330            });
331        }
332        Ok(u16::from_be_bytes([s[8], s[9]]))
333    }
334
335    /// Get the Quantity field (bytes 10-11 for request PDUs like 0x01-0x04).
336    pub fn quantity(&self, buf: &[u8]) -> Result<u16, FieldError> {
337        let s = self.slice(buf);
338        if s.len() < 12 {
339            return Err(FieldError::BufferTooShort {
340                offset: self.index.start + 10,
341                need: 2,
342                have: s.len().saturating_sub(10),
343            });
344        }
345        Ok(u16::from_be_bytes([s[10], s[11]]))
346    }
347
348    /// Get the Byte Count field (byte 8 for response PDUs like 0x01-0x04).
349    pub fn byte_count(&self, buf: &[u8]) -> Result<u8, FieldError> {
350        let s = self.slice(buf);
351        if s.len() < 9 {
352            return Err(FieldError::BufferTooShort {
353                offset: self.index.start + 8,
354                need: 1,
355                have: s.len().saturating_sub(8),
356            });
357        }
358        Ok(s[8])
359    }
360
361    /// Get the Output Value for Write Single Coil/Register (bytes 10-11).
362    pub fn output_value(&self, buf: &[u8]) -> Result<u16, FieldError> {
363        self.quantity(buf) // same offset (bytes 10-11)
364    }
365
366    /// Get the Register Value for Write Single Register (bytes 10-11).
367    pub fn register_val(&self, buf: &[u8]) -> Result<u16, FieldError> {
368        self.quantity(buf) // same offset (bytes 10-11)
369    }
370
371    /// Get the Sub-function code for Diagnostics (0x08) (bytes 8-9).
372    pub fn sub_func(&self, buf: &[u8]) -> Result<u16, FieldError> {
373        self.start_addr(buf) // same offset (bytes 8-9)
374    }
375
376    /// Get the Reference Address for Mask Write Register (0x16) (bytes 8-9).
377    pub fn ref_addr(&self, buf: &[u8]) -> Result<u16, FieldError> {
378        self.start_addr(buf) // same offset
379    }
380
381    /// Get the AND mask for Mask Write Register (0x16) (bytes 10-11).
382    pub fn and_mask(&self, buf: &[u8]) -> Result<u16, FieldError> {
383        self.quantity(buf) // same offset
384    }
385
386    /// Get the OR mask for Mask Write Register (0x16) (bytes 12-13).
387    pub fn or_mask(&self, buf: &[u8]) -> Result<u16, FieldError> {
388        let s = self.slice(buf);
389        if s.len() < 14 {
390            return Err(FieldError::BufferTooShort {
391                offset: self.index.start + 12,
392                need: 2,
393                have: s.len().saturating_sub(12),
394            });
395        }
396        Ok(u16::from_be_bytes([s[12], s[13]]))
397    }
398
399    /// Get the raw data bytes after the function code (bytes 8..end).
400    pub fn data(&self, buf: &[u8]) -> Result<Vec<u8>, FieldError> {
401        let s = self.slice(buf);
402        if s.len() < 8 {
403            return Err(FieldError::BufferTooShort {
404                offset: self.index.start + 8,
405                need: 1,
406                have: 0,
407            });
408        }
409        Ok(s[8..].to_vec())
410    }
411
412    // ========================================================================
413    // Setters
414    // ========================================================================
415
416    /// Set the Transaction ID.
417    pub fn set_trans_id(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
418        let off = self.index.start;
419        if buf.len() < off + 2 {
420            return Err(FieldError::BufferTooShort {
421                offset: off,
422                need: 2,
423                have: buf.len().saturating_sub(off),
424            });
425        }
426        buf[off..off + 2].copy_from_slice(&value.to_be_bytes());
427        Ok(())
428    }
429
430    /// Set the Unit ID.
431    pub fn set_unit_id(&self, buf: &mut [u8], value: u8) -> Result<(), FieldError> {
432        let off = self.index.start + 6;
433        if buf.len() < off + 1 {
434            return Err(FieldError::BufferTooShort {
435                offset: off,
436                need: 1,
437                have: buf.len().saturating_sub(off),
438            });
439        }
440        buf[off] = value;
441        Ok(())
442    }
443
444    /// Set the Function Code.
445    pub fn set_func_code(&self, buf: &mut [u8], value: u8) -> Result<(), FieldError> {
446        let off = self.index.start + 7;
447        if buf.len() < off + 1 {
448            return Err(FieldError::BufferTooShort {
449                offset: off,
450                need: 1,
451                have: buf.len().saturating_sub(off),
452            });
453        }
454        buf[off] = value;
455        Ok(())
456    }
457
458    /// Set the Start Address (bytes 8-9).
459    pub fn set_start_addr(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
460        let off = self.index.start + 8;
461        if buf.len() < off + 2 {
462            return Err(FieldError::BufferTooShort {
463                offset: off,
464                need: 2,
465                have: buf.len().saturating_sub(off),
466            });
467        }
468        buf[off..off + 2].copy_from_slice(&value.to_be_bytes());
469        Ok(())
470    }
471
472    /// Set the Quantity (bytes 10-11).
473    pub fn set_quantity(&self, buf: &mut [u8], value: u16) -> Result<(), FieldError> {
474        let off = self.index.start + 10;
475        if buf.len() < off + 2 {
476            return Err(FieldError::BufferTooShort {
477                offset: off,
478                need: 2,
479                have: buf.len().saturating_sub(off),
480            });
481        }
482        buf[off..off + 2].copy_from_slice(&value.to_be_bytes());
483        Ok(())
484    }
485
486    // ========================================================================
487    // Compute header length
488    // ========================================================================
489
490    fn compute_header_len(&self, buf: &[u8]) -> usize {
491        let s = self.slice(buf);
492        if s.len() < MODBUS_MIN_HEADER_LEN {
493            return MODBUS_MIN_HEADER_LEN;
494        }
495        // For Modbus/TCP, the total message size = 6 (MBAP header) + length field
496        let length = u16::from_be_bytes([s[4], s[5]]) as usize;
497        let total = 6 + length;
498        let max = self.index.end - self.index.start;
499        total.min(max)
500    }
501
502    // ========================================================================
503    // Summary
504    // ========================================================================
505
506    /// Generate a one-line summary of this Modbus layer.
507    #[must_use]
508    pub fn summary(&self, buf: &[u8]) -> String {
509        let fc = match self.func_code(buf) {
510            Ok(v) => v,
511            Err(_) => return "Modbus".to_string(),
512        };
513
514        let fc_name = func_code_name(fc);
515
516        if self.is_error(buf) {
517            let ec = self.except_code(buf).map_or_else(
518                |_| "?".to_string(),
519                |v| format!("{} ({})", v, except_code_name(v)),
520            );
521            return format!("Modbus Error fc={fc:#04x} except={ec}");
522        }
523
524        let tid = self
525            .trans_id(buf)
526            .map_or_else(|_| "?".to_string(), |v| v.to_string());
527        let uid = self
528            .unit_id(buf)
529            .map_or_else(|_| "?".to_string(), |v| format!("{v:#04x}"));
530
531        format!("Modbus {fc_name} trans_id={tid} unit_id={uid}")
532    }
533
534    // ========================================================================
535    // Field Access API
536    // ========================================================================
537
538    /// Get the field names for this layer.
539    #[must_use]
540    pub fn field_names() -> &'static [&'static str] {
541        MODBUS_FIELD_NAMES
542    }
543
544    /// Get a field value by name.
545    pub fn get_field(&self, buf: &[u8], name: &str) -> Option<Result<FieldValue, FieldError>> {
546        match name {
547            "trans_id" => Some(self.trans_id(buf).map(FieldValue::U16)),
548            "proto_id" => Some(self.proto_id(buf).map(FieldValue::U16)),
549            "length" => Some(self.length(buf).map(FieldValue::U16)),
550            "unit_id" => Some(self.unit_id(buf).map(FieldValue::U8)),
551            "func_code" => Some(self.func_code(buf).map(FieldValue::U8)),
552            "except_code" => {
553                if self.is_error(buf) {
554                    Some(self.except_code(buf).map(FieldValue::U8))
555                } else {
556                    Some(Ok(FieldValue::U8(0)))
557                }
558            },
559            "start_addr" => Some(self.start_addr(buf).map(FieldValue::U16)),
560            "quantity" => Some(self.quantity(buf).map(FieldValue::U16)),
561            "byte_count" => Some(self.byte_count(buf).map(FieldValue::U8)),
562            "output_value" => Some(self.output_value(buf).map(FieldValue::U16)),
563            "register_val" => Some(self.register_val(buf).map(FieldValue::U16)),
564            "coil_status" => Some(self.byte_count(buf).map(FieldValue::U8)),
565            "sub_func" => Some(self.sub_func(buf).map(FieldValue::U16)),
566            "ref_addr" => Some(self.ref_addr(buf).map(FieldValue::U16)),
567            "and_mask" => Some(self.and_mask(buf).map(FieldValue::U16)),
568            "or_mask" => Some(self.or_mask(buf).map(FieldValue::U16)),
569            "data" => Some(self.data(buf).map(FieldValue::Bytes)),
570            _ => None,
571        }
572    }
573
574    /// Set a field value by name.
575    pub fn set_field(
576        &self,
577        buf: &mut [u8],
578        name: &str,
579        value: FieldValue,
580    ) -> Option<Result<(), FieldError>> {
581        match name {
582            "trans_id" => {
583                if let FieldValue::U16(v) = value {
584                    Some(self.set_trans_id(buf, v))
585                } else {
586                    Some(Err(FieldError::InvalidValue(format!(
587                        "trans_id: expected U16, got {value:?}"
588                    ))))
589                }
590            },
591            "unit_id" => {
592                if let FieldValue::U8(v) = value {
593                    Some(self.set_unit_id(buf, v))
594                } else {
595                    Some(Err(FieldError::InvalidValue(format!(
596                        "unit_id: expected U8, got {value:?}"
597                    ))))
598                }
599            },
600            "func_code" => {
601                if let FieldValue::U8(v) = value {
602                    Some(self.set_func_code(buf, v))
603                } else {
604                    Some(Err(FieldError::InvalidValue(format!(
605                        "func_code: expected U8, got {value:?}"
606                    ))))
607                }
608            },
609            "start_addr" => {
610                if let FieldValue::U16(v) = value {
611                    Some(self.set_start_addr(buf, v))
612                } else {
613                    Some(Err(FieldError::InvalidValue(format!(
614                        "start_addr: expected U16, got {value:?}"
615                    ))))
616                }
617            },
618            "quantity" => {
619                if let FieldValue::U16(v) = value {
620                    Some(self.set_quantity(buf, v))
621                } else {
622                    Some(Err(FieldError::InvalidValue(format!(
623                        "quantity: expected U16, got {value:?}"
624                    ))))
625                }
626            },
627            _ => None,
628        }
629    }
630}
631
632impl Layer for ModbusLayer {
633    fn kind(&self) -> LayerKind {
634        LayerKind::Modbus
635    }
636
637    fn summary(&self, data: &[u8]) -> String {
638        self.summary(data)
639    }
640
641    fn header_len(&self, data: &[u8]) -> usize {
642        self.compute_header_len(data)
643    }
644
645    fn field_names(&self) -> &'static [&'static str] {
646        MODBUS_FIELD_NAMES
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653
654    fn make_layer(buf: &[u8]) -> ModbusLayer {
655        let idx = LayerIndex::new(LayerKind::Modbus, 0, buf.len());
656        ModbusLayer::new(idx)
657    }
658
659    // Read Coils Request: trans_id=0x0001, proto_id=0x0000, length=6,
660    // unit_id=0x01, fc=0x01, start=0x0000, qty=0x000A
661    fn read_coils_request() -> Vec<u8> {
662        vec![
663            0x00, 0x01, // trans_id
664            0x00, 0x00, // proto_id
665            0x00, 0x06, // length (unit + fc + data = 1 + 1 + 4 = 6)
666            0x01, // unit_id
667            0x01, // func_code: Read Coils
668            0x00, 0x00, // start_addr
669            0x00, 0x0A, // quantity
670        ]
671    }
672
673    // Error response: trans_id=0x0001, proto_id=0x0000, length=3,
674    // unit_id=0x01, fc=0x81 (error), except=0x02
675    fn error_response() -> Vec<u8> {
676        vec![
677            0x00, 0x01, // trans_id
678            0x00, 0x00, // proto_id
679            0x00, 0x03, // length
680            0x01, // unit_id
681            0x81, // func_code (error: 0x80 | 0x01)
682            0x02, // except_code: Illegal Data Address
683        ]
684    }
685
686    #[test]
687    fn test_read_coils_request_fields() {
688        let data = read_coils_request();
689        let layer = make_layer(&data);
690
691        assert_eq!(layer.trans_id(&data).unwrap(), 1);
692        assert_eq!(layer.proto_id(&data).unwrap(), 0);
693        assert_eq!(layer.length(&data).unwrap(), 6);
694        assert_eq!(layer.unit_id(&data).unwrap(), 1);
695        assert_eq!(layer.func_code(&data).unwrap(), 0x01);
696        assert!(!layer.is_error(&data));
697        assert_eq!(layer.start_addr(&data).unwrap(), 0);
698        assert_eq!(layer.quantity(&data).unwrap(), 10);
699    }
700
701    #[test]
702    fn test_error_response_fields() {
703        let data = error_response();
704        let layer = make_layer(&data);
705
706        assert_eq!(layer.trans_id(&data).unwrap(), 1);
707        assert_eq!(layer.func_code(&data).unwrap(), 0x81);
708        assert!(layer.is_error(&data));
709        assert_eq!(layer.except_code(&data).unwrap(), 0x02);
710    }
711
712    #[test]
713    fn test_summary_request() {
714        let data = read_coils_request();
715        let layer = make_layer(&data);
716        let s = layer.summary(&data);
717        assert!(s.contains("Read Coils"));
718        assert!(s.contains("trans_id=1"));
719    }
720
721    #[test]
722    fn test_summary_error() {
723        let data = error_response();
724        let layer = make_layer(&data);
725        let s = layer.summary(&data);
726        assert!(s.contains("Error"));
727        assert!(s.contains("Illegal Data Address"));
728    }
729
730    #[test]
731    fn test_get_field_api() {
732        let data = read_coils_request();
733        let layer = make_layer(&data);
734
735        assert_eq!(
736            layer.get_field(&data, "trans_id").unwrap().unwrap(),
737            FieldValue::U16(1)
738        );
739        assert_eq!(
740            layer.get_field(&data, "func_code").unwrap().unwrap(),
741            FieldValue::U8(0x01)
742        );
743        assert_eq!(
744            layer.get_field(&data, "start_addr").unwrap().unwrap(),
745            FieldValue::U16(0)
746        );
747        assert!(layer.get_field(&data, "nonexistent").is_none());
748    }
749
750    #[test]
751    fn test_set_field_api() {
752        let mut data = read_coils_request();
753        let layer = make_layer(&data);
754
755        layer
756            .set_field(&mut data, "trans_id", FieldValue::U16(42))
757            .unwrap()
758            .unwrap();
759        assert_eq!(layer.trans_id(&data).unwrap(), 42);
760
761        layer
762            .set_field(&mut data, "unit_id", FieldValue::U8(0xFF))
763            .unwrap()
764            .unwrap();
765        assert_eq!(layer.unit_id(&data).unwrap(), 0xFF);
766    }
767
768    #[test]
769    fn test_detection_valid() {
770        let data = read_coils_request();
771        assert!(is_modbus_tcp_payload(&data));
772    }
773
774    #[test]
775    fn test_detection_bad_proto_id() {
776        let mut data = read_coils_request();
777        data[2] = 0x01; // corrupt protocol ID
778        assert!(!is_modbus_tcp_payload(&data));
779    }
780
781    #[test]
782    fn test_detection_too_short() {
783        assert!(!is_modbus_tcp_payload(&[0x00; 7]));
784        assert!(!is_modbus_tcp_payload(&[]));
785    }
786
787    #[test]
788    fn test_detection_bad_length() {
789        let mut data = read_coils_request();
790        data[4] = 0x00;
791        data[5] = 0x01; // length = 1, too small (must be >= 2)
792        assert!(!is_modbus_tcp_payload(&data));
793    }
794
795    #[test]
796    fn test_header_len() {
797        let data = read_coils_request();
798        let layer = make_layer(&data);
799        // 6 (MBAP) + 6 (length field value) = 12
800        assert_eq!(layer.compute_header_len(&data), 12);
801    }
802
803    #[test]
804    fn test_func_code_name() {
805        assert_eq!(func_code_name(0x01), "Read Coils");
806        assert_eq!(func_code_name(0x03), "Read Holding Registers");
807        assert_eq!(func_code_name(0x10), "Write Multiple Registers");
808        assert_eq!(func_code_name(0x81), "Read Coils"); // error bit stripped
809        assert_eq!(func_code_name(0xFF), "Unknown");
810    }
811
812    #[test]
813    fn test_except_code_name() {
814        assert_eq!(except_code_name(0x01), "Illegal Function");
815        assert_eq!(except_code_name(0x02), "Illegal Data Address");
816        assert_eq!(except_code_name(0xFF), "Unknown");
817    }
818
819    #[test]
820    fn test_write_single_coil() {
821        // Write Single Coil: trans_id=2, unit=1, fc=0x05, addr=0x0013, value=0xFF00
822        let data: Vec<u8> = vec![
823            0x00, 0x02, // trans_id
824            0x00, 0x00, // proto_id
825            0x00, 0x06, // length
826            0x01, // unit_id
827            0x05, // func_code: Write Single Coil
828            0x00, 0x13, // start_addr (coil address)
829            0xFF, 0x00, // output_value (0xFF00 = ON)
830        ];
831        let layer = make_layer(&data);
832        assert_eq!(layer.func_code(&data).unwrap(), 0x05);
833        assert_eq!(layer.start_addr(&data).unwrap(), 0x0013);
834        assert_eq!(layer.output_value(&data).unwrap(), 0xFF00);
835    }
836
837    #[test]
838    fn test_mask_write_register() {
839        // Mask Write Register: fc=0x16, ref_addr=0x0004, and=0x00F2, or=0x0025
840        let data: Vec<u8> = vec![
841            0x00, 0x01, // trans_id
842            0x00, 0x00, // proto_id
843            0x00, 0x08, // length (unit + fc + ref_addr + and_mask + or_mask = 1+1+2+2+2=8)
844            0x01, // unit_id
845            0x16, // func_code: Mask Write Register
846            0x00, 0x04, // ref_addr
847            0x00, 0xF2, // and_mask
848            0x00, 0x25, // or_mask
849        ];
850        let layer = make_layer(&data);
851        assert_eq!(layer.func_code(&data).unwrap(), 0x16);
852        assert_eq!(layer.ref_addr(&data).unwrap(), 0x0004);
853        assert_eq!(layer.and_mask(&data).unwrap(), 0x00F2);
854        assert_eq!(layer.or_mask(&data).unwrap(), 0x0025);
855    }
856
857    #[test]
858    fn test_layer_trait() {
859        let data = read_coils_request();
860        let layer = make_layer(&data);
861
862        assert_eq!(layer.kind(), LayerKind::Modbus);
863        assert!(!layer.field_names().is_empty());
864        assert!(layer.field_names().contains(&"func_code"));
865    }
866}