vls_proxy/grpc/
adapter.rs

1use std::collections::BTreeMap;
2use std::pin::Pin;
3use std::result::Result as StdResult;
4use std::sync::Arc;
5
6use futures::{Stream, StreamExt};
7use log::*;
8use tokio::sync::mpsc::{Receiver, Sender};
9use tokio::sync::{mpsc, oneshot, Mutex};
10use tokio::task::JoinHandle;
11use tonic::{transport::Server, Request, Response, Status, Streaming};
12
13use super::hsmd::{
14    hsmd_server, HsmRequestContext, PingReply, PingRequest, SignerRequest, SignerResponse,
15};
16use super::incoming::TcpIncoming;
17use crate::grpc::signer_loop::InitMessageCache;
18use std::sync::atomic::{AtomicU64, Ordering};
19use tonic::transport::Error;
20use triggered::{Listener, Trigger};
21
22struct Requests {
23    requests: BTreeMap<u64, ChannelRequest>,
24    request_id: AtomicU64,
25}
26
27const DUMMY_REQUEST_ID: u64 = u64::MAX;
28
29/// Adapt the hsmd UNIX socket protocol to gRPC streaming
30#[derive(Clone)]
31pub struct ProtocolAdapter {
32    receiver: Arc<Mutex<Receiver<ChannelRequest>>>,
33    requests: Arc<Mutex<Requests>>,
34    #[allow(unused)]
35    shutdown_trigger: Trigger,
36    shutdown_signal: Listener,
37    init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
38}
39
40pub type SignerStream =
41    Pin<Box<dyn Stream<Item = StdResult<SignerRequest, Status>> + Send + 'static>>;
42
43impl ProtocolAdapter {
44    pub fn new(
45        receiver: Receiver<ChannelRequest>,
46        shutdown_trigger: Trigger,
47        shutdown_signal: Listener,
48        init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
49    ) -> Self {
50        ProtocolAdapter {
51            receiver: Arc::new(Mutex::new(receiver)),
52            requests: Arc::new(Mutex::new(Requests {
53                requests: BTreeMap::new(),
54                request_id: AtomicU64::new(0),
55            })),
56            shutdown_trigger,
57            shutdown_signal,
58            init_message_cache,
59        }
60    }
61    // Get requests from the parent process and feed them to gRPC.
62    // Will abort the stream reader task of the parent process goes away.
63    pub async fn writer_stream(&self, stream_reader_task: JoinHandle<()>) -> SignerStream {
64        let receiver = self.receiver.clone();
65        let requests = self.requests.clone();
66        let shutdown_signal = self.shutdown_signal.clone();
67
68        let cache = self.init_message_cache.lock().unwrap().clone();
69        let output = async_stream::try_stream! {
70            // send any init message
71            if let Some(message) = cache.init_message.as_ref() {
72                yield SignerRequest {
73                    request_id: DUMMY_REQUEST_ID,
74                    message: message.clone(),
75                    context: None,
76                };
77            }
78
79            // Retransmit any requests that were not processed during the signer's previous connection.
80            // We reacquire the lock on each iteration because we yield inside the loop.
81            let mut ind = 0;
82            loop {
83                let reqs = requests.lock().await;
84                if ind == 0 {
85                    info!("retransmitting {} outstanding requests", reqs.requests.len());
86                }
87                // get the first key/value where key >= ind
88                if let Some((&request_id, req)) = reqs.requests.range(ind..).next() {
89                    ind = request_id + 1;
90                    debug!("writer sending request {} to signer", request_id);
91                    yield Self::make_signer_request(request_id, req);
92                } else {
93                    break;
94                }
95            };
96
97            let mut receiver = receiver.lock().await;
98
99            // read requests from parent
100            loop {
101                tokio::select! {
102                    _ = shutdown_signal.clone() => {
103                        info!("writer got shutdown_signal");
104                        break;
105                    }
106                    resp_opt = receiver.recv() => {
107                        if let Some(req) = resp_opt {
108                            let mut reqs = requests.lock().await;
109                            let request_id = reqs.request_id.fetch_add(1, Ordering::AcqRel);
110                            debug!("writer sending request {} to signer", request_id);
111                            let signer_request = Self::make_signer_request(request_id, &req);
112                            reqs.requests.insert(request_id, req);
113                            yield signer_request;
114                        } else {
115                            // parent closed UNIX fd - we are shutting down
116                            info!("writer: parent closed - shutting down signer stream");
117                            break;
118                        }
119                    }
120                }
121            }
122            info!("stream writer loop finished");
123            stream_reader_task.abort();
124            // ignore join result
125            let _ = stream_reader_task.await;
126        };
127
128        Box::pin(output)
129    }
130
131    // Get signer responses from gRPC and feed them back to the parent process
132    pub fn start_stream_reader(&self, mut stream: Streaming<SignerResponse>) -> JoinHandle<()> {
133        let requests = self.requests.clone();
134        let shutdown_signal = self.shutdown_signal.clone();
135        let shutdown_trigger = self.shutdown_trigger.clone();
136        tokio::spawn(async move {
137            loop {
138                tokio::select! {
139                    _ = shutdown_signal.clone() => {
140                        info!("reader got shutdown_signal");
141                        break;
142                    }
143                    resp_opt = stream.next() => {
144                        match resp_opt {
145                            Some(Ok(resp)) => {
146                                debug!("got signer response {}", resp.request_id);
147                                // temporary failures are not fatal and are handled below
148                                if !resp.error.is_empty() && !resp.is_temporary_failure {
149                                    error!("signer error: {}; triggering shutdown", resp.error);
150                                    shutdown_trigger.trigger();
151                                    break;
152                                }
153
154                                if resp.is_temporary_failure {
155                                    warn!("signer temporary failure on {}: {}", resp.request_id, resp.error);
156                                }
157
158                                if resp.request_id == DUMMY_REQUEST_ID {
159                                    // TODO do something clever with the init reply message
160                                    continue;
161                                }
162
163                                let mut reqs = requests.lock().await;
164                                let channel_req_opt = reqs.requests.remove(&resp.request_id);
165                                if let Some(channel_req) = channel_req_opt {
166                                    let reply = ChannelReply { reply: resp.message, is_temporary_failure: resp.is_temporary_failure };
167                                    let send_res = channel_req.reply_tx.send(reply);
168                                    if send_res.is_err() {
169                                        error!("failed to send response back to internal channel; \
170                                               triggering shutdown");
171                                        shutdown_trigger.trigger();
172                                        break;
173                                    }
174                                } else {
175                                    error!("got response for unknown request ID {}; \
176                                            triggering shutdown", resp.request_id);
177                                    shutdown_trigger.trigger();
178                                    break;
179                                }
180                            }
181                            Some(Err(err)) => {
182                                // signer connection error
183                                error!("got signer gRPC error {}", err);
184                                break;
185                            }
186                            None => {
187                                // signer closed connection
188                                info!("response task closing - EOF");
189                                break;
190                            }
191                        }
192                    }
193                }
194            }
195            info!("stream reader loop finished");
196        })
197    }
198
199    fn make_signer_request(request_id: u64, req: &ChannelRequest) -> SignerRequest {
200        let context = req.client_id.as_ref().map(|c| HsmRequestContext {
201            peer_id: c.peer_id.to_vec(),
202            dbid: c.dbid,
203            capabilities: 0,
204        });
205        SignerRequest { request_id, message: req.message.clone(), context }
206    }
207}
208
209/// A request
210/// Responses are received on the oneshot sender inside this struct
211pub struct ChannelRequest {
212    pub message: Vec<u8>,
213    pub reply_tx: oneshot::Sender<ChannelReply>,
214    pub client_id: Option<ClientId>,
215}
216
217// mpsc reply
218pub struct ChannelReply {
219    pub reply: Vec<u8>,
220    pub is_temporary_failure: bool,
221}
222
223#[derive(Clone, Debug)]
224pub struct ClientId {
225    pub peer_id: [u8; 33],
226    pub dbid: u64,
227}
228
229/// Listens for a connection from the signer, and then sends requests to it
230#[derive(Clone)]
231pub struct HsmdService {
232    #[allow(unused)]
233    shutdown_trigger: Trigger,
234    adapter: ProtocolAdapter,
235    sender: Sender<ChannelRequest>,
236}
237
238impl HsmdService {
239    /// Create the service
240    pub fn new(
241        shutdown_trigger: Trigger,
242        shutdown_signal: Listener,
243        init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
244    ) -> Self {
245        let (sender, receiver) = mpsc::channel(1000);
246        let adapter = ProtocolAdapter::new(
247            receiver,
248            shutdown_trigger.clone(),
249            shutdown_signal.clone(),
250            init_message_cache,
251        );
252
253        HsmdService { shutdown_trigger, adapter, sender }
254    }
255
256    pub async fn start(
257        self,
258        incoming: TcpIncoming,
259        shutdown_signal: Listener,
260    ) -> Result<(), Error> {
261        let service = Server::builder()
262            .add_service(hsmd_server::HsmdServer::new(self))
263            .serve_with_incoming_shutdown(incoming, shutdown_signal);
264        service.await
265    }
266
267    /// Get the sender for the request channel
268    pub fn sender(&self) -> Sender<ChannelRequest> {
269        self.sender.clone()
270    }
271}
272
273#[tonic::async_trait]
274impl hsmd_server::Hsmd for HsmdService {
275    async fn ping(&self, request: Request<PingRequest>) -> StdResult<Response<PingReply>, Status> {
276        info!("got ping request");
277        let r = request.into_inner();
278        Ok(Response::new(PingReply { message: r.message }))
279    }
280
281    type SignerStreamStream = SignerStream;
282
283    async fn signer_stream(
284        &self,
285        request: Request<Streaming<SignerResponse>>,
286    ) -> StdResult<Response<Self::SignerStreamStream>, Status> {
287        let request_stream = request.into_inner();
288
289        let stream_reader_task = self.adapter.start_stream_reader(request_stream);
290
291        let response_stream = self.adapter.writer_stream(stream_reader_task).await;
292
293        Ok(Response::new(response_stream as Self::SignerStreamStream))
294    }
295}