use std::collections::hash_map::HashMap;
use std::default::Default;
use std::fmt::Debug;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use arc_swap::ArcSwapOption;
use chrono::Utc;
use enum_map::{Enum, EnumMap};
use futures_timer::Delay;
use log::{debug, trace, warn};
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::api::{Feature, Features, Metrics, MetricsBucket, Registration};
use crate::context::Context;
use crate::http::HTTP;
use crate::strategy;
pub struct ClientBuilder {
enable_str_features: bool,
interval: u64,
strategies: HashMap<String, strategy::Strategy>,
}
impl ClientBuilder {
pub fn into_client<C, F>(
self,
api_url: &str,
app_name: &str,
instance_id: &str,
authorization: Option<String>,
) -> Result<Client<C, F>, http_client::Error>
where
C: http_client::HttpClient + Default,
F: Enum<CachedFeature> + Debug + DeserializeOwned + Serialize,
{
Ok(Client {
api_url: api_url.into(),
app_name: app_name.into(),
enable_str_features: self.enable_str_features,
instance_id: instance_id.into(),
interval: self.interval,
polling: AtomicBool::new(false),
http: HTTP::new(app_name.into(), instance_id.into(), authorization)?,
cached_state: ArcSwapOption::from(None),
strategies: Mutex::new(self.strategies),
})
}
pub fn enable_string_features(mut self) -> Self {
self.enable_str_features = true;
self
}
pub fn interval(mut self, interval: u64) -> Self {
self.interval = interval;
self
}
pub fn strategy(mut self, name: &str, strategy: strategy::Strategy) -> Self {
self.strategies.insert(name.into(), strategy);
self
}
}
impl Default for ClientBuilder {
fn default() -> ClientBuilder {
let result = ClientBuilder {
enable_str_features: false,
interval: 15000,
strategies: Default::default(),
};
result
.strategy("default", Box::new(&strategy::default))
.strategy("applicationHostname", Box::new(&strategy::hostname))
.strategy("default", Box::new(&strategy::default))
.strategy("gradualRolloutRandom", Box::new(&strategy::random))
.strategy("gradualRolloutSessionId", Box::new(&strategy::session_id))
.strategy("gradualRolloutUserId", Box::new(&strategy::user_id))
.strategy("remoteAddress", Box::new(&strategy::remote_address))
.strategy("userWithId", Box::new(&strategy::user_with_id))
.strategy("flexibleRollout", Box::new(&strategy::flexible_rollout))
}
}
#[derive(Default)]
pub struct CachedFeature {
strategies: Vec<strategy::Evaluate>,
known: bool,
enabled: AtomicU64,
disabled: AtomicU64,
}
struct CachedState<F>
where
F: Enum<CachedFeature>,
{
start: chrono::DateTime<chrono::Utc>,
features: EnumMap<F, CachedFeature>,
str_features: HashMap<String, CachedFeature>,
}
pub struct Client<C, F>
where
C: http_client::HttpClient,
F: Enum<CachedFeature> + Debug + DeserializeOwned + Serialize,
{
api_url: String,
app_name: String,
enable_str_features: bool,
instance_id: String,
interval: u64,
polling: AtomicBool,
http: HTTP<C>,
strategies: Mutex<HashMap<String, strategy::Strategy>>,
cached_state: ArcSwapOption<CachedState<F>>,
}
impl<C, F> Client<C, F>
where
C: http_client::HttpClient + std::default::Default,
F: Enum<CachedFeature> + Clone + Debug + DeserializeOwned + Serialize,
{
pub fn is_enabled(&self, feature_enum: F, context: Option<&Context>, default: bool) -> bool {
trace!(
"is_enabled: feature {:?} default {}, context {:?}",
feature_enum,
default,
context
);
let cache = self.cached_state.load();
let cache = if let Some(cache) = &*cache {
cache
} else {
trace!("is_enabled: No API state");
return false;
};
let feature = &cache.features[feature_enum.clone()];
let default_context: Context = Default::default();
let context = context.unwrap_or(&default_context);
for memo in feature.strategies.iter() {
if memo(context) {
debug!(
"is_enabled: feature {:?} enabled by memo {:p}, context {:?}",
feature_enum, memo, context
);
feature.enabled.fetch_add(1, Ordering::Relaxed);
return true;
} else {
feature.disabled.fetch_add(1, Ordering::Relaxed);
trace!(
"is_enabled: feature {:?} not enabled by memo {:p}, context {:?}",
feature_enum,
memo,
context
);
}
}
if !feature.known {
trace!(
"is_enabled: Unknown feature {:?}, using default {}",
feature_enum,
default
);
if default {
feature.enabled.fetch_add(1, Ordering::Relaxed);
} else {
feature.disabled.fetch_add(1, Ordering::Relaxed);
}
default
} else {
false
}
}
pub fn is_enabled_str(
&self,
feature_name: &str,
context: Option<&Context>,
default: bool,
) -> bool {
trace!(
"is_enabled: feature_str {:?} default {}, context {:?}",
feature_name,
default,
context
);
assert!(
self.enable_str_features,
"String feature lookup not enabled"
);
let cache = self.cached_state.load();
let cache = if let Some(cache) = &*cache {
cache
} else {
trace!("is_enabled: No API state");
return false;
};
if let Some(feature) = cache.str_features.get(feature_name) {
let default_context: Context = Default::default();
let context = context.unwrap_or(&default_context);
for memo in feature.strategies.iter() {
if memo(context) {
debug!(
"is_enabled: feature {} enabled by memo {:p}, context {:?}",
feature_name, memo, context
);
feature.enabled.fetch_add(1, Ordering::Relaxed);
return true;
} else {
feature.disabled.fetch_add(1, Ordering::Relaxed);
trace!(
"is_enabled: feature {} not enabled by memo {:p}, context {:?}",
feature_name,
memo,
context
);
}
}
if !feature.known {
trace!(
"is_enabled: Unknown feature {}, using default {}",
feature_name,
default
);
if default {
feature.enabled.fetch_add(1, Ordering::Relaxed);
} else {
feature.disabled.fetch_add(1, Ordering::Relaxed);
}
default
} else {
false
}
} else {
trace!(
"is_enabled: Unknown feature {}, using default {}",
feature_name,
default
);
self.cached_state
.rcu(|cached_state: &Option<Arc<CachedState<F>>>| {
if let Some(cached_state) = cached_state {
let cached_state = cached_state.clone();
if let Some(feature) = cached_state.str_features.get(feature_name) {
if default {
feature.enabled.fetch_add(1, Ordering::Relaxed);
} else {
feature.disabled.fetch_add(1, Ordering::Relaxed);
}
Some(cached_state)
} else {
let mut new_state = CachedState {
start: cached_state.start,
features: EnumMap::new(),
str_features: HashMap::new(),
};
fn cloned_feature(feature: &CachedFeature) -> CachedFeature {
CachedFeature {
disabled: AtomicU64::new(
feature.disabled.load(Ordering::Relaxed),
),
enabled: AtomicU64::new(
feature.enabled.load(Ordering::Relaxed),
),
known: feature.known,
strategies: feature.strategies.clone(),
}
};
for (key, feature) in &cached_state.features {
new_state.features[key] = cloned_feature(&feature);
}
for (name, feature) in &cached_state.str_features {
new_state
.str_features
.insert(name.clone(), cloned_feature(&feature));
}
let stub_feature = CachedFeature {
disabled: AtomicU64::new(if default { 0 } else { 1 }),
enabled: AtomicU64::new(if default { 1 } else { 0 }),
known: false,
strategies: vec![],
};
new_state
.str_features
.insert(feature_name.into(), stub_feature);
Some(Arc::new(new_state))
}
} else {
None
}
});
default
}
}
pub fn memoize(
&self,
features: Vec<Feature>,
) -> Result<Option<Metrics>, Box<dyn std::error::Error>> {
let now = Utc::now();
trace!("memoize: start with {} features", features.len());
let source_strategies = self.strategies.lock().unwrap();
let mut unenumerated_features: HashMap<String, CachedFeature> = HashMap::new();
let mut cached_features: EnumMap<F, CachedFeature> = EnumMap::new();
for feature in features {
let cached_feature = {
if !feature.enabled {
let strategies = vec![];
CachedFeature {
strategies,
disabled: AtomicU64::new(0),
enabled: AtomicU64::new(0),
known: true,
}
} else {
let mut strategies = vec![];
for api_strategy in feature.strategies {
if let Some(code_strategy) = source_strategies.get(&api_strategy.name) {
strategies.push(code_strategy(api_strategy.parameters));
}
}
CachedFeature {
strategies,
disabled: AtomicU64::new(0),
enabled: AtomicU64::new(0),
known: true,
}
}
};
if let Ok(feature_enum) = serde_plain::from_str::<F>(feature.name.as_str()) {
cached_features[feature_enum] = cached_feature;
} else {
unenumerated_features.insert(feature.name.clone(), cached_feature);
}
}
let new_cache = CachedState {
start: now,
features: cached_features,
str_features: unenumerated_features,
};
let old = self.cached_state.swap(Some(Arc::new(new_cache)));
trace!("memoize: swapped memoized state in");
if let Some(old) = old {
let mut bucket = MetricsBucket {
start: old.start,
stop: now,
toggles: HashMap::new(),
};
for (key, feature) in &old.features {
bucket.toggles.insert(
serde_plain::to_string(&key).unwrap(),
[
("yes".into(), feature.enabled.load(Ordering::Relaxed)),
("no".into(), feature.disabled.load(Ordering::Relaxed)),
]
.iter()
.cloned()
.collect(),
);
}
for (name, feature) in &old.str_features {
bucket.toggles.insert(
name.clone(),
[
("yes".into(), feature.enabled.load(Ordering::Relaxed)),
("no".into(), feature.disabled.load(Ordering::Relaxed)),
]
.iter()
.cloned()
.collect(),
);
}
let metrics = Metrics {
app_name: self.app_name.clone(),
instance_id: self.instance_id.clone(),
bucket,
};
Ok(Some(metrics))
} else {
Ok(None)
}
}
pub async fn poll_for_updates(&self) {
let endpoint = Features::endpoint(&self.api_url);
let metrics_endpoint = Metrics::endpoint(&self.api_url);
self.polling.store(true, Ordering::Relaxed);
loop {
debug!("poll: retrieving features");
let res = self.http.get(&endpoint).recv_json().await;
if let Ok(res) = res {
let features: Features = res;
match self.memoize(features.features) {
Ok(None) => {}
Ok(Some(metrics)) => {
let mut metrics_uploaded = false;
let req = self.http.post(&metrics_endpoint).body_json(&metrics);
if let Ok(req) = req {
let res = req.await;
if let Ok(res) = res {
if res.status().is_success() {
metrics_uploaded = true;
debug!("poll: uploaded feature metrics")
}
}
}
if !metrics_uploaded {
warn!("poll: error uploading feature metrics");
}
}
Err(_) => {
warn!("poll: failed to memoize features");
}
}
} else {
warn!("poll: failed to retrieve features");
}
Delay::new(Duration::from_millis(self.interval)).await;
if !self.polling.load(Ordering::Relaxed) {
return;
}
}
}
pub async fn register(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let registration = Registration {
app_name: self.app_name.clone(),
instance_id: self.instance_id.clone(),
interval: self.interval,
strategies: self
.strategies
.lock()
.unwrap()
.keys()
.map(|s| s.to_owned())
.collect(),
..Default::default()
};
let res = self
.http
.post(Registration::endpoint(&self.api_url))
.body_json(®istration)?
.await?;
if !res.status().is_success() {
return Err(anyhow::anyhow!("Failed to register with unleash API server").into());
}
Ok(())
}
pub async fn stop_poll(&self) {
loop {
match self
.polling
.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed)
{
Ok(_) => {
return;
}
Err(_) => {
Delay::new(Duration::from_millis(50)).await;
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::collections::hash_map::HashMap;
use std::collections::hash_set::HashSet;
use std::hash::BuildHasher;
use enum_map::Enum;
use maplit::hashmap;
use serde::{Deserialize, Serialize};
use super::ClientBuilder;
use crate::api::{Feature, Features, Strategy};
use crate::context::Context;
use crate::strategy;
fn features() -> Features {
Features {
version: 1,
features: vec![
Feature {
description: "default".into(),
enabled: true,
created_at: None,
variants: None,
name: "default".into(),
strategies: vec![Strategy {
name: "default".into(),
parameters: None,
}],
},
Feature {
description: "userWithId".into(),
enabled: true,
created_at: None,
variants: None,
name: "userWithId".into(),
strategies: vec![Strategy {
name: "userWithId".into(),
parameters: Some(hashmap!["userIds".into()=>"present".into()]),
}],
},
Feature {
description: "userWithId+default".into(),
enabled: true,
created_at: None,
variants: None,
name: "userWithId+default".into(),
strategies: vec![
Strategy {
name: "userWithId".into(),
parameters: Some(hashmap!["userIds".into()=>"present".into()]),
},
Strategy {
name: "default".into(),
parameters: None,
},
],
},
Feature {
description: "disabled".into(),
enabled: false,
created_at: None,
variants: None,
name: "disabled".into(),
strategies: vec![Strategy {
name: "default".into(),
parameters: None,
}],
},
],
}
}
#[test]
fn test_memoization_enum() {
let _ = simple_logger::init();
let f = features();
#[allow(non_camel_case_types)]
#[derive(Debug, Deserialize, Serialize, Enum, Clone)]
enum UserFeatures {
unknown,
default,
userWithId,
#[serde(rename = "userWithId+default")]
userWithId_Default,
disabled,
}
let c = ClientBuilder::default()
.into_client::<http_client::native::NativeClient, UserFeatures>(
"http://127.0.0.1:1234/",
"foo",
"test",
None,
)
.unwrap();
c.memoize(f.features).unwrap();
let present: Context = Context {
user_id: Some("present".into()),
..Default::default()
};
let missing: Context = Context {
user_id: Some("missing".into()),
..Default::default()
};
assert_eq!(false, c.is_enabled(UserFeatures::unknown, None, false));
assert_eq!(true, c.is_enabled(UserFeatures::unknown, None, true));
assert_eq!(true, c.is_enabled(UserFeatures::default, None, false));
assert_eq!(
true,
c.is_enabled(UserFeatures::userWithId, Some(&present), false)
);
assert_eq!(
false,
c.is_enabled(UserFeatures::userWithId, Some(&missing), false)
);
assert_eq!(
true,
c.is_enabled(UserFeatures::userWithId_Default, Some(&missing), false)
);
assert_eq!(false, c.is_enabled(UserFeatures::disabled, None, true));
}
#[test]
fn test_memoization_strs() {
let _ = simple_logger::init();
let f = features();
#[derive(Debug, Deserialize, Serialize, Enum, Clone)]
enum NoFeatures {}
let c = ClientBuilder::default()
.enable_string_features()
.into_client::<http_client::native::NativeClient, NoFeatures>(
"http://127.0.0.1:1234/",
"foo",
"test",
None,
)
.unwrap();
c.memoize(f.features).unwrap();
let present: Context = Context {
user_id: Some("present".into()),
..Default::default()
};
let missing: Context = Context {
user_id: Some("missing".into()),
..Default::default()
};
assert_eq!(false, c.is_enabled_str("unknown", None, false));
assert_eq!(true, c.is_enabled_str("unknown", None, true));
assert_eq!(true, c.is_enabled_str("default", None, false));
assert_eq!(true, c.is_enabled_str("userWithId", Some(&present), false));
assert_eq!(false, c.is_enabled_str("userWithId", Some(&missing), false));
assert_eq!(
true,
c.is_enabled_str("userWithId+default", Some(&missing), false)
);
assert_eq!(false, c.is_enabled_str("disabled", None, true));
}
fn _reversed_uids<S: BuildHasher>(
parameters: Option<HashMap<String, String, S>>,
) -> strategy::Evaluate {
let mut uids: HashSet<String> = HashSet::new();
if let Some(parameters) = parameters {
if let Some(uids_list) = parameters.get("userIds") {
for uid in uids_list.split(',') {
uids.insert(uid.chars().rev().collect());
}
}
}
Box::new(move |context: &Context| -> bool {
context
.user_id
.as_ref()
.map(|uid| uids.contains(uid))
.unwrap_or(false)
})
}
#[test]
fn test_custom_strategy() {
let _ = simple_logger::init();
#[allow(non_camel_case_types)]
#[derive(Debug, Deserialize, Serialize, Enum, Clone)]
enum UserFeatures {
default,
reversed,
}
let client = ClientBuilder::default()
.strategy("reversed", Box::new(&_reversed_uids))
.into_client::<http_client::native::NativeClient, UserFeatures>(
"http://127.0.0.1:1234/",
"foo",
"test",
None,
)
.unwrap();
let f = Features {
version: 1,
features: vec![
Feature {
description: "default".into(),
enabled: true,
created_at: None,
variants: None,
name: "default".into(),
strategies: vec![Strategy {
name: "default".into(),
parameters: None,
}],
},
Feature {
description: "reversed".into(),
enabled: true,
created_at: None,
variants: None,
name: "reversed".into(),
strategies: vec![Strategy {
name: "reversed".into(),
parameters: Some(hashmap!["userIds".into()=>"abc".into()]),
}],
},
],
};
client.memoize(f.features).unwrap();
let present: Context = Context {
user_id: Some("cba".into()),
..Default::default()
};
let missing: Context = Context {
user_id: Some("abc".into()),
..Default::default()
};
assert_eq!(
true,
client.is_enabled(UserFeatures::reversed, Some(&present), false)
);
assert_eq!(
false,
client.is_enabled(UserFeatures::reversed, Some(&missing), false)
);
assert_eq!(true, client.is_enabled(UserFeatures::default, None, false));
}
}