unleash_proxy/
lib.rs

1//! Copyright 2020-2022 Cognite AS
2//!
3//! The included binary `unleash-proxy` is equivalent to this code:
4//! ```rust,no_run
5//! #[tokio::main]
6//! async fn main() -> anyhow::Result<()> {
7//!   // Add deployment specific concerns here
8//!   env_logger::init();
9//!   unleash_proxy::main().await
10//! }
11//! ```
12//!
13//! Here is an example adding a custom strategy:
14//! ```rust,no_run
15//! use std::collections::{HashMap, HashSet};
16//! use std::hash::BuildHasher;
17//! use serde::{Deserialize, Serialize};
18//! use unleash_api_client::context::Context;
19//! use unleash_api_client::strategy;
20//!
21//! pub fn example<S: BuildHasher>(
22//!     parameters: Option<HashMap<String, String, S>>,
23//! ) -> strategy::Evaluate {
24//!     let mut items: HashSet<String> = HashSet::new();
25//!     if let Some(parameters) = parameters {
26//!         if let Some(item_list) = parameters.get("exampleParameter") {
27//!             for item in item_list.split(',') {
28//!                 items.insert(item.trim().into());
29//!             }
30//!         }
31//!     }
32//!     Box::new(move |context: &Context| -> bool {
33//!         matches!(
34//!             context
35//!                 .properties
36//!                 .get("exampleProperty")
37//!                 .map(|item| items.contains(item)),
38//!             Some(true)
39//!         )
40//!     })
41//! }
42//!
43//! #[tokio::main]
44//! async fn main() -> anyhow::Result<()> {
45//!   // Add deployment specific concerns here
46//!   env_logger::init();
47//!   unleash_proxy::ProxyBuilder::default().
48//!       strategy("example", Box::new(&example)).execute().await
49//! }
50//! ```
51#![warn(clippy::all)]
52
53use std::collections::HashMap;
54use std::convert::Infallible;
55use std::sync::{Arc, Mutex};
56use std::time::Duration;
57
58use anyhow::{anyhow, Context as AnyhowContext, Result};
59use chrono::Utc;
60use enum_map::Enum;
61use futures_timer::Delay;
62use hyper::service::{make_service_fn, service_fn};
63use hyper::{Body, Request, Response, Server};
64use hyper::{Method, StatusCode};
65use log::{debug, warn};
66use serde::{Deserialize, Serialize};
67use unleash_api_client::{
68    api::{Metrics, MetricsBucket},
69    client,
70    config::EnvironmentConfig,
71    context::{Context, IPAddress},
72    strategy::Strategy,
73    ClientBuilder,
74};
75
76const ALLOWED_HEADERS: &str = "authorization,content-type,if-none-match";
77
78#[allow(non_camel_case_types)]
79#[derive(Debug, Deserialize, Serialize, Enum, Clone)]
80enum UserFeatures {}
81
82#[derive(Deserialize, Serialize, Debug, Clone)]
83struct Payload {
84    #[serde(rename = "type")]
85    _type: String,
86    value: String,
87}
88
89#[derive(Deserialize, Serialize, Debug, Clone)]
90struct Variant {
91    name: String,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    payload: Option<Payload>,
94}
95
96#[derive(Deserialize, Serialize, Debug, Clone)]
97struct Toggle {
98    name: String,
99    enabled: bool,
100    variant: Variant,
101}
102
103#[derive(Default, Deserialize, Serialize, Debug, Clone)]
104struct Toggles {
105    toggles: Vec<Toggle>,
106}
107
108const PROPERTY_PREFIX: &str = "properties[";
109
110fn extract_key(k: &str) -> String {
111    k[PROPERTY_PREFIX.len()..k.len() - 1].to_string()
112}
113
114async fn toggles(
115    client: Arc<client::Client<UserFeatures>>,
116    req: Request<Body>,
117) -> Result<Response<Body>> {
118    let cache = client.cached_state();
119    let toggles = match cache.as_ref() {
120        // Make an empty API doc with nothing in it
121        None => Toggles::default(),
122        Some(cache) => {
123            let mut toggles = Toggles::default();
124            let mut context: Context = Default::default();
125            let fake_root = url::Url::parse("http://fakeroot.example.com/")?;
126            // unwrap should be safe because to get here the uri must have been valid already
127            // but perhaps we should handle it
128            let url = fake_root
129                .join(&req.uri().to_string())
130                .context("bad uri in request")?;
131            for (k, v) in url.query_pairs() {
132                match k.as_ref() {
133                    "environment" => context.environment = v.to_string(),
134                    "appName" => context.app_name = v.to_string(),
135                    "userId" => context.user_id = Some(v.to_string()),
136                    "sessionId" => context.session_id = Some(v.to_string()),
137                    "remoteAddress" => {
138                        let ip_parsed = ipaddress::IPAddress::parse(v.to_string());
139                        // should we report errors on bad IP address formats?
140                        context.remote_address = ip_parsed.ok().map(IPAddress);
141                    }
142                    k if k.starts_with(PROPERTY_PREFIX) && k.ends_with(']') => {
143                        let k = extract_key(k);
144                        context.properties.insert(k, v.to_string());
145                    }
146                    _ => {}
147                }
148            }
149            for (name, feature) in cache.str_features() {
150                let mut enabled = false;
151                for memo in feature.strategies.iter() {
152                    if memo(&context) {
153                        enabled = true;
154                        break;
155                    }
156                }
157                let toggle = Toggle {
158                    name: name.to_string(),
159                    enabled,
160                    variant: Variant {
161                        // TODO: support variants in the underlying client
162                        name: "default".into(),
163                        payload: None,
164                    },
165                };
166                toggles.toggles.push(toggle);
167            }
168            toggles
169        }
170    };
171
172    Ok(Response::builder()
173        .header(hyper::header::CONTENT_TYPE, "application/json")
174        .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
175        .header(hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, ALLOWED_HEADERS)
176        .status(StatusCode::OK)
177        .body(serde_json::to_vec(&toggles)?.into())?)
178}
179
180async fn metrics(
181    metrics: Arc<Mutex<HashMap<String, Metrics>>>,
182    req: Request<Body>,
183) -> Result<Response<Body>> {
184    // TODO: fixup the result types.
185    let whole_body = hyper::body::to_bytes(req.into_body())
186        .await
187        .expect("failed to get body");
188    let req_metrics: Metrics = serde_json::from_slice(&whole_body).expect("valid metrics");
189    // We could could be super clever here and only merge buckets that are broadly compatible by time, but honestly,
190    // supporting time skewed clients doesn't make sense. The use case for metrics is really just to know whether a
191    // thing is or isn't being used and by which application.
192
193    // Secondly, most folk running this are going to have just a few web apps, so being super scalable in the app-name
194    // dimension isn't very useful: we actually need to be scalable in accepting the updates, and this is a write-heavy
195    // workload, so arc_swap isn't useful. If this becomes a hot spot, we need to look at a journalling mechanism. For
196    // now, we lock around updates, making this a serialisation point but hopefully fast.
197    {
198        let mut metrics = metrics.lock().unwrap();
199        let entry = metrics
200            .entry(req_metrics.app_name.clone())
201            .or_insert_with(|| Metrics {
202                app_name: req_metrics.app_name.clone(),
203                instance_id: "proxy".into(),
204                bucket: MetricsBucket {
205                    // Save on computing times here: we will calculate appropriate buckets when we submit to the API
206                    // server.
207                    start: req_metrics.bucket.start,
208                    stop: req_metrics.bucket.stop,
209                    toggles: HashMap::new(),
210                },
211            });
212        for (toggle, info) in req_metrics.bucket.toggles {
213            for (state, count) in info {
214                let toggle_map = entry.bucket.toggles.entry(toggle.clone());
215                let counter = toggle_map
216                    .or_insert_with(HashMap::new)
217                    .entry(state)
218                    .or_insert(0);
219                *counter += count;
220            }
221        }
222    }
223    Ok(Response::builder()
224        .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
225        .header(hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, ALLOWED_HEADERS)
226        .status(StatusCode::OK)
227        .body(Body::empty())?)
228}
229
230async fn send_metrics(
231    url: &str,
232    client: Arc<client::Client<UserFeatures>>,
233    metrics: Arc<Mutex<HashMap<String, Metrics>>>,
234    interval: Duration,
235) {
236    let metrics_endpoint = Metrics::endpoint(url);
237    loop {
238        let start = Utc::now();
239        debug!("send_metrics: waiting {:?}", interval);
240        Delay::new(interval).await;
241        let mut batch = HashMap::new();
242        {
243            let mut locked = metrics.lock().unwrap();
244            std::mem::swap(&mut batch, &mut locked);
245        }
246        debug!("sending metrics");
247        let stop = Utc::now();
248        // TODO: very large numbers of discrete apps will cause this loop to
249        // start exceeding 15 seconds and require assembling a concurrent
250        // approach here as well, but this is probably a very very long way off.
251        for (app_name, mut metrics) in batch {
252            let mut metrics_uploaded = false;
253            metrics.bucket.start = start;
254            metrics.bucket.stop = stop;
255            let req = client.http.post(&metrics_endpoint);
256            if let Ok(body) = http_types::Body::from_json(&metrics) {
257                let res = req.body(body).await;
258                if let Ok(res) = res {
259                    if res.status().is_success() {
260                        metrics_uploaded = true;
261                        debug!("poll: uploaded feature metrics `{}`", app_name);
262                    }
263                }
264            }
265            if !metrics_uploaded {
266                warn!("poll: error uploading feature metrics `{}`", app_name);
267            }
268        }
269    }
270}
271
272/// Core workhorse for the proxy. Code in this function is generic across
273/// different deployment configurations e.g. logging, tracing, metrics
274/// implementations. See the [`crate`] level documentation for examples.
275pub async fn main() -> Result<()> {
276    ProxyBuilder::default().execute().await
277}
278
279async fn _main(builder: ClientBuilder) -> Result<()> {
280    // Not deployment specific:
281    // We'll bind to 127.0.0.1:3000
282    debug!("serving on 127.0.0.1:3000");
283    let addr = ([127, 0, 0, 1], 3000).into();
284
285    let config = EnvironmentConfig::from_env().map_err(|e| anyhow!(e))?;
286    let client = Arc::new(
287        builder
288            .into_client::<UserFeatures>(
289                &config.api_url,
290                &config.app_name,
291                &config.instance_id,
292                config.secret.clone(),
293            )
294            .map_err(|e| anyhow!(e))?,
295    );
296    client.register().await.map_err(|e| anyhow!(e))?;
297
298    let client_metrics = Arc::new(Mutex::new(HashMap::new()));
299
300    let make_svc = make_service_fn(|_conn| {
301        let conn_client = client.clone();
302        let conn_metrics = client_metrics.clone();
303        async move {
304            Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
305                // Consider making a single struct to reduce Arc reference
306                // taking overheads if this service gets busy.
307                let req_client = conn_client.clone();
308                let req_metrics = conn_metrics.clone();
309                async move {
310                    match (req.method(), req.uri().path()) {
311                        (&Method::GET, "/") => toggles(req_client, req).await,
312                        (&Method::OPTIONS, "/") => Ok(Response::builder()
313                            .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
314                            .header(hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, ALLOWED_HEADERS)
315                            .status(StatusCode::OK)
316                            .body(Body::empty())?),
317                        (&Method::POST, "/client/metrics") => metrics(req_metrics, req).await,
318                        (&Method::OPTIONS, "/client/metrics") => Ok(Response::builder()
319                            .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
320                            .header(hyper::header::ACCESS_CONTROL_ALLOW_HEADERS, ALLOWED_HEADERS)
321                            .status(StatusCode::OK)
322                            .body(Body::empty())?),
323                        _ => Ok(Response::builder()
324                            .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
325                            .status(StatusCode::NOT_FOUND)
326                            .body(Body::empty())?),
327                    }
328                }
329            }))
330        }
331    });
332
333    let server = Server::bind(&addr).serve(make_svc);
334    if let Err(e) = futures::try_join!(
335        async {
336            client.poll_for_updates().await;
337            Ok(())
338        },
339        async {
340            send_metrics(
341                &config.api_url,
342                client.clone(),
343                client_metrics.clone(),
344                // 30 seconds is the default interval for metrics in the browser client source
345                Duration::from_secs(30),
346            )
347            .await;
348            Ok(())
349        },
350        server,
351    ) {
352        eprintln!("server error: {}", e);
353    }
354    Ok(())
355}
356
357/// Permits customising the Proxy behaviour. See the [`crate`] level docs for examples.
358pub struct ProxyBuilder {
359    client_builder: ClientBuilder,
360}
361
362impl ProxyBuilder {
363    /// Run the configured proxy
364    pub async fn execute(self) -> Result<()> {
365        _main(self.client_builder).await
366    }
367
368    /// Add a [`Strategy`] to this proxy.
369    pub fn strategy(self, name: &str, strategy: Strategy) -> Self {
370        ProxyBuilder {
371            client_builder: self.client_builder.strategy(name, strategy),
372        }
373    }
374}
375
376impl Default for ProxyBuilder {
377    fn default() -> Self {
378        ProxyBuilder {
379            client_builder: ClientBuilder::default()
380                .disable_metric_submission()
381                .enable_string_features(),
382        }
383    }
384}
385
386mod tests {
387    #[test]
388    fn properties() {
389        assert_eq!("foo", super::extract_key("properties[foo]"));
390        assert_eq!("", super::extract_key("properties[]"));
391    }
392}