1use std::os::fd::IntoRawFd;
2use std::os::unix::io::RawFd;
3
4use bitcoin::consensus::Encodable;
5use lightning_signer::bitcoin;
6use nix::sys::socket::{socketpair, AddressFamily, SockFlag, SockType};
7
8use serde_bolt::{io::Read, ReadBigEndian};
9use vls_protocol::serde_bolt;
10use vls_protocol::serde_bolt::io::FromStd;
11use vls_protocol::{msgs, Error, Result};
12use vls_protocol_signer::vls_protocol;
13
14use crate::connection::UnixConnection;
15
16pub trait Client: Send {
17 fn write<M: msgs::DeBolt + Encodable>(&mut self, msg: M) -> Result<()>;
18 fn write_vec(&mut self, v: Vec<u8>) -> Result<()>;
19 fn read(&mut self) -> Result<msgs::Message>;
20 fn read_raw(&mut self) -> Result<Vec<u8>>;
21 fn id(&self) -> u64;
22 #[must_use = "don't leak the client fd"]
23 fn new_client(&mut self) -> Self;
24}
25
26pub struct UnixClient {
27 conn: FromStd<UnixConnection>,
28}
29
30impl UnixClient {
31 pub fn new(conn: UnixConnection) -> Self {
32 Self { conn: FromStd::new(conn) }
33 }
34
35 pub fn recv_fd(&mut self) -> core::result::Result<RawFd, ()> {
36 self.conn.inner().recv_fd()
37 }
38}
39
40impl Client for UnixClient {
41 fn write<M: msgs::DeBolt + Encodable>(&mut self, msg: M) -> Result<()> {
42 msgs::write(&mut self.conn, msg)?;
43 Ok(())
44 }
45
46 fn write_vec(&mut self, v: Vec<u8>) -> Result<()> {
47 msgs::write_vec(&mut self.conn, v)?;
48 Ok(())
49 }
50
51 fn read(&mut self) -> Result<msgs::Message> {
52 msgs::read(&mut self.conn)
53 }
54
55 fn read_raw(&mut self) -> Result<Vec<u8>> {
56 let len = self.conn.read_u32_be().map_err(|e| Error::Io(e.to_string()))?;
58 let mut data = Vec::new();
59 data.resize(len as usize, 0);
60 let len = self.conn.read(&mut data)?;
61 if len < data.len() {
62 return Err(Error::ShortRead);
63 }
64 Ok(data)
65 }
66
67 fn id(&self) -> u64 {
68 self.conn.inner().id()
69 }
70
71 fn new_client(&mut self) -> UnixClient {
72 let (fd_a, fd_b) =
73 socketpair(AddressFamily::Unix, SockType::Stream, None, SockFlag::empty()).unwrap();
74 self.conn.inner().send_fd(fd_a.into_raw_fd());
75 UnixClient::new(UnixConnection::new(fd_b.into_raw_fd()))
76 }
77}