1use std::collections::BTreeMap;
2use std::pin::Pin;
3use std::result::Result as StdResult;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use futures::{Stream, StreamExt};
8use log::*;
9use tokio::sync::mpsc::{Receiver, Sender};
10use tokio::sync::{mpsc, oneshot, Mutex};
11use tokio::task::JoinHandle;
12use tonic::transport::Server;
13use tonic::{Request, Response, Status, Streaming};
14
15use super::incoming::TcpIncoming;
16use crate::grpc::signer_loop::InitMessageCache;
17use std::sync::atomic::{AtomicU64, Ordering};
18use tonic::transport::Error;
19use triggered::{Listener, Trigger};
20use vlsd::grpc::hsmd::{
21 hsmd_server, HsmRequestContext, PingReply, PingRequest, SignerRequest, SignerResponse,
22};
23
24struct Requests {
25 requests: BTreeMap<u64, ChannelRequest>,
26 request_id: AtomicU64,
27}
28
29const DUMMY_REQUEST_ID: u64 = u64::MAX;
30
31#[derive(Clone)]
33pub struct ProtocolAdapter {
34 receiver: Arc<Mutex<Receiver<ChannelRequest>>>,
35 requests: Arc<Mutex<Requests>>,
36 #[allow(unused)]
37 shutdown_trigger: Trigger,
38 shutdown_signal: Listener,
39 init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
40}
41
42pub type SignerStream =
43 Pin<Box<dyn Stream<Item = StdResult<SignerRequest, Status>> + Send + 'static>>;
44
45impl ProtocolAdapter {
46 pub fn new(
47 receiver: Receiver<ChannelRequest>,
48 shutdown_trigger: Trigger,
49 shutdown_signal: Listener,
50 init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
51 ) -> Self {
52 ProtocolAdapter {
53 receiver: Arc::new(Mutex::new(receiver)),
54 requests: Arc::new(Mutex::new(Requests {
55 requests: BTreeMap::new(),
56 request_id: AtomicU64::new(0),
57 })),
58 shutdown_trigger,
59 shutdown_signal,
60 init_message_cache,
61 }
62 }
63 pub async fn writer_stream(&self, stream_reader_task: JoinHandle<()>) -> SignerStream {
66 let receiver = self.receiver.clone();
67 let requests = self.requests.clone();
68 let shutdown_signal = self.shutdown_signal.clone();
69
70 let cache = self.init_message_cache.lock().unwrap().clone();
71 let output = async_stream::try_stream! {
72 if let Some(message) = cache.init_message.as_ref() {
74 yield SignerRequest {
75 request_id: DUMMY_REQUEST_ID,
76 message: message.clone(),
77 context: None,
78 };
79 }
80
81 let mut ind = 0;
84 loop {
85 let reqs = requests.lock().await;
86 if ind == 0 {
87 info!("retransmitting {} outstanding requests", reqs.requests.len());
88 }
89 if let Some((&request_id, req)) = reqs.requests.range(ind..).next() {
91 ind = request_id + 1;
92 debug!("writer sending request {} to signer", request_id);
93 yield Self::make_signer_request(request_id, req);
94 } else {
95 break;
96 }
97 };
98
99 let mut receiver = receiver.lock().await;
100
101 loop {
103 tokio::select! {
104 _ = shutdown_signal.clone() => {
105 info!("writer got shutdown_signal");
106 break;
107 }
108 resp_opt = receiver.recv() => {
109 if let Some(req) = resp_opt {
110 let mut reqs = requests.lock().await;
111 let request_id = reqs.request_id.fetch_add(1, Ordering::AcqRel);
112 debug!("writer sending request {} to signer", request_id);
113 let signer_request = Self::make_signer_request(request_id, &req);
114 reqs.requests.insert(request_id, req);
115 yield signer_request;
116 } else {
117 info!("writer: parent closed - shutting down signer stream");
119 break;
120 }
121 }
122 }
123 }
124 info!("stream writer loop finished");
125 stream_reader_task.abort();
126 let _ = stream_reader_task.await;
128 };
129
130 Box::pin(output)
131 }
132
133 pub fn start_stream_reader(&self, mut stream: Streaming<SignerResponse>) -> JoinHandle<()> {
135 let requests = self.requests.clone();
136 let shutdown_signal = self.shutdown_signal.clone();
137 let shutdown_trigger = self.shutdown_trigger.clone();
138 tokio::spawn(async move {
139 loop {
140 tokio::select! {
141 _ = shutdown_signal.clone() => {
142 info!("reader got shutdown_signal");
143 break;
144 }
145 resp_opt = stream.next() => {
146 match resp_opt {
147 Some(Ok(resp)) => {
148 debug!("got signer response {}", resp.request_id);
149 if !resp.error.is_empty() && !resp.is_temporary_failure {
151 error!("signer error: {}; triggering shutdown", resp.error);
152 shutdown_trigger.trigger();
153 break;
154 }
155
156 if resp.is_temporary_failure {
157 warn!("signer temporary failure on {}: {}", resp.request_id, resp.error);
158 }
159
160 if resp.request_id == DUMMY_REQUEST_ID {
161 continue;
163 }
164
165 let mut reqs = requests.lock().await;
166 let channel_req_opt = reqs.requests.remove(&resp.request_id);
167 if let Some(channel_req) = channel_req_opt {
168 let reply = ChannelReply { reply: resp.message, is_temporary_failure: resp.is_temporary_failure };
169 let send_res = channel_req.reply_tx.send(reply);
170 if send_res.is_err() {
171 error!("failed to send response back to internal channel; \
172 triggering shutdown");
173 shutdown_trigger.trigger();
174 break;
175 }
176 } else {
177 error!("got response for unknown request ID {}; \
178 triggering shutdown", resp.request_id);
179 shutdown_trigger.trigger();
180 break;
181 }
182 }
183 Some(Err(err)) => {
184 error!("got signer gRPC error {}", err);
186 break;
187 }
188 None => {
189 info!("response task closing - EOF");
191 break;
192 }
193 }
194 }
195 }
196 }
197 info!("stream reader loop finished");
198 })
199 }
200
201 fn make_signer_request(request_id: u64, req: &ChannelRequest) -> SignerRequest {
202 let context = req.client_id.as_ref().map(|c| HsmRequestContext {
203 peer_id: c.peer_id.to_vec(),
204 dbid: c.dbid,
205 capabilities: 0,
206 });
207 SignerRequest { request_id, message: req.message.clone(), context }
208 }
209}
210
211pub struct ChannelRequest {
214 pub message: Vec<u8>,
215 pub reply_tx: oneshot::Sender<ChannelReply>,
216 pub client_id: Option<ClientId>,
217}
218
219pub struct ChannelReply {
221 pub reply: Vec<u8>,
222 pub is_temporary_failure: bool,
223}
224
225#[derive(Clone, Debug)]
226pub struct ClientId {
227 pub peer_id: [u8; 33],
228 pub dbid: u64,
229}
230
231#[derive(Clone)]
233pub struct HsmdService {
234 #[allow(unused)]
235 shutdown_trigger: Trigger,
236 adapter: ProtocolAdapter,
237 sender: Sender<ChannelRequest>,
238}
239
240impl HsmdService {
241 pub fn new(
243 shutdown_trigger: Trigger,
244 shutdown_signal: Listener,
245 init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
246 ) -> Self {
247 let (sender, receiver) = mpsc::channel(1000);
248 let adapter = ProtocolAdapter::new(
249 receiver,
250 shutdown_trigger.clone(),
251 shutdown_signal.clone(),
252 init_message_cache,
253 );
254
255 HsmdService { shutdown_trigger, adapter, sender }
256 }
257
258 pub async fn start(
259 self,
260 incoming: TcpIncoming,
261 shutdown_signal: Listener,
262 ) -> Result<(), Error> {
263 let service = Server::builder()
264 .add_service(hsmd_server::HsmdServer::new(self))
265 .serve_with_incoming_shutdown(incoming, shutdown_signal);
266 service.await
267 }
268
269 pub fn sender(&self) -> Sender<ChannelRequest> {
271 self.sender.clone()
272 }
273}
274
275#[async_trait]
276impl hsmd_server::Hsmd for HsmdService {
277 async fn ping(&self, request: Request<PingRequest>) -> StdResult<Response<PingReply>, Status> {
278 info!("got ping request");
279 let r = request.into_inner();
280 Ok(Response::new(PingReply { message: r.message }))
281 }
282
283 type SignerStreamStream = SignerStream;
284
285 async fn signer_stream(
286 &self,
287 request: Request<Streaming<SignerResponse>>,
288 ) -> StdResult<Response<Self::SignerStreamStream>, Status> {
289 let request_stream = request.into_inner();
290
291 let stream_reader_task = self.adapter.start_stream_reader(request_stream);
292
293 let response_stream = self.adapter.writer_stream(stream_reader_task).await;
294
295 Ok(Response::new(response_stream as Self::SignerStreamStream))
296 }
297}