vls_proxy/grpc/
signer_loop.rs

1use async_trait::async_trait;
2use backon::{BlockingRetryable, ExponentialBuilder, Retryable};
3use log::*;
4use lru::LruCache;
5use std::num::NonZeroUsize;
6use std::process;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10use std::time::SystemTime;
11use tokio::sync::{mpsc, oneshot};
12use tokio::task::spawn_blocking;
13use triggered::{Listener, Trigger};
14
15use lightning_signer::bitcoin::hashes::sha256::Hash as Sha256Hash;
16use lightning_signer::bitcoin::hashes::Hash;
17
18use super::adapter::{ChannelReply, ChannelRequest, ClientId};
19use vls_common::*;
20use vls_protocol::{
21    msgs, msgs::DeBolt as _, msgs::Message, msgs::SerBolt as _, Error as ProtocolError,
22};
23use vls_protocol_client::{ClientResult as Result, Error, SignerPort};
24use vls_protocol_signer::vls_protocol;
25
26use crate::client::Client;
27use crate::*;
28
29// The default backon rules retry after 1s then start
30// doubling, so this number equates to 1 + 2^(n - 1) seconds
31const BACKOFF_RETRY_TIMES: usize = 8;
32const PREAPPROVE_CACHE_TTL: Duration = Duration::from_secs(60);
33const PREAPPROVE_CACHE_SIZE: usize = 6;
34
35struct PreapprovalCacheEntry {
36    tstamp: SystemTime,
37    reply_bytes: Vec<u8>,
38}
39
40pub struct GrpcSignerPort {
41    sender: mpsc::Sender<ChannelRequest>,
42    is_ready: Arc<AtomicBool>,
43}
44
45// create a Backoff
46fn backoff() -> ExponentialBuilder {
47    ExponentialBuilder::new().with_max_times(BACKOFF_RETRY_TIMES)
48}
49
50#[async_trait]
51impl SignerPort for GrpcSignerPort {
52    async fn handle_message(&self, message: Vec<u8>) -> Result<Vec<u8>> {
53        let result = (|| async {
54            let reply_rx = self.send_request(message.clone()).await?;
55            // Wait for the signer reply
56            // Can fail if the adapter shut down
57            let reply = reply_rx.await.map_err(|_| Error::Transport)?;
58            if reply.is_temporary_failure {
59                // Retry with backoff
60                info!("temporary error, retrying");
61                return Err(Error::TransportTransient);
62            }
63
64            return Ok(reply.reply);
65        })
66        .retry(backoff())
67        .when(|e| *e == Error::TransportTransient)
68        .await
69        .map_err(|e| {
70            error!("signer retry failed: {:?}", e);
71            e
72        })?;
73        Ok(result)
74    }
75
76    fn is_ready(&self) -> bool {
77        self.is_ready.load(Ordering::Relaxed)
78    }
79}
80
81impl GrpcSignerPort {
82    pub fn new(sender: mpsc::Sender<ChannelRequest>) -> Self {
83        GrpcSignerPort { sender, is_ready: Arc::new(AtomicBool::new(false)) }
84    }
85
86    pub(crate) fn set_ready(&self) {
87        info!("setting is_ready true");
88        self.is_ready.store(true, Ordering::Relaxed);
89    }
90
91    async fn send_request(&self, message: Vec<u8>) -> Result<oneshot::Receiver<ChannelReply>> {
92        let (reply_rx, request) = Self::prepare_request(message, None);
93
94        // Send a request to the gRPC handler to send to signer
95        // This can fail if gRPC adapter shut down
96        self.sender.send(request).await.map_err(|_| ProtocolError::Eof)?;
97
98        Ok(reply_rx)
99    }
100
101    // Send a blocking request to the signer with an optional client_id
102    // for use in [`SignerLoop`]
103    fn send_request_blocking(
104        &self,
105        message: Vec<u8>,
106        client_id: Option<ClientId>,
107    ) -> Result<oneshot::Receiver<ChannelReply>> {
108        let (reply_rx, request) = Self::prepare_request(message, client_id);
109
110        // Send a request to the gRPC handler to send to signer
111        // This can fail if gRPC adapter shut down
112        self.sender.blocking_send(request).map_err(|_| ProtocolError::Eof)?;
113
114        Ok(reply_rx)
115    }
116
117    fn prepare_request(
118        message: Vec<u8>,
119        client_id: Option<ClientId>,
120    ) -> (oneshot::Receiver<ChannelReply>, ChannelRequest) {
121        // Create a one-shot channel to receive the reply
122        let (reply_tx, reply_rx) = oneshot::channel();
123
124        let request = ChannelRequest { client_id, message, reply_tx };
125        (reply_rx, request)
126    }
127}
128
129/// A cache of the init message from the node, in case the signer reconnects
130#[derive(Clone)]
131pub struct InitMessageCache {
132    /// The HsmdInit or HsmdInit2 message from node
133    pub init_message: Option<Vec<u8>>,
134}
135
136impl InitMessageCache {
137    /// Create a new cache
138    pub fn new() -> Self {
139        Self { init_message: None }
140    }
141}
142
143/// Implement the hsmd UNIX fd protocol.
144/// This doesn't actually perform the signing - the hsmd packets are transported via gRPC to the
145/// real signer.
146pub struct SignerLoop<C: 'static + Client> {
147    client: C,
148    log_prefix: String,
149    signer_port: Arc<GrpcSignerPort>,
150    client_id: Option<ClientId>,
151    shutdown_trigger: Option<Trigger>,
152    shutdown_signal: Option<Listener>,
153    preapproval_cache: LruCache<Sha256Hash, PreapprovalCacheEntry>,
154    init_message_cache: Arc<Mutex<InitMessageCache>>,
155}
156
157impl<C: 'static + Client> SignerLoop<C> {
158    /// Create a loop for the root (lightningd) connection, but doesn't start it yet
159    pub fn new(
160        client: C,
161        signer_port: Arc<GrpcSignerPort>,
162        shutdown_trigger: Trigger,
163        shutdown_signal: Listener,
164        init_message_cache: Arc<Mutex<InitMessageCache>>,
165    ) -> Self {
166        let log_prefix = format!("{}/{}/{}", std::process::id(), client.id(), 0);
167        let preapproval_cache = LruCache::new(NonZeroUsize::new(PREAPPROVE_CACHE_SIZE).unwrap());
168        Self {
169            client,
170            log_prefix,
171            signer_port,
172            client_id: None,
173            shutdown_trigger: Some(shutdown_trigger),
174            shutdown_signal: Some(shutdown_signal),
175            preapproval_cache,
176            init_message_cache,
177        }
178    }
179
180    // Create a loop for a non-root connection
181    fn new_for_client(client: C, signer_port: Arc<GrpcSignerPort>, client_id: ClientId) -> Self {
182        let log_prefix = format!("{}/{}/{}", std::process::id(), client.id(), client_id.dbid);
183        let preapproval_cache = LruCache::new(NonZeroUsize::new(PREAPPROVE_CACHE_SIZE).unwrap());
184        Self {
185            client,
186            log_prefix,
187            signer_port,
188            client_id: Some(client_id),
189            shutdown_trigger: None,
190            shutdown_signal: None,
191            preapproval_cache,
192            init_message_cache: Arc::new(Mutex::new(InitMessageCache::new())),
193        }
194    }
195
196    fn is_root(&self) -> bool {
197        self.client_id.is_none()
198    }
199
200    /// The init message cache
201    pub fn init_message_cache(&self) -> Arc<Mutex<InitMessageCache>> {
202        self.init_message_cache.clone()
203    }
204
205    /// Start the read loop
206    pub fn start(&mut self) {
207        info!("read loop {}: start", self.log_prefix);
208        if let Some(shutdown_signal) = self.shutdown_signal.as_ref() {
209            // TODO exit more cleanly
210            // Right now there's no clean way to stop the UNIX fd reader loop so just be
211            // aggressive here and exit when it's time to shutdown
212            let shutdown_signal_clone = shutdown_signal.clone();
213            let log_prefix_clone = self.log_prefix.clone();
214            tokio::spawn(async move {
215                info!("read loop {} waiting for shutdown", log_prefix_clone);
216                tokio::select! {
217                    _ = shutdown_signal_clone => {
218                        info!("read loop {} saw shutdown, calling exit", log_prefix_clone);
219                        process::exit(0);
220                    }
221                }
222            });
223        }
224        match self.do_loop() {
225            Ok(()) => info!("read loop {} done", self.log_prefix),
226            Err(Error::Protocol(ProtocolError::Eof)) => {
227                info!("read loop {} saw EOF; ending", self.log_prefix)
228            }
229            Err(e) => error!("read loop {} saw error {:?}; ending", self.log_prefix, e),
230        }
231        if let Some(trigger) = self.shutdown_trigger.as_ref() {
232            warn!("read loop {} terminated; triggering shutdown", self.log_prefix);
233            trigger.trigger();
234        }
235    }
236
237    fn do_loop(&mut self) -> Result<()> {
238        loop {
239            let raw_msg = self.client.read_raw()?;
240            let msg = msgs::from_vec(raw_msg.clone())?;
241            log_request!(msg);
242            match msg {
243                Message::ClientHsmFd(m) => {
244                    self.client.write(msgs::ClientHsmFdReply {}).unwrap();
245                    let new_client = self.client.new_client();
246                    info!("new client {} -> {}", self.log_prefix, new_client.id());
247                    let peer_id = m.peer_id.0;
248                    let client_id = ClientId { peer_id, dbid: m.dbid };
249                    let mut new_loop =
250                        SignerLoop::new_for_client(new_client, self.signer_port.clone(), client_id);
251                    spawn_blocking(move || new_loop.start());
252                }
253                Message::PreapproveInvoice(_) | Message::PreapproveKeysend(_) => {
254                    let now = SystemTime::now();
255                    let req_hash = Sha256Hash::hash(&raw_msg);
256                    if let Some(entry) = self.preapproval_cache.get(&req_hash) {
257                        let age = now.duration_since(entry.tstamp).expect("age");
258                        if age < PREAPPROVE_CACHE_TTL {
259                            debug!("{} found in preapproval cache", self.log_prefix);
260                            let reply = entry.reply_bytes.clone();
261                            log_reply!(reply, self);
262                            self.client.write_vec(reply)?;
263                            continue;
264                        }
265                    }
266                    let reply_bytes = self.do_proxy_msg(raw_msg, false)?;
267                    let reply = msgs::from_vec(reply_bytes.clone()).expect("parse reply failed");
268                    // Did we just witness an approval?
269                    match reply {
270                        Message::PreapproveKeysendReply(pkr) =>
271                            if pkr.result == true {
272                                debug!("{} adding keysend to preapproval cache", self.log_prefix);
273                                self.preapproval_cache.put(
274                                    req_hash,
275                                    PreapprovalCacheEntry { tstamp: now, reply_bytes },
276                                );
277                            },
278                        Message::PreapproveInvoiceReply(pir) =>
279                            if pir.result == true {
280                                debug!("{} adding invoice to preapproval cache", self.log_prefix);
281                                self.preapproval_cache.put(
282                                    req_hash,
283                                    PreapprovalCacheEntry { tstamp: now, reply_bytes },
284                                );
285                            },
286                        _ => {} // allow future out-of-band reply types
287                    }
288                }
289                #[cfg(feature = "developer")]
290                Message::HsmdDevPreinit2(_) => {
291                    if !self.is_root() {
292                        error!(
293                            "read loop {}: unexpected HsmdDevPreinit2 on non-root connection",
294                            self.log_prefix
295                        );
296                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
297                            msgs::HsmdInit::TYPE,
298                        )));
299                    }
300                    _ = self.do_proxy_msg(raw_msg, /*ONEWAY*/ true)?;
301                }
302                Message::HsmdInit(mut m) => {
303                    if !self.is_root() {
304                        error!(
305                            "read loop {}: unexpected HsmdInit on non-root connection",
306                            self.log_prefix
307                        );
308                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
309                            msgs::HsmdInit::TYPE,
310                        )));
311                    }
312                    let raw_reply = self.do_proxy_msg(raw_msg, false)?;
313                    // decode the reply and extract the protocol version
314                    let reply = msgs::from_vec(raw_reply)?;
315                    // we expect a HsmdInitReply
316                    let init_reply = match reply {
317                        Message::HsmdInitReplyV4(m) => m,
318                        x => {
319                            error!(
320                                "read loop {}: unexpected reply to HsmdInit {:?}",
321                                self.log_prefix, x
322                            );
323                            return Err(Error::Protocol(ProtocolError::UnexpectedType(0)));
324                        }
325                    };
326
327                    // We will only accept the version that was negotiated
328                    m.hsm_wire_max_version = init_reply.hsm_version;
329                    m.hsm_wire_min_version = init_reply.hsm_version;
330
331                    let mut init_message_cache = self.init_message_cache.lock().unwrap();
332                    if init_message_cache.init_message.is_some() {
333                        error!("read loop {}: unexpected duplicate HsmdInit", self.log_prefix);
334                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
335                            msgs::HsmdInit::TYPE,
336                        )));
337                    }
338                    init_message_cache.init_message = Some(m.as_vec());
339
340                    // The signer is not ready for requests until the HsmdInit
341                    // has been handled.
342                    self.signer_port.set_ready();
343                }
344                Message::HsmdInit2(m) => {
345                    if !self.is_root() {
346                        error!(
347                            "read loop {}: unexpected HsmdInit on non-root connection",
348                            self.log_prefix
349                        );
350                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
351                            msgs::HsmdInit2::TYPE,
352                        )));
353                    }
354                    self.do_proxy_msg(raw_msg, false)?;
355
356                    // TODO HsmdInit2 does not have version negotiation
357                    let mut init_message_cache = self.init_message_cache.lock().unwrap();
358                    if init_message_cache.init_message.is_some() {
359                        error!("read loop {}: unexpected duplicate HsmdInit", self.log_prefix);
360                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
361                            msgs::HsmdInit2::TYPE,
362                        )));
363                    }
364                    init_message_cache.init_message = Some(m.as_vec());
365                }
366                _ => {
367                    self.do_proxy_msg(raw_msg, false)?;
368                }
369            }
370        }
371    }
372
373    // Proxy the request to the signer, return the result to the node.
374    // Returns the last response for caching
375    fn do_proxy_msg(&mut self, raw_msg: Vec<u8>, is_oneway: bool) -> Result<Vec<u8>> {
376        let result = self.handle_message(raw_msg, is_oneway);
377        if let Err(ref err) = result {
378            log_error!(err, self);
379        }
380        let reply = result?;
381        if is_oneway {
382            debug!("oneway sent {}", self.log_prefix);
383        } else {
384            log_reply!(reply, self);
385            self.client.write_vec(reply.clone())?;
386            debug!("replied {}", self.log_prefix);
387        }
388        Ok(reply)
389    }
390
391    fn handle_message(&mut self, message: Vec<u8>, is_oneway: bool) -> Result<Vec<u8>> {
392        let result = (|| {
393            info!(
394                "read loop {}: request {}{}",
395                self.log_prefix,
396                msgs::message_name_from_vec(&message),
397                if is_oneway { " (oneway)" } else { "" }
398            );
399            let reply_rx = self.send_request(message.clone())?;
400            if is_oneway {
401                Ok(vec![])
402            } else {
403                // Wait for the signer reply
404                // Can fail if the adapter shut down
405                let reply = reply_rx.blocking_recv().map_err(|_| Error::Transport)?;
406                if reply.is_temporary_failure {
407                    // Retry with backoff
408                    info!("read loop {}: temporary error, retrying", self.log_prefix);
409                    return Err(Error::TransportTransient);
410                };
411                info!(
412                    "read loop {}: reply {}",
413                    self.log_prefix,
414                    msgs::message_name_from_vec(&reply.reply)
415                );
416                Ok(reply.reply)
417            }
418        })
419        .retry(backoff())
420        .when(|e| *e == Error::TransportTransient)
421        .call()
422        .map_err(|e| {
423            error!("read loop {}: signer retry failed: {:?}", self.log_prefix, e);
424            e
425        })?;
426        Ok(result)
427    }
428
429    fn send_request(&mut self, message: Vec<u8>) -> Result<oneshot::Receiver<ChannelReply>> {
430        self.signer_port.send_request_blocking(message, self.client_id.clone())
431    }
432}