1use std::collections::VecDeque;
2
3use bytes::BytesMut;
4use rusty_tpkt::{TpktConnection, TpktReader, TpktRecvResult, TpktWriter};
5
6use crate::{
7 CotpAcceptInformation,
8 api::{CotpConnectInformation, CotpConnection, CotpError, CotpReader, CotpRecvResult, CotpResponder, CotpWriter},
9 packet::{
10 connection_confirm::ConnectionConfirm,
11 connection_request::ConnectionRequest,
12 data_transfer::DataTransfer,
13 parameters::{ConnectionClass, CotpParameter, TpduSize},
14 payload::TransportProtocolDataUnit,
15 },
16 parser::packet::TransportProtocolDataUnitParser,
17 serialiser::packet::serialise,
18};
19
20pub struct TcpCotpConnection<R: TpktReader, W: TpktWriter> {
21 reader: R,
22 writer: W,
23
24 max_payload_size: usize,
25 parser: TransportProtocolDataUnitParser,
26}
27
28impl<R: TpktReader, W: TpktWriter> TcpCotpConnection<R, W> {
29 pub async fn initiate(connection: impl TpktConnection, options: CotpConnectInformation) -> Result<TcpCotpConnection<impl TpktReader, impl TpktWriter>, CotpError> {
30 let source_reference: u16 = options.initiator_reference;
31 let parser = TransportProtocolDataUnitParser::new();
32 let (mut reader, mut writer) = connection.split().await?;
33
34 send_connection_request(&mut writer, source_reference, options).await?;
35 let connection_confirm = receive_connection_confirm(&mut reader, &parser).await?;
36 let (_, max_payload_size) = calculate_remote_size_payload(connection_confirm.parameters()).await?;
37
38 Ok(TcpCotpConnection::new(reader, writer, max_payload_size).await)
39 }
40
41 async fn new(reader: R, writer: W, max_payload_size: usize) -> TcpCotpConnection<R, W> {
42 TcpCotpConnection { reader, writer, max_payload_size, parser: TransportProtocolDataUnitParser::new() }
43 }
44}
45
46impl<R: TpktReader, W: TpktWriter> CotpConnection for TcpCotpConnection<R, W> {
47 async fn split(self) -> Result<(impl CotpReader, impl CotpWriter), CotpError> {
48 let reader = self.reader;
49 let writer = self.writer;
50 Ok((TcpCotpReader::new(reader, self.parser), TcpCotpWriter::new(writer, self.max_payload_size)))
51 }
52}
53
54pub struct TcpCotpAcceptor<R: TpktReader, W: TpktWriter> {
55 reader: R,
56 writer: W,
57 initiator_reference: u16,
58 max_payload_size: usize,
59 max_payload_indicator: TpduSize,
60 called_tsap_id: Option<Vec<u8>>,
61 calling_tsap_id: Option<Vec<u8>>,
62}
63
64impl<R: TpktReader, W: TpktWriter> TcpCotpAcceptor<R, W> {
65 pub async fn new(tpkt_connection: impl TpktConnection) -> Result<(TcpCotpAcceptor<impl TpktReader, impl TpktWriter>, CotpConnectInformation), CotpError> {
66 let parser = TransportProtocolDataUnitParser::new();
67 let (mut reader, writer) = tpkt_connection.split().await?;
68
69 let connection_request = receive_connection_request(&mut reader, &parser).await?;
70 let (max_payload_indicator, max_payload_size) = calculate_remote_size_payload(connection_request.parameters()).await?;
71 verify_class_compatibility(&connection_request).await?;
72
73 let mut calling_tsap_id = None;
74 let mut called_tsap_id = None;
75 for parameter in connection_request.parameters() {
76 match parameter {
77 CotpParameter::CallingTsap(tsap_id) => calling_tsap_id = Some(tsap_id.clone()),
78 CotpParameter::CalledTsap(tsap_id) => called_tsap_id = Some(tsap_id.clone()),
79 _ => (),
80 }
81 }
82
83 Ok((
84 TcpCotpAcceptor { reader, writer, max_payload_size, max_payload_indicator, called_tsap_id: called_tsap_id.clone(), calling_tsap_id: calling_tsap_id.clone(), initiator_reference: connection_request.source_reference() },
85 CotpConnectInformation { calling_tsap_id, called_tsap_id, initiator_reference: connection_request.source_reference() },
86 ))
87 }
88}
89
90impl<R: TpktReader, W: TpktWriter> CotpResponder for TcpCotpAcceptor<R, W> {
91 async fn accept(mut self, options: CotpAcceptInformation) -> Result<impl CotpConnection, CotpError> {
92 send_connection_confirm(&mut self.writer, options.responder_reference, self.initiator_reference, self.max_payload_indicator, self.calling_tsap_id, self.called_tsap_id).await?;
93 Ok(TcpCotpConnection::new(self.reader, self.writer, self.max_payload_size).await)
94 }
95}
96
97pub struct TcpCotpReader<R: TpktReader> {
98 reader: R,
100 parser: TransportProtocolDataUnitParser,
101
102 data_buffer: BytesMut,
103}
104
105impl<R: TpktReader> TcpCotpReader<R> {
106 pub fn new(reader: R, parser: TransportProtocolDataUnitParser) -> Self {
107 Self { reader, parser, data_buffer: BytesMut::new() }
108 }
109}
110
111impl<R: TpktReader> CotpReader for TcpCotpReader<R> {
112 async fn recv(&mut self) -> Result<CotpRecvResult, CotpError> {
113 loop {
114 let raw_data = match self.reader.recv().await? {
116 TpktRecvResult::Closed => return Ok(CotpRecvResult::Closed),
117 TpktRecvResult::Data(raw_data) => raw_data,
118 };
119 let data_transfer = match self.parser.parse(raw_data.as_slice())? {
120 TransportProtocolDataUnit::ER(tpdu_error) => return Err(CotpError::ProtocolError(format!("Received an error from the remote host: {:?}", tpdu_error.reason()).into())),
122 TransportProtocolDataUnit::CR(_) => return Err(CotpError::ProtocolError("Received a Connection Request when expecting data.".into())),
123 TransportProtocolDataUnit::CC(_) => return Err(CotpError::ProtocolError("Received a Connection Config when expecting data.".into())),
124 TransportProtocolDataUnit::DR(_) => return Ok(CotpRecvResult::Closed),
125 TransportProtocolDataUnit::DT(data_transfer) => data_transfer,
126 };
127 self.data_buffer.extend_from_slice(data_transfer.user_data());
130 if data_transfer.end_of_transmission() {
131 let data = self.data_buffer.to_vec();
132 self.data_buffer.clear();
133 return Ok(CotpRecvResult::Data(data));
134 }
135 }
136 }
137}
138
139pub struct TcpCotpWriter<W: TpktWriter> {
140 writer: W,
141 max_payload_size: usize,
142 chunks: VecDeque<Vec<u8>>,
143}
144
145impl<W: TpktWriter> TcpCotpWriter<W> {
146 pub fn new(writer: W, max_payload_size: usize) -> Self {
147 Self { writer, max_payload_size, chunks: VecDeque::new() }
148 }
149}
150
151impl<W: TpktWriter> CotpWriter for TcpCotpWriter<W> {
152 async fn send(&mut self, input: &mut VecDeque<Vec<u8>>) -> Result<(), CotpError> {
153 const HEADER_LENGTH: usize = 3;
154
155 while let Some(data_item) = input.pop_front() {
156 let chunks = data_item.as_slice().chunks(self.max_payload_size - HEADER_LENGTH);
157 let chunk_count = chunks.len();
158 for (chunk_index, chunk_data) in chunks.enumerate() {
159 let end_of_transmission = chunk_index + 1 >= chunk_count;
160 let tpdu = DataTransfer::new(end_of_transmission, chunk_data);
161 let tpdu_data = serialise(&TransportProtocolDataUnit::DT(tpdu))?;
162 self.chunks.push_back(tpdu_data);
163 }
164 }
165
166 while !self.chunks.is_empty() {
167 self.writer.send(&mut self.chunks).await?;
168 }
169
170 self.writer.send(&mut self.chunks).await?;
172 Ok(())
173 }
174}
175
176async fn verify_class_compatibility(connection_request: &ConnectionRequest) -> Result<(), CotpError> {
177 let empty_set = Vec::new();
178 let class_parameters = connection_request
179 .parameters()
180 .iter()
181 .filter_map(|p| match p {
182 CotpParameter::AlternativeClassParameter(x) => Some(x),
183 _ => None,
184 })
185 .last()
186 .unwrap_or(&empty_set);
187
188 match connection_request.preferred_class() {
190 ConnectionClass::Class0 => (),
191 ConnectionClass::Class1 => (),
192 ConnectionClass::Class2 if class_parameters.contains(&&ConnectionClass::Class0) => (),
193 ConnectionClass::Class3 if class_parameters.contains(&&ConnectionClass::Class0) => (),
194 ConnectionClass::Class3 if class_parameters.contains(&&ConnectionClass::Class1) => (),
195 ConnectionClass::Class4 if class_parameters.contains(&&ConnectionClass::Class0) => (),
196 ConnectionClass::Class4 if class_parameters.contains(&&ConnectionClass::Class1) => (),
197 _ => {
198 return Err(CotpError::ProtocolError(format!("Cannot downgrade connection request to Class 0 {:?} - {:?}", connection_request.preferred_class(), class_parameters)));
199 }
200 };
201 Ok(())
202}
203
204async fn receive_connection_request(reader: &mut impl TpktReader, parser: &TransportProtocolDataUnitParser) -> Result<ConnectionRequest, CotpError> {
205 let data = match reader.recv().await {
206 Ok(TpktRecvResult::Data(x)) => x,
207 Ok(TpktRecvResult::Closed) => return Err(CotpError::ProtocolError("The connection was closed before the COTP handshake was complete.".into())),
208 Err(e) => return Err(e.into()),
209 };
210 return Ok(match parser.parse(data.as_slice())? {
211 TransportProtocolDataUnit::CR(x) => x,
212 TransportProtocolDataUnit::CC(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a connextion confirm".into())),
213 TransportProtocolDataUnit::DR(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a disconnect reqeust".into())),
214 TransportProtocolDataUnit::DT(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a data transfer".into())),
215 TransportProtocolDataUnit::ER(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a error response".into())),
216 });
217}
218
219async fn calculate_remote_size_payload(parameters: &[CotpParameter]) -> Result<(TpduSize, usize), CotpError> {
220 let parameter: &TpduSize = parameters
221 .iter()
222 .filter_map(|p| match p {
223 CotpParameter::TpduLengthParameter(x) => Some(x),
224 _ => None,
225 })
226 .last()
227 .unwrap_or(&TpduSize::Size128);
228
229 Ok(match parameter {
230 TpduSize::Size8192 => return Err(CotpError::ProtocolError("The remote side selected an 8192 bytes COTP payload but Class 0 support a maximum for 2048 bytes.".into())),
231 TpduSize::Size4096 => return Err(CotpError::ProtocolError("The remote side selected an 4096 bytes COTP payload but Class 0 support a maximum for 2048 bytes.".into())),
232 TpduSize::Unknown(x) => return Err(CotpError::ProtocolError(format!("The requested TPDU size is unknown {:?}.", x).into())),
233 TpduSize::Size128 => (TpduSize::Size128, 128),
234 TpduSize::Size256 => (TpduSize::Size256, 256),
235 TpduSize::Size512 => (TpduSize::Size512, 512),
236 TpduSize::Size1024 => (TpduSize::Size1024, 1024),
237 TpduSize::Size2048 => (TpduSize::Size2048, 2048),
238 })
239}
240
241async fn send_connection_confirm<W: TpktWriter>(writer: &mut W, source_reference: u16, destination_reference: u16, size: TpduSize, calling_tsap_id: Option<Vec<u8>>, called_tsap_id: Option<Vec<u8>>) -> Result<(), CotpError> {
242 let mut parameters = vec![CotpParameter::TpduLengthParameter(size)];
243 if let Some(tsap_id) = calling_tsap_id {
244 parameters.push(CotpParameter::CallingTsap(tsap_id));
245 }
246 if let Some(tsap_id) = called_tsap_id {
247 parameters.push(CotpParameter::CalledTsap(tsap_id));
248 }
249
250 let payload = serialise(&TransportProtocolDataUnit::CC(ConnectionConfirm::new(0, source_reference, destination_reference, ConnectionClass::Class0, vec![], parameters, &[])))?;
251 Ok(writer.send(&mut VecDeque::from_iter(vec![payload].into_iter())).await?)
252}
253
254async fn send_connection_request(writer: &mut impl TpktWriter, source_reference: u16, options: CotpConnectInformation) -> Result<(), CotpError> {
255 let mut parameters = vec![CotpParameter::TpduLengthParameter(TpduSize::Size2048)];
256 if let Some(calling_tsap) = options.calling_tsap_id {
257 parameters.push(CotpParameter::CallingTsap(calling_tsap));
258 }
259 if let Some(called_tsap) = options.called_tsap_id {
260 parameters.push(CotpParameter::CalledTsap(called_tsap));
261 }
262
263 let payload = serialise(&TransportProtocolDataUnit::CR(ConnectionRequest::new(source_reference, 0, ConnectionClass::Class0, vec![], parameters, &[])))?;
264 Ok(writer.send(&mut VecDeque::from_iter(vec![payload].into_iter())).await?)
265}
266
267async fn receive_connection_confirm(reader: &mut impl TpktReader, parser: &TransportProtocolDataUnitParser) -> Result<ConnectionConfirm, CotpError> {
268 let data = match reader.recv().await {
269 Ok(TpktRecvResult::Data(x)) => x,
270 Ok(TpktRecvResult::Closed) => return Err(CotpError::ProtocolError("The connection was closed before the COTP handshake was complete.".into())),
271 Err(e) => return Err(e.into()),
272 };
273 return Ok(match parser.parse(data.as_slice())? {
274 TransportProtocolDataUnit::CC(x) if x.preferred_class() != &ConnectionClass::Class0 => return Err(CotpError::ProtocolError("Remote failed to select COTP Class 0.".into())),
275 TransportProtocolDataUnit::CC(x) => x,
276 TransportProtocolDataUnit::CR(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a connection request".into())),
277 TransportProtocolDataUnit::DR(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a disconnect reqeust".into())),
278 TransportProtocolDataUnit::DT(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a data transfer".into())),
279 TransportProtocolDataUnit::ER(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a error response".into())),
280 });
281}