1use std::collections::BTreeMap;
2use std::pin::Pin;
3use std::result::Result as StdResult;
4use std::sync::Arc;
5
6use futures::{Stream, StreamExt};
7use log::*;
8use tokio::sync::mpsc::{Receiver, Sender};
9use tokio::sync::{mpsc, oneshot, Mutex};
10use tokio::task::JoinHandle;
11use tonic::{transport::Server, Request, Response, Status, Streaming};
12
13use super::hsmd::{
14 hsmd_server, HsmRequestContext, PingReply, PingRequest, SignerRequest, SignerResponse,
15};
16use super::incoming::TcpIncoming;
17use crate::grpc::signer_loop::InitMessageCache;
18use std::sync::atomic::{AtomicU64, Ordering};
19use tonic::transport::Error;
20use triggered::{Listener, Trigger};
21
22struct Requests {
23 requests: BTreeMap<u64, ChannelRequest>,
24 request_id: AtomicU64,
25}
26
27const DUMMY_REQUEST_ID: u64 = u64::MAX;
28
29#[derive(Clone)]
31pub struct ProtocolAdapter {
32 receiver: Arc<Mutex<Receiver<ChannelRequest>>>,
33 requests: Arc<Mutex<Requests>>,
34 #[allow(unused)]
35 shutdown_trigger: Trigger,
36 shutdown_signal: Listener,
37 init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
38}
39
40pub type SignerStream =
41 Pin<Box<dyn Stream<Item = StdResult<SignerRequest, Status>> + Send + 'static>>;
42
43impl ProtocolAdapter {
44 pub fn new(
45 receiver: Receiver<ChannelRequest>,
46 shutdown_trigger: Trigger,
47 shutdown_signal: Listener,
48 init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
49 ) -> Self {
50 ProtocolAdapter {
51 receiver: Arc::new(Mutex::new(receiver)),
52 requests: Arc::new(Mutex::new(Requests {
53 requests: BTreeMap::new(),
54 request_id: AtomicU64::new(0),
55 })),
56 shutdown_trigger,
57 shutdown_signal,
58 init_message_cache,
59 }
60 }
61 pub async fn writer_stream(&self, stream_reader_task: JoinHandle<()>) -> SignerStream {
64 let receiver = self.receiver.clone();
65 let requests = self.requests.clone();
66 let shutdown_signal = self.shutdown_signal.clone();
67
68 let cache = self.init_message_cache.lock().unwrap().clone();
69 let output = async_stream::try_stream! {
70 if let Some(message) = cache.init_message.as_ref() {
72 yield SignerRequest {
73 request_id: DUMMY_REQUEST_ID,
74 message: message.clone(),
75 context: None,
76 };
77 }
78
79 let mut ind = 0;
82 loop {
83 let reqs = requests.lock().await;
84 if ind == 0 {
85 info!("retransmitting {} outstanding requests", reqs.requests.len());
86 }
87 if let Some((&request_id, req)) = reqs.requests.range(ind..).next() {
89 ind = request_id + 1;
90 debug!("writer sending request {} to signer", request_id);
91 yield Self::make_signer_request(request_id, req);
92 } else {
93 break;
94 }
95 };
96
97 let mut receiver = receiver.lock().await;
98
99 loop {
101 tokio::select! {
102 _ = shutdown_signal.clone() => {
103 info!("writer got shutdown_signal");
104 break;
105 }
106 resp_opt = receiver.recv() => {
107 if let Some(req) = resp_opt {
108 let mut reqs = requests.lock().await;
109 let request_id = reqs.request_id.fetch_add(1, Ordering::AcqRel);
110 debug!("writer sending request {} to signer", request_id);
111 let signer_request = Self::make_signer_request(request_id, &req);
112 reqs.requests.insert(request_id, req);
113 yield signer_request;
114 } else {
115 info!("writer: parent closed - shutting down signer stream");
117 break;
118 }
119 }
120 }
121 }
122 info!("stream writer loop finished");
123 stream_reader_task.abort();
124 let _ = stream_reader_task.await;
126 };
127
128 Box::pin(output)
129 }
130
131 pub fn start_stream_reader(&self, mut stream: Streaming<SignerResponse>) -> JoinHandle<()> {
133 let requests = self.requests.clone();
134 let shutdown_signal = self.shutdown_signal.clone();
135 let shutdown_trigger = self.shutdown_trigger.clone();
136 tokio::spawn(async move {
137 loop {
138 tokio::select! {
139 _ = shutdown_signal.clone() => {
140 info!("reader got shutdown_signal");
141 break;
142 }
143 resp_opt = stream.next() => {
144 match resp_opt {
145 Some(Ok(resp)) => {
146 debug!("got signer response {}", resp.request_id);
147 if !resp.error.is_empty() && !resp.is_temporary_failure {
149 error!("signer error: {}; triggering shutdown", resp.error);
150 shutdown_trigger.trigger();
151 break;
152 }
153
154 if resp.is_temporary_failure {
155 warn!("signer temporary failure on {}: {}", resp.request_id, resp.error);
156 }
157
158 if resp.request_id == DUMMY_REQUEST_ID {
159 continue;
161 }
162
163 let mut reqs = requests.lock().await;
164 let channel_req_opt = reqs.requests.remove(&resp.request_id);
165 if let Some(channel_req) = channel_req_opt {
166 let reply = ChannelReply { reply: resp.message, is_temporary_failure: resp.is_temporary_failure };
167 let send_res = channel_req.reply_tx.send(reply);
168 if send_res.is_err() {
169 error!("failed to send response back to internal channel; \
170 triggering shutdown");
171 shutdown_trigger.trigger();
172 break;
173 }
174 } else {
175 error!("got response for unknown request ID {}; \
176 triggering shutdown", resp.request_id);
177 shutdown_trigger.trigger();
178 break;
179 }
180 }
181 Some(Err(err)) => {
182 error!("got signer gRPC error {}", err);
184 break;
185 }
186 None => {
187 info!("response task closing - EOF");
189 break;
190 }
191 }
192 }
193 }
194 }
195 info!("stream reader loop finished");
196 })
197 }
198
199 fn make_signer_request(request_id: u64, req: &ChannelRequest) -> SignerRequest {
200 let context = req.client_id.as_ref().map(|c| HsmRequestContext {
201 peer_id: c.peer_id.to_vec(),
202 dbid: c.dbid,
203 capabilities: 0,
204 });
205 SignerRequest { request_id, message: req.message.clone(), context }
206 }
207}
208
209pub struct ChannelRequest {
212 pub message: Vec<u8>,
213 pub reply_tx: oneshot::Sender<ChannelReply>,
214 pub client_id: Option<ClientId>,
215}
216
217pub struct ChannelReply {
219 pub reply: Vec<u8>,
220 pub is_temporary_failure: bool,
221}
222
223#[derive(Clone, Debug)]
224pub struct ClientId {
225 pub peer_id: [u8; 33],
226 pub dbid: u64,
227}
228
229#[derive(Clone)]
231pub struct HsmdService {
232 #[allow(unused)]
233 shutdown_trigger: Trigger,
234 adapter: ProtocolAdapter,
235 sender: Sender<ChannelRequest>,
236}
237
238impl HsmdService {
239 pub fn new(
241 shutdown_trigger: Trigger,
242 shutdown_signal: Listener,
243 init_message_cache: Arc<std::sync::Mutex<InitMessageCache>>,
244 ) -> Self {
245 let (sender, receiver) = mpsc::channel(1000);
246 let adapter = ProtocolAdapter::new(
247 receiver,
248 shutdown_trigger.clone(),
249 shutdown_signal.clone(),
250 init_message_cache,
251 );
252
253 HsmdService { shutdown_trigger, adapter, sender }
254 }
255
256 pub async fn start(
257 self,
258 incoming: TcpIncoming,
259 shutdown_signal: Listener,
260 ) -> Result<(), Error> {
261 let service = Server::builder()
262 .add_service(hsmd_server::HsmdServer::new(self))
263 .serve_with_incoming_shutdown(incoming, shutdown_signal);
264 service.await
265 }
266
267 pub fn sender(&self) -> Sender<ChannelRequest> {
269 self.sender.clone()
270 }
271}
272
273#[tonic::async_trait]
274impl hsmd_server::Hsmd for HsmdService {
275 async fn ping(&self, request: Request<PingRequest>) -> StdResult<Response<PingReply>, Status> {
276 info!("got ping request");
277 let r = request.into_inner();
278 Ok(Response::new(PingReply { message: r.message }))
279 }
280
281 type SignerStreamStream = SignerStream;
282
283 async fn signer_stream(
284 &self,
285 request: Request<Streaming<SignerResponse>>,
286 ) -> StdResult<Response<Self::SignerStreamStream>, Status> {
287 let request_stream = request.into_inner();
288
289 let stream_reader_task = self.adapter.start_stream_reader(request_stream);
290
291 let response_stream = self.adapter.writer_stream(stream_reader_task).await;
292
293 Ok(Response::new(response_stream as Self::SignerStreamStream))
294 }
295}