1use bytes::{Bytes, BytesMut};
2use snap7_client::proto::{
3 cotp::CotpPdu,
4 s7::{
5 header::{PduType, S7Header},
6 negotiate::{NegotiateRequest, NegotiateResponse},
7 },
8 tpkt::TpktFrame,
9};
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11
12use crate::error::{Error, Result};
13
14const MAX_PDU_SIZE: u16 = 480;
15
16const NEGOTIATE_PARAM_LEN: u16 = 8;
18
19pub async fn server_handshake<T>(mut transport: T) -> Result<u16>
29where
30 T: AsyncRead + AsyncWrite + Unpin,
31{
32 let cr = recv_tpkt_cotp(&mut transport).await?;
34 let src_ref = match cr {
35 CotpPdu::ConnectRequest { src_ref, .. } => src_ref,
36 _ => return Err(Error::NegotiationFailed),
37 };
38
39 let cc = CotpPdu::ConnectConfirm {
41 dst_ref: src_ref,
42 src_ref: 0x0001,
43 };
44 send_tpkt_cotp(&mut transport, &cc).await?;
45
46 let mut payload = recv_cotp_data(&mut transport).await?;
48 let req_header = S7Header::decode(&mut payload)?;
49 if req_header.pdu_type != PduType::Job {
50 return Err(Error::NegotiationFailed);
51 }
52 let neg_req = NegotiateRequest::decode(&mut payload)?;
53
54 let negotiated = neg_req.pdu_length.min(MAX_PDU_SIZE);
56 let resp_header = S7Header {
57 pdu_type: PduType::AckData,
58 reserved: 0,
59 pdu_ref: req_header.pdu_ref,
60 param_len: NEGOTIATE_PARAM_LEN,
61 data_len: 0,
62 error_class: Some(0),
63 error_code: Some(0),
64 };
65 let neg_resp = NegotiateResponse {
66 max_amq_calling: neg_req.max_amq_calling,
67 max_amq_called: neg_req.max_amq_called,
68 pdu_length: negotiated,
69 };
70 let mut s7_buf = BytesMut::new();
71 resp_header.encode(&mut s7_buf);
72 neg_resp.encode(&mut s7_buf);
73 send_cotp_data(&mut transport, s7_buf.freeze()).await?;
74
75 Ok(negotiated)
76}
77
78pub(crate) async fn recv_tpkt_cotp<T: AsyncRead + Unpin>(transport: &mut T) -> Result<CotpPdu> {
80 let mut header = [0u8; 4];
81 transport.read_exact(&mut header).await?;
82 if header[0] != 0x03 {
83 return Err(Error::NegotiationFailed);
84 }
85 let total = u16::from_be_bytes([header[2], header[3]]) as usize;
86 if total < 4 {
87 return Err(Error::NegotiationFailed);
88 }
89 let payload_len = total - 4;
90 let mut payload = vec![0u8; payload_len];
91 transport.read_exact(&mut payload).await?;
92 let mut b = Bytes::from(payload);
93 CotpPdu::decode(&mut b).map_err(Error::Proto)
94}
95
96pub(crate) async fn recv_cotp_data<T: AsyncRead + Unpin>(transport: &mut T) -> Result<Bytes> {
100 let pdu = recv_tpkt_cotp(transport).await?;
101 match pdu {
102 CotpPdu::Data { payload, .. } => Ok(payload),
103 _ => Err(Error::NegotiationFailed),
104 }
105}
106
107pub(crate) async fn send_tpkt_cotp<T: AsyncWrite + Unpin>(
109 transport: &mut T,
110 pdu: &CotpPdu,
111) -> Result<()> {
112 let mut cotp_buf = BytesMut::new();
113 pdu.encode(&mut cotp_buf);
114 let tpkt = TpktFrame {
115 payload: cotp_buf.freeze(),
116 };
117 let mut buf = BytesMut::new();
118 tpkt.encode(&mut buf)?;
119 transport.write_all(&buf).await?;
120 Ok(())
121}
122
123pub(crate) async fn send_cotp_data<T: AsyncWrite + Unpin>(
125 transport: &mut T,
126 payload: Bytes,
127) -> Result<()> {
128 let dt = CotpPdu::Data {
129 tpdu_nr: 0,
130 last: true,
131 payload,
132 };
133 send_tpkt_cotp(transport, &dt).await
134}
135
136#[cfg(test)]
141mod tests {
142 use super::*;
143 use bytes::BytesMut;
144 use snap7_client::proto::{
145 cotp::CotpPdu,
146 s7::{
147 header::{PduType, S7Header},
148 negotiate::NegotiateRequest,
149 },
150 tpkt::TpktFrame,
151 };
152 use tokio::io::AsyncWriteExt;
153
154 async fn write_tpkt_cotp(writer: &mut (impl tokio::io::AsyncWrite + Unpin), cotp: &CotpPdu) {
156 let mut cotp_buf = BytesMut::new();
157 cotp.encode(&mut cotp_buf);
158 let tpkt = TpktFrame {
159 payload: cotp_buf.freeze(),
160 };
161 let mut buf = BytesMut::new();
162 tpkt.encode(&mut buf).unwrap();
163 writer.write_all(&buf).await.unwrap();
164 }
165
166 async fn write_negotiate_request(
168 writer: &mut (impl tokio::io::AsyncWrite + Unpin),
169 pdu_length: u16,
170 ) {
171 let header = S7Header {
172 pdu_type: PduType::Job,
173 reserved: 0,
174 pdu_ref: 1,
175 param_len: 8,
176 data_len: 0,
177 error_class: None,
178 error_code: None,
179 };
180 let req = NegotiateRequest {
181 max_amq_calling: 1,
182 max_amq_called: 1,
183 pdu_length,
184 };
185 let mut s7_buf = BytesMut::new();
186 header.encode(&mut s7_buf);
187 req.encode(&mut s7_buf);
188 let dt = CotpPdu::Data {
189 tpdu_nr: 0,
190 last: true,
191 payload: s7_buf.freeze(),
192 };
193 write_tpkt_cotp(writer, &dt).await;
194 }
195
196 #[tokio::test]
197 async fn handshake_completes_with_valid_client() {
198 let (server_io, mut client_io) = tokio::io::duplex(4096);
199
200 let client_task = tokio::spawn(async move {
202 use tokio::io::AsyncReadExt;
203
204 let cr = CotpPdu::ConnectRequest {
206 dst_ref: 0x0000,
207 src_ref: 0x0001,
208 rack: 0,
209 slot: 2,
210 };
211 write_tpkt_cotp(&mut client_io, &cr).await;
212
213 let mut hdr = [0u8; 4];
215 client_io.read_exact(&mut hdr).await.unwrap();
216 let total = u16::from_be_bytes([hdr[2], hdr[3]]) as usize;
217 let mut body = vec![0u8; total - 4];
218 client_io.read_exact(&mut body).await.unwrap();
219 let mut b = Bytes::from(body);
220 let cc = CotpPdu::decode(&mut b).unwrap();
221 assert!(
222 matches!(cc, CotpPdu::ConnectConfirm { .. }),
223 "expected ConnectConfirm, got {cc:?}"
224 );
225
226 write_negotiate_request(&mut client_io, 480).await;
228
229 let mut drain = vec![0u8; 512];
231 let _ = client_io.read(&mut drain).await;
232 });
233
234 let result = server_handshake(server_io).await;
235 client_task.await.unwrap();
236 assert!(
237 result.is_ok(),
238 "server_handshake returned error: {result:?}"
239 );
240 assert_eq!(result.unwrap(), 480);
241 }
242
243 #[tokio::test]
244 async fn handshake_fails_on_non_cr() {
245 let (server_io, mut client_io) = tokio::io::duplex(4096);
246
247 tokio::spawn(async move {
248 let dt = CotpPdu::Data {
250 tpdu_nr: 0,
251 last: true,
252 payload: Bytes::from_static(b"oops"),
253 };
254 write_tpkt_cotp(&mut client_io, &dt).await;
255 });
256
257 let result = server_handshake(server_io).await;
258 assert!(result.is_err(), "expected error, got: {result:?}");
259 }
260}