vls_proxy/grpc/
signer_loop.rs

1use async_trait::async_trait;
2use backoff::Error as BackoffError;
3use log::*;
4use lru::LruCache;
5use std::num::NonZeroUsize;
6use std::process;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use std::time::Duration;
10use std::time::SystemTime;
11use tokio::sync::{mpsc, oneshot};
12use tokio::task::spawn_blocking;
13use triggered::{Listener, Trigger};
14
15use lightning_signer::bitcoin::hashes::sha256::Hash as Sha256Hash;
16use lightning_signer::bitcoin::hashes::Hash;
17
18use super::adapter::{ChannelReply, ChannelRequest, ClientId};
19use vls_protocol::{
20    msgs, msgs::DeBolt as _, msgs::Message, msgs::SerBolt as _, Error as ProtocolError,
21};
22use vls_protocol_client::{ClientResult as Result, Error, SignerPort};
23use vls_protocol_signer::vls_protocol;
24
25use crate::client::Client;
26use crate::*;
27
28const PREAPPROVE_CACHE_TTL: Duration = Duration::from_secs(60);
29const PREAPPROVE_CACHE_SIZE: usize = 6;
30
31struct PreapprovalCacheEntry {
32    tstamp: SystemTime,
33    reply_bytes: Vec<u8>,
34}
35
36pub struct GrpcSignerPort {
37    sender: mpsc::Sender<ChannelRequest>,
38    is_ready: Arc<AtomicBool>,
39}
40
41// create a Backoff
42fn backoff() -> backoff::ExponentialBackoff {
43    backoff::ExponentialBackoffBuilder::default()
44        .with_initial_interval(Duration::from_secs(1))
45        .with_max_interval(Duration::from_secs(10))
46        .with_max_elapsed_time(Some(Duration::from_secs(300)))
47        .build()
48}
49
50#[async_trait]
51impl SignerPort for GrpcSignerPort {
52    async fn handle_message(&self, message: Vec<u8>) -> Result<Vec<u8>> {
53        let result = backoff::future::retry(backoff(), || async {
54            let reply_rx =
55                self.send_request(message.clone()).await.map_err(|e| BackoffError::permanent(e))?;
56            // Wait for the signer reply
57            // Can fail if the adapter shut down
58            let reply = reply_rx.await.map_err(|_| BackoffError::permanent(Error::Transport))?;
59            if reply.is_temporary_failure {
60                // Retry with backoff
61                info!("temporary error, retrying");
62                return Err(BackoffError::transient(Error::Transport));
63            }
64
65            return Ok(reply.reply);
66        })
67        .await
68        .map_err(|e| {
69            error!("signer retry failed: {:?}", e);
70            e
71        })?;
72        Ok(result)
73    }
74
75    fn is_ready(&self) -> bool {
76        self.is_ready.load(Ordering::Relaxed)
77    }
78}
79
80impl GrpcSignerPort {
81    pub fn new(sender: mpsc::Sender<ChannelRequest>) -> Self {
82        GrpcSignerPort { sender, is_ready: Arc::new(AtomicBool::new(false)) }
83    }
84
85    pub(crate) fn set_ready(&self) {
86        info!("setting is_ready true");
87        self.is_ready.store(true, Ordering::Relaxed);
88    }
89
90    async fn send_request(&self, message: Vec<u8>) -> Result<oneshot::Receiver<ChannelReply>> {
91        let (reply_rx, request) = Self::prepare_request(message, None);
92
93        // Send a request to the gRPC handler to send to signer
94        // This can fail if gRPC adapter shut down
95        self.sender.send(request).await.map_err(|_| ProtocolError::Eof)?;
96
97        Ok(reply_rx)
98    }
99
100    // Send a blocking request to the signer with an optional client_id
101    // for use in [`SignerLoop`]
102    fn send_request_blocking(
103        &self,
104        message: Vec<u8>,
105        client_id: Option<ClientId>,
106    ) -> Result<oneshot::Receiver<ChannelReply>> {
107        let (reply_rx, request) = Self::prepare_request(message, client_id);
108
109        // Send a request to the gRPC handler to send to signer
110        // This can fail if gRPC adapter shut down
111        self.sender.blocking_send(request).map_err(|_| ProtocolError::Eof)?;
112
113        Ok(reply_rx)
114    }
115
116    fn prepare_request(
117        message: Vec<u8>,
118        client_id: Option<ClientId>,
119    ) -> (oneshot::Receiver<ChannelReply>, ChannelRequest) {
120        // Create a one-shot channel to receive the reply
121        let (reply_tx, reply_rx) = oneshot::channel();
122
123        let request = ChannelRequest { client_id, message, reply_tx };
124        (reply_rx, request)
125    }
126}
127
128/// A cache of the init message from the node, in case the signer reconnects
129#[derive(Clone)]
130pub struct InitMessageCache {
131    /// The HsmdInit or HsmdInit2 message from node
132    pub init_message: Option<Vec<u8>>,
133}
134
135impl InitMessageCache {
136    /// Create a new cache
137    pub fn new() -> Self {
138        Self { init_message: None }
139    }
140}
141
142/// Implement the hsmd UNIX fd protocol.
143/// This doesn't actually perform the signing - the hsmd packets are transported via gRPC to the
144/// real signer.
145pub struct SignerLoop<C: 'static + Client> {
146    client: C,
147    log_prefix: String,
148    signer_port: Arc<GrpcSignerPort>,
149    client_id: Option<ClientId>,
150    shutdown_trigger: Option<Trigger>,
151    shutdown_signal: Option<Listener>,
152    preapproval_cache: LruCache<Sha256Hash, PreapprovalCacheEntry>,
153    init_message_cache: Arc<Mutex<InitMessageCache>>,
154}
155
156impl<C: 'static + Client> SignerLoop<C> {
157    /// Create a loop for the root (lightningd) connection, but doesn't start it yet
158    pub fn new(
159        client: C,
160        signer_port: Arc<GrpcSignerPort>,
161        shutdown_trigger: Trigger,
162        shutdown_signal: Listener,
163        init_message_cache: Arc<Mutex<InitMessageCache>>,
164    ) -> Self {
165        let log_prefix = format!("{}/{}/{}", std::process::id(), client.id(), 0);
166        let preapproval_cache = LruCache::new(NonZeroUsize::new(PREAPPROVE_CACHE_SIZE).unwrap());
167        Self {
168            client,
169            log_prefix,
170            signer_port,
171            client_id: None,
172            shutdown_trigger: Some(shutdown_trigger),
173            shutdown_signal: Some(shutdown_signal),
174            preapproval_cache,
175            init_message_cache,
176        }
177    }
178
179    // Create a loop for a non-root connection
180    fn new_for_client(client: C, signer_port: Arc<GrpcSignerPort>, client_id: ClientId) -> Self {
181        let log_prefix = format!("{}/{}/{}", std::process::id(), client.id(), client_id.dbid);
182        let preapproval_cache = LruCache::new(NonZeroUsize::new(PREAPPROVE_CACHE_SIZE).unwrap());
183        Self {
184            client,
185            log_prefix,
186            signer_port,
187            client_id: Some(client_id),
188            shutdown_trigger: None,
189            shutdown_signal: None,
190            preapproval_cache,
191            init_message_cache: Arc::new(Mutex::new(InitMessageCache::new())),
192        }
193    }
194
195    fn is_root(&self) -> bool {
196        self.client_id.is_none()
197    }
198
199    /// The init message cache
200    pub fn init_message_cache(&self) -> Arc<Mutex<InitMessageCache>> {
201        self.init_message_cache.clone()
202    }
203
204    /// Start the read loop
205    pub fn start(&mut self) {
206        info!("read loop {}: start", self.log_prefix);
207        if let Some(shutdown_signal) = self.shutdown_signal.as_ref() {
208            // TODO exit more cleanly
209            // Right now there's no clean way to stop the UNIX fd reader loop so just be
210            // aggressive here and exit when it's time to shutdown
211            let shutdown_signal_clone = shutdown_signal.clone();
212            let log_prefix_clone = self.log_prefix.clone();
213            tokio::spawn(async move {
214                info!("read loop {} waiting for shutdown", log_prefix_clone);
215                tokio::select! {
216                    _ = shutdown_signal_clone => {
217                        info!("read loop {} saw shutdown, calling exit", log_prefix_clone);
218                        process::exit(0);
219                    }
220                }
221            });
222        }
223        match self.do_loop() {
224            Ok(()) => info!("read loop {} done", self.log_prefix),
225            Err(Error::Protocol(ProtocolError::Eof)) =>
226                info!("read loop {} saw EOF; ending", self.log_prefix),
227            Err(e) => error!("read loop {} saw error {:?}; ending", self.log_prefix, e),
228        }
229        if let Some(trigger) = self.shutdown_trigger.as_ref() {
230            warn!("read loop {} terminated; triggering shutdown", self.log_prefix);
231            trigger.trigger();
232        }
233    }
234
235    fn do_loop(&mut self) -> Result<()> {
236        loop {
237            let raw_msg = self.client.read_raw()?;
238            let msg = msgs::from_vec(raw_msg.clone())?;
239            log_request!(msg);
240            match msg {
241                Message::ClientHsmFd(m) => {
242                    self.client.write(msgs::ClientHsmFdReply {}).unwrap();
243                    let new_client = self.client.new_client();
244                    info!("new client {} -> {}", self.log_prefix, new_client.id());
245                    let peer_id = m.peer_id.0;
246                    let client_id = ClientId { peer_id, dbid: m.dbid };
247                    let mut new_loop =
248                        SignerLoop::new_for_client(new_client, self.signer_port.clone(), client_id);
249                    spawn_blocking(move || new_loop.start());
250                }
251                Message::PreapproveInvoice(_) | Message::PreapproveKeysend(_) => {
252                    let now = SystemTime::now();
253                    let req_hash = Sha256Hash::hash(&raw_msg);
254                    if let Some(entry) = self.preapproval_cache.get(&req_hash) {
255                        let age = now.duration_since(entry.tstamp).expect("age");
256                        if age < PREAPPROVE_CACHE_TTL {
257                            debug!("{} found in preapproval cache", self.log_prefix);
258                            let reply = entry.reply_bytes.clone();
259                            log_reply!(reply, self);
260                            self.client.write_vec(reply)?;
261                            continue;
262                        }
263                    }
264                    let reply_bytes = self.do_proxy_msg(raw_msg, false)?;
265                    let reply = msgs::from_vec(reply_bytes.clone()).expect("parse reply failed");
266                    // Did we just witness an approval?
267                    match reply {
268                        Message::PreapproveKeysendReply(pkr) =>
269                            if pkr.result == true {
270                                debug!("{} adding keysend to preapproval cache", self.log_prefix);
271                                self.preapproval_cache.put(
272                                    req_hash,
273                                    PreapprovalCacheEntry { tstamp: now, reply_bytes },
274                                );
275                            },
276                        Message::PreapproveInvoiceReply(pir) =>
277                            if pir.result == true {
278                                debug!("{} adding invoice to preapproval cache", self.log_prefix);
279                                self.preapproval_cache.put(
280                                    req_hash,
281                                    PreapprovalCacheEntry { tstamp: now, reply_bytes },
282                                );
283                            },
284                        _ => {} // allow future out-of-band reply types
285                    }
286                }
287                #[cfg(feature = "developer")]
288                Message::HsmdDevPreinit2(_) => {
289                    if !self.is_root() {
290                        error!(
291                            "read loop {}: unexpected HsmdDevPreinit2 on non-root connection",
292                            self.log_prefix
293                        );
294                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
295                            msgs::HsmdInit::TYPE,
296                        )));
297                    }
298                    _ = self.do_proxy_msg(raw_msg, /*ONEWAY*/ true)?;
299                }
300                Message::HsmdInit(mut m) => {
301                    if !self.is_root() {
302                        error!(
303                            "read loop {}: unexpected HsmdInit on non-root connection",
304                            self.log_prefix
305                        );
306                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
307                            msgs::HsmdInit::TYPE,
308                        )));
309                    }
310                    let raw_reply = self.do_proxy_msg(raw_msg, false)?;
311                    // decode the reply and extract the protocol version
312                    let reply = msgs::from_vec(raw_reply)?;
313                    // we expect a HsmdInitReply
314                    let init_reply = match reply {
315                        Message::HsmdInitReplyV4(m) => m,
316                        x => {
317                            error!(
318                                "read loop {}: unexpected reply to HsmdInit {:?}",
319                                self.log_prefix, x
320                            );
321                            return Err(Error::Protocol(ProtocolError::UnexpectedType(0)));
322                        }
323                    };
324
325                    // We will only accept the version that was negotiated
326                    m.hsm_wire_max_version = init_reply.hsm_version;
327                    m.hsm_wire_min_version = init_reply.hsm_version;
328
329                    let mut init_message_cache = self.init_message_cache.lock().unwrap();
330                    if init_message_cache.init_message.is_some() {
331                        error!("read loop {}: unexpected duplicate HsmdInit", self.log_prefix);
332                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
333                            msgs::HsmdInit::TYPE,
334                        )));
335                    }
336                    init_message_cache.init_message = Some(m.as_vec());
337
338                    // The signer is not ready for requests until the HsmdInit
339                    // has been handled.
340                    self.signer_port.set_ready();
341                }
342                Message::HsmdInit2(m) => {
343                    if !self.is_root() {
344                        error!(
345                            "read loop {}: unexpected HsmdInit on non-root connection",
346                            self.log_prefix
347                        );
348                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
349                            msgs::HsmdInit2::TYPE,
350                        )));
351                    }
352                    self.do_proxy_msg(raw_msg, false)?;
353
354                    // TODO HsmdInit2 does not have version negotiation
355                    let mut init_message_cache = self.init_message_cache.lock().unwrap();
356                    if init_message_cache.init_message.is_some() {
357                        error!("read loop {}: unexpected duplicate HsmdInit", self.log_prefix);
358                        return Err(Error::Protocol(ProtocolError::UnexpectedType(
359                            msgs::HsmdInit2::TYPE,
360                        )));
361                    }
362                    init_message_cache.init_message = Some(m.as_vec());
363                }
364                _ => {
365                    self.do_proxy_msg(raw_msg, false)?;
366                }
367            }
368        }
369    }
370
371    // Proxy the request to the signer, return the result to the node.
372    // Returns the last response for caching
373    fn do_proxy_msg(&mut self, raw_msg: Vec<u8>, is_oneway: bool) -> Result<Vec<u8>> {
374        let result = self.handle_message(raw_msg, is_oneway);
375        if let Err(ref err) = result {
376            log_error!(err, self);
377        }
378        let reply = result?;
379        if is_oneway {
380            debug!("oneway sent {}", self.log_prefix);
381        } else {
382            log_reply!(reply, self);
383            self.client.write_vec(reply.clone())?;
384            debug!("replied {}", self.log_prefix);
385        }
386        Ok(reply)
387    }
388
389    fn handle_message(&mut self, message: Vec<u8>, is_oneway: bool) -> Result<Vec<u8>> {
390        let result = backoff::retry(backoff(), || {
391            info!(
392                "read loop {}: request {}{}",
393                self.log_prefix,
394                msgs::message_name_from_vec(&message),
395                if is_oneway { " (oneway)" } else { "" }
396            );
397            let reply_rx =
398                self.send_request(message.clone()).map_err(|e| BackoffError::permanent(e))?;
399            if is_oneway {
400                Ok(vec![])
401            } else {
402                // Wait for the signer reply
403                // Can fail if the adapter shut down
404                let reply = reply_rx
405                    .blocking_recv()
406                    .map_err(|_| BackoffError::permanent(Error::Transport))?;
407                if reply.is_temporary_failure {
408                    // Retry with backoff
409                    info!("read loop {}: temporary error, retrying", self.log_prefix);
410                    return Err(BackoffError::transient(Error::Transport));
411                };
412                info!(
413                    "read loop {}: reply {}",
414                    self.log_prefix,
415                    msgs::message_name_from_vec(&reply.reply)
416                );
417                Ok(reply.reply)
418            }
419        })
420        .map_err(|e| error_from_backoff(e))
421        .map_err(|e| {
422            error!("read loop {}: signer retry failed: {:?}", self.log_prefix, e);
423            e
424        })?;
425        Ok(result)
426    }
427
428    fn send_request(&mut self, message: Vec<u8>) -> Result<oneshot::Receiver<ChannelReply>> {
429        self.signer_port.send_request_blocking(message, self.client_id.clone())
430    }
431}
432
433fn error_from_backoff(e: BackoffError<Error>) -> Error {
434    match e {
435        BackoffError::Transient { err, .. } => err,
436        BackoffError::Permanent(err) => err,
437    }
438}