1use alloc::vec;
2use cobs::encode;
3use delegate::delegate;
4use std::io::Write;
5use std::net::SocketAddr;
6use std::net::TcpListener;
7use std::net::TcpStream;
8use std::vec::Vec;
9
10use crate::encoding::parse_buffer_for_cobs_encoded_packets;
11use crate::tmtc::ReceivesTc;
12use crate::tmtc::TmPacketSource;
13
14use crate::hal::std::tcp_server::{
15 ConnectionResult, ServerConfig, TcpTcParser, TcpTmSender, TcpTmtcError, TcpTmtcGenericServer,
16};
17
18#[derive(Default)]
20pub struct CobsTcParser {}
21
22impl<TmError, TcError: 'static> TcpTcParser<TmError, TcError> for CobsTcParser {
23 fn handle_tc_parsing(
24 &mut self,
25 tc_buffer: &mut [u8],
26 tc_receiver: &mut (impl ReceivesTc<Error = TcError> + ?Sized),
27 conn_result: &mut ConnectionResult,
28 current_write_idx: usize,
29 next_write_idx: &mut usize,
30 ) -> Result<(), TcpTmtcError<TmError, TcError>> {
31 conn_result.num_received_tcs += parse_buffer_for_cobs_encoded_packets(
32 &mut tc_buffer[..current_write_idx],
33 tc_receiver.upcast_mut(),
34 next_write_idx,
35 )
36 .map_err(|e| TcpTmtcError::TcError(e))?;
37 Ok(())
38 }
39}
40
41pub struct CobsTmSender {
43 tm_encoding_buffer: Vec<u8>,
44}
45
46impl CobsTmSender {
47 fn new(tm_buffer_size: usize) -> Self {
48 Self {
49 tm_encoding_buffer: vec![0; cobs::max_encoding_length(tm_buffer_size)],
52 }
53 }
54}
55
56impl<TmError, TcError> TcpTmSender<TmError, TcError> for CobsTmSender {
57 fn handle_tm_sending(
58 &mut self,
59 tm_buffer: &mut [u8],
60 tm_source: &mut (impl TmPacketSource<Error = TmError> + ?Sized),
61 conn_result: &mut ConnectionResult,
62 stream: &mut TcpStream,
63 ) -> Result<bool, TcpTmtcError<TmError, TcError>> {
64 let mut tm_was_sent = false;
65 loop {
66 let read_tm_len = tm_source
69 .retrieve_packet(tm_buffer)
70 .map_err(|e| TcpTmtcError::TmError(e))?;
71
72 if read_tm_len == 0 {
73 return Ok(tm_was_sent);
74 }
75 tm_was_sent = true;
76 conn_result.num_sent_tms += 1;
77
78 let mut current_idx = 0;
80 self.tm_encoding_buffer[current_idx] = 0;
81 current_idx += 1;
82 current_idx += encode(
83 &tm_buffer[..read_tm_len],
84 &mut self.tm_encoding_buffer[current_idx..],
85 );
86 self.tm_encoding_buffer[current_idx] = 0;
87 current_idx += 1;
88 stream.write_all(&self.tm_encoding_buffer[..current_idx])?;
89 }
90 }
91}
92
93pub struct TcpTmtcInCobsServer<
113 TmError,
114 TcError: 'static,
115 TmSource: TmPacketSource<Error = TmError>,
116 TcReceiver: ReceivesTc<Error = TcError>,
117> {
118 generic_server:
119 TcpTmtcGenericServer<TmError, TcError, TmSource, TcReceiver, CobsTmSender, CobsTcParser>,
120}
121
122impl<
123 TmError: 'static,
124 TcError: 'static,
125 TmSource: TmPacketSource<Error = TmError>,
126 TcReceiver: ReceivesTc<Error = TcError>,
127 > TcpTmtcInCobsServer<TmError, TcError, TmSource, TcReceiver>
128{
129 pub fn new(
140 cfg: ServerConfig,
141 tm_source: TmSource,
142 tc_receiver: TcReceiver,
143 ) -> Result<Self, std::io::Error> {
144 Ok(Self {
145 generic_server: TcpTmtcGenericServer::new(
146 cfg,
147 CobsTcParser::default(),
148 CobsTmSender::new(cfg.tm_buffer_size),
149 tm_source,
150 tc_receiver,
151 )?,
152 })
153 }
154
155 delegate! {
156 to self.generic_server {
157 pub fn listener(&mut self) -> &mut TcpListener;
158
159 pub fn local_addr(&self) -> std::io::Result<SocketAddr>;
162
163 pub fn handle_next_connection(
165 &mut self,
166 ) -> Result<ConnectionResult, TcpTmtcError<TmError, TcError>>;
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use core::{
174 sync::atomic::{AtomicBool, Ordering},
175 time::Duration,
176 };
177 use std::{
178 io::{Read, Write},
179 net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream},
180 thread,
181 };
182
183 use crate::{
184 encoding::tests::{INVERTED_PACKET, SIMPLE_PACKET},
185 hal::std::tcp_server::{
186 tests::{SyncTcCacher, SyncTmSource},
187 ServerConfig,
188 },
189 };
190 use alloc::sync::Arc;
191 use cobs::encode;
192
193 use super::TcpTmtcInCobsServer;
194
195 fn encode_simple_packet(encoded_buf: &mut [u8], current_idx: &mut usize) {
196 encode_packet(&SIMPLE_PACKET, encoded_buf, current_idx)
197 }
198
199 fn encode_inverted_packet(encoded_buf: &mut [u8], current_idx: &mut usize) {
200 encode_packet(&INVERTED_PACKET, encoded_buf, current_idx)
201 }
202
203 fn encode_packet(packet: &[u8], encoded_buf: &mut [u8], current_idx: &mut usize) {
204 encoded_buf[*current_idx] = 0;
205 *current_idx += 1;
206 *current_idx += encode(packet, &mut encoded_buf[*current_idx..]);
207 encoded_buf[*current_idx] = 0;
208 *current_idx += 1;
209 }
210
211 fn generic_tmtc_server(
212 addr: &SocketAddr,
213 tc_receiver: SyncTcCacher,
214 tm_source: SyncTmSource,
215 ) -> TcpTmtcInCobsServer<(), (), SyncTmSource, SyncTcCacher> {
216 TcpTmtcInCobsServer::new(
217 ServerConfig::new(*addr, Duration::from_millis(2), 1024, 1024),
218 tm_source,
219 tc_receiver,
220 )
221 .expect("TCP server generation failed")
222 }
223
224 #[test]
225 fn test_server_basic_no_tm() {
226 let auto_port_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
227 let tc_receiver = SyncTcCacher::default();
228 let tm_source = SyncTmSource::default();
229 let mut tcp_server = generic_tmtc_server(&auto_port_addr, tc_receiver.clone(), tm_source);
230 let dest_addr = tcp_server
231 .local_addr()
232 .expect("retrieving dest addr failed");
233 let conn_handled: Arc<AtomicBool> = Default::default();
234 let set_if_done = conn_handled.clone();
235 thread::spawn(move || {
237 let result = tcp_server.handle_next_connection();
238 if result.is_err() {
239 panic!("handling connection failed: {:?}", result.unwrap_err());
240 }
241 let conn_result = result.unwrap();
242 assert_eq!(conn_result.num_received_tcs, 1);
243 assert_eq!(conn_result.num_sent_tms, 0);
244 set_if_done.store(true, Ordering::Relaxed);
245 });
246 let mut encoded_buf: [u8; 16] = [0; 16];
248 let mut current_idx = 0;
249 encode_simple_packet(&mut encoded_buf, &mut current_idx);
250 let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed");
251 stream
252 .write_all(&encoded_buf[..current_idx])
253 .expect("writing to TCP server failed");
254 drop(stream);
255 for _ in 0..3 {
257 if !conn_handled.load(Ordering::Relaxed) {
258 thread::sleep(Duration::from_millis(5));
259 }
260 }
261 if !conn_handled.load(Ordering::Relaxed) {
262 panic!("connection was not handled properly");
263 }
264 let mut tc_queue = tc_receiver
266 .tc_queue
267 .lock()
268 .expect("locking tc queue failed");
269 assert_eq!(tc_queue.len(), 1);
270 assert_eq!(tc_queue.pop_front().unwrap(), &SIMPLE_PACKET);
271 drop(tc_queue);
272 }
273
274 #[test]
275 fn test_server_basic_multi_tm_multi_tc() {
276 let auto_port_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0);
277 let tc_receiver = SyncTcCacher::default();
278 let mut tm_source = SyncTmSource::default();
279 tm_source.add_tm(&INVERTED_PACKET);
280 tm_source.add_tm(&SIMPLE_PACKET);
281 let mut tcp_server =
282 generic_tmtc_server(&auto_port_addr, tc_receiver.clone(), tm_source.clone());
283 let dest_addr = tcp_server
284 .local_addr()
285 .expect("retrieving dest addr failed");
286 let conn_handled: Arc<AtomicBool> = Default::default();
287 let set_if_done = conn_handled.clone();
288 thread::spawn(move || {
290 let result = tcp_server.handle_next_connection();
291 if result.is_err() {
292 panic!("handling connection failed: {:?}", result.unwrap_err());
293 }
294 let conn_result = result.unwrap();
295 assert_eq!(conn_result.num_received_tcs, 2, "Not enough TCs received");
296 assert_eq!(conn_result.num_sent_tms, 2, "Not enough TMs received");
297 set_if_done.store(true, Ordering::Relaxed);
298 });
299 let mut encoded_buf: [u8; 32] = [0; 32];
301 let mut current_idx = 0;
302 encode_simple_packet(&mut encoded_buf, &mut current_idx);
303 encode_inverted_packet(&mut encoded_buf, &mut current_idx);
304 let mut stream = TcpStream::connect(dest_addr).expect("connecting to TCP server failed");
305 stream
306 .set_read_timeout(Some(Duration::from_millis(10)))
307 .expect("setting reas timeout failed");
308 stream
309 .write_all(&encoded_buf[..current_idx])
310 .expect("writing to TCP server failed");
311 stream
313 .shutdown(std::net::Shutdown::Write)
314 .expect("shutting down write failed");
315 let mut read_buf: [u8; 16] = [0; 16];
316 let mut read_len_total = 0;
317 while read_len_total < 16 {
319 let read_len = stream.read(&mut read_buf).expect("read failed");
320 read_len_total += read_len;
321 if read_len == 16 {
323 current_idx = 0;
325 assert_eq!(read_len, 16);
326 assert_eq!(read_buf[0], 0);
327 current_idx += 1;
328 let mut dec_report = cobs::decode_in_place_report(&mut read_buf[current_idx..])
329 .expect("COBS decoding failed");
330 assert_eq!(dec_report.dst_used, 5);
331 assert_eq!(
333 &read_buf[current_idx..current_idx + INVERTED_PACKET.len()],
334 &INVERTED_PACKET
335 );
336 current_idx += dec_report.src_used;
337 assert_eq!(read_buf[current_idx], 0, "invalid sentinel end byte");
339 current_idx += 1;
340
341 assert_eq!(read_buf[current_idx], 0);
343 current_idx += 1;
344 dec_report = cobs::decode_in_place_report(&mut read_buf[current_idx..])
345 .expect("COBS decoding failed");
346 assert_eq!(dec_report.dst_used, 5);
347 assert_eq!(
349 &read_buf[current_idx..current_idx + SIMPLE_PACKET.len()],
350 &SIMPLE_PACKET
351 );
352 current_idx += dec_report.src_used;
353 assert_eq!(read_buf[current_idx], 0);
355 break;
356 }
357 }
358 drop(stream);
359
360 for _ in 0..3 {
362 if !conn_handled.load(Ordering::Relaxed) {
363 thread::sleep(Duration::from_millis(5));
364 }
365 }
366 if !conn_handled.load(Ordering::Relaxed) {
367 panic!("connection was not handled properly");
368 }
369 let mut tc_queue = tc_receiver
371 .tc_queue
372 .lock()
373 .expect("locking tc queue failed");
374 assert_eq!(tc_queue.len(), 2);
375 assert_eq!(tc_queue.pop_front().unwrap(), &SIMPLE_PACKET);
376 assert_eq!(tc_queue.pop_front().unwrap(), &INVERTED_PACKET);
377 drop(tc_queue);
378 }
379}