1#![deny(missing_docs)]
2use std::collections::HashMap;
7use std::io::{Error, Result};
8use std::sync::{Arc, Mutex, Weak};
9
10pub use sbd_client::{PubKey, SbdClientConfig};
11
12pub mod protocol;
13
14mod sodoken_crypto;
15pub use sodoken_crypto::*;
16
17pub struct Config {
19 pub client_config: SbdClientConfig,
21
22 pub listener: bool,
26
27 pub max_connections: usize,
29
30 pub max_idle: std::time::Duration,
32}
33
34impl Default for Config {
35 fn default() -> Self {
36 Self {
37 client_config: Default::default(),
38 listener: false,
39 max_connections: 4096,
40 max_idle: std::time::Duration::from_secs(10),
41 }
42 }
43}
44
45type ClientSync = tokio::sync::Mutex<sbd_client::SbdClient>;
48
49pub struct MsgRecv {
51 inner: Arc<Mutex<Inner>>,
52 recv: sbd_client::MsgRecv,
53 client: Weak<ClientSync>,
54}
55
56impl MsgRecv {
57 pub async fn recv(&mut self) -> Option<(PubKey, bytes::Bytes)> {
59 while let Some(msg) = self.recv.recv().await {
60 let pk = msg.pub_key();
61 let dec = self.inner.lock().unwrap().dec(msg);
62 match dec {
63 DecRes::Ok(msg) => return Some((pk, msg)),
64 DecRes::Ignore => (),
65 DecRes::ReqNewStream => {
66 if let Some(client) = self.client.upgrade() {
68 let msg =
69 protocol::Protocol::request_new_stream(&*pk.0);
70 if let Err(err) =
71 client.lock().await.send(&pk, msg.base_msg()).await
72 {
73 tracing::debug!(?err, "failure sending request_new_stream in message receive handler");
74 }
75 } else {
76 return None;
77 }
78 }
79 }
80 }
81 None
82 }
83}
84
85pub struct SbdClientCrypto {
87 pub_key: PubKey,
88 inner: Arc<Mutex<Inner>>,
89 client: Arc<ClientSync>,
90}
91
92impl SbdClientCrypto {
93 pub async fn new(
95 url: &str,
96 config: Arc<Config>,
97 ) -> Result<(Self, MsgRecv)> {
98 let crypto = SodokenCrypto::new()?;
100 use sbd_client::Crypto;
101 let pub_key = PubKey(Arc::new(*crypto.pub_key()));
102
103 let (client, recv) = sbd_client::SbdClient::connect_config(
105 url,
106 &crypto,
107 config.client_config.clone(),
108 )
109 .await?;
110
111 let client = Arc::new(tokio::sync::Mutex::new(client));
112 let inner = Arc::new(Mutex::new(Inner::new(config, crypto)));
113
114 let recv = MsgRecv {
115 inner: inner.clone(),
116 recv,
117 client: Arc::downgrade(&client),
118 };
119
120 let this = Self {
121 pub_key,
122 inner,
123 client,
124 };
125
126 Ok((this, recv))
127 }
128
129 pub fn pub_key(&self) -> &PubKey {
131 &self.pub_key
132 }
133
134 pub fn active_peers(&self) -> Vec<PubKey> {
136 let mut inner = self.inner.lock().unwrap();
137 let max_idle = inner.config.max_idle;
138 Inner::prune(&mut inner.c_map, max_idle);
139 inner.c_map.keys().cloned().collect()
140 }
141
142 pub async fn assert(&self, pk: &PubKey) -> Result<()> {
144 let enc = self.inner.lock().unwrap().enc(pk, None)?;
145
146 {
147 let client = self.client.lock().await;
148 for enc in enc {
149 client.send(pk, &enc).await?;
150 }
151 }
152
153 Ok(())
154 }
155
156 pub async fn send(&self, pk: &PubKey, msg: &[u8]) -> Result<()> {
158 const SBD_MAX: usize = 20_000;
159 const SBD_HDR: usize = 32;
160 const SBD_SS_HDR: usize = 1;
163 const SS_ABYTES: usize = sodoken::secretstream::ABYTES;
164 const MAX_MSG: usize = SBD_MAX - SBD_HDR - SBD_SS_HDR - SS_ABYTES;
165
166 if msg.len() > MAX_MSG {
167 return Err(Error::other("message too long"));
168 }
169
170 let enc = self.inner.lock().unwrap().enc(pk, Some(msg))?;
172
173 {
174 let client = self.client.lock().await;
175
176 for enc in enc {
178 client.send(pk, &enc).await?;
179 }
180 }
181
182 Ok(())
183 }
184
185 pub async fn close_peer(&self, pk: &PubKey) {
187 self.inner.lock().unwrap().close(pk);
188 }
189
190 pub async fn close(&self) {
192 self.client.lock().await.close().await;
193 }
194}
195
196enum DecRes {
197 Ok(bytes::Bytes),
198 Ignore,
199 ReqNewStream,
200}
201
202struct InnerRec {
203 enc: Option<Encryptor>,
204 dec: Option<Decryptor>,
205 last_active: std::time::Instant,
206}
207
208impl InnerRec {
209 pub fn new() -> Self {
210 Self {
211 enc: None,
212 dec: None,
213 last_active: std::time::Instant::now(),
214 }
215 }
216}
217
218struct Inner {
219 config: Arc<Config>,
220 crypto: SodokenCrypto,
221 c_map: HashMap<PubKey, InnerRec>,
222}
223
224impl Inner {
225 fn new(config: Arc<Config>, crypto: SodokenCrypto) -> Self {
227 Self {
228 config,
229 crypto,
230 c_map: HashMap::new(),
231 }
232 }
233
234 fn close(&mut self, pk: &PubKey) {
236 self.c_map.remove(pk);
237 }
238
239 fn prune(
241 c_map: &mut HashMap<PubKey, InnerRec>,
242 max_idle: std::time::Duration,
243 ) {
244 let now = std::time::Instant::now();
245 c_map.retain(|_pk, r| now - r.last_active < max_idle);
246 }
247
248 fn loc_assert<'a>(
250 config: &'a Config,
251 c_map: &'a mut HashMap<PubKey, InnerRec>,
252 pk: PubKey,
253 do_create: bool,
254 ) -> Result<&'a mut InnerRec> {
255 use std::collections::hash_map::Entry;
256 let tot = c_map.len();
257 Self::prune(c_map, config.max_idle);
258 match c_map.entry(pk.clone()) {
259 Entry::Vacant(e) => {
260 if do_create {
261 if tot >= config.max_connections {
262 return Err(Error::other("too many connections"));
263 }
264 Ok(e.insert(InnerRec::new()))
265 } else {
266 Err(Error::other("ignore unsolicited"))
267 }
268 }
269 Entry::Occupied(e) => Ok(e.into_mut()),
270 }
271 }
272
273 fn enc(
278 &mut self,
279 pk: &PubKey,
280 msg: Option<&[u8]>,
281 ) -> Result<Vec<bytes::Bytes>> {
282 let Self {
283 config,
284 crypto,
285 c_map,
286 } = self;
287
288 let mut out = Vec::new();
289
290 let rec = Self::loc_assert(config, c_map, pk.clone(), true)?;
292 rec.last_active = std::time::Instant::now();
293
294 if rec.enc.is_none() {
296 let (enc, hdr) = crypto.new_enc(pk)?;
297 rec.enc = Some(enc);
298 let msg = protocol::Protocol::new_stream(&**pk, &hdr);
299
300 out.push(msg.base_msg().clone());
302 }
303
304 if let Some(msg) = msg {
305 let msg = rec.enc.as_mut().unwrap().encrypt(&*pk.0, msg)?;
307
308 out.push(msg.base_msg().clone());
309 }
310
311 Ok(out)
312 }
313
314 fn dec(&mut self, msg: sbd_client::Msg) -> DecRes {
317 let Self {
318 config,
319 crypto,
320 c_map,
321 } = self;
322
323 let rec = match Self::loc_assert(
325 config,
326 c_map,
327 msg.pub_key(),
328 config.listener,
329 ) {
330 Ok(rec) => rec,
331 Err(_) => {
332 return DecRes::Ignore;
334 }
335 };
336
337 rec.last_active = std::time::Instant::now();
339
340 let dec = match protocol::Protocol::from_full(
342 bytes::Bytes::copy_from_slice(&msg.0),
343 ) {
344 Some(dec) => dec,
345 None => {
346 rec.dec = None;
347 return DecRes::ReqNewStream;
350 }
351 };
352
353 match dec {
355 protocol::Protocol::NewStream { header, .. } => {
356 let dec =
358 match crypto.new_dec(msg.pub_key_ref(), header.as_ref()) {
359 Ok(dec) => dec,
360 Err(_) => return DecRes::ReqNewStream,
361 };
362 rec.dec = Some(dec);
363 DecRes::Ignore
364 }
365 protocol::Protocol::Message { message, .. } => {
366 match rec.dec.as_mut() {
368 Some(dec) => match dec.decrypt(message.as_ref()) {
369 Ok(message) => DecRes::Ok(message),
370 Err(_) => DecRes::ReqNewStream,
371 },
372 None => {
373 DecRes::Ignore
376 }
377 }
378 }
379 protocol::Protocol::RequestNewStream { .. } => {
380 rec.enc = None;
382 DecRes::Ignore
383 }
384 }
385 }
386}
387
388#[cfg(test)]
389mod test;