Skip to main content

snap7_server/
handshake.rs

1use bytes::{Bytes, BytesMut};
2use snap7_client::proto::{
3    cotp::CotpPdu,
4    s7::{
5        header::{PduType, S7Header},
6        negotiate::{NegotiateRequest, NegotiateResponse},
7    },
8    tpkt::TpktFrame,
9};
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11
12use crate::error::{Error, Result};
13
14const MAX_PDU_SIZE: u16 = 480;
15
16// NegotiateResponse encodes as: func(1) + reserved(1) + max_amq_calling(2) + max_amq_called(2) + pdu_length(2) = 8 bytes
17const NEGOTIATE_PARAM_LEN: u16 = 8;
18
19/// Perform the server-side COTP/S7 handshake over an already-accepted transport.
20///
21/// Protocol steps:
22///   1. Receive COTP ConnectRequest (CR)
23///   2. Send COTP ConnectConfirm (CC)
24///   3. Receive S7 NegotiateRequest inside a COTP Data PDU
25///   4. Send S7 NegotiateResponse (AckData) with negotiated PDU size
26///
27/// Returns the negotiated PDU size (capped at [`MAX_PDU_SIZE`]).
28pub async fn server_handshake<T>(mut transport: T) -> Result<u16>
29where
30    T: AsyncRead + AsyncWrite + Unpin,
31{
32    // Step 1: receive CR
33    let cr = recv_tpkt_cotp(&mut transport).await?;
34    let src_ref = match cr {
35        CotpPdu::ConnectRequest { src_ref, .. } => src_ref,
36        _ => return Err(Error::NegotiationFailed),
37    };
38
39    // Step 2: send CC — dst_ref = client's src_ref, our src_ref = 0x0001
40    let cc = CotpPdu::ConnectConfirm {
41        dst_ref: src_ref,
42        src_ref: 0x0001,
43    };
44    send_tpkt_cotp(&mut transport, &cc).await?;
45
46    // Step 3: receive S7 NegotiateRequest in a COTP Data PDU
47    let mut payload = recv_cotp_data(&mut transport).await?;
48    let req_header = S7Header::decode(&mut payload)?;
49    if req_header.pdu_type != PduType::Job {
50        return Err(Error::NegotiationFailed);
51    }
52    let neg_req = NegotiateRequest::decode(&mut payload)?;
53
54    // Step 4: send S7 NegotiateResponse with negotiated (capped) PDU size
55    let negotiated = neg_req.pdu_length.min(MAX_PDU_SIZE);
56    let resp_header = S7Header {
57        pdu_type: PduType::AckData,
58        reserved: 0,
59        pdu_ref: req_header.pdu_ref,
60        param_len: NEGOTIATE_PARAM_LEN,
61        data_len: 0,
62        error_class: Some(0),
63        error_code: Some(0),
64    };
65    let neg_resp = NegotiateResponse {
66        max_amq_calling: neg_req.max_amq_calling,
67        max_amq_called: neg_req.max_amq_called,
68        pdu_length: negotiated,
69    };
70    let mut s7_buf = BytesMut::new();
71    resp_header.encode(&mut s7_buf);
72    neg_resp.encode(&mut s7_buf);
73    send_cotp_data(&mut transport, s7_buf.freeze()).await?;
74
75    Ok(negotiated)
76}
77
78/// Read one TPKT frame from `transport` and decode the contained COTP PDU.
79pub(crate) async fn recv_tpkt_cotp<T: AsyncRead + Unpin>(transport: &mut T) -> Result<CotpPdu> {
80    let mut header = [0u8; 4];
81    transport.read_exact(&mut header).await?;
82    if header[0] != 0x03 {
83        return Err(Error::NegotiationFailed);
84    }
85    let total = u16::from_be_bytes([header[2], header[3]]) as usize;
86    if total < 4 {
87        return Err(Error::NegotiationFailed);
88    }
89    let payload_len = total - 4;
90    let mut payload = vec![0u8; payload_len];
91    transport.read_exact(&mut payload).await?;
92    let mut b = Bytes::from(payload);
93    CotpPdu::decode(&mut b).map_err(Error::Proto)
94}
95
96/// Read one TPKT+COTP frame and extract the Data PDU payload.
97///
98/// Returns an error if the COTP PDU is not a Data variant.
99pub(crate) async fn recv_cotp_data<T: AsyncRead + Unpin>(transport: &mut T) -> Result<Bytes> {
100    let pdu = recv_tpkt_cotp(transport).await?;
101    match pdu {
102        CotpPdu::Data { payload, .. } => Ok(payload),
103        _ => Err(Error::NegotiationFailed),
104    }
105}
106
107/// Encode `pdu` into a TPKT frame and write it to `transport`.
108pub(crate) async fn send_tpkt_cotp<T: AsyncWrite + Unpin>(
109    transport: &mut T,
110    pdu: &CotpPdu,
111) -> Result<()> {
112    let mut cotp_buf = BytesMut::new();
113    pdu.encode(&mut cotp_buf);
114    let tpkt = TpktFrame {
115        payload: cotp_buf.freeze(),
116    };
117    let mut buf = BytesMut::new();
118    tpkt.encode(&mut buf)?;
119    transport.write_all(&buf).await?;
120    Ok(())
121}
122
123/// Wrap `payload` in a COTP Data PDU and write it as a TPKT frame.
124pub(crate) async fn send_cotp_data<T: AsyncWrite + Unpin>(
125    transport: &mut T,
126    payload: Bytes,
127) -> Result<()> {
128    let dt = CotpPdu::Data {
129        tpdu_nr: 0,
130        last: true,
131        payload,
132    };
133    send_tpkt_cotp(transport, &dt).await
134}
135
136// ---------------------------------------------------------------------------
137// Tests
138// ---------------------------------------------------------------------------
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use bytes::BytesMut;
144    use snap7_client::proto::{
145        cotp::CotpPdu,
146        s7::{
147            header::{PduType, S7Header},
148            negotiate::NegotiateRequest,
149        },
150        tpkt::TpktFrame,
151    };
152    use tokio::io::AsyncWriteExt;
153
154    /// Write a COTP PDU wrapped in a TPKT frame to `writer`.
155    async fn write_tpkt_cotp(writer: &mut (impl tokio::io::AsyncWrite + Unpin), cotp: &CotpPdu) {
156        let mut cotp_buf = BytesMut::new();
157        cotp.encode(&mut cotp_buf);
158        let tpkt = TpktFrame {
159            payload: cotp_buf.freeze(),
160        };
161        let mut buf = BytesMut::new();
162        tpkt.encode(&mut buf).unwrap();
163        writer.write_all(&buf).await.unwrap();
164    }
165
166    /// Write an S7 NegotiateRequest wrapped in COTP Data + TPKT to `writer`.
167    async fn write_negotiate_request(
168        writer: &mut (impl tokio::io::AsyncWrite + Unpin),
169        pdu_length: u16,
170    ) {
171        let header = S7Header {
172            pdu_type: PduType::Job,
173            reserved: 0,
174            pdu_ref: 1,
175            param_len: 8,
176            data_len: 0,
177            error_class: None,
178            error_code: None,
179        };
180        let req = NegotiateRequest {
181            max_amq_calling: 1,
182            max_amq_called: 1,
183            pdu_length,
184        };
185        let mut s7_buf = BytesMut::new();
186        header.encode(&mut s7_buf);
187        req.encode(&mut s7_buf);
188        let dt = CotpPdu::Data {
189            tpdu_nr: 0,
190            last: true,
191            payload: s7_buf.freeze(),
192        };
193        write_tpkt_cotp(writer, &dt).await;
194    }
195
196    #[tokio::test]
197    async fn handshake_completes_with_valid_client() {
198        let (server_io, mut client_io) = tokio::io::duplex(4096);
199
200        // Spawn a task that plays the role of the client.
201        let client_task = tokio::spawn(async move {
202            use tokio::io::AsyncReadExt;
203
204            // Send CR
205            let cr = CotpPdu::ConnectRequest {
206                dst_ref: 0x0000,
207                src_ref: 0x0001,
208                rack: 0,
209                slot: 2,
210            };
211            write_tpkt_cotp(&mut client_io, &cr).await;
212
213            // Read CC
214            let mut hdr = [0u8; 4];
215            client_io.read_exact(&mut hdr).await.unwrap();
216            let total = u16::from_be_bytes([hdr[2], hdr[3]]) as usize;
217            let mut body = vec![0u8; total - 4];
218            client_io.read_exact(&mut body).await.unwrap();
219            let mut b = Bytes::from(body);
220            let cc = CotpPdu::decode(&mut b).unwrap();
221            assert!(
222                matches!(cc, CotpPdu::ConnectConfirm { .. }),
223                "expected ConnectConfirm, got {cc:?}"
224            );
225
226            // Send NegotiateRequest
227            write_negotiate_request(&mut client_io, 480).await;
228
229            // Drain the NegotiateResponse (just read all remaining bytes)
230            let mut drain = vec![0u8; 512];
231            let _ = client_io.read(&mut drain).await;
232        });
233
234        let result = server_handshake(server_io).await;
235        client_task.await.unwrap();
236        assert!(
237            result.is_ok(),
238            "server_handshake returned error: {result:?}"
239        );
240        assert_eq!(result.unwrap(), 480);
241    }
242
243    #[tokio::test]
244    async fn handshake_fails_on_non_cr() {
245        let (server_io, mut client_io) = tokio::io::duplex(4096);
246
247        tokio::spawn(async move {
248            // Send a Data PDU instead of a ConnectRequest
249            let dt = CotpPdu::Data {
250                tpdu_nr: 0,
251                last: true,
252                payload: Bytes::from_static(b"oops"),
253            };
254            write_tpkt_cotp(&mut client_io, &dt).await;
255        });
256
257        let result = server_handshake(server_io).await;
258        assert!(result.is_err(), "expected error, got: {result:?}");
259    }
260}