Skip to main content

sigstat_grpc/
mock_forward_proxy.rs

1use lazy_static::lazy_static;
2use parking_lot::Mutex;
3use std::net::SocketAddr;
4use std::sync::atomic::{AtomicI32, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::sync::mpsc::{Receiver, Sender};
8use tokio::sync::Notify;
9use tokio::task::JoinHandle;
10use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
11use tonic::transport::Server;
12use tonic::{Request, Response, Status};
13
14pub mod api {
15    tonic::include_proto!("statsig_forward_proxy");
16}
17
18use api::statsig_forward_proxy_server::{StatsigForwardProxy, StatsigForwardProxyServer};
19use api::{ConfigSpecRequest, ConfigSpecResponse};
20
21lazy_static! {
22    static ref PORT_ID: AtomicI32 = AtomicI32::new(50051);
23}
24
25#[tokio::main]
26pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
27    let mock_server = MockForwardProxy::spawn().await;
28    mock_server
29        .send_stream_update(Ok(ConfigSpecResponse {
30            spec: "bg_sync".to_string(),
31            last_updated: 123,
32            zstd_dict_id: None,
33        }))
34        .await;
35
36    Ok(())
37}
38
39pub async fn wait_one_ms() {
40    tokio::time::sleep(Duration::from_millis(1)).await;
41}
42pub struct MockForwardProxy {
43    pub proxy_address: SocketAddr,
44    pub stubbed_get_config_spec_response: Mutex<ConfigSpecResponse>,
45
46    shutdown_notifier: Arc<Notify>,
47    server_handle: Mutex<Option<JoinHandle<()>>>,
48
49    stream_tx: Mutex<Option<Sender<Result<ConfigSpecResponse, Status>>>>,
50    stream_rx: Mutex<Option<Receiver<Result<ConfigSpecResponse, Status>>>>,
51}
52
53impl MockForwardProxy {
54    pub async fn spawn() -> Arc<MockForwardProxy> {
55        let port = PORT_ID.fetch_add(1, Ordering::SeqCst);
56        let proxy_address: SocketAddr = format!("127.0.0.1:{port}").parse().unwrap();
57
58        let forward_proxy = Arc::new(MockForwardProxy {
59            proxy_address,
60            stubbed_get_config_spec_response: Mutex::new(ConfigSpecResponse {
61                spec: "NOT STUBBED".to_string(),
62                last_updated: 0,
63                zstd_dict_id: None,
64            }),
65
66            shutdown_notifier: Arc::new(Notify::new()),
67            server_handle: Mutex::new(None),
68
69            stream_tx: Mutex::new(None),
70            stream_rx: Mutex::new(None),
71        });
72
73        forward_proxy.clone().restart().await;
74        forward_proxy
75    }
76
77    pub async fn send_stream_update(&self, update: Result<ConfigSpecResponse, Status>) {
78        let sender = {
79            let guard = self.stream_tx.try_lock().unwrap();
80            guard.as_ref().unwrap().clone()
81        };
82
83        if let Err(err) = sender.send(update).await {
84            print!("Failed to send update {err}")
85        }
86    }
87
88    pub async fn stop(&self) {
89        let handle = self.server_handle.try_lock().unwrap().take();
90        if let Some(handle) = handle {
91            self.send_stream_update(Err(Status::unavailable("Connection Lost")))
92                .await;
93            self.shutdown_notifier.notify_one();
94            wait_one_ms().await;
95
96            let _ = handle.await;
97        }
98    }
99
100    pub async fn restart(self: Arc<Self>) {
101        self.stop().await;
102
103        let mock_service = MockForwardProxyService {
104            proxy: self.clone(),
105        };
106
107        let shutdown_notify = self.shutdown_notifier.clone();
108        let address = self.proxy_address;
109
110        let server_handle = tokio::spawn(async move {
111            let _ = Server::builder()
112                .add_service(StatsigForwardProxyServer::new(mock_service))
113                .serve_with_shutdown(address, async {
114                    shutdown_notify.notified().await;
115                })
116                .await;
117        });
118
119        let (tx, rx) = tokio::sync::mpsc::channel(4);
120
121        *self.stream_tx.try_lock().unwrap() = Some(tx);
122        *self.stream_rx.try_lock().unwrap() = Some(rx);
123        *self.server_handle.try_lock().unwrap() = Some(server_handle);
124
125        wait_one_ms().await; // wait for the update to be applied
126    }
127}
128
129struct MockForwardProxyService {
130    pub proxy: Arc<MockForwardProxy>,
131}
132
133#[tonic::async_trait]
134impl StatsigForwardProxy for MockForwardProxyService {
135    async fn get_config_spec(
136        &self,
137        _request: Request<ConfigSpecRequest>,
138    ) -> Result<Response<ConfigSpecResponse>, Status> {
139        let response = self
140            .proxy
141            .stubbed_get_config_spec_response
142            .try_lock()
143            .unwrap()
144            .clone();
145        Ok(Response::new(response))
146    }
147
148    type StreamConfigSpecStream = ReceiverStream<Result<ConfigSpecResponse, Status>>;
149
150    async fn stream_config_spec(
151        &self,
152        _request: Request<ConfigSpecRequest>,
153    ) -> Result<Response<Self::StreamConfigSpecStream>, Status> {
154        let rx = self.proxy.stream_rx.try_lock().unwrap().take().unwrap();
155
156        let stream = ReceiverStream::new(rx);
157        Ok(Response::new(stream))
158    }
159}