Skip to main content

snap7_client/
plus_connection.rs

1use bytes::{Bytes, BytesMut};
2use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3
4use crate::proto::s7commplus::frame::{S7PlusFrame, Version};
5use crate::proto::s7commplus::session::{CreateObjectRequest, CreateObjectResponse};
6use crate::proto::tpkt::TpktFrame;
7
8use crate::error::Error;
9
10/// Result of a successful S7CommPlus CreateObject handshake.
11#[derive(Debug)]
12pub struct PlusConnection {
13    pub session_id: u32,
14    pub seqnum: u16,
15    pub version: Version,
16}
17
18/// Perform the S7CommPlus CreateObject handshake over `transport`.
19///
20/// Sends a `CreateObjectRequest` wrapped in an S7CommPlus V1 frame inside a
21/// TPKT envelope, then reads the `CreateObjectResponse` and returns a
22/// [`PlusConnection`] containing the negotiated `session_id`.
23/// Perform the S7CommPlus CreateObject handshake and return both the negotiated
24/// connection state and the transport, so the caller can store both.
25pub async fn plus_connect<T>(mut transport: T) -> Result<(PlusConnection, T), Error>
26where
27    T: AsyncRead + AsyncWrite + Unpin,
28{
29    // --- Build and send CreateObject request ---
30    let req = CreateObjectRequest::new(1);
31    let mut da_buf = BytesMut::new();
32    req.encode(&mut da_buf);
33
34    let plus_frame = S7PlusFrame {
35        version: Version::V1,
36        data: da_buf.freeze(),
37    };
38    let mut frame_buf = BytesMut::new();
39    plus_frame.encode(&mut frame_buf).map_err(Error::Proto)?;
40
41    let tpkt = TpktFrame {
42        payload: frame_buf.freeze(),
43    };
44    let mut out = BytesMut::new();
45    tpkt.encode(&mut out).map_err(Error::Proto)?;
46    transport.write_all(&out).await?;
47
48    // --- Read TPKT response: 4-byte header then payload ---
49    let mut hdr = [0u8; 4];
50    transport.read_exact(&mut hdr).await?;
51    let total = u16::from_be_bytes([hdr[2], hdr[3]]) as usize;
52    let payload_len = total.saturating_sub(4);
53    let mut payload = vec![0u8; payload_len];
54    transport.read_exact(&mut payload).await?;
55
56    // --- Decode S7CommPlus frame from TPKT payload ---
57    let mut b = Bytes::from(payload);
58    let s7plus_frame = S7PlusFrame::decode(&mut b).map_err(Error::Proto)?;
59
60    // --- Decode CreateObject response ---
61    let mut data = s7plus_frame.data.clone();
62    let resp = CreateObjectResponse::decode(&mut data).map_err(Error::Proto)?;
63
64    let conn = PlusConnection {
65        session_id: resp.session_id,
66        seqnum: 2, // seqnum 1 was consumed by the CreateObject request
67        version: s7plus_frame.version,
68    };
69    Ok((conn, transport))
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use bytes::BufMut;
76    use tokio::io::AsyncWriteExt;
77
78    fn build_create_object_response(session_id: u32) -> Vec<u8> {
79        use bytes::BytesMut;
80        use crate::proto::s7commplus::frame::{S7PlusFrame, Version};
81        use crate::proto::s7commplus::session::OPCODE_RESPONSE;
82        use crate::proto::tpkt::TpktFrame;
83
84        let mut da = BytesMut::new();
85        da.put_u8(OPCODE_RESPONSE); // opcode
86        da.put_u16(0x0000); // reserved
87        da.put_u16(0x04CA); // FC
88        da.put_u16(0x0000); // reserved
89        da.put_u16(0x0001); // seqnum
90        da.put_u32(session_id); // session_id
91        da.put_u8(0x00); // transport_flags
92
93        let plus_frame = S7PlusFrame {
94            version: Version::V1,
95            data: da.freeze(),
96        };
97        let mut frame_buf = BytesMut::new();
98        plus_frame.encode(&mut frame_buf).unwrap();
99
100        let tpkt = TpktFrame {
101            payload: frame_buf.freeze(),
102        };
103        let mut tpkt_buf = BytesMut::new();
104        tpkt.encode(&mut tpkt_buf).unwrap();
105        tpkt_buf.to_vec()
106    }
107
108    #[tokio::test]
109    async fn plus_connect_extracts_session_id() {
110        let expected_sid = 0xCAFEBABE_u32;
111        let response = build_create_object_response(expected_sid);
112
113        let (mut server, client_io) = tokio::io::duplex(4096);
114        tokio::spawn(async move {
115            let mut buf = vec![0u8; 4096];
116            let _ = tokio::io::AsyncReadExt::read(&mut server, &mut buf).await;
117            server.write_all(&response).await.unwrap();
118        });
119
120        let (conn, _transport) = plus_connect(client_io).await.unwrap();
121        assert_eq!(conn.session_id, expected_sid);
122    }
123}