Skip to main content

sigstat_grpc/
statsig_grpc_client.rs

1use crate::statsig_forward_proxy::config_spec_request::ApiVersion;
2use crate::statsig_forward_proxy::statsig_forward_proxy_client::StatsigForwardProxyClient;
3use crate::statsig_forward_proxy::{ConfigSpecRequest, ConfigSpecResponse};
4use crate::statsig_grpc_err::StatsigGrpcErr;
5use parking_lot::Mutex;
6use std::time::Duration;
7use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
8use tonic::Streaming;
9
10pub struct StatsigGrpcClient {
11    sdk_key: String,
12    proxy_api: String,
13    grpc_client: Mutex<Option<StatsigForwardProxyClient<Channel>>>,
14    tls_config: Option<ClientTlsConfig>,
15}
16
17impl StatsigGrpcClient {
18    pub fn new(
19        sdk_key: &str,
20        proxy_api: &str,
21        authentication_mode: Option<String>,
22        ca_cert_path: Option<String>,
23        client_cert_path: Option<String>,
24        client_key_path: Option<String>,
25        domain_name: Option<String>,
26    ) -> Self {
27        Self {
28            sdk_key: sdk_key.to_string(),
29            proxy_api: proxy_api.to_string(),
30            tls_config: Self::setup_tls_client(
31                authentication_mode,
32                ca_cert_path,
33                client_cert_path,
34                client_key_path,
35                domain_name,
36                proxy_api,
37            ),
38            grpc_client: Mutex::new(None),
39        }
40    }
41
42    pub async fn connect_client(&self) -> Result<(), StatsigGrpcErr> {
43        self.get_or_setup_grpc_client().await.map(|_| ())
44    }
45
46    pub fn reset_client(&self) {
47        match self.grpc_client.try_lock_for(Duration::from_secs(5)) {
48            Some(mut lock) => {
49                *lock = None;
50            }
51            None => {
52                eprintln!("Failed to reset grpc client");
53            }
54        };
55    }
56
57    pub async fn get_specs(&self, lcut: Option<u64>) -> Result<ConfigSpecResponse, StatsigGrpcErr> {
58        let request = create_config_spec_request(&self.sdk_key, lcut);
59        let mut client = self.get_or_setup_grpc_client().await?;
60
61        client
62            .get_config_spec(request)
63            .await
64            .map_err(StatsigGrpcErr::ErrorGrpcStatus)
65            .map(|r| r.into_inner())
66    }
67
68    pub async fn get_specs_stream(
69        &self,
70        lcut: Option<u64>,
71    ) -> Result<Streaming<ConfigSpecResponse>, StatsigGrpcErr> {
72        let request = create_config_spec_request(&self.sdk_key, lcut);
73        let mut client = self.get_or_setup_grpc_client().await?;
74
75        client
76            .stream_config_spec(request)
77            .await
78            .map_err(StatsigGrpcErr::ErrorGrpcStatus)
79            .map(|s| s.into_inner())
80    }
81
82    fn setup_tls_client(
83        authentication_mode: Option<String>,
84        ca_cert_path: Option<String>,
85        client_cert_path: Option<String>,
86        client_key_path: Option<String>,
87        domain_name: Option<String>,
88        proxy_api: &str,
89    ) -> Option<ClientTlsConfig> {
90        let domain_name = domain_name.unwrap_or_else(|| {
91            Self::extract_host(proxy_api)
92                .unwrap_or_default()
93                .to_string()
94        });
95        match authentication_mode
96            .as_deref()
97            .map(str::to_ascii_lowercase)
98            .as_deref()
99        {
100            Some("tls") => {
101                let ca_cert_path = ca_cert_path?;
102                let ca_cert: Vec<u8> = std::fs::read(ca_cert_path).ok()?;
103                let ca_cert = Certificate::from_pem(ca_cert);
104
105                Some(
106                    ClientTlsConfig::new()
107                        .ca_certificate(ca_cert)
108                        .domain_name(domain_name), // <-- adjust this as needed
109                )
110            }
111            Some("mtls") => {
112                let ca_cert_path = ca_cert_path?;
113                let client_cert_path = client_cert_path?;
114                let client_key_path = client_key_path?;
115
116                let ca_cert = std::fs::read(ca_cert_path).ok()?;
117                let client_cert = std::fs::read(client_cert_path).ok()?;
118                let client_key = std::fs::read(client_key_path).ok()?;
119
120                let ca_cert = Certificate::from_pem(ca_cert);
121                let identity = Identity::from_pem(client_cert, client_key);
122
123                Some(
124                    ClientTlsConfig::new()
125                        .ca_certificate(ca_cert)
126                        .identity(identity)
127                        .domain_name(domain_name), // <-- adjust this as needed
128                )
129            }
130            _ => None,
131        }
132    }
133
134    fn extract_host(url: &str) -> Option<&str> {
135        // Strip scheme if present
136        let without_scheme = if let Some(pos) = url.find("://") {
137            &url[(pos + 3)..]
138        } else {
139            url
140        };
141
142        // Split off path/query/fragment after the host[:port]
143        let host_port = without_scheme.split('/').next()?; // First part is host[:port]
144
145        // Split off port if present
146        host_port.split(':').next()
147    }
148
149    async fn get_or_setup_grpc_client(
150        &self,
151    ) -> Result<StatsigForwardProxyClient<Channel>, StatsigGrpcErr> {
152        {
153            let lock = self
154                .grpc_client
155                .try_lock_for(Duration::from_secs(5))
156                .ok_or(StatsigGrpcErr::FailedToGetLock)?;
157
158            if let Some(client) = lock.as_ref() {
159                return Ok(client.clone());
160            }
161        }
162
163        let mut channel_builder = Channel::from_shared(self.proxy_api.clone())
164            .map_err(|e| StatsigGrpcErr::FailedToConnect(e.to_string()))?
165            .connect_timeout(Duration::from_secs(5))
166            .tcp_keepalive(Some(Duration::from_secs(30)))
167            .keep_alive_while_idle(true)
168            .http2_keep_alive_interval(Duration::from_secs(30));
169
170        if let Some(tls_config) = self.tls_config.clone() {
171            channel_builder = channel_builder
172                .tls_config(tls_config)
173                .map_err(|e| StatsigGrpcErr::Authentication(e.to_string()))?;
174        }
175        let channel = channel_builder
176            .connect()
177            .await
178            .map_err(|e| StatsigGrpcErr::FailedToConnect(e.to_string()))?;
179
180        let new_client = StatsigForwardProxyClient::new(channel);
181
182        let mut lock = self
183            .grpc_client
184            .try_lock_for(Duration::from_secs(5))
185            .ok_or(StatsigGrpcErr::FailedToGetLock)?;
186
187        *lock = Some(new_client.clone());
188        Ok(new_client)
189    }
190}
191
192fn create_config_spec_request(sdk_key: &str, current_lcut: Option<u64>) -> ConfigSpecRequest {
193    ConfigSpecRequest {
194        since_time: current_lcut,
195        sdk_key: sdk_key.to_string(),
196        version: Some(ApiVersion::V2 as i32),
197        zstd_dict_id: None,
198    }
199}