Skip to main content

rusty_cotp/
service.rs

1use std::collections::VecDeque;
2
3use bytes::BytesMut;
4use rusty_tpkt::{TpktConnection, TpktReader, TpktRecvResult, TpktWriter};
5
6use crate::{
7    CotpAcceptInformation,
8    api::{CotpConnectInformation, CotpConnection, CotpError, CotpReader, CotpRecvResult, CotpResponder, CotpWriter},
9    packet::{
10        connection_confirm::ConnectionConfirm,
11        connection_request::ConnectionRequest,
12        data_transfer::DataTransfer,
13        parameters::{ConnectionClass, CotpParameter, TpduSize},
14        payload::TransportProtocolDataUnit,
15    },
16    parser::packet::TransportProtocolDataUnitParser,
17    serialiser::packet::serialise,
18};
19
20pub struct TcpCotpConnection<R: TpktReader, W: TpktWriter> {
21    reader: R,
22    writer: W,
23
24    max_payload_size: usize,
25    parser: TransportProtocolDataUnitParser,
26}
27
28impl<R: TpktReader, W: TpktWriter> TcpCotpConnection<R, W> {
29    pub async fn initiate(connection: impl TpktConnection, options: CotpConnectInformation) -> Result<TcpCotpConnection<impl TpktReader, impl TpktWriter>, CotpError> {
30        let source_reference: u16 = options.initiator_reference;
31        let parser = TransportProtocolDataUnitParser::new();
32        let (mut reader, mut writer) = connection.split().await?;
33
34        send_connection_request(&mut writer, source_reference, options).await?;
35        let connection_confirm = receive_connection_confirm(&mut reader, &parser).await?;
36        let (_, max_payload_size) = calculate_remote_size_payload(connection_confirm.parameters()).await?;
37
38        Ok(TcpCotpConnection::new(reader, writer, max_payload_size).await)
39    }
40
41    async fn new(reader: R, writer: W, max_payload_size: usize) -> TcpCotpConnection<R, W> {
42        TcpCotpConnection {
43            reader,
44            writer,
45            max_payload_size,
46            parser: TransportProtocolDataUnitParser::new(),
47        }
48    }
49}
50
51impl<R: TpktReader, W: TpktWriter> CotpConnection for TcpCotpConnection<R, W> {
52    async fn split(self) -> Result<(impl CotpReader, impl CotpWriter), CotpError> {
53        let reader = self.reader;
54        let writer = self.writer;
55        Ok((TcpCotpReader::new(reader, self.parser), TcpCotpWriter::new(writer, self.max_payload_size)))
56    }
57}
58
59pub struct TcpCotpAcceptor<R: TpktReader, W: TpktWriter> {
60    reader: R,
61    writer: W,
62    initiator_reference: u16,
63    max_payload_size: usize,
64    max_payload_indicator: TpduSize,
65    called_tsap_id: Option<Vec<u8>>,
66    calling_tsap_id: Option<Vec<u8>>,
67}
68
69impl<R: TpktReader, W: TpktWriter> TcpCotpAcceptor<R, W> {
70    pub async fn new(tpkt_connection: impl TpktConnection) -> Result<(TcpCotpAcceptor<impl TpktReader, impl TpktWriter>, CotpConnectInformation), CotpError> {
71        let parser = TransportProtocolDataUnitParser::new();
72        let (mut reader, writer) = tpkt_connection.split().await?;
73
74        let connection_request = receive_connection_request(&mut reader, &parser).await?;
75        let (max_payload_indicator, max_payload_size) = calculate_remote_size_payload(connection_request.parameters()).await?;
76        verify_class_compatibility(&connection_request).await?;
77
78        let mut calling_tsap_id = None;
79        let mut called_tsap_id = None;
80        for parameter in connection_request.parameters() {
81            match parameter {
82                CotpParameter::CallingTsap(tsap_id) => calling_tsap_id = Some(tsap_id.clone()),
83                CotpParameter::CalledTsap(tsap_id) => called_tsap_id = Some(tsap_id.clone()),
84                _ => (),
85            }
86        }
87
88        Ok((
89            TcpCotpAcceptor {
90                reader,
91                writer,
92                max_payload_size,
93                max_payload_indicator,
94                called_tsap_id: called_tsap_id.clone(),
95                calling_tsap_id: calling_tsap_id.clone(),
96                initiator_reference: connection_request.source_reference(),
97            },
98            CotpConnectInformation {
99                calling_tsap_id,
100                called_tsap_id,
101                initiator_reference: connection_request.source_reference(),
102            },
103        ))
104    }
105}
106
107impl<R: TpktReader, W: TpktWriter> CotpResponder for TcpCotpAcceptor<R, W> {
108    async fn accept(mut self, options: CotpAcceptInformation) -> Result<impl CotpConnection, CotpError> {
109        send_connection_confirm(&mut self.writer, options.responder_reference, self.initiator_reference, self.max_payload_indicator, self.calling_tsap_id, self.called_tsap_id).await?;
110        Ok(TcpCotpConnection::new(self.reader, self.writer, self.max_payload_size).await)
111    }
112}
113
114pub struct TcpCotpReader<R: TpktReader> {
115    // Not caring about the size of the payload we receive.
116    reader: R,
117    parser: TransportProtocolDataUnitParser,
118
119    data_buffer: BytesMut,
120}
121
122impl<R: TpktReader> TcpCotpReader<R> {
123    pub fn new(reader: R, parser: TransportProtocolDataUnitParser) -> Self {
124        Self {
125            reader,
126            parser,
127            data_buffer: BytesMut::new(),
128        }
129    }
130}
131
132impl<R: TpktReader> CotpReader for TcpCotpReader<R> {
133    async fn recv(&mut self) -> Result<CotpRecvResult, CotpError> {
134        loop {
135            // I don't really care to check max size. It is 2025.
136            let raw_data = match self.reader.recv().await? {
137                TpktRecvResult::Closed => return Ok(CotpRecvResult::Closed),
138                TpktRecvResult::Data(raw_data) => raw_data,
139            };
140            let data_transfer = match self.parser.parse(raw_data.as_slice())? {
141                // Choosing the standards based option of reporting the TPDU error locally but not sending an error.
142                TransportProtocolDataUnit::ER(tpdu_error) => return Err(CotpError::ProtocolError(format!("Received an error from the remote host: {:?}", tpdu_error.reason()).into())),
143                TransportProtocolDataUnit::CR(_) => return Err(CotpError::ProtocolError("Received a Connection Request when expecting data.".into())),
144                TransportProtocolDataUnit::CC(_) => return Err(CotpError::ProtocolError("Received a Connection Config when expecting data.".into())),
145                TransportProtocolDataUnit::DR(_) => return Ok(CotpRecvResult::Closed),
146                TransportProtocolDataUnit::DT(data_transfer) => data_transfer,
147            };
148            // I do not really care about the source and destination reference here. It is over a TCP stream. I'd rather keep it relaxed and avoid interop issues.
149
150            self.data_buffer.extend_from_slice(data_transfer.user_data());
151            if data_transfer.end_of_transmission() {
152                let data = self.data_buffer.to_vec();
153                self.data_buffer.clear();
154                return Ok(CotpRecvResult::Data(data));
155            }
156        }
157    }
158}
159
160pub struct TcpCotpWriter<W: TpktWriter> {
161    writer: W,
162    max_payload_size: usize,
163    chunks: VecDeque<Vec<u8>>,
164}
165
166impl<W: TpktWriter> TcpCotpWriter<W> {
167    pub fn new(writer: W, max_payload_size: usize) -> Self {
168        Self {
169            writer,
170            max_payload_size,
171            chunks: VecDeque::new(),
172        }
173    }
174}
175
176impl<W: TpktWriter> CotpWriter for TcpCotpWriter<W> {
177    async fn send(&mut self, data: &[u8]) -> Result<(), CotpError> {
178        const HEADER_LENGTH: usize = 3;
179
180        let chunks = data.chunks(self.max_payload_size - HEADER_LENGTH);
181        let chunk_count = chunks.len();
182        for (chunk_index, chunk_data) in chunks.enumerate() {
183            let end_of_transmission = chunk_index + 1 >= chunk_count;
184            let tpdu = DataTransfer::new(end_of_transmission, chunk_data);
185            let tpdu_data = serialise(&TransportProtocolDataUnit::DT(tpdu))?;
186            self.chunks.push_back(tpdu_data);
187        }
188        self.continue_send().await
189    }
190
191    async fn continue_send(&mut self) -> Result<(), CotpError> {
192        while let Some(data) = self.chunks.pop_front() {
193            self.writer.send(data.as_slice()).await?;
194        }
195        Ok(())
196    }
197}
198
199async fn verify_class_compatibility(connection_request: &ConnectionRequest) -> Result<(), CotpError> {
200    let empty_set = Vec::new();
201    let class_parameters = connection_request
202        .parameters()
203        .iter()
204        .filter_map(|p| match p {
205            CotpParameter::AlternativeClassParameter(x) => Some(x),
206            _ => None,
207        })
208        .last()
209        .unwrap_or(&empty_set);
210
211    // Verify we can downgrade to Class 0
212    match connection_request.preferred_class() {
213        ConnectionClass::Class0 => (),
214        ConnectionClass::Class1 => (),
215        ConnectionClass::Class2 if class_parameters.contains(&&ConnectionClass::Class0) => (),
216        ConnectionClass::Class3 if class_parameters.contains(&&ConnectionClass::Class0) => (),
217        ConnectionClass::Class3 if class_parameters.contains(&&ConnectionClass::Class1) => (),
218        ConnectionClass::Class4 if class_parameters.contains(&&ConnectionClass::Class0) => (),
219        ConnectionClass::Class4 if class_parameters.contains(&&ConnectionClass::Class1) => (),
220        _ => {
221            return Err(CotpError::ProtocolError(format!(
222                "Cannot downgrade connection request to Class 0 {:?} - {:?}",
223                connection_request.preferred_class(),
224                class_parameters
225            )));
226        }
227    };
228    Ok(())
229}
230
231async fn receive_connection_request(reader: &mut impl TpktReader, parser: &TransportProtocolDataUnitParser) -> Result<ConnectionRequest, CotpError> {
232    let data = match reader.recv().await {
233        Ok(TpktRecvResult::Data(x)) => x,
234        Ok(TpktRecvResult::Closed) => return Err(CotpError::ProtocolError("The connection was closed before the COTP handshake was complete.".into())),
235        Err(e) => return Err(e.into()),
236    };
237    return Ok(match parser.parse(data.as_slice())? {
238        TransportProtocolDataUnit::CR(x) => x,
239        TransportProtocolDataUnit::CC(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a connextion confirm".into())),
240        TransportProtocolDataUnit::DR(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a disconnect reqeust".into())),
241        TransportProtocolDataUnit::DT(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a data transfer".into())),
242        TransportProtocolDataUnit::ER(_) => return Err(CotpError::ProtocolError("Expected connection request on handshake but got a error response".into())),
243    });
244}
245
246async fn calculate_remote_size_payload(parameters: &[CotpParameter]) -> Result<(TpduSize, usize), CotpError> {
247    let parameter: &TpduSize = parameters
248        .iter()
249        .filter_map(|p| match p {
250            CotpParameter::TpduLengthParameter(x) => Some(x),
251            _ => None,
252        })
253        .last()
254        .unwrap_or(&TpduSize::Size128);
255
256    Ok(match parameter {
257        TpduSize::Size8192 => return Err(CotpError::ProtocolError("The remote side selected an 8192 bytes COTP payload but Class 0 support a maximum for 2048 bytes.".into())),
258        TpduSize::Size4096 => return Err(CotpError::ProtocolError("The remote side selected an 4096 bytes COTP payload but Class 0 support a maximum for 2048 bytes.".into())),
259        TpduSize::Unknown(x) => return Err(CotpError::ProtocolError(format!("The requested TPDU size is unknown {:?}.", x).into())),
260        TpduSize::Size128 => (TpduSize::Size128, 128),
261        TpduSize::Size256 => (TpduSize::Size256, 256),
262        TpduSize::Size512 => (TpduSize::Size512, 512),
263        TpduSize::Size1024 => (TpduSize::Size1024, 1024),
264        TpduSize::Size2048 => (TpduSize::Size2048, 2048),
265    })
266}
267
268async fn send_connection_confirm<W: TpktWriter>(writer: &mut W, source_reference: u16, destination_reference: u16, size: TpduSize, calling_tsap_id: Option<Vec<u8>>, called_tsap_id: Option<Vec<u8>>) -> Result<(), CotpError> {
269    let mut parameters = vec![CotpParameter::TpduLengthParameter(size)];
270    if let Some(tsap_id) = calling_tsap_id {
271        parameters.push(CotpParameter::CallingTsap(tsap_id));
272    }
273    if let Some(tsap_id) = called_tsap_id {
274        parameters.push(CotpParameter::CalledTsap(tsap_id));
275    }
276
277    let payload = serialise(&TransportProtocolDataUnit::CC(ConnectionConfirm::new(
278        0,
279        source_reference,
280        destination_reference,
281        ConnectionClass::Class0,
282        vec![],
283        parameters,
284        &[],
285    )))?;
286    Ok(writer.send(&payload.as_slice()).await?)
287}
288
289async fn send_connection_request(writer: &mut impl TpktWriter, source_reference: u16, options: CotpConnectInformation) -> Result<(), CotpError> {
290    let mut parameters = vec![CotpParameter::TpduLengthParameter(TpduSize::Size2048)];
291    if let Some(calling_tsap) = options.calling_tsap_id {
292        parameters.push(CotpParameter::CallingTsap(calling_tsap));
293    }
294    if let Some(called_tsap) = options.called_tsap_id {
295        parameters.push(CotpParameter::CalledTsap(called_tsap));
296    }
297
298    let payload = serialise(&TransportProtocolDataUnit::CR(ConnectionRequest::new(source_reference, 0, ConnectionClass::Class0, vec![], parameters, &[])))?;
299    Ok(writer.send(&payload.as_slice()).await?)
300}
301
302async fn receive_connection_confirm(reader: &mut impl TpktReader, parser: &TransportProtocolDataUnitParser) -> Result<ConnectionConfirm, CotpError> {
303    let data = match reader.recv().await {
304        Ok(TpktRecvResult::Data(x)) => x,
305        Ok(TpktRecvResult::Closed) => return Err(CotpError::ProtocolError("The connection was closed before the COTP handshake was complete.".into())),
306        Err(e) => return Err(e.into()),
307    };
308    return Ok(match parser.parse(data.as_slice())? {
309        TransportProtocolDataUnit::CC(x) if x.preferred_class() != &ConnectionClass::Class0 => return Err(CotpError::ProtocolError("Remote failed to select COTP Class 0.".into())),
310        TransportProtocolDataUnit::CC(x) => x,
311        TransportProtocolDataUnit::CR(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a connection request".into())),
312        TransportProtocolDataUnit::DR(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a disconnect reqeust".into())),
313        TransportProtocolDataUnit::DT(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a data transfer".into())),
314        TransportProtocolDataUnit::ER(_) => return Err(CotpError::ProtocolError("Expected connection confirmed on handshake but got a error response".into())),
315    });
316}