1use std::collections::VecDeque;
2
3use bytes::BytesMut;
4use rusty_tpkt::{ProtocolInformation, TpktConnection, TpktReader, TpktWriter};
5
6use crate::{
7 CotpConnectionParameters,
8 api::{CotpConnection, CotpError, CotpProtocolInformation, CotpReader, CotpResponder, CotpWriter},
9 packet::{
10 connection_confirm::ConnectionConfirm,
11 connection_request::ConnectionRequest,
12 data_transfer::DataTransfer,
13 parameters::{ConnectionClass, CotpParameter, TpduSize},
14 payload::TransportProtocolDataUnit,
15 },
16 parser::packet::TransportProtocolDataUnitParser,
17 serialiser::packet::serialise,
18};
19
20pub struct RustyCotpConnection<R: TpktReader, W: TpktWriter> {
24 reader: R,
25 writer: W,
26
27 max_payload_size: usize,
28 parser: TransportProtocolDataUnitParser,
29 connection_options: CotpConnectionParameters,
30 protocol_infomation_list: Vec<Box<dyn ProtocolInformation>>,
31}
32
33impl<R: TpktReader, W: TpktWriter> RustyCotpConnection<R, W> {
34 pub async fn initiate(connection: impl TpktConnection, options: CotpProtocolInformation, connection_options: CotpConnectionParameters) -> Result<RustyCotpConnection<impl TpktReader, impl TpktWriter>, CotpError> {
36 let mut protocol_infomation_list = connection.get_protocol_infomation_list().clone();
38 let local_calling_tsap = options.calling_tsap_id().cloned();
39
40 let source_reference: u16 = options.initiator_reference();
41 let parser = TransportProtocolDataUnitParser::new();
42 let (mut reader, mut writer) = connection.split().await?;
43
44 send_connection_request(&mut writer, source_reference, options).await?;
45 let connection_confirm = receive_connection_confirm(&mut reader, &parser).await?;
46 let (_, max_payload_size) = calculate_remote_size_payload(connection_confirm.parameters()).await?;
47
48 let remote_called_tsap = connection_confirm.parameters().iter().filter_map(|x| if let CotpParameter::CalledTsap(tsap) = x { Some(tsap.clone()) } else { None }).last();
49 protocol_infomation_list.push(Box::new(CotpProtocolInformation::new(source_reference, connection_confirm.destination_reference(), local_calling_tsap, remote_called_tsap)));
50
51 Ok(RustyCotpConnection::new(reader, writer, max_payload_size, protocol_infomation_list, connection_options).await)
52 }
53
54 async fn new(reader: R, writer: W, max_payload_size: usize, protocol_infomation_list: Vec<Box<dyn ProtocolInformation>>, connection_options: CotpConnectionParameters) -> RustyCotpConnection<R, W> {
55 RustyCotpConnection { reader, writer, max_payload_size, parser: TransportProtocolDataUnitParser::new(), protocol_infomation_list, connection_options }
56 }
57}
58
59impl<R: TpktReader, W: TpktWriter> CotpConnection for RustyCotpConnection<R, W> {
60 fn get_protocol_infomation_list(&self) -> &Vec<Box<dyn rusty_tpkt::ProtocolInformation>> {
61 &self.protocol_infomation_list
62 }
63
64 async fn split(self) -> Result<(impl CotpReader, impl CotpWriter), CotpError> {
65 let reader = self.reader;
66 let writer = self.writer;
67 Ok((RustyCotpReader::new(reader, self.parser, self.connection_options), RustyCotpWriter::new(writer, self.max_payload_size)))
68 }
69}
70
71pub struct RustyCotpResponder<R: TpktReader, W: TpktWriter> {
73 reader: R,
74 writer: W,
75 initiator_reference: u16,
76 max_payload_size: usize,
77 max_payload_indicator: TpduSize,
78 called_tsap_id: Option<Vec<u8>>,
79 calling_tsap_id: Option<Vec<u8>>,
80 connection_options: CotpConnectionParameters,
81 protocol_information_list: Vec<Box<dyn ProtocolInformation>>,
82}
83
84impl<R: TpktReader, W: TpktWriter> RustyCotpResponder<R, W> {
85 pub async fn new(tpkt_connection: impl TpktConnection, connection_options: CotpConnectionParameters) -> Result<(RustyCotpResponder<impl TpktReader, impl TpktWriter>, CotpProtocolInformation), CotpError> {
90 let parser = TransportProtocolDataUnitParser::new();
91 let mut protocol_information_list = tpkt_connection.get_protocol_infomation_list().clone();
92 let (mut reader, writer) = tpkt_connection.split().await?;
93
94 let connection_request = receive_connection_request(&mut reader, &parser).await?;
95 let (max_payload_indicator, max_payload_size) = calculate_remote_size_payload(connection_request.parameters()).await?;
96 verify_class_compatibility(&connection_request).await?;
97
98 let mut calling_tsap_id = None;
99 let mut called_tsap_id = None;
100 for parameter in connection_request.parameters() {
101 match parameter {
102 CotpParameter::CallingTsap(tsap_id) => calling_tsap_id = Some(tsap_id.clone()),
103 CotpParameter::CalledTsap(tsap_id) => called_tsap_id = Some(tsap_id.clone()),
104 _ => (),
105 }
106 }
107
108 let protocol_information = CotpProtocolInformation::new(connection_request.source_reference(), 0, calling_tsap_id.clone(), called_tsap_id.clone());
109 protocol_information_list.push(Box::new(protocol_information.clone()));
110
111 Ok((
112 RustyCotpResponder {
113 reader,
114 writer,
115 max_payload_size,
116 connection_options,
117 max_payload_indicator,
118 called_tsap_id: called_tsap_id,
119 calling_tsap_id: calling_tsap_id,
120 initiator_reference: connection_request.source_reference(),
121 protocol_information_list,
122 },
123 protocol_information,
124 ))
125 }
126}
127
128impl<R: TpktReader, W: TpktWriter> CotpResponder for RustyCotpResponder<R, W> {
129 async fn accept(mut self, options: CotpProtocolInformation) -> Result<impl CotpConnection, CotpError> {
130 send_connection_confirm(&mut self.writer, options.responder_reference(), self.initiator_reference, self.max_payload_indicator, self.calling_tsap_id, self.called_tsap_id).await?;
131 Ok(RustyCotpConnection::new(self.reader, self.writer, self.max_payload_size, self.protocol_information_list, self.connection_options).await)
132 }
133}
134
135pub struct RustyCotpReader<R: TpktReader> {
137 reader: R,
138 parser: TransportProtocolDataUnitParser,
139 connection_options: CotpConnectionParameters,
140
141 data_buffer: BytesMut,
142}
143
144impl<R: TpktReader> RustyCotpReader<R> {
145 fn new(reader: R, parser: TransportProtocolDataUnitParser, connection_options: CotpConnectionParameters) -> Self {
146 Self { reader, parser, data_buffer: BytesMut::new(), connection_options }
147 }
148}
149
150impl<R: TpktReader> CotpReader for RustyCotpReader<R> {
151 async fn recv(&mut self) -> Result<Option<Vec<u8>>, CotpError> {
152 loop {
153 let raw_data = match self.reader.recv().await? {
154 None => return Ok(None),
155 Some(raw_data) => raw_data,
156 };
157 let data_transfer = match self.parser.parse(raw_data.as_slice())? {
158 TransportProtocolDataUnit::ER(tpdu_error) => return Err(CotpError::ProtocolError(format!("Received an error from the remote host: {:?}", tpdu_error.reason()).into())),
160 TransportProtocolDataUnit::CR(_) => return Err(CotpError::ProtocolError("Received a Connection Request when expecting data.".into())),
161 TransportProtocolDataUnit::CC(_) => return Err(CotpError::ProtocolError("Received a Connection Config when expecting data.".into())),
162 TransportProtocolDataUnit::DR(_) => return Ok(None),
163 TransportProtocolDataUnit::DT(data_transfer) => data_transfer,
164 };
165
166 self.data_buffer.extend_from_slice(data_transfer.user_data());
171 if self.data_buffer.len() > self.connection_options.max_reassembled_payload_size {
172 let reassembled_size = self.data_buffer.len();
173 let max_reassembled_size = self.connection_options.max_reassembled_payload_size;
174 self.data_buffer.clear();
175 return Err(CotpError::ProtocolError(format!("Reassembled payload size {reassembled_size} exceeds maximum payload size {max_reassembled_size}")));
176 }
177 if data_transfer.end_of_transmission() {
178 let data = self.data_buffer.to_vec();
179 self.data_buffer.clear();
180 return Ok(Some(data));
181 }
182 }
183 }
184}
185
186pub struct RustyCotpWriter<W: TpktWriter> {
188 writer: W,
189 max_payload_size: usize,
190 chunks: VecDeque<Vec<u8>>,
191}
192
193impl<W: TpktWriter> RustyCotpWriter<W> {
194 fn new(writer: W, max_payload_size: usize) -> Self {
195 Self { writer, max_payload_size, chunks: VecDeque::new() }
196 }
197}
198
199impl<W: TpktWriter> CotpWriter for RustyCotpWriter<W> {
200 async fn send(&mut self, input: &mut VecDeque<Vec<u8>>) -> Result<(), CotpError> {
201 const HEADER_LENGTH: usize = 3;
202
203 while let Some(data_item) = input.pop_front() {
204 let chunks = data_item.as_slice().chunks(self.max_payload_size - HEADER_LENGTH);
205 let chunk_count = chunks.len();
206 for (chunk_index, chunk_data) in chunks.enumerate() {
207 let end_of_transmission = chunk_index + 1 >= chunk_count;
208 let tpdu = DataTransfer::new(end_of_transmission, chunk_data);
209 let tpdu_data = serialise(&TransportProtocolDataUnit::DT(tpdu))?;
210 self.chunks.push_back(tpdu_data);
211 }
212 }
213
214 while !self.chunks.is_empty() {
215 self.writer.send(&mut self.chunks).await?;
216 }
217
218 self.writer.send(&mut self.chunks).await?;
220 Ok(())
221 }
222}
223
224async fn verify_class_compatibility(connection_request: &ConnectionRequest) -> Result<(), CotpError> {
225 let empty_set = Vec::new();
226 let class_parameters = connection_request
227 .parameters()
228 .iter()
229 .filter_map(|p| match p {
230 CotpParameter::AlternativeClassParameter(x) => Some(x),
231 _ => None,
232 })
233 .last()
234 .unwrap_or(&empty_set);
235
236 match connection_request.preferred_class() {
238 ConnectionClass::Class0 => (),
239 ConnectionClass::Class1 => (),
240 ConnectionClass::Class2 if class_parameters.contains(&&ConnectionClass::Class0) => (),
241 ConnectionClass::Class3 if class_parameters.contains(&&ConnectionClass::Class0) => (),
242 ConnectionClass::Class3 if class_parameters.contains(&&ConnectionClass::Class1) => (),
243 ConnectionClass::Class4 if class_parameters.contains(&&ConnectionClass::Class0) => (),
244 ConnectionClass::Class4 if class_parameters.contains(&&ConnectionClass::Class1) => (),
245 _ => {
246 return Err(CotpError::ProtocolError(format!("Cannot downgrade connection request to Class 0 {:?} - {:?}", connection_request.preferred_class(), class_parameters)));
247 }
248 };
249 Ok(())
250}
251
252async fn receive_connection_request(reader: &mut impl TpktReader, parser: &TransportProtocolDataUnitParser) -> Result<ConnectionRequest, CotpError> {
253 let data = match reader.recv().await {
254 Ok(Some(x)) => x,
255 Ok(None) => return Err(CotpError::ProtocolError("The connection was closed before the COTP handshake was complete.".into())),
256 Err(e) => return Err(e.into()),
257 };
258 return Ok(match parser.parse(data.as_slice())? {
259 TransportProtocolDataUnit::CR(x) => x,
260 TransportProtocolDataUnit::CC(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a connextion confirm".into())),
261 TransportProtocolDataUnit::DR(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a disconnect reqeust".into())),
262 TransportProtocolDataUnit::DT(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a data transfer".into())),
263 TransportProtocolDataUnit::ER(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a error response".into())),
264 });
265}
266
267async fn calculate_remote_size_payload(parameters: &[CotpParameter]) -> Result<(TpduSize, usize), CotpError> {
268 let parameter: &TpduSize = parameters
269 .iter()
270 .filter_map(|p| match p {
271 CotpParameter::TpduLengthParameter(x) => Some(x),
272 _ => None,
273 })
274 .last()
275 .unwrap_or(&TpduSize::Size128);
276
277 Ok(match parameter {
278 TpduSize::Size8192 => return Err(CotpError::ProtocolError("The remote side selected an 8192 bytes COTP payload but Class 0 support a maximum for 2048 bytes.".into())),
279 TpduSize::Size4096 => return Err(CotpError::ProtocolError("The remote side selected an 4096 bytes COTP payload but Class 0 support a maximum for 2048 bytes.".into())),
280 TpduSize::Unknown(x) => return Err(CotpError::ProtocolError(format!("The requested TPDU size is unknown {:?}.", x).into())),
281 TpduSize::Size128 => (TpduSize::Size128, 128),
282 TpduSize::Size256 => (TpduSize::Size256, 256),
283 TpduSize::Size512 => (TpduSize::Size512, 512),
284 TpduSize::Size1024 => (TpduSize::Size1024, 1024),
285 TpduSize::Size2048 => (TpduSize::Size2048, 2048),
286 })
287}
288
289async fn send_connection_confirm<W: TpktWriter>(writer: &mut W, source_reference: u16, destination_reference: u16, size: TpduSize, calling_tsap_id: Option<Vec<u8>>, called_tsap_id: Option<Vec<u8>>) -> Result<(), CotpError> {
290 let mut parameters = vec![CotpParameter::TpduLengthParameter(size)];
291 if let Some(tsap_id) = calling_tsap_id {
292 parameters.push(CotpParameter::CallingTsap(tsap_id));
293 }
294 if let Some(tsap_id) = called_tsap_id {
295 parameters.push(CotpParameter::CalledTsap(tsap_id));
296 }
297
298 let payload = serialise(&TransportProtocolDataUnit::CC(ConnectionConfirm::new(0, source_reference, destination_reference, ConnectionClass::Class0, vec![], parameters, &[])))?;
299 Ok(writer.send(&mut VecDeque::from_iter(vec![payload].into_iter())).await?)
300}
301
302async fn send_connection_request(writer: &mut impl TpktWriter, source_reference: u16, options: CotpProtocolInformation) -> Result<(), CotpError> {
303 let mut parameters = vec![CotpParameter::TpduLengthParameter(TpduSize::Size2048)];
304 if let Some(calling_tsap) = options.calling_tsap_id() {
305 parameters.push(CotpParameter::CallingTsap(calling_tsap.clone()));
306 }
307 if let Some(called_tsap) = options.called_tsap_id() {
308 parameters.push(CotpParameter::CalledTsap(called_tsap.clone()));
309 }
310
311 let payload = serialise(&TransportProtocolDataUnit::CR(ConnectionRequest::new(source_reference, 0, ConnectionClass::Class0, vec![], parameters, &[])))?;
312 Ok(writer.send(&mut VecDeque::from_iter(vec![payload].into_iter())).await?)
313}
314
315async fn receive_connection_confirm(reader: &mut impl TpktReader, parser: &TransportProtocolDataUnitParser) -> Result<ConnectionConfirm, CotpError> {
316 let data = match reader.recv().await {
317 Ok(Some(x)) => x,
318 Ok(None) => return Err(CotpError::ProtocolError("The connection was closed before the COTP handshake was complete.".into())),
319 Err(e) => return Err(e.into()),
320 };
321 return Ok(match parser.parse(data.as_slice())? {
322 TransportProtocolDataUnit::CC(x) if x.preferred_class() != &ConnectionClass::Class0 => return Err(CotpError::ProtocolError("Remote failed to select COTP Class 0.".into())),
323 TransportProtocolDataUnit::CC(x) => x,
324 TransportProtocolDataUnit::CR(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a connection request".into())),
325 TransportProtocolDataUnit::DR(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a disconnect reqeust".into())),
326 TransportProtocolDataUnit::DT(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a data transfer".into())),
327 TransportProtocolDataUnit::ER(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a error response".into())),
328 });
329}