1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright 2023 Remi Bernotavicius

use derive_more::From;
use serde::{de::DeserializeOwned, Serialize};
use std::io;
use sun_rpc::{
    AcceptedReplyBody, AuthSysParameters, CallBody, Gid, Message, MessageBody, OpaqueAuth,
    ReplyBody, Uid, Xid,
};

pub type Result<T> = std::result::Result<T, Error>;

#[derive(Debug, From)]
pub enum Error {
    Deseralization(serde_xdr::CompatDeserializationError),
    Serialization(serde_xdr::CompatSerializationError),
    Io(io::Error),
    ProgramUnavailable,
    ProgramMismatch,
    ProcedureUnavailable,
    GarbageArguments,
    SystemError,
    UnexpectedReply,
}

pub trait Transport: io::Read + io::Write {}

impl<T> Transport for T where T: io::Read + io::Write {}

pub const PORT_MAPPER: u32 = 100000;
pub const PORT_MAPPER_PORT: u16 = 111;
pub const NULL_PROCEDURE: u32 = 0;

pub struct RpcClient {
    xid: Xid,
    program: u32,
}

impl RpcClient {
    pub fn new(program: u32) -> Self {
        Self {
            xid: Xid(1),
            program,
        }
    }

    pub fn send_request<T: Serialize>(
        &mut self,
        transport: &mut impl Transport,
        procedure: u32,
        call_args: T,
    ) -> Result<()> {
        let message = Message {
            xid: self.xid.clone(),
            body: MessageBody::Call(CallBody {
                rpc_version: 2,
                program: self.program,
                version: 4,
                procedure,
                credential: OpaqueAuth::auth_sys(AuthSysParameters {
                    stamp: 0,
                    machine_name: "test-machine".into(),
                    uid: Uid(0),
                    gid: Gid(0),
                    gids: vec![Gid(0)],
                }),
                verifier: OpaqueAuth::none(),
                call_args,
            }),
        };
        let mut serialized = vec![0; 4];
        serde_xdr::to_writer(&mut serialized, &message)?;

        let fragment_header = (serialized.len() - 4) as u32 | 0x1 << 31;
        serde_xdr::to_writer(&mut &mut serialized[..4], &fragment_header)?;

        transport.write_all(&serialized[..])?;

        self.xid = Xid(self.xid.0 + 1);

        Ok(())
    }

    pub fn receive_reply<T: DeserializeOwned>(
        &mut self,
        mut transport: &mut impl Transport,
    ) -> Result<T> {
        let fragment_header: u32 = serde_xdr::from_reader(transport)?;
        let length = fragment_header & !(0x1 << 31);
        let reply: Message<T> =
            serde_xdr::from_reader(&mut io::Read::take(&mut transport, length as u64))?;

        if let Message {
            body: MessageBody::Reply(ReplyBody::Accepted(accepted_reply)),
            ..
        } = reply
        {
            match accepted_reply.body {
                AcceptedReplyBody::Success(b) => Ok(b),
                AcceptedReplyBody::ProgramUnavailable => Err(Error::ProgramUnavailable),
                AcceptedReplyBody::ProgramMismatch { .. } => Err(Error::ProgramMismatch),
                AcceptedReplyBody::ProcedureUnavailable => Err(Error::ProcedureUnavailable),
                AcceptedReplyBody::GarbageArguments => Err(Error::GarbageArguments),
                AcceptedReplyBody::SystemError => Err(Error::SystemError),
            }
        } else {
            Err(Error::UnexpectedReply)
        }
    }
}

#[test]
fn ping() {
    vm_test_fixture::fixture(|m| {
        let port = m
            .forwarded_ports()
            .iter()
            .find(|p| p.guest == PORT_MAPPER_PORT)
            .unwrap();
        let mut transport = std::net::TcpStream::connect(("127.0.0.1", port.host)).unwrap();
        let mut client = RpcClient::new(PORT_MAPPER);

        client
            .send_request(&mut transport, NULL_PROCEDURE, ())
            .unwrap();

        client.receive_reply::<()>(&mut transport).unwrap();
    });
}