switchgear_components/pool/lnd/grpc/
client.rs

1use crate::pool::error::LnPoolError;
2use crate::pool::lnd::grpc::config::{LndGrpcClientAuth, LndGrpcDiscoveryBackendImplementation};
3use crate::pool::{Bolt11InvoiceDescription, LnFeatures, LnMetrics, LnRpcClient};
4use async_trait::async_trait;
5use rustls::pki_types::CertificateDer;
6use sha2::Digest;
7use std::fs;
8use std::sync::Arc;
9use std::time::Duration;
10use switchgear_service_api::service::ServiceErrorSource;
11use tokio::sync::Mutex;
12use tonic::service::Interceptor;
13use tonic::transport::{Certificate, Channel, ClientTlsConfig};
14
15#[allow(clippy::all)]
16pub mod lnrpc {
17    tonic::include_proto!("lnrpc");
18}
19
20use lnrpc::lightning_client::LightningClient;
21
22pub struct TonicLndGrpcClient {
23    timeout: Duration,
24    config: LndGrpcDiscoveryBackendImplementation,
25    features: Option<LnFeatures>,
26    inner: Arc<Mutex<Option<Arc<InnerTonicLndGrpcClient>>>>,
27    ca_certificates: Vec<Certificate>,
28    macaroon: String,
29}
30
31impl TonicLndGrpcClient {
32    pub fn create(
33        timeout: Duration,
34        config: LndGrpcDiscoveryBackendImplementation,
35        trusted_roots: &[CertificateDer],
36    ) -> Result<Self, LnPoolError> {
37        let LndGrpcClientAuth::Path(auth) = &config.auth;
38
39        let mut ca_certificates = trusted_roots
40            .iter()
41            .map(|c| {
42                let c = Self::certificate_der_as_pem(c);
43                Certificate::from_pem(&c)
44            })
45            .collect::<Vec<_>>();
46
47        if let Some(tls_cert_path) = &auth.tls_cert_path {
48            let ca_certificate = fs::read(tls_cert_path).map_err(|e| {
49                LnPoolError::from_invalid_credentials(
50                    e.to_string(),
51                    ServiceErrorSource::Internal,
52                    format!(
53                        "loading LND credentials and reading CA certificate from path {}",
54                        tls_cert_path.to_string_lossy()
55                    ),
56                )
57            })?;
58            ca_certificates.push(Certificate::from_pem(&ca_certificate));
59        }
60
61        let macaroon = fs::read(&auth.macaroon_path).map_err(|e| {
62            LnPoolError::from_invalid_credentials(
63                e.to_string(),
64                ServiceErrorSource::Internal,
65                format!(
66                    "loading LND macaroon from {}",
67                    auth.macaroon_path.to_string_lossy()
68                ),
69            )
70        })?;
71        let macaroon = hex::encode(&macaroon);
72
73        Ok(Self {
74            timeout,
75            config,
76            features: Some(LnFeatures {
77                invoice_from_desc_hash: true,
78            }),
79            inner: Arc::new(Default::default()),
80            ca_certificates,
81            macaroon,
82        })
83    }
84
85    async fn inner_connect(&self) -> Result<Arc<InnerTonicLndGrpcClient>, LnPoolError> {
86        let mut inner = self.inner.lock().await;
87        match inner.as_ref() {
88            None => {
89                let inner_connect = Arc::new(
90                    InnerTonicLndGrpcClient::connect(
91                        self.timeout,
92                        self.ca_certificates.clone(),
93                        self.macaroon.clone(),
94                        self.config.url.to_string(),
95                        self.config.domain.as_deref(),
96                        self.config.amp_invoice,
97                    )
98                    .await?,
99                );
100                *inner = Some(inner_connect.clone());
101                Ok(inner_connect)
102            }
103            Some(inner) => Ok(inner.clone()),
104        }
105    }
106
107    async fn inner_disconnect(&self) {
108        let mut inner = self.inner.lock().await;
109        *inner = None;
110    }
111
112    fn certificate_der_as_pem(certificate: &CertificateDer) -> String {
113        use base64::Engine;
114        let base64_cert = base64::engine::general_purpose::STANDARD.encode(certificate.as_ref());
115        format!("-----BEGIN CERTIFICATE-----\n{base64_cert}\n-----END CERTIFICATE-----")
116    }
117}
118
119#[async_trait]
120impl LnRpcClient for TonicLndGrpcClient {
121    type Error = LnPoolError;
122
123    async fn get_invoice<'a>(
124        &self,
125        amount_msat: Option<u64>,
126        description: Bolt11InvoiceDescription<'a>,
127        expiry_secs: Option<u64>,
128    ) -> Result<String, Self::Error> {
129        let inner = self.inner_connect().await?;
130
131        let r = inner
132            .get_invoice(amount_msat, description, expiry_secs)
133            .await;
134
135        if r.is_err() {
136            self.inner_disconnect().await;
137        }
138        r
139    }
140
141    async fn get_metrics(&self) -> Result<LnMetrics, Self::Error> {
142        let inner = self.inner_connect().await?;
143
144        let r = inner.get_metrics().await;
145
146        if r.is_err() {
147            self.inner_disconnect().await;
148        }
149        r
150    }
151
152    fn get_features(&self) -> Option<&LnFeatures> {
153        self.features.as_ref()
154    }
155}
156
157struct InnerTonicLndGrpcClient {
158    client: LightningClient<
159        tonic::service::interceptor::InterceptedService<Channel, MacaroonInterceptor>,
160    >,
161    url: String,
162    amp_invoice: bool,
163}
164
165impl InnerTonicLndGrpcClient {
166    async fn connect(
167        timeout: Duration,
168        ca_certificates: Vec<Certificate>,
169        macaroon: String,
170        url: String,
171        domain: Option<&str>,
172        amp_invoice: bool,
173    ) -> Result<Self, LnPoolError> {
174        let endpoint = Channel::from_shared(url.clone()).map_err(|e| {
175            LnPoolError::from_invalid_configuration(
176                format!("Invalid endpoint URI: {}", e),
177                ServiceErrorSource::Internal,
178                format!("LND connecting to endpoint address {url}"),
179            )
180        })?;
181
182        let mut tls_config = ClientTlsConfig::new()
183            .with_native_roots()
184            .ca_certificates(ca_certificates);
185
186        if let Some(domain) = domain {
187            tls_config = tls_config.domain_name(domain);
188        }
189
190        let endpoint = endpoint.tls_config(tls_config).map_err(|e| {
191            LnPoolError::from_invalid_credentials(
192                e.to_string(),
193                ServiceErrorSource::Internal,
194                format!("loading LND TLS configuration into client for {url}"),
195            )
196        })?;
197
198        let channel = endpoint
199            .connect_timeout(timeout)
200            .timeout(timeout)
201            .connect()
202            .await
203            .map_err(|e| {
204                LnPoolError::from_transport_error(
205                    e,
206                    ServiceErrorSource::Upstream,
207                    format!("connecting LND client to {url}"),
208                )
209            })?;
210
211        let interceptor = MacaroonInterceptor { macaroon };
212
213        let client = LightningClient::with_interceptor(channel, interceptor);
214        Ok(Self {
215            client,
216            url,
217            amp_invoice,
218        })
219    }
220
221    async fn get_invoice<'a>(
222        &self,
223        amount_msat: Option<u64>,
224        description: Bolt11InvoiceDescription<'a>,
225        expiry_secs: Option<u64>,
226    ) -> Result<String, LnPoolError> {
227        let mut client = self.client.clone();
228
229        let (memo, description_hash) = match description {
230            Bolt11InvoiceDescription::Direct(d) => (d.to_string(), vec![]),
231            Bolt11InvoiceDescription::DirectIntoHash(d) => {
232                (String::new(), sha2::Sha256::digest(d.as_bytes()).to_vec())
233            }
234            Bolt11InvoiceDescription::Hash(h) => (String::new(), h.to_vec()),
235        };
236
237        let invoice_request = lnrpc::Invoice {
238            memo,
239            value_msat: amount_msat.unwrap_or(0) as i64,
240            description_hash,
241            expiry: expiry_secs.unwrap_or(3600) as i64,
242            is_amp: self.amp_invoice,
243            ..Default::default()
244        };
245
246        let response = client
247            .add_invoice(invoice_request)
248            .await
249            .map_err(|e| {
250                LnPoolError::from_tonic_error(
251                    e,
252                    format!("LND get invoice from {}, requesting invoice", self.url),
253                )
254            })?
255            .into_inner();
256
257        Ok(response.payment_request)
258    }
259
260    async fn get_metrics(&self) -> Result<LnMetrics, LnPoolError> {
261        let mut client = self.client.clone();
262
263        let channel_balance_request = lnrpc::ChannelBalanceRequest {};
264        let channels_balance_response = client
265            .channel_balance(channel_balance_request)
266            .await
267            .map_err(|e| {
268                LnPoolError::from_tonic_error(
269                    e,
270                    format!("LND get metrics for {}, requesting channels", self.url),
271                )
272            })?
273            .into_inner();
274
275        let node_effective_inbound_msat = channels_balance_response
276            .remote_balance
277            .map(|balance| balance.msat)
278            .unwrap_or(0);
279
280        Ok(LnMetrics {
281            healthy: true,
282            node_effective_inbound_msat,
283        })
284    }
285}
286
287#[derive(Clone)]
288struct MacaroonInterceptor {
289    macaroon: String,
290}
291
292impl Interceptor for MacaroonInterceptor {
293    fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
294        req.metadata_mut().insert(
295            "macaroon",
296            tonic::metadata::MetadataValue::try_from(self.macaroon.clone())
297                .map_err(|_| tonic::Status::invalid_argument("Invalid macaroon"))?,
298        );
299        Ok(req)
300    }
301}