sfo_cmd_server/server/
peer_manager.rs1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use crate::{CmdTunnelRead, CmdTunnelWrite};
4use crate::peer_connection::PeerConnection;
5use crate::peer_id::PeerId;
6use crate::server::CmdServerEventListener;
7use crate::tunnel_id::{TunnelId, TunnelIdGenerator};
8
9#[derive(Clone)]
10pub struct CachedPeerInfo {
11 pub conn_list: Vec<TunnelId>,
12}
13
14pub struct PeerManager<R: CmdTunnelRead, W: CmdTunnelWrite> {
15 conn_cache: Mutex<HashMap<TunnelId, (PeerId, Arc<tokio::sync::Mutex<PeerConnection<R, W>>>)>>,
16 device_conn_map: Mutex<HashMap<PeerId, CachedPeerInfo>>,
17 conn_id_generator: TunnelIdGenerator,
18 listener: Arc<dyn CmdServerEventListener>,
19}
20pub type PeerManagerRef<R, W> = Arc<PeerManager<R, W>>;
21
22
23impl<R: CmdTunnelRead, W: CmdTunnelWrite> PeerManager<R, W> {
24 pub fn new(listener: Arc<dyn CmdServerEventListener>) -> PeerManagerRef<R, W> {
25 Arc::new(PeerManager {
26 conn_cache: Mutex::new(HashMap::new()),
27 device_conn_map: Mutex::new(HashMap::new()),
28 conn_id_generator: TunnelIdGenerator::new(),
29 listener,
30 })
31 }
32
33 pub fn generate_conn_id(&self) -> TunnelId {
34 self.conn_id_generator.generate()
35 }
36
37 pub async fn add_peer_connection(self: &Arc<Self>, mut conn: PeerConnection<R, W>) {
38 let recv_handle = conn.handle.take().unwrap();
39 let peer_id = conn.peer_id.clone();
40 let conn_id = conn.conn_id;
41 let conn_count = {
42 self.conn_cache.lock().unwrap().insert(conn_id, (peer_id.clone(), Arc::new(tokio::sync::Mutex::new(conn))));
43 let mut device_conn_map = self.device_conn_map.lock().unwrap();
44 let peer_info = device_conn_map.entry(peer_id.clone()).or_insert(CachedPeerInfo { conn_list: Vec::new() });
45 peer_info.conn_list.push(conn_id);
46 peer_info.conn_list.len()
47 };
48
49 let this = self.clone();
50 tokio::spawn(async move {
51 let _ = recv_handle.await;
52 this.remove_peer_connection(conn_id).await;
53 });
54 if conn_count == 1 {
55 let _ = self.listener.on_peer_connected(&peer_id).await;
56 }
57 }
58
59 pub async fn remove_peer_connection(&self, conn_id: TunnelId) {
60 let mut peer_id = None;
61 {
62 let mut conn_cache = self.conn_cache.lock().unwrap();
63 if let Some(conn) = conn_cache.remove(&conn_id) {
64 let mut device_conn_map = self.device_conn_map.lock().unwrap();
65 if let Some(peer_info) = device_conn_map.get_mut(&conn.0) {
66 peer_info.conn_list.retain(|&id| id != conn_id);
67 if peer_info.conn_list.is_empty() {
68 device_conn_map.remove(&conn.0);
69 peer_id = Some(conn.0.clone());
70 }
71 }
72 }
73 }
74 if peer_id.is_some() {
75 let _ = self.listener.on_peer_disconnected(peer_id.as_ref().unwrap()).await;
76 }
77 }
78
79 pub fn find_connection(&self, conn_id: TunnelId) -> Option<Arc<tokio::sync::Mutex<PeerConnection<R, W>>>> {
80 let conn_cache = self.conn_cache.lock().unwrap();
81 conn_cache.get(&conn_id).map(|c| c.1.clone())
82 }
83
84 pub fn find_connections(&self, device_id: &PeerId) -> Vec<Arc<tokio::sync::Mutex<PeerConnection<R, W>>>> {
85 let conn_cache = self.conn_cache.lock().unwrap();
86 let device_conn_map = self.device_conn_map.lock().unwrap();
87 device_conn_map.get(device_id).map(|conns| {
88 conns.conn_list.iter().filter_map(|c| conn_cache.get(c).map(|c| c.1.clone())).collect()
89 }).unwrap_or_default()
90 }
91
92}