switchgear_components/pool/cln/grpc/
client.rs

1use crate::pool::cln::grpc::config::{ClnGrpcClientAuth, ClnGrpcDiscoveryBackendImplementation};
2use crate::pool::error::LnPoolError;
3use crate::pool::{Bolt11InvoiceDescription, LnFeatures, LnMetrics, LnRpcClient};
4use async_trait::async_trait;
5use hex::ToHex;
6use rustls::pki_types::CertificateDer;
7use sha2::Digest;
8use std::fs;
9use std::sync::Arc;
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11use switchgear_service_api::service::ServiceErrorSource;
12use tokio::sync::Mutex;
13use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
14
15#[allow(clippy::all)]
16pub mod cln {
17    tonic::include_proto!("cln");
18}
19
20use cln::node_client::NodeClient;
21
22pub struct TonicClnGrpcClient {
23    timeout: Duration,
24    config: ClnGrpcDiscoveryBackendImplementation,
25    features: Option<LnFeatures>,
26    inner: Arc<Mutex<Option<Arc<InnerTonicClnGrpcClient>>>>,
27    ca_certificates: Vec<Certificate>,
28    identity: Identity,
29}
30
31impl TonicClnGrpcClient {
32    pub fn create(
33        timeout: Duration,
34        config: ClnGrpcDiscoveryBackendImplementation,
35        trusted_roots: &[CertificateDer],
36    ) -> Result<Self, LnPoolError> {
37        let ClnGrpcClientAuth::Path(auth) = &config.auth;
38
39        let mut ca_certificates = trusted_roots
40            .iter()
41            .map(|c| {
42                let c = Self::certificate_der_as_pem(c);
43                Certificate::from_pem(&c)
44            })
45            .collect::<Vec<_>>();
46
47        if let Some(ca_cert_path) = &auth.ca_cert_path {
48            let ca_certificate = fs::read(ca_cert_path).map_err(|e| {
49                LnPoolError::from_invalid_credentials(
50                    e.to_string(),
51                    ServiceErrorSource::Internal,
52                    format!(
53                        "loading CLN credentials and reading CA certificate from path {}",
54                        ca_cert_path.to_string_lossy()
55                    ),
56                )
57            })?;
58            ca_certificates.push(Certificate::from_pem(&ca_certificate));
59        }
60
61        let client_cert = fs::read(&auth.client_cert_path).map_err(|e| {
62            LnPoolError::from_invalid_credentials(
63                e.to_string(),
64                ServiceErrorSource::Internal,
65                format!(
66                    "loading CLN credentials and reading client certificate from path {}",
67                    auth.client_cert_path.to_string_lossy()
68                ),
69            )
70        })?;
71
72        let client_key = fs::read(&auth.client_key_path).map_err(|e| {
73            LnPoolError::from_invalid_credentials(
74                e.to_string(),
75                ServiceErrorSource::Internal,
76                format!(
77                    "loading CLN credentials and reading client key from path {}",
78                    auth.client_key_path.to_string_lossy()
79                ),
80            )
81        })?;
82
83        let identity = Identity::from_pem(client_cert, client_key);
84
85        Ok(Self {
86            timeout,
87            config,
88            features: Some(LnFeatures {
89                invoice_from_desc_hash: false,
90            }),
91            inner: Arc::new(Default::default()),
92            ca_certificates,
93            identity,
94        })
95    }
96
97    async fn inner_connect(&self) -> Result<Arc<InnerTonicClnGrpcClient>, LnPoolError> {
98        let mut inner = self.inner.lock().await;
99        match inner.as_ref() {
100            None => {
101                let inner_connect = Arc::new(
102                    InnerTonicClnGrpcClient::connect(
103                        self.timeout,
104                        self.ca_certificates.clone(),
105                        self.identity.clone(),
106                        self.config.url.to_string(),
107                        self.config.domain.as_deref(),
108                    )
109                    .await?,
110                );
111                *inner = Some(inner_connect.clone());
112                Ok(inner_connect)
113            }
114            Some(inner) => Ok(inner.clone()),
115        }
116    }
117
118    async fn inner_disconnect(&self) {
119        let mut inner = self.inner.lock().await;
120        *inner = None;
121    }
122
123    fn certificate_der_as_pem(certificate: &CertificateDer) -> String {
124        use base64::Engine;
125        let base64_cert = base64::engine::general_purpose::STANDARD.encode(certificate.as_ref());
126        format!("-----BEGIN CERTIFICATE-----\n{base64_cert}\n-----END CERTIFICATE-----")
127    }
128}
129
130#[async_trait]
131impl LnRpcClient for TonicClnGrpcClient {
132    type Error = LnPoolError;
133
134    async fn get_invoice<'a>(
135        &self,
136        amount_msat: Option<u64>,
137        description: Bolt11InvoiceDescription<'a>,
138        expiry_secs: Option<u64>,
139    ) -> Result<String, Self::Error> {
140        let inner = self.inner_connect().await?;
141
142        let r = inner
143            .get_invoice(amount_msat, description, expiry_secs)
144            .await;
145
146        if r.is_err() {
147            self.inner_disconnect().await;
148        }
149        r
150    }
151
152    async fn get_metrics(&self) -> Result<LnMetrics, Self::Error> {
153        let inner = self.inner_connect().await?;
154
155        let r = inner.get_metrics().await;
156
157        if r.is_err() {
158            self.inner_disconnect().await;
159        }
160        r
161    }
162
163    fn get_features(&self) -> Option<&LnFeatures> {
164        self.features.as_ref()
165    }
166}
167
168struct InnerTonicClnGrpcClient {
169    client: NodeClient<Channel>,
170    url: String,
171}
172
173impl InnerTonicClnGrpcClient {
174    async fn connect(
175        timeout: Duration,
176        ca_certificates: Vec<Certificate>,
177        identity: Identity,
178        url: String,
179        domain: Option<&str>,
180    ) -> Result<Self, LnPoolError> {
181        let endpoint = Channel::from_shared(url.clone()).map_err(|e| {
182            LnPoolError::from_invalid_configuration(
183                format!("Invalid endpoint URI: {}", e),
184                ServiceErrorSource::Internal,
185                format!("CLN connecting to endpoint address {url}"),
186            )
187        })?;
188
189        let mut tls_config = ClientTlsConfig::new()
190            .with_native_roots()
191            .ca_certificates(ca_certificates)
192            .identity(identity);
193
194        if let Some(domain) = domain {
195            tls_config = tls_config.domain_name(domain);
196        }
197
198        let endpoint = endpoint.tls_config(tls_config).map_err(|e| {
199            LnPoolError::from_invalid_credentials(
200                e.to_string(),
201                ServiceErrorSource::Internal,
202                format!("loading CLN TLS configuration into client for {url}"),
203            )
204        })?;
205
206        let channel = endpoint
207            .connect_timeout(timeout)
208            .timeout(timeout)
209            .connect()
210            .await
211            .map_err(|e| {
212                LnPoolError::from_transport_error(
213                    e,
214                    ServiceErrorSource::Upstream,
215                    format!("connecting CLN client to {url}"),
216                )
217            })?;
218
219        let client = NodeClient::new(channel);
220        Ok(Self { client, url })
221    }
222
223    async fn get_invoice<'a>(
224        &self,
225        amount_msat: Option<u64>,
226        description: Bolt11InvoiceDescription<'a>,
227        expiry_secs: Option<u64>,
228    ) -> Result<String, LnPoolError> {
229        let (description_str, deschashonly, label) = match description {
230            Bolt11InvoiceDescription::Direct(d) => (d.to_string(), Some(false), d.to_string()),
231            Bolt11InvoiceDescription::DirectIntoHash(d) => {
232                let hash = sha2::Sha256::digest(d.as_bytes()).to_vec();
233                (d.to_string(), Some(true), hash.encode_hex())
234            }
235            Bolt11InvoiceDescription::Hash(_) => {
236                return Err(LnPoolError::from_invalid_configuration(
237                    "hash descriptions unsupported".to_string(),
238                    ServiceErrorSource::Internal,
239                    format!(
240                        "CLN get invoice from {}, parsing invoice description",
241                        self.url
242                    ),
243                ))
244            }
245        };
246
247        let now = SystemTime::now().duration_since(UNIX_EPOCH).map_err(|e| {
248            LnPoolError::from_invalid_configuration(
249                e.to_string(),
250                ServiceErrorSource::Internal,
251                format!(
252                    "CLN get invoice from {}, getting current time for label",
253                    self.url
254                ),
255            )
256        })?;
257        let label = format!("{label}:{}", now.as_nanos());
258
259        let mut client = self.client.clone();
260        let request = cln::InvoiceRequest {
261            amount_msat: match amount_msat {
262                Some(msat) => Some(cln::AmountOrAny {
263                    value: Some(cln::amount_or_any::Value::Amount(cln::Amount { msat })),
264                }),
265                None => Some(cln::AmountOrAny {
266                    value: Some(cln::amount_or_any::Value::Any(true)),
267                }),
268            },
269            description: description_str,
270            label,
271            deschashonly,
272            expiry: expiry_secs,
273            ..Default::default()
274        };
275
276        let response = client
277            .invoice(request)
278            .await
279            .map_err(|e| {
280                LnPoolError::from_tonic_error(
281                    e,
282                    format!("CLN get invoice from {}, requesting invoice", self.url),
283                )
284            })?
285            .into_inner();
286
287        Ok(response.bolt11)
288    }
289
290    async fn get_metrics(&self) -> Result<LnMetrics, LnPoolError> {
291        let channels_request = cln::ListpeerchannelsRequest {
292            id: None,
293            short_channel_id: None,
294        };
295        let mut client = self.client.clone();
296        let channels_response = client
297            .list_peer_channels(channels_request)
298            .await
299            .map_err(|e| {
300                LnPoolError::from_tonic_error(
301                    e,
302                    format!("CLN get metrics for {}, requesting channels", self.url),
303                )
304            })?
305            .into_inner();
306
307        let mut node_effective_inbound_msat = 0u64;
308
309        const CHANNELD_NORMAL: i32 = 2;
310
311        for channel in &channels_response.channels {
312            if channel.state == CHANNELD_NORMAL {
313                let receivable_msat = channel
314                    .receivable_msat
315                    .as_ref()
316                    .map(|a| a.msat)
317                    .unwrap_or(0);
318                node_effective_inbound_msat += receivable_msat;
319            }
320        }
321
322        Ok(LnMetrics {
323            healthy: true,
324            node_effective_inbound_msat,
325        })
326    }
327}