1use std::{io, sync::Arc, time::Duration};
2
3use bytes::BytesMut;
4use tokio::{
5 io::{AsyncReadExt, AsyncWriteExt},
6 net::{TcpListener, TcpStream, ToSocketAddrs, tcp::OwnedReadHalf},
7 sync::Mutex,
8 task::block_in_place,
9 time::timeout,
10};
11use tokio_util::codec::Decoder;
12use tokio_util::sync::CancellationToken;
13
14use crate::XvcServer;
15use xvc_protocol::{
16 Message, OwnedMessage, Version, XvcInfo, error::ReadError, tokio_codec::MessageDecoder,
17};
18
19#[derive(Debug, Clone)]
20pub struct Config {
21 pub max_vector_size: u32,
23 pub read_write_timeout: Duration,
26}
27
28impl Default for Config {
29 fn default() -> Self {
30 Self {
31 max_vector_size: 10 * 1024 * 1024,
32 read_write_timeout: Duration::from_secs(30),
33 }
34 }
35}
36
37#[derive(Debug)]
38pub struct Server<T: XvcServer> {
39 server: Arc<Mutex<T>>,
40 config: Config,
41}
42
43#[derive(Default)]
57pub struct Builder {
58 config: Config,
59}
60
61impl Builder {
62 pub fn new() -> Builder {
63 Builder::default()
64 }
65
66 pub fn max_vector_size(mut self, size: u32) -> Self {
68 self.config.max_vector_size = size;
69 self
70 }
71
72 pub fn rw_timeout(mut self, timeout: Duration) -> Self {
74 self.config.read_write_timeout = timeout;
75 self
76 }
77
78 pub fn build<T: XvcServer>(self, server: T) -> Server<T> {
80 Server::new(server, self.config)
81 }
82}
83
84impl<T: XvcServer> Server<T> {
85 pub fn new(server: T, config: Config) -> Server<T> {
87 Server {
88 server: Arc::new(Mutex::new(server)),
89 config,
90 }
91 }
92
93 pub async fn listen(&self, addr: impl ToSocketAddrs) -> io::Result<()>
99 where
100 T: Send + 'static,
101 {
102 let listener = TcpListener::bind(addr).await?;
103 self.listen_on(listener, CancellationToken::new()).await
104 }
105
106 pub async fn listen_on(
132 &self,
133 listener: TcpListener,
134 shutdown: CancellationToken,
135 ) -> io::Result<()>
136 where
137 T: Send + 'static,
138 {
139 log::info!("Server listening for connections");
140
141 loop {
142 tokio::select! {
143 _ = shutdown.cancelled() => {
144 log::info!("Shutdown signal received, stopping listener");
145 break;
146 }
147 result = listener.accept() => {
148 match result {
149 Ok((stream, addr)) => {
150 let guard = match Arc::clone(&self.server).try_lock_owned() {
151 Ok(guard) => guard,
152 Err(_) => {
153 log::warn!("Rejected concurrent client from {}: another client is already active", addr);
154 continue;
155 }
156 };
157 stream.set_nodelay(true)?;
158 log::info!("New client connection from {}", addr);
159 let config = self.config.clone();
160 tokio::spawn(async move {
161 if let Err(e) = handle_client(guard, config, stream).await {
162 log::error!("Client error: {}", e);
163 }
164 });
165 }
166 Err(e) => log::error!("Connection error: {}", e),
167 }
168 }
169 }
170 }
171
172 Ok(())
173 }
174}
175
176async fn handle_client<T>(
177 server: tokio::sync::OwnedMutexGuard<T>,
178 config: Config,
179 stream: TcpStream,
180) -> Result<(), ReadError>
181where
182 T: XvcServer + Send + 'static,
183{
184 let (mut read_half, mut write_half) = stream.into_split();
185 let mut buf = BytesMut::new();
186 let mut decoder = MessageDecoder::new(config.max_vector_size as usize);
187
188 loop {
189 match read_message(
190 &mut read_half,
191 &mut buf,
192 &mut decoder,
193 config.read_write_timeout,
194 )
195 .await
196 {
197 Ok(Some(msg)) => {
198 let response = block_in_place(|| compute_response(&*server, &config, msg))?;
199 write_half.write_all(&response).await?;
200 }
201 Ok(None) => break,
202 Err(e) => return Err(e),
203 }
204 }
205
206 Ok(())
207}
208
209async fn read_message(
212 read: &mut OwnedReadHalf,
213 buf: &mut BytesMut,
214 decoder: &mut MessageDecoder,
215 rw_timeout: Duration,
216) -> Result<Option<OwnedMessage>, ReadError> {
217 loop {
218 if let Some(msg) = decoder.decode(buf)? {
219 return Ok(Some(msg));
220 }
221
222 match timeout(rw_timeout, read.read_buf(buf)).await {
223 Ok(Ok(0)) => return Ok(None), Ok(Ok(_)) => {} Ok(Err(e)) => return Err(ReadError::from(e)),
226 Err(_elapsed) => {
227 log::warn!("Client read timeout, closing connection");
228 return Ok(None);
229 }
230 }
231 }
232}
233
234fn compute_response<T: XvcServer>(
235 server: &T,
236 config: &Config,
237 msg: OwnedMessage,
238) -> Result<Vec<u8>, ReadError> {
239 let mut buf = Vec::new();
240 match msg {
241 Message::GetInfo => {
242 log::info!("Received GetInfo message");
243 let info = XvcInfo::new(Version::V1_0, config.max_vector_size);
244 info.write_to(&mut buf)?;
245 log::debug!("Sent XVC info response");
246 }
247 Message::SetTck { period_ns } => {
248 log::debug!("Received SetTck message: period_ns={}", period_ns);
249 let ret_period = server.set_tck(period_ns);
250 log::debug!("Set TCK returned: period_ns={}", ret_period);
251 buf.extend_from_slice(&ret_period.to_le_bytes());
252 }
253 Message::Shift { num_bits, tms, tdi } => {
254 log::debug!(
255 "Received Shift message: num_bits={}, tms_len={}, tdi_len={}",
256 num_bits,
257 tms.len(),
258 tdi.len()
259 );
260 log::trace!("Shift TMS data: {:02x?}", &tms[..]);
261 log::trace!("Shift TDI data: {:02x?}", &tdi[..]);
262 let tdo = server.shift(num_bits, &tms, &tdi);
263 log::trace!("Shift result TDO data: {:02x?}", &tdo[..]);
264 buf.extend_from_slice(&tdo);
265 }
266 }
267 Ok(buf)
268}