1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{BufMut, BytesMut};
use futures::future::FutureResult;
use log::{error, info};
use tokio::codec::{Decoder, Encoder, Framed};
use tokio::net::TcpListener;
use tokio::prelude::*;
use tokio_uds::UnixListener;

use std::error::Error;
use std::fmt::Debug;
use std::mem::size_of;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;

use super::error::AgentError;
use super::proto::message::Message;
use super::proto::{from_bytes, to_bytes};

struct MessageCodec;

impl Decoder for MessageCodec {
    type Item = Message;
    type Error = AgentError;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        let mut bytes = &src[..];

        if bytes.len() < size_of::<u32>() {
            return Ok(None);
        }

        let length = bytes.read_u32::<BigEndian>()? as usize;

        if bytes.len() < length {
            return Ok(None);
        }

        let message: Message = from_bytes(bytes)?;
        src.advance(size_of::<u32>() + length);
        Ok(Some(message))
    }
}

impl Encoder for MessageCodec {
    type Item = Message;
    type Error = AgentError;

    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
        let bytes = to_bytes(&to_bytes(&item)?)?;
        dst.put(bytes);
        Ok(())
    }
}

macro_rules! handle_clients {
    ($self:ident, $socket:ident) => {{
        info!("Listening; socket = {:?}", $socket);
        let arc_self = Arc::new($self);
        $socket
            .incoming()
            .map_err(|e| error!("Failed to accept socket; error = {:?}", e))
            .for_each(move |socket| {
                let (write, read) = Framed::new(socket, MessageCodec).split();
                let arc_self = arc_self.clone();
                let connection = write
                    .send_all(read.and_then(move |message| {
                        arc_self.handle_async(message).map_err(|e| {
                            error!("Error handling message; error = {:?}", e);
                            AgentError::User
                        })
                    }))
                    .map(|_| ())
                    .map_err(|e| error!("Error while handling message; error = {:?}", e));
                tokio::spawn(connection)
            })
            .map_err(|e| e.into())
    }};
}

pub trait Agent: 'static + Sync + Send + Sized {
    type Error: Debug + Send + Sync;

    fn handle(&self, message: Message) -> Result<Message, Self::Error>;

    fn handle_async(
        &self,
        message: Message,
    ) -> Box<dyn Future<Item = Message, Error = Self::Error> + Send + Sync> {
        Box::new(FutureResult::from(self.handle(message)))
    }

    #[allow(clippy::unit_arg)]
    fn run_listener(self, socket: UnixListener) -> Result<(), Box<dyn Error + Send + Sync>> {
        Ok(tokio::run(handle_clients!(self, socket)))
    }

    fn run_unix(self, path: impl AsRef<Path>) -> Result<(), Box<dyn Error + Send + Sync>> {
        self.run_listener(UnixListener::bind(path)?)
    }

    #[allow(clippy::unit_arg)]
    fn run_tcp(self, addr: &str) -> Result<(), Box<dyn Error + Send + Sync>> {
        let socket = TcpListener::bind(&addr.parse::<SocketAddr>()?)?;
        Ok(tokio::run(handle_clients!(self, socket)))
    }
}