switchgear_components/pool/cln/grpc/
client.rs1use crate::pool::cln::grpc::config::{ClnGrpcClientAuth, ClnGrpcDiscoveryBackendImplementation};
2use crate::pool::error::LnPoolError;
3use crate::pool::{Bolt11InvoiceDescription, LnFeatures, LnMetrics, LnRpcClient};
4use async_trait::async_trait;
5use hex::ToHex;
6use rustls::pki_types::CertificateDer;
7use sha2::Digest;
8use std::fs;
9use std::sync::Arc;
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11use switchgear_service_api::service::ServiceErrorSource;
12use tokio::sync::Mutex;
13use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
14
15#[allow(clippy::all)]
16pub mod cln {
17 tonic::include_proto!("cln");
18}
19
20use cln::node_client::NodeClient;
21
22pub struct TonicClnGrpcClient {
23 timeout: Duration,
24 config: ClnGrpcDiscoveryBackendImplementation,
25 features: Option<LnFeatures>,
26 inner: Arc<Mutex<Option<Arc<InnerTonicClnGrpcClient>>>>,
27 ca_certificates: Vec<Certificate>,
28 identity: Identity,
29}
30
31impl TonicClnGrpcClient {
32 pub fn create(
33 timeout: Duration,
34 config: ClnGrpcDiscoveryBackendImplementation,
35 trusted_roots: &[CertificateDer],
36 ) -> Result<Self, LnPoolError> {
37 let ClnGrpcClientAuth::Path(auth) = &config.auth;
38
39 let mut ca_certificates = trusted_roots
40 .iter()
41 .map(|c| {
42 let c = Self::certificate_der_as_pem(c);
43 Certificate::from_pem(&c)
44 })
45 .collect::<Vec<_>>();
46
47 if let Some(ca_cert_path) = &auth.ca_cert_path {
48 let ca_certificate = fs::read(ca_cert_path).map_err(|e| {
49 LnPoolError::from_invalid_credentials(
50 e.to_string(),
51 ServiceErrorSource::Internal,
52 format!(
53 "loading CLN credentials and reading CA certificate from path {}",
54 ca_cert_path.to_string_lossy()
55 ),
56 )
57 })?;
58 ca_certificates.push(Certificate::from_pem(&ca_certificate));
59 }
60
61 let client_cert = fs::read(&auth.client_cert_path).map_err(|e| {
62 LnPoolError::from_invalid_credentials(
63 e.to_string(),
64 ServiceErrorSource::Internal,
65 format!(
66 "loading CLN credentials and reading client certificate from path {}",
67 auth.client_cert_path.to_string_lossy()
68 ),
69 )
70 })?;
71
72 let client_key = fs::read(&auth.client_key_path).map_err(|e| {
73 LnPoolError::from_invalid_credentials(
74 e.to_string(),
75 ServiceErrorSource::Internal,
76 format!(
77 "loading CLN credentials and reading client key from path {}",
78 auth.client_key_path.to_string_lossy()
79 ),
80 )
81 })?;
82
83 let identity = Identity::from_pem(client_cert, client_key);
84
85 Ok(Self {
86 timeout,
87 config,
88 features: Some(LnFeatures {
89 invoice_from_desc_hash: false,
90 }),
91 inner: Arc::new(Default::default()),
92 ca_certificates,
93 identity,
94 })
95 }
96
97 async fn inner_connect(&self) -> Result<Arc<InnerTonicClnGrpcClient>, LnPoolError> {
98 let mut inner = self.inner.lock().await;
99 match inner.as_ref() {
100 None => {
101 let inner_connect = Arc::new(
102 InnerTonicClnGrpcClient::connect(
103 self.timeout,
104 self.ca_certificates.clone(),
105 self.identity.clone(),
106 self.config.url.to_string(),
107 self.config.domain.as_deref(),
108 )
109 .await?,
110 );
111 *inner = Some(inner_connect.clone());
112 Ok(inner_connect)
113 }
114 Some(inner) => Ok(inner.clone()),
115 }
116 }
117
118 async fn inner_disconnect(&self) {
119 let mut inner = self.inner.lock().await;
120 *inner = None;
121 }
122
123 fn certificate_der_as_pem(certificate: &CertificateDer) -> String {
124 use base64::Engine;
125 let base64_cert = base64::engine::general_purpose::STANDARD.encode(certificate.as_ref());
126 format!("-----BEGIN CERTIFICATE-----\n{base64_cert}\n-----END CERTIFICATE-----")
127 }
128}
129
130#[async_trait]
131impl LnRpcClient for TonicClnGrpcClient {
132 type Error = LnPoolError;
133
134 async fn get_invoice<'a>(
135 &self,
136 amount_msat: Option<u64>,
137 description: Bolt11InvoiceDescription<'a>,
138 expiry_secs: Option<u64>,
139 ) -> Result<String, Self::Error> {
140 let inner = self.inner_connect().await?;
141
142 let r = inner
143 .get_invoice(amount_msat, description, expiry_secs)
144 .await;
145
146 if r.is_err() {
147 self.inner_disconnect().await;
148 }
149 r
150 }
151
152 async fn get_metrics(&self) -> Result<LnMetrics, Self::Error> {
153 let inner = self.inner_connect().await?;
154
155 let r = inner.get_metrics().await;
156
157 if r.is_err() {
158 self.inner_disconnect().await;
159 }
160 r
161 }
162
163 fn get_features(&self) -> Option<&LnFeatures> {
164 self.features.as_ref()
165 }
166}
167
168struct InnerTonicClnGrpcClient {
169 client: NodeClient<Channel>,
170 url: String,
171}
172
173impl InnerTonicClnGrpcClient {
174 async fn connect(
175 timeout: Duration,
176 ca_certificates: Vec<Certificate>,
177 identity: Identity,
178 url: String,
179 domain: Option<&str>,
180 ) -> Result<Self, LnPoolError> {
181 let endpoint = Channel::from_shared(url.clone()).map_err(|e| {
182 LnPoolError::from_invalid_configuration(
183 format!("Invalid endpoint URI: {}", e),
184 ServiceErrorSource::Internal,
185 format!("CLN connecting to endpoint address {url}"),
186 )
187 })?;
188
189 let mut tls_config = ClientTlsConfig::new()
190 .with_native_roots()
191 .ca_certificates(ca_certificates)
192 .identity(identity);
193
194 if let Some(domain) = domain {
195 tls_config = tls_config.domain_name(domain);
196 }
197
198 let endpoint = endpoint.tls_config(tls_config).map_err(|e| {
199 LnPoolError::from_invalid_credentials(
200 e.to_string(),
201 ServiceErrorSource::Internal,
202 format!("loading CLN TLS configuration into client for {url}"),
203 )
204 })?;
205
206 let channel = endpoint
207 .connect_timeout(timeout)
208 .timeout(timeout)
209 .connect()
210 .await
211 .map_err(|e| {
212 LnPoolError::from_transport_error(
213 e,
214 ServiceErrorSource::Upstream,
215 format!("connecting CLN client to {url}"),
216 )
217 })?;
218
219 let client = NodeClient::new(channel);
220 Ok(Self { client, url })
221 }
222
223 async fn get_invoice<'a>(
224 &self,
225 amount_msat: Option<u64>,
226 description: Bolt11InvoiceDescription<'a>,
227 expiry_secs: Option<u64>,
228 ) -> Result<String, LnPoolError> {
229 let (description_str, deschashonly, label) = match description {
230 Bolt11InvoiceDescription::Direct(d) => (d.to_string(), Some(false), d.to_string()),
231 Bolt11InvoiceDescription::DirectIntoHash(d) => {
232 let hash = sha2::Sha256::digest(d.as_bytes()).to_vec();
233 (d.to_string(), Some(true), hash.encode_hex())
234 }
235 Bolt11InvoiceDescription::Hash(_) => {
236 return Err(LnPoolError::from_invalid_configuration(
237 "hash descriptions unsupported".to_string(),
238 ServiceErrorSource::Internal,
239 format!(
240 "CLN get invoice from {}, parsing invoice description",
241 self.url
242 ),
243 ))
244 }
245 };
246
247 let now = SystemTime::now().duration_since(UNIX_EPOCH).map_err(|e| {
248 LnPoolError::from_invalid_configuration(
249 e.to_string(),
250 ServiceErrorSource::Internal,
251 format!(
252 "CLN get invoice from {}, getting current time for label",
253 self.url
254 ),
255 )
256 })?;
257 let label = format!("{label}:{}", now.as_nanos());
258
259 let mut client = self.client.clone();
260 let request = cln::InvoiceRequest {
261 amount_msat: match amount_msat {
262 Some(msat) => Some(cln::AmountOrAny {
263 value: Some(cln::amount_or_any::Value::Amount(cln::Amount { msat })),
264 }),
265 None => Some(cln::AmountOrAny {
266 value: Some(cln::amount_or_any::Value::Any(true)),
267 }),
268 },
269 description: description_str,
270 label,
271 deschashonly,
272 expiry: expiry_secs,
273 ..Default::default()
274 };
275
276 let response = client
277 .invoice(request)
278 .await
279 .map_err(|e| {
280 LnPoolError::from_tonic_error(
281 e,
282 format!("CLN get invoice from {}, requesting invoice", self.url),
283 )
284 })?
285 .into_inner();
286
287 Ok(response.bolt11)
288 }
289
290 async fn get_metrics(&self) -> Result<LnMetrics, LnPoolError> {
291 let channels_request = cln::ListpeerchannelsRequest {
292 id: None,
293 short_channel_id: None,
294 };
295 let mut client = self.client.clone();
296 let channels_response = client
297 .list_peer_channels(channels_request)
298 .await
299 .map_err(|e| {
300 LnPoolError::from_tonic_error(
301 e,
302 format!("CLN get metrics for {}, requesting channels", self.url),
303 )
304 })?
305 .into_inner();
306
307 let mut node_effective_inbound_msat = 0u64;
308
309 const CHANNELD_NORMAL: i32 = 2;
310
311 for channel in &channels_response.channels {
312 if channel.state == CHANNELD_NORMAL {
313 let receivable_msat = channel
314 .receivable_msat
315 .as_ref()
316 .map(|a| a.msat)
317 .unwrap_or(0);
318 node_effective_inbound_msat += receivable_msat;
319 }
320 }
321
322 Ok(LnMetrics {
323 healthy: true,
324 node_effective_inbound_msat,
325 })
326 }
327}