sigstat_grpc/
mock_forward_proxy.rs

1use lazy_static::lazy_static;
2use std::net::SocketAddr;
3use std::sync::atomic::{AtomicI32, Ordering};
4use std::sync::{Arc, Mutex};
5use std::time::Duration;
6use tokio::sync::mpsc::{Receiver, Sender};
7use tokio::sync::Notify;
8use tokio::task::JoinHandle;
9use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
10use tonic::transport::Server;
11use tonic::{Request, Response, Status};
12
13pub mod api {
14    tonic::include_proto!("statsig_forward_proxy");
15}
16
17use api::statsig_forward_proxy_server::{StatsigForwardProxy, StatsigForwardProxyServer};
18use api::{ConfigSpecRequest, ConfigSpecResponse};
19
20lazy_static! {
21    static ref PORT_ID: AtomicI32 = AtomicI32::new(50051);
22}
23
24#[tokio::main]
25pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
26    let mock_server = MockForwardProxy::spawn().await;
27    mock_server
28        .send_stream_update(Ok(ConfigSpecResponse {
29            spec: "bg_sync".to_string(),
30            last_updated: 123,
31            zstd_dict_id: None,
32        }))
33        .await;
34
35    Ok(())
36}
37
38pub async fn wait_one_ms() {
39    tokio::time::sleep(Duration::from_millis(1)).await;
40}
41pub struct MockForwardProxy {
42    pub proxy_address: SocketAddr,
43    pub stubbed_get_config_spec_response: Mutex<ConfigSpecResponse>,
44
45    shutdown_notifier: Arc<Notify>,
46    server_handle: Mutex<Option<JoinHandle<()>>>,
47
48    stream_tx: Mutex<Option<Sender<Result<ConfigSpecResponse, Status>>>>,
49    stream_rx: Mutex<Option<Receiver<Result<ConfigSpecResponse, Status>>>>,
50}
51
52impl MockForwardProxy {
53    pub async fn spawn() -> Arc<MockForwardProxy> {
54        let port = PORT_ID.fetch_add(1, Ordering::SeqCst);
55        let proxy_address: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
56
57        let forward_proxy = Arc::new(MockForwardProxy {
58            proxy_address,
59            stubbed_get_config_spec_response: Mutex::new(ConfigSpecResponse {
60                spec: "NOT STUBBED".to_string(),
61                last_updated: 0,
62                zstd_dict_id: None,
63            }),
64
65            shutdown_notifier: Arc::new(Notify::new()),
66            server_handle: Mutex::new(None),
67
68            stream_tx: Mutex::new(None),
69            stream_rx: Mutex::new(None),
70        });
71
72        forward_proxy.clone().restart().await;
73        forward_proxy
74    }
75
76    pub async fn send_stream_update(&self, update: Result<ConfigSpecResponse, Status>) {
77        let sender = {
78            let guard = self.stream_tx.lock().unwrap();
79            guard.as_ref().unwrap().clone()
80        };
81
82        if let Err(err) = sender.send(update).await {
83            print!("Failed to send update {}", err)
84        }
85    }
86
87    pub async fn stop(&self) {
88        let handle = self.server_handle.lock().unwrap().take();
89        if let Some(handle) = handle {
90            self.send_stream_update(Err(Status::unavailable("Connection Lost")))
91                .await;
92            self.shutdown_notifier.notify_one();
93            wait_one_ms().await;
94
95            let _ = handle.await;
96        }
97    }
98
99    pub async fn restart(self: Arc<Self>) {
100        self.stop().await;
101
102        let mock_service = MockForwardProxyService {
103            proxy: self.clone(),
104        };
105
106        let shutdown_notify = self.shutdown_notifier.clone();
107        let address = self.proxy_address;
108
109        let server_handle = tokio::spawn(async move {
110            let _ = Server::builder()
111                .add_service(StatsigForwardProxyServer::new(mock_service))
112                .serve_with_shutdown(address, async {
113                    shutdown_notify.notified().await;
114                })
115                .await;
116        });
117
118        let (tx, rx) = tokio::sync::mpsc::channel(4);
119
120        *self.stream_tx.lock().unwrap() = Some(tx);
121        *self.stream_rx.lock().unwrap() = Some(rx);
122        *self.server_handle.lock().unwrap() = Some(server_handle);
123
124        wait_one_ms().await; // wait for the update to be applied
125    }
126}
127
128struct MockForwardProxyService {
129    pub proxy: Arc<MockForwardProxy>,
130}
131
132#[tonic::async_trait]
133impl StatsigForwardProxy for MockForwardProxyService {
134    async fn get_config_spec(
135        &self,
136        _request: Request<ConfigSpecRequest>,
137    ) -> Result<Response<ConfigSpecResponse>, Status> {
138        let response = self
139            .proxy
140            .stubbed_get_config_spec_response
141            .lock()
142            .unwrap()
143            .clone();
144        Ok(Response::new(response))
145    }
146
147    type StreamConfigSpecStream = ReceiverStream<Result<ConfigSpecResponse, Status>>;
148
149    async fn stream_config_spec(
150        &self,
151        _request: Request<ConfigSpecRequest>,
152    ) -> Result<Response<Self::StreamConfigSpecStream>, Status> {
153        let rx = self.proxy.stream_rx.lock().unwrap().take().unwrap();
154
155        let stream = ReceiverStream::new(rx);
156        Ok(Response::new(stream))
157    }
158}