switchgear_components/pool/
client_pool.rs

1use crate::pool::cln::grpc::client::TonicClnGrpcClient;
2use crate::pool::error::LnPoolError;
3use crate::pool::lnd::grpc::client::TonicLndGrpcClient;
4use crate::pool::{
5    Bolt11InvoiceDescription, DiscoveryBackendImplementation, LnMetrics, LnRpcClient,
6};
7use std::collections::HashMap;
8use std::fmt::Debug;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use switchgear_service_api::discovery::DiscoveryBackend;
12use switchgear_service_api::offer::Offer;
13use switchgear_service_api::service::ServiceErrorSource;
14use tonic::transport::CertificateDer;
15
16type LnClientMap<K> =
17    HashMap<K, Arc<Box<dyn LnRpcClient<Error = LnPoolError> + Send + Sync + 'static>>>;
18
19#[derive(Clone)]
20pub struct LnClientPool<K>
21where
22    K: Clone + std::hash::Hash + Eq,
23{
24    timeout: Duration,
25    pool: Arc<Mutex<LnClientMap<K>>>,
26    metrics_cache: Arc<Mutex<HashMap<K, LnMetrics>>>,
27    trusted_roots: Vec<CertificateDer<'static>>,
28}
29
30impl<K> LnClientPool<K>
31where
32    K: Clone + std::hash::Hash + Eq + Debug,
33{
34    pub fn new(timeout: Duration, trusted_roots: Vec<CertificateDer<'static>>) -> LnClientPool<K> {
35        Self {
36            timeout,
37            pool: Default::default(),
38            metrics_cache: Default::default(),
39            trusted_roots,
40        }
41    }
42
43    async fn get_client(
44        &self,
45        key: &K,
46    ) -> Result<Arc<Box<dyn LnRpcClient<Error = LnPoolError> + Send + Sync + 'static>>, LnPoolError>
47    {
48        let pool = self.pool.lock().map_err(|e| {
49            LnPoolError::from_memory_error(
50                e.to_string(),
51                format!("fetching client from pool for key: {key:?}"),
52            )
53        })?;
54        let client = pool.get(key).ok_or_else(|| {
55            LnPoolError::from_invalid_configuration(
56                format!("client for key: {key:?} not found in pool"),
57                ServiceErrorSource::Internal,
58                format!("fetching client from pool for key: {key:?}"),
59            )
60        })?;
61        Ok(client.clone())
62    }
63
64    pub async fn get_invoice(
65        &self,
66        offer: &Offer,
67        key: &K,
68        amount_msat: Option<u64>,
69        expiry_secs: Option<u64>,
70    ) -> Result<String, LnPoolError> {
71        let client = self.get_client(key).await?;
72
73        let capabilities = client.get_features();
74
75        let invoice_from_desc_hash =
76            capabilities.map_or_else(|| false, |c| c.invoice_from_desc_hash);
77
78        let description = if invoice_from_desc_hash {
79            Bolt11InvoiceDescription::Hash(&offer.metadata_json_hash)
80        } else {
81            Bolt11InvoiceDescription::DirectIntoHash(offer.metadata_json_string.as_str())
82        };
83
84        client
85            .get_invoice(amount_msat, description, expiry_secs)
86            .await
87    }
88
89    pub async fn get_metrics(&self, key: &K) -> Result<LnMetrics, LnPoolError> {
90        let client = self.get_client(key).await?;
91
92        let metrics = client.get_metrics().await?;
93
94        let mut cache = self.metrics_cache.lock().map_err(|e| {
95            LnPoolError::from_memory_error(e.to_string(), format!("get node metrics key: {key:?}"))
96        })?;
97
98        cache.insert(key.clone(), metrics.clone());
99        Ok(metrics)
100    }
101
102    pub fn connect(&self, key: K, backend: &DiscoveryBackend) -> Result<(), LnPoolError> {
103        let implementation: DiscoveryBackendImplementation =
104            serde_json::from_slice(backend.backend.implementation.as_slice())
105                .map_err(|e| LnPoolError::from_json_error(e, "parsing backend implementation"))?;
106        let client: Box<dyn LnRpcClient<Error = LnPoolError> + Send + Sync> = match implementation {
107            DiscoveryBackendImplementation::ClnGrpc(implementation) => Box::new(
108                TonicClnGrpcClient::create(self.timeout, implementation, &self.trusted_roots)?,
109            ),
110            DiscoveryBackendImplementation::LndGrpc(implementation) => Box::new(
111                TonicLndGrpcClient::create(self.timeout, implementation, &self.trusted_roots)?,
112            ),
113        };
114
115        let mut pool = self.pool.lock().map_err(|e| {
116            LnPoolError::from_memory_error(e.to_string(), format!("connecting ln client {key:?}"))
117        })?;
118        pool.insert(key, Arc::new(client));
119
120        Ok(())
121    }
122
123    pub fn get_cached_metrics(&self, key: &K) -> Option<LnMetrics> {
124        match self.metrics_cache.lock() {
125            Ok(cache) => cache.get(key).cloned(),
126            Err(_) => None,
127        }
128    }
129}