1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
use std::io::{IoSlice, IoSliceMut, Read as _, Write as _};
use std::os::unix::io::{FromRawFd, IntoRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::{fs, io};

use log::{error, info, trace};
use nix::cmsg_space;
use nix::sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags};
use serde_bolt::{Error as SError, Read, Result as SResult, Write};

use nix::libc;
use nix::unistd::close;
use vls_protocol::serde_bolt;
use vls_protocol_signer::vls_protocol;

const PARENT_FD: u16 = 3;

pub struct UnixConnection {
    fd: RawFd,
    stream: UnixStream,
    peek: Option<u8>,
}

impl UnixConnection {
    pub fn new(fd: RawFd) -> Self {
        UnixConnection { fd, stream: unsafe { UnixStream::from_raw_fd(fd) }, peek: None }
    }

    pub(crate) fn id(&self) -> u64 {
        self.fd as u64
    }

    pub(crate) fn send_fd(&self, fd: RawFd) {
        info!("sending fd {}", fd);
        let fds = [fd];
        let fd_msg = ControlMessage::ScmRights(&fds);
        let c = [0xff];
        let x = IoSlice::new(&c);
        sendmsg::<()>(self.fd, &[x], &[fd_msg], MsgFlags::empty(), None).unwrap();
        close(fd).unwrap();
    }

    pub(crate) fn recv_fd(&self) -> Result<RawFd, ()> {
        let mut cmsgs = cmsg_space!(RawFd);
        let mut c = [0];
        let x = IoSliceMut::new(&mut c);
        let result = recvmsg::<()>(self.fd, &mut [x], Some(&mut cmsgs), MsgFlags::empty()).unwrap();
        let mut iter = result.cmsgs();
        let cmsg = iter.next().ok_or_else(|| {
            error!("expected a control message");
        })?;
        if iter.next().is_some() {
            error!("expected exactly one control message");
            return Err(());
        }
        match cmsg {
            ControlMessageOwned::ScmRights(r) =>
                if r.len() != 1 {
                    error!("expected exactly one fd");
                    Err(())
                } else {
                    if c[0] != 0xff {
                        error!("expected a 0xff byte ancillary byte, got {}", c[0]);
                        return Err(());
                    }
                    Ok(r[0])
                },
            m => {
                error!("unexpected cmsg {:?}", m);
                Err(())
            }
        }
    }
}

impl Read for UnixConnection {
    type Error = SError;

    fn read(&mut self, dest: &mut [u8]) -> SResult<usize> {
        let mut cursor = 0;
        if dest.is_empty() {
            return Ok(0);
        }
        if let Some(peek) = self.peek {
            cursor += 1;
            dest[0] = peek;
            self.peek = None;
        }
        while cursor < dest.len() {
            let res: io::Result<usize> = self.stream.read(&mut dest[cursor..]);
            trace!("read {}: {:?} cursor={} expected={}", self.id(), res, cursor, dest.len());
            match res {
                Ok(n) => {
                    if n == 0 {
                        return Ok(cursor);
                    }
                    cursor = cursor + n;
                }
                Err(e) => {
                    return Err(SError::Message(format!("{}", e)));
                }
            }
        }
        Ok(cursor)
    }

    fn peek(&mut self) -> SResult<Option<u8>> {
        if self.peek.is_some() {
            return Ok(self.peek);
        }
        let mut buf = [0; 1];
        let res: io::Result<usize> = self.stream.read(&mut buf);
        return match res {
            Ok(n) =>
                if n == 0 {
                    Ok(None)
                } else {
                    assert_eq!(n, 1);
                    self.peek = Some(buf[0]);
                    Ok(self.peek)
                },
            Err(e) => Err(SError::Message(format!("{}", e))),
        };
    }
}

impl Write for UnixConnection {
    type Error = SError;

    fn write_all(&mut self, buf: &[u8]) -> SResult<()> {
        self.stream.write_all(buf).map_err(|e| SError::Message(format!("{}", e)))?;
        Ok(())
    }
}

pub fn open_parent_fd() -> RawFd {
    // Only use fd 3 if we are really running with a lightningd parent, so we don't conflict with future fd allocation
    // Check this before opening any files or sockets!
    let have_parent = unsafe { libc::fcntl(PARENT_FD as libc::c_int, libc::F_GETFD) } != -1;

    let dummy_file = fs::File::open("/dev/null").unwrap().into_raw_fd();

    let parent_fd = if have_parent {
        close(dummy_file).expect("close dummy");
        RawFd::from(PARENT_FD)
    } else {
        error!("no parent on {}, using /dev/null", PARENT_FD);
        dummy_file
    };
    parent_fd
}