sigstat_grpc/
statsig_grpc_client.rs1use crate::statsig_forward_proxy::config_spec_request::ApiVersion;
2use crate::statsig_forward_proxy::statsig_forward_proxy_client::StatsigForwardProxyClient;
3use crate::statsig_forward_proxy::{ConfigSpecRequest, ConfigSpecResponse};
4use crate::statsig_grpc_err::StatsigGrpcErr;
5use parking_lot::Mutex;
6use std::time::Duration;
7use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
8use tonic::Streaming;
9
10pub struct StatsigGrpcClient {
11 sdk_key: String,
12 proxy_api: String,
13 grpc_client: Mutex<Option<StatsigForwardProxyClient<Channel>>>,
14 tls_config: Option<ClientTlsConfig>,
15}
16
17impl StatsigGrpcClient {
18 pub fn new(
19 sdk_key: &str,
20 proxy_api: &str,
21 authentication_mode: Option<String>,
22 ca_cert_path: Option<String>,
23 client_cert_path: Option<String>,
24 client_key_path: Option<String>,
25 domain_name: Option<String>,
26 ) -> Self {
27 Self {
28 sdk_key: sdk_key.to_string(),
29 proxy_api: proxy_api.to_string(),
30 tls_config: Self::setup_tls_client(
31 authentication_mode,
32 ca_cert_path,
33 client_cert_path,
34 client_key_path,
35 domain_name,
36 proxy_api,
37 ),
38 grpc_client: Mutex::new(None),
39 }
40 }
41
42 pub async fn connect_client(&self) -> Result<(), StatsigGrpcErr> {
43 self.get_or_setup_grpc_client().await.map(|_| ())
44 }
45
46 pub fn reset_client(&self) {
47 match self.grpc_client.try_lock_for(Duration::from_secs(5)) {
48 Some(mut lock) => {
49 *lock = None;
50 }
51 None => {
52 eprintln!("Failed to reset grpc client");
53 }
54 };
55 }
56
57 pub async fn get_specs(&self, lcut: Option<u64>) -> Result<ConfigSpecResponse, StatsigGrpcErr> {
58 let request = create_config_spec_request(&self.sdk_key, lcut);
59 let mut client = self.get_or_setup_grpc_client().await?;
60
61 client
62 .get_config_spec(request)
63 .await
64 .map_err(StatsigGrpcErr::ErrorGrpcStatus)
65 .map(|r| r.into_inner())
66 }
67
68 pub async fn get_specs_stream(
69 &self,
70 lcut: Option<u64>,
71 ) -> Result<Streaming<ConfigSpecResponse>, StatsigGrpcErr> {
72 let request = create_config_spec_request(&self.sdk_key, lcut);
73 let mut client = self.get_or_setup_grpc_client().await?;
74
75 client
76 .stream_config_spec(request)
77 .await
78 .map_err(StatsigGrpcErr::ErrorGrpcStatus)
79 .map(|s| s.into_inner())
80 }
81
82 fn setup_tls_client(
83 authentication_mode: Option<String>,
84 ca_cert_path: Option<String>,
85 client_cert_path: Option<String>,
86 client_key_path: Option<String>,
87 domain_name: Option<String>,
88 proxy_api: &str,
89 ) -> Option<ClientTlsConfig> {
90 let domain_name = domain_name.unwrap_or_else(|| {
91 Self::extract_host(proxy_api)
92 .unwrap_or_default()
93 .to_string()
94 });
95 match authentication_mode
96 .as_deref()
97 .map(str::to_ascii_lowercase)
98 .as_deref()
99 {
100 Some("tls") => {
101 let ca_cert_path = ca_cert_path?;
102 let ca_cert: Vec<u8> = std::fs::read(ca_cert_path).ok()?;
103 let ca_cert = Certificate::from_pem(ca_cert);
104
105 Some(
106 ClientTlsConfig::new()
107 .ca_certificate(ca_cert)
108 .domain_name(domain_name), )
110 }
111 Some("mtls") => {
112 let ca_cert_path = ca_cert_path?;
113 let client_cert_path = client_cert_path?;
114 let client_key_path = client_key_path?;
115
116 let ca_cert = std::fs::read(ca_cert_path).ok()?;
117 let client_cert = std::fs::read(client_cert_path).ok()?;
118 let client_key = std::fs::read(client_key_path).ok()?;
119
120 let ca_cert = Certificate::from_pem(ca_cert);
121 let identity = Identity::from_pem(client_cert, client_key);
122
123 Some(
124 ClientTlsConfig::new()
125 .ca_certificate(ca_cert)
126 .identity(identity)
127 .domain_name(domain_name), )
129 }
130 _ => None,
131 }
132 }
133
134 fn extract_host(url: &str) -> Option<&str> {
135 let without_scheme = if let Some(pos) = url.find("://") {
137 &url[(pos + 3)..]
138 } else {
139 url
140 };
141
142 let host_port = without_scheme.split('/').next()?; host_port.split(':').next()
147 }
148
149 async fn get_or_setup_grpc_client(
150 &self,
151 ) -> Result<StatsigForwardProxyClient<Channel>, StatsigGrpcErr> {
152 {
153 let lock = self
154 .grpc_client
155 .try_lock_for(Duration::from_secs(5))
156 .ok_or(StatsigGrpcErr::FailedToGetLock)?;
157
158 if let Some(client) = lock.as_ref() {
159 return Ok(client.clone());
160 }
161 }
162
163 let mut channel_builder = Channel::from_shared(self.proxy_api.clone())
164 .map_err(|e| StatsigGrpcErr::FailedToConnect(e.to_string()))?
165 .connect_timeout(Duration::from_secs(5))
166 .tcp_keepalive(Some(Duration::from_secs(30)))
167 .keep_alive_while_idle(true)
168 .http2_keep_alive_interval(Duration::from_secs(30));
169
170 if let Some(tls_config) = self.tls_config.clone() {
171 channel_builder = channel_builder
172 .tls_config(tls_config)
173 .map_err(|e| StatsigGrpcErr::Authentication(e.to_string()))?;
174 }
175 let channel = channel_builder
176 .connect()
177 .await
178 .map_err(|e| StatsigGrpcErr::FailedToConnect(e.to_string()))?;
179
180 let new_client = StatsigForwardProxyClient::new(channel);
181
182 let mut lock = self
183 .grpc_client
184 .try_lock_for(Duration::from_secs(5))
185 .ok_or(StatsigGrpcErr::FailedToGetLock)?;
186
187 *lock = Some(new_client.clone());
188 Ok(new_client)
189 }
190}
191
192fn create_config_spec_request(sdk_key: &str, current_lcut: Option<u64>) -> ConfigSpecRequest {
193 ConfigSpecRequest {
194 since_time: current_lcut,
195 sdk_key: sdk_key.to_string(),
196 version: Some(ApiVersion::V2 as i32),
197 zstd_dict_id: None,
198 }
199}