vls_proxy/grpc/
adapter.rs

1use std::collections::BTreeMap;
2use std::pin::Pin;
3use std::result::Result as StdResult;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::{Stream, StreamExt};
8use log::*;
9use tokio::sync::mpsc::{Receiver, Sender};
10use tokio::sync::{mpsc, oneshot, Mutex};
11use tokio::task::JoinHandle;
12use tonic::transport::Server;
13use tonic::{Request, Response, Status, Streaming};
14
15use super::incoming::TcpIncoming;
16use crate::grpc::signer_loop::InitMessageCache;
17use std::sync::atomic::{AtomicU64, Ordering};
18use tonic::transport::Error;
19use triggered::{Listener, Trigger};
20use vlsd::grpc::hsmd::{
21    hsmd_server, HsmRequestContext, PingReply, PingRequest, SignerRequest, SignerResponse,
22};
23
24struct Requests {
25    requests: BTreeMap<u64, ChannelRequest>,
26    request_id: AtomicU64,
27}
28
29const DUMMY_REQUEST_ID: u64 = u64::MAX;
30
31/// Adapt the hsmd UNIX socket protocol to gRPC streaming
32#[derive(Clone)]
33pub struct ProtocolAdapter {
34    receiver: Arc<Mutex<Receiver<ChannelRequest>>>,
35    requests: Arc<Mutex<Requests>>,
36    #[allow(unused)]
37    shutdown_trigger: Trigger,
38    shutdown_signal: Listener,
39    init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
40}
41
42pub type SignerStream =
43    Pin<Box<dyn Stream<Item = StdResult<SignerRequest, Status>> + Send + 'static>>;
44
45impl ProtocolAdapter {
46    pub fn new(
47        receiver: Receiver<ChannelRequest>,
48        shutdown_trigger: Trigger,
49        shutdown_signal: Listener,
50        init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
51    ) -> Self {
52        ProtocolAdapter {
53            receiver: Arc::new(Mutex::new(receiver)),
54            requests: Arc::new(Mutex::new(Requests {
55                requests: BTreeMap::new(),
56                request_id: AtomicU64::new(0),
57            })),
58            shutdown_trigger,
59            shutdown_signal,
60            init_message_cache,
61        }
62    }
63    // Get requests from the parent process and feed them to gRPC.
64    // Will abort the stream reader task of the parent process goes away.
65    pub async fn writer_stream(&self, stream_reader_task: JoinHandle<()>) -> SignerStream {
66        let receiver = self.receiver.clone();
67        let requests = self.requests.clone();
68        let shutdown_signal = self.shutdown_signal.clone();
69
70        let cache = self.init_message_cache.lock().unwrap().clone();
71        let output = async_stream::try_stream! {
72            // send any init message
73            if let Some(message) = cache.init_message.as_ref() {
74                yield SignerRequest {
75                    request_id: DUMMY_REQUEST_ID,
76                    message: message.clone(),
77                    context: None,
78                };
79            }
80
81            // Retransmit any requests that were not processed during the signer's previous connection.
82            // We reacquire the lock on each iteration because we yield inside the loop.
83            let mut ind = 0;
84            loop {
85                let reqs = requests.lock().await;
86                if ind == 0 {
87                    info!("retransmitting {} outstanding requests", reqs.requests.len());
88                }
89                // get the first key/value where key >= ind
90                if let Some((&request_id, req)) = reqs.requests.range(ind..).next() {
91                    ind = request_id + 1;
92                    debug!("writer sending request {} to signer", request_id);
93                    yield Self::make_signer_request(request_id, req);
94                } else {
95                    break;
96                }
97            };
98
99            let mut receiver = receiver.lock().await;
100
101            // read requests from parent
102            loop {
103                tokio::select! {
104                    _ = shutdown_signal.clone() => {
105                        info!("writer got shutdown_signal");
106                        break;
107                    }
108                    resp_opt = receiver.recv() => {
109                        if let Some(req) = resp_opt {
110                            let mut reqs = requests.lock().await;
111                            let request_id = reqs.request_id.fetch_add(1, Ordering::AcqRel);
112                            debug!("writer sending request {} to signer", request_id);
113                            let signer_request = Self::make_signer_request(request_id, &req);
114                            reqs.requests.insert(request_id, req);
115                            yield signer_request;
116                        } else {
117                            // parent closed UNIX fd - we are shutting down
118                            info!("writer: parent closed - shutting down signer stream");
119                            break;
120                        }
121                    }
122                }
123            }
124            info!("stream writer loop finished");
125            stream_reader_task.abort();
126            // ignore join result
127            let _ = stream_reader_task.await;
128        };
129
130        Box::pin(output)
131    }
132
133    // Get signer responses from gRPC and feed them back to the parent process
134    pub fn start_stream_reader(&self, mut stream: Streaming<SignerResponse>) -> JoinHandle<()> {
135        let requests = self.requests.clone();
136        let shutdown_signal = self.shutdown_signal.clone();
137        let shutdown_trigger = self.shutdown_trigger.clone();
138        tokio::spawn(async move {
139            loop {
140                tokio::select! {
141                    _ = shutdown_signal.clone() => {
142                        info!("reader got shutdown_signal");
143                        break;
144                    }
145                    resp_opt = stream.next() => {
146                        match resp_opt {
147                            Some(Ok(resp)) => {
148                                debug!("got signer response {}", resp.request_id);
149                                // temporary failures are not fatal and are handled below
150                                if !resp.error.is_empty() && !resp.is_temporary_failure {
151                                    error!("signer error: {}; triggering shutdown", resp.error);
152                                    shutdown_trigger.trigger();
153                                    break;
154                                }
155
156                                if resp.is_temporary_failure {
157                                    warn!("signer temporary failure on {}: {}", resp.request_id, resp.error);
158                                }
159
160                                if resp.request_id == DUMMY_REQUEST_ID {
161                                    // TODO do something clever with the init reply message
162                                    continue;
163                                }
164
165                                let mut reqs = requests.lock().await;
166                                let channel_req_opt = reqs.requests.remove(&resp.request_id);
167                                if let Some(channel_req) = channel_req_opt {
168                                    let reply = ChannelReply { reply: resp.message, is_temporary_failure: resp.is_temporary_failure };
169                                    let send_res = channel_req.reply_tx.send(reply);
170                                    if send_res.is_err() {
171                                        error!("failed to send response back to internal channel; \
172                                               triggering shutdown");
173                                        shutdown_trigger.trigger();
174                                        break;
175                                    }
176                                } else {
177                                    error!("got response for unknown request ID {}; \
178                                            triggering shutdown", resp.request_id);
179                                    shutdown_trigger.trigger();
180                                    break;
181                                }
182                            }
183                            Some(Err(err)) => {
184                                // signer connection error
185                                error!("got signer gRPC error {}", err);
186                                break;
187                            }
188                            None => {
189                                // signer closed connection
190                                info!("response task closing - EOF");
191                                break;
192                            }
193                        }
194                    }
195                }
196            }
197            info!("stream reader loop finished");
198        })
199    }
200
201    fn make_signer_request(request_id: u64, req: &ChannelRequest) -> SignerRequest {
202        let context = req.client_id.as_ref().map(|c| HsmRequestContext {
203            peer_id: c.peer_id.to_vec(),
204            dbid: c.dbid,
205            capabilities: 0,
206        });
207        SignerRequest { request_id, message: req.message.clone(), context }
208    }
209}
210
211/// A request
212/// Responses are received on the oneshot sender inside this struct
213pub struct ChannelRequest {
214    pub message: Vec<u8>,
215    pub reply_tx: oneshot::Sender<ChannelReply>,
216    pub client_id: Option<ClientId>,
217}
218
219// mpsc reply
220pub struct ChannelReply {
221    pub reply: Vec<u8>,
222    pub is_temporary_failure: bool,
223}
224
225#[derive(Clone, Debug)]
226pub struct ClientId {
227    pub peer_id: [u8; 33],
228    pub dbid: u64,
229}
230
231/// Listens for a connection from the signer, and then sends requests to it
232#[derive(Clone)]
233pub struct HsmdService {
234    #[allow(unused)]
235    shutdown_trigger: Trigger,
236    adapter: ProtocolAdapter,
237    sender: Sender<ChannelRequest>,
238}
239
240impl HsmdService {
241    /// Create the service
242    pub fn new(
243        shutdown_trigger: Trigger,
244        shutdown_signal: Listener,
245        init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
246    ) -> Self {
247        let (sender, receiver) = mpsc::channel(1000);
248        let adapter = ProtocolAdapter::new(
249            receiver,
250            shutdown_trigger.clone(),
251            shutdown_signal.clone(),
252            init_message_cache,
253        );
254
255        HsmdService { shutdown_trigger, adapter, sender }
256    }
257
258    pub async fn start(
259        self,
260        incoming: TcpIncoming,
261        shutdown_signal: Listener,
262    ) -> Result<(), Error> {
263        let service = Server::builder()
264            .add_service(hsmd_server::HsmdServer::new(self))
265            .serve_with_incoming_shutdown(incoming, shutdown_signal);
266        service.await
267    }
268
269    /// Get the sender for the request channel
270    pub fn sender(&self) -> Sender<ChannelRequest> {
271        self.sender.clone()
272    }
273}
274
275#[async_trait]
276impl hsmd_server::Hsmd for HsmdService {
277    async fn ping(&self, request: Request<PingRequest>) -> StdResult<Response<PingReply>, Status> {
278        info!("got ping request");
279        let r = request.into_inner();
280        Ok(Response::new(PingReply { message: r.message }))
281    }
282
283    type SignerStreamStream = SignerStream;
284
285    async fn signer_stream(
286        &self,
287        request: Request<Streaming<SignerResponse>>,
288    ) -> StdResult<Response<Self::SignerStreamStream>, Status> {
289        let request_stream = request.into_inner();
290
291        let stream_reader_task = self.adapter.start_stream_reader(request_stream);
292
293        let response_stream = self.adapter.writer_stream(stream_reader_task).await;
294
295        Ok(Response::new(response_stream as Self::SignerStreamStream))
296    }
297}