1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//! # TCP binder
//!
//! This module contains the implementation of the TCP server binder,
//! based on [`tokio::net::TcpStream`].

use async_trait::async_trait;
use log::debug;
use std::io;
use tokio::{
    io::{AsyncBufReadExt, AsyncWriteExt},
    net::TcpListener,
};

use crate::{
    request::{Request, RequestReader},
    response::{Response, ResponseWriter},
    tcp::TcpHandler,
    timer::ThreadSafeTimer,
};

use super::{ServerBind, ServerStream};

/// The TCP server binder.
///
/// This [`ServerBind`]er uses the TCP protocol to bind a listener, to
/// read requests and write responses.
pub struct TcpBind {
    /// The TCP host of the listener.
    pub host: String,

    /// The TCP port of the listener.
    pub port: u16,
}

impl TcpBind {
    /// Create a new TCP binder using the given host and port.
    pub fn new(host: impl ToString, port: u16) -> Box<dyn ServerBind> {
        Box::new(Self {
            host: host.to_string(),
            port,
        })
    }
}

#[async_trait]
impl ServerBind for TcpBind {
    async fn bind(&self, timer: ThreadSafeTimer) -> io::Result<()> {
        let listener = TcpListener::bind((self.host.as_str(), self.port)).await?;

        loop {
            match listener.accept().await {
                Ok((stream, _)) => {
                    let mut handler = TcpHandler::from(stream);
                    if let Err(err) = handler.handle(timer.clone()).await {
                        debug!("cannot handle request");
                        debug!("{err:?}");
                    }
                }
                Err(err) => {
                    debug!("cannot get stream from client");
                    debug!("{err:?}");
                }
            }
        }
    }
}

#[async_trait]
impl RequestReader for TcpHandler {
    async fn read(&mut self) -> io::Result<Request> {
        let mut req = String::new();
        self.reader.read_line(&mut req).await?;

        let mut tokens = req.split_whitespace();
        match tokens.next() {
            Some("start") => Ok(Request::Start),
            Some("get") => Ok(Request::Get),
            Some("set") => match tokens.next().map(|duration| duration.parse::<usize>()) {
                Some(Ok(duration)) => Ok(Request::Set(duration)),
                Some(Err(err)) => Err(io::Error::new(
                    io::ErrorKind::InvalidInput,
                    format!("invalid duration: {err}"),
                )),
                None => Err(io::Error::new(
                    io::ErrorKind::InvalidInput,
                    "missing duration".to_owned(),
                )),
            },
            Some("pause") => Ok(Request::Pause),
            Some("resume") => Ok(Request::Resume),
            Some("stop") => Ok(Request::Stop),
            Some(req) => Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                format!("invalid request: {req}"),
            )),
            None => Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                "missing request".to_owned(),
            )),
        }
    }
}

#[async_trait]
impl ResponseWriter for TcpHandler {
    async fn write(&mut self, res: Response) -> io::Result<()> {
        let res = match res {
            Response::Ok => format!("ok\n"),
            Response::Timer(timer) => {
                format!("timer {}\n", serde_json::to_string(&timer).unwrap())
            }
        };

        self.writer.write_all(res.as_bytes()).await?;

        Ok(())
    }
}