unleash_edge_types/
lib.rs

1use std::cmp::min;
2use std::fmt;
3use std::fmt::{Debug, Display, Formatter};
4use std::net::IpAddr;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::{
8    collections::HashMap,
9    hash::{Hash, Hasher},
10    str::FromStr,
11};
12
13use axum::Json;
14use axum::http::HeaderValue;
15use axum::response::{IntoResponse, Response};
16use chrono::{DateTime, Duration, Utc};
17use dashmap::DashMap;
18use etag::EntityTag;
19use http::StatusCode;
20use serde::{Deserialize, Deserializer, Serialize, Serializer};
21use shadow_rs::shadow;
22use unleash_types::client_features::Context;
23use unleash_types::client_features::{ClientFeatures, ClientFeaturesDelta};
24use unleash_types::client_metrics::{ClientApplication, ClientMetricsEnv, ImpactMetric};
25use unleash_yggdrasil::EngineState;
26use utoipa::{IntoParams, ToSchema};
27
28pub mod enterprise;
29pub mod errors;
30pub mod filters;
31pub mod headers;
32pub mod metrics;
33pub mod tokens;
34pub mod urls;
35use crate::enterprise::LicenseState;
36use crate::errors::EdgeError;
37use crate::tokens::EdgeToken;
38
39pub type EdgeJsonResult<T> = Result<Json<T>, EdgeError>;
40pub type EdgeResult<T> = Result<T, EdgeError>;
41pub type EdgeAcceptedJsonResult<T> = Result<AcceptedJson<T>, EdgeError>;
42pub struct AcceptedJson<T>
43where
44    T: Serialize,
45{
46    pub body: T,
47}
48impl<T> IntoResponse for AcceptedJson<T>
49where
50    T: Serialize,
51{
52    fn into_response(self) -> Response {
53        (StatusCode::ACCEPTED, Json(self.body)).into_response()
54    }
55}
56
57pub type TokenCache = DashMap<String, EdgeToken>;
58pub type EngineCache = DashMap<String, EngineState>;
59pub fn entity_tag_to_header_value(etag: EntityTag) -> HeaderValue {
60    HeaderValue::from_str(&etag.to_string()).expect("Failed to convert ETag to HeaderValue")
61}
62
63pub type BackgroundTask = Pin<Box<dyn Future<Output = ()> + Send>>;
64
65#[derive(Debug, Clone, PartialEq, Eq, Copy)]
66pub enum RefreshState {
67    Running,
68    Paused,
69}
70
71impl From<LicenseState> for RefreshState {
72    fn from(val: LicenseState) -> Self {
73        match val {
74            LicenseState::Valid => RefreshState::Running,
75            LicenseState::Invalid => RefreshState::Paused,
76            LicenseState::Expired => RefreshState::Running,
77        }
78    }
79}
80
81#[derive(Deserialize, Serialize, Debug, Clone)]
82#[serde(rename_all = "camelCase")]
83pub struct IncomingContext {
84    #[serde(flatten)]
85    pub context: Context,
86
87    #[serde(flatten)]
88    pub extra_properties: HashMap<String, String>,
89}
90
91impl From<IncomingContext> for Context {
92    fn from(input: IncomingContext) -> Self {
93        let properties = if input.extra_properties.is_empty() {
94            input.context.properties
95        } else {
96            let mut input_properties = input.extra_properties;
97            input_properties.extend(input.context.properties.unwrap_or_default());
98            Some(input_properties)
99        };
100        Context {
101            properties,
102            ..input.context
103        }
104    }
105}
106
107#[derive(Deserialize, Serialize, Debug, Clone)]
108#[serde(rename_all = "camelCase")]
109pub struct PostContext {
110    pub context: Option<Context>,
111    #[serde(flatten)]
112    pub flattened_context: Option<Context>,
113    #[serde(flatten)]
114    pub extra_properties: HashMap<String, String>,
115}
116
117impl From<PostContext> for Context {
118    fn from(input: PostContext) -> Self {
119        if let Some(context) = input.context {
120            context
121        } else {
122            IncomingContext {
123                context: input.flattened_context.unwrap_or_default(),
124                extra_properties: input.extra_properties,
125            }
126            .into()
127        }
128    }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, utoipa::ToSchema)]
132#[serde(rename_all = "lowercase")]
133pub enum TokenType {
134    #[serde(alias = "FRONTEND")]
135    Frontend,
136    #[serde(alias = "CLIENT", alias = "client", alias = "BACKEND")]
137    Backend,
138    #[serde(alias = "ADMIN")]
139    Admin,
140    Invalid,
141}
142
143#[derive(Clone, Debug)]
144#[allow(clippy::large_enum_variant)]
145pub enum ClientFeaturesResponse {
146    NoUpdate(EntityTag),
147    Updated(ClientFeatures, Option<EntityTag>),
148}
149
150#[derive(Clone, Debug)]
151pub enum ClientFeaturesDeltaResponse {
152    NoUpdate(EntityTag),
153    Updated(ClientFeaturesDelta, Option<EntityTag>),
154}
155
156#[derive(Clone, Debug, PartialEq, Eq, Serialize, Default, Deserialize, utoipa::ToSchema)]
157pub enum TokenValidationStatus {
158    Invalid,
159    #[default]
160    Unknown,
161    Trusted,
162    Validated,
163}
164
165impl TokenValidationStatus {
166    pub fn is_valid(&self) -> bool {
167        matches!(
168            self,
169            &TokenValidationStatus::Trusted | &TokenValidationStatus::Validated
170        )
171    }
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
175#[serde(rename_all = "UPPERCASE")]
176pub enum Status {
177    Ok,
178    NotOk,
179    NotReady,
180    Ready,
181}
182#[derive(Clone, Debug)]
183pub struct ClientFeaturesRequest {
184    pub api_key: String,
185    pub etag: Option<EntityTag>,
186    pub interval: Option<i64>,
187}
188
189#[derive(Clone, Debug, Serialize, Deserialize)]
190pub struct ValidateTokensRequest {
191    pub tokens: Vec<String>,
192}
193
194#[derive(Debug, Clone, PartialEq, Eq)]
195pub struct ClientIp {
196    pub ip: IpAddr,
197}
198
199impl Display for ClientIp {
200    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
201        write!(f, "{}", self.ip)
202    }
203}
204
205#[derive(Clone, Deserialize, Serialize)]
206pub struct TokenRefresh {
207    pub token: EdgeToken,
208    #[serde(
209        deserialize_with = "deserialize_entity_tag",
210        serialize_with = "serialize_entity_tag"
211    )]
212    pub etag: Option<EntityTag>,
213    pub next_refresh: Option<DateTime<Utc>>,
214    pub last_refreshed: Option<DateTime<Utc>>,
215    pub last_feature_count: Option<usize>,
216    pub last_check: Option<DateTime<Utc>>,
217    pub failure_count: u32,
218}
219
220impl PartialEq for TokenRefresh {
221    fn eq(&self, other: &Self) -> bool {
222        self.token == other.token
223            && self.etag == other.etag
224            && self.last_refreshed == other.last_refreshed
225            && self.last_check == other.last_check
226    }
227}
228
229#[derive(Clone, Deserialize, Serialize, Debug)]
230pub struct UnleashValidationDetail {
231    pub path: Option<String>,
232    pub description: Option<String>,
233    pub message: Option<String>,
234}
235
236#[derive(Clone, Deserialize, Serialize, Debug)]
237pub struct UnleashBadRequest {
238    pub id: Option<String>,
239    pub name: Option<String>,
240    pub message: Option<String>,
241    pub details: Option<Vec<UnleashValidationDetail>>,
242}
243
244impl TokenRefresh {
245    pub fn new(token: EdgeToken, etag: Option<EntityTag>) -> Self {
246        Self {
247            token,
248            etag,
249            last_refreshed: None,
250            last_check: None,
251            next_refresh: None,
252            failure_count: 0,
253            last_feature_count: None,
254        }
255    }
256
257    /// Something went wrong (but it was retriable. Increment our failure count and set last_checked and next_refresh
258    pub fn backoff(&self, refresh_interval: &Duration) -> Self {
259        let failure_count: u32 = min(self.failure_count + 1, 10);
260        let now = Utc::now();
261        let next_refresh = calculate_next_refresh(now, *refresh_interval, failure_count as u64);
262        Self {
263            failure_count,
264            next_refresh: Some(next_refresh),
265            last_check: Some(now),
266            ..self.clone()
267        }
268    }
269    /// We successfully talked to upstream, but there was no updates. Update our next_refresh, decrement our failure count and set when we last_checked
270    pub fn successful_check(&self, refresh_interval: &Duration) -> Self {
271        let failure_count = if self.failure_count > 0 {
272            self.failure_count - 1
273        } else {
274            0
275        };
276        let now = Utc::now();
277        let next_refresh = calculate_next_refresh(now, *refresh_interval, failure_count as u64);
278        Self {
279            failure_count,
280            next_refresh: Some(next_refresh),
281            last_check: Some(now),
282            ..self.clone()
283        }
284    }
285    /// We successfully talked to upstream. There were updates. Update next_refresh, last_refreshed and last_check, and decrement our failure count
286    pub fn successful_refresh(
287        &self,
288        refresh_interval: &Duration,
289        etag: Option<EntityTag>,
290        feature_count: usize,
291    ) -> Self {
292        let failure_count = if self.failure_count > 0 {
293            self.failure_count - 1
294        } else {
295            0
296        };
297        let now = Utc::now();
298        let next_refresh = calculate_next_refresh(now, *refresh_interval, failure_count as u64);
299        Self {
300            failure_count,
301            next_refresh: Some(next_refresh),
302            last_refreshed: Some(now),
303            last_check: Some(now),
304            last_feature_count: Some(feature_count),
305            etag,
306            ..self.clone()
307        }
308    }
309}
310
311fn calculate_next_refresh(
312    now: DateTime<Utc>,
313    refresh_interval: Duration,
314    failure_count: u64,
315) -> DateTime<Utc> {
316    if failure_count == 0 {
317        now + refresh_interval
318    } else {
319        now + refresh_interval + (refresh_interval * (failure_count.try_into().unwrap_or(0)))
320    }
321}
322
323impl fmt::Debug for TokenRefresh {
324    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
325        f.debug_struct("FeatureRefresh")
326            .field("token", &"***")
327            .field("etag", &self.etag)
328            .field("last_refreshed", &self.last_refreshed)
329            .field("last_check", &self.last_check)
330            .finish()
331    }
332}
333
334#[derive(Clone, Default)]
335pub struct CacheHolder {
336    pub token_cache: Arc<DashMap<String, EdgeToken>>,
337    pub features_cache: Arc<DashMap<String, ClientFeatures>>,
338    pub engine_cache: Arc<DashMap<String, EngineState>>,
339}
340
341fn deserialize_entity_tag<'de, D>(deserializer: D) -> Result<Option<EntityTag>, D::Error>
342where
343    D: Deserializer<'de>,
344{
345    let s: Option<String> = Option::deserialize(deserializer)?;
346
347    s.map(|s| EntityTag::from_str(&s).map_err(serde::de::Error::custom))
348        .transpose()
349}
350
351fn serialize_entity_tag<S>(etag: &Option<EntityTag>, serializer: S) -> Result<S::Ok, S::Error>
352where
353    S: Serializer,
354{
355    let s = etag.as_ref().map(|e| e.to_string());
356    serializer.serialize_some(&s)
357}
358
359pub fn into_entity_tag(client_features: ClientFeatures) -> Option<EntityTag> {
360    client_features
361        .xx3_hash()
362        .ok()
363        .map(|hash| EntityTag::new(true, &hash))
364}
365
366#[derive(Clone, Debug, Serialize, Deserialize)]
367pub struct BatchMetricsRequest {
368    pub api_key: String,
369    pub body: BatchMetricsRequestBody,
370}
371
372#[derive(Clone, Debug, Serialize, Deserialize, utoipa::ToSchema)]
373pub struct BatchMetricsRequestBody {
374    pub applications: Vec<ClientApplication>,
375    pub metrics: Vec<ClientMetricsEnv>,
376    #[serde(
377        default,
378        skip_serializing_if = "Option::is_none",
379        rename = "impactMetrics"
380    )]
381    pub impact_metrics: Option<Vec<ImpactMetric>>,
382}
383
384#[derive(Debug, Serialize, Deserialize, Clone)]
385#[serde(rename_all = "camelCase")]
386pub struct ClientTokenRequest {
387    pub token_name: String,
388    #[serde(rename = "type")]
389    pub token_type: TokenType,
390    pub projects: Vec<String>,
391    pub environment: String,
392    pub expires_at: DateTime<Utc>,
393}
394
395#[derive(Debug, Serialize, Deserialize, Clone)]
396pub struct BuildInfo {
397    pub package_version: String,
398    pub app_name: String,
399    pub package_major: String,
400    pub package_minor: String,
401    pub package_patch: String,
402    pub package_version_pre: Option<String>,
403    pub branch: String,
404    pub tag: String,
405    pub rust_version: String,
406    pub rust_channel: String,
407    pub short_commit_hash: String,
408    pub full_commit_hash: String,
409    pub build_os: String,
410    pub build_target: String,
411}
412shadow!(build); // Get build information set to build placeholder
413pub const EDGE_VERSION: &str = build::PKG_VERSION;
414impl Default for BuildInfo {
415    fn default() -> Self {
416        BuildInfo {
417            package_version: build::PKG_VERSION.into(),
418            app_name: build::PROJECT_NAME.into(),
419            package_major: build::PKG_VERSION_MAJOR.into(),
420            package_minor: build::PKG_VERSION_MINOR.into(),
421            package_patch: build::PKG_VERSION_PATCH.into(),
422            #[allow(clippy::const_is_empty)]
423            package_version_pre: if build::PKG_VERSION_PRE.is_empty() {
424                None
425            } else {
426                Some(build::PKG_VERSION_PRE.into())
427            },
428            branch: build::BRANCH.into(),
429            tag: build::TAG.into(),
430            rust_version: build::RUST_VERSION.into(),
431            rust_channel: build::RUST_CHANNEL.into(),
432            short_commit_hash: build::SHORT_COMMIT.into(),
433            full_commit_hash: build::COMMIT_HASH.into(),
434            build_os: build::BUILD_OS.into(),
435            build_target: build::BUILD_TARGET.into(),
436        }
437    }
438}
439
440#[derive(Clone, Debug, Serialize, Deserialize, IntoParams)]
441#[serde(rename_all = "camelCase")]
442pub struct FeatureFilters {
443    pub name_prefix: Option<String>,
444}
445
446#[derive(Serialize, Deserialize, Debug, Clone)]
447#[serde(rename_all = "camelCase")]
448pub struct TokenInfo {
449    pub token_refreshes: Vec<TokenRefresh>,
450    pub token_validation_status: Vec<EdgeToken>,
451    pub invalid_token_count: usize,
452}
453
454#[derive(Debug, Clone, Eq, Deserialize, Serialize, ToSchema)]
455pub struct MetricsKey {
456    pub app_name: String,
457    pub feature_name: String,
458    pub environment: String,
459    pub timestamp: DateTime<Utc>,
460}
461
462impl Hash for MetricsKey {
463    fn hash<H: Hasher>(&self, state: &mut H) {
464        self.app_name.hash(state);
465        self.feature_name.hash(state);
466        self.environment.hash(state);
467        to_time_key(&self.timestamp).hash(state);
468    }
469}
470
471fn to_time_key(timestamp: &DateTime<Utc>) -> String {
472    format!("{}", timestamp.format("%Y-%m-%d %H"))
473}
474
475impl PartialEq for MetricsKey {
476    fn eq(&self, other: &Self) -> bool {
477        let other_hour_bin = to_time_key(&other.timestamp);
478        let self_hour_bin = to_time_key(&self.timestamp);
479
480        self.app_name == other.app_name
481            && self.feature_name == other.feature_name
482            && self.environment == other.environment
483            && self_hour_bin == other_hour_bin
484    }
485}
486
487#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
488pub struct ClientMetric {
489    pub key: MetricsKey,
490    pub bucket: ClientMetricsEnv,
491}
492#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
493pub struct MetricsInfo {
494    pub applications: Vec<ClientApplication>,
495    pub metrics: Vec<ClientMetric>,
496}
497
498#[derive(Clone, Debug, Serialize, Deserialize)]
499pub struct EdgeTokens {
500    pub tokens: Vec<EdgeToken>,
501}
502
503#[cfg(test)]
504mod tests {
505    use std::collections::HashMap;
506    use std::str::FromStr;
507
508    use serde_json::json;
509    use test_case::test_case;
510    use tracing::warn;
511    use unleash_types::client_features::Context;
512
513    use crate::errors::EdgeError::EdgeTokenParseError;
514
515    use super::{EdgeResult, EdgeToken, EdgeTokens};
516
517    fn test_str(token: &str) -> EdgeToken {
518        EdgeToken::from_str(
519            &(token.to_owned() + ".614a75cf68bef8703aa1bd8304938a81ec871f86ea40c975468eabd6"),
520        )
521        .unwrap()
522    }
523
524    fn test_token(env: Option<&str>, projects: Vec<&str>) -> EdgeToken {
525        EdgeToken {
526            environment: env.map(|env| env.into()),
527            projects: projects.into_iter().map(|p| p.into()).collect(),
528            ..EdgeToken::default()
529        }
530    }
531
532    #[test_case("demo-app:production.614a75cf68bef8703aa1bd8304938a81ec871f86ea40c975468eabd6"; "demo token with project and environment")]
533    #[test_case("*:default.5fa5ac2580c7094abf0d87c68b1eeb54bdc485014aef40f9fcb0673b"; "demo token with access to all projects and default environment")]
534    fn edge_token_from_string(token: &str) {
535        let parsed_token = EdgeToken::from_str(token);
536        match parsed_token {
537            Ok(t) => {
538                assert_eq!(t.token, token);
539            }
540            Err(e) => {
541                warn!("{}", e);
542                panic!("Could not parse token");
543            }
544        }
545    }
546
547    #[test_case("943ca9171e2c884c545c5d82417a655fb77cec970cc3b78a8ff87f4406b495d0"; "old java client token")]
548    #[test_case("secret-123"; "old example proxy token")]
549    fn offline_token_from_string(token: &str) {
550        let offline_token = EdgeToken::offline_token(token);
551        assert_eq!(offline_token.environment, None);
552        assert!(offline_token.projects.is_empty());
553    }
554
555    #[test_case(
556        "demo-app:production",
557        "demo-app:production"
558        => true
559    ; "idempotency")]
560    #[test_case(
561        "aproject:production",
562        "another:production"
563        => false
564    ; "project mismatch")]
565    #[test_case(
566        "demo-app:development",
567        "demo-app:production"
568        => false
569    ; "environment mismatch")]
570    #[test_case(
571        "*:production",
572        "demo-app:production"
573        => true
574    ; "* subsumes a project token")]
575    fn edge_token_subsumes_edge_token(token1: &str, token2: &str) -> bool {
576        let t1 = test_str(token1);
577        let t2 = test_str(token2);
578        t1.subsumes(&t2)
579    }
580
581    #[test]
582    fn edge_token_unrelated_by_subsume() {
583        let t1 = test_str("demo-app:production");
584        let t2 = test_str("another:production");
585        assert!(!t1.subsumes(&t2));
586        assert!(!t2.subsumes(&t1));
587    }
588
589    #[test]
590    fn edge_token_does_not_subsume_if_projects_is_subset_of_other_tokens_project() {
591        let token1 = test_token(None, vec!["p1", "p2"]);
592
593        let token2 = test_token(None, vec!["p1"]);
594
595        assert!(token1.subsumes(&token2));
596        assert!(!token2.subsumes(&token1));
597    }
598
599    #[test]
600    fn token_type_should_be_case_insensitive() {
601        let json = r#"{ "tokens": [{
602              "token": "chriswk-test:development.notusedsecret",
603              "type": "CLIENT",
604              "projects": [
605                "chriswk-test"
606              ]
607            },
608            {
609              "token": "demo-app:production.notusedsecret",
610              "type": "client",
611              "projects": [
612                "demo-app"
613              ]
614            }] }"#;
615        let tokens: EdgeResult<EdgeTokens> =
616            serde_json::from_str(json).map_err(|_| EdgeTokenParseError);
617        assert!(tokens.is_ok());
618        assert_eq!(tokens.unwrap().tokens.len(), 2);
619    }
620
621    #[test]
622    fn context_conversion_works() {
623        let json = json!({
624            "context": {
625                "userId": "user",
626                "sessionId": "session",
627                "environment": "env",
628                "appName": "app",
629                "currentTime": "2024-03-12T11:42:46+01:00",
630                "remoteAddress": "127.0.0.1",
631                "properties": {
632                    "normal property": "normal",
633                },
634                "top-level property": "top"
635            }
636        });
637
638        let converted: Context = serde_json::from_value(json).unwrap();
639
640        assert_eq!(converted.user_id, Some("user".into()));
641        assert_eq!(converted.session_id, Some("session".into()));
642        assert_eq!(converted.environment, Some("env".into()));
643        assert_eq!(converted.app_name, Some("app".into()));
644        assert_eq!(
645            converted.current_time,
646            Some("2024-03-12T11:42:46+01:00".into())
647        );
648        assert_eq!(converted.remote_address, Some("127.0.0.1".into()));
649        assert_eq!(
650            converted.properties,
651            Some(HashMap::from([
652                ("normal property".into(), "normal".into()),
653                ("top-level property".into(), "top".into())
654            ]))
655        );
656    }
657
658    #[test]
659    fn context_conversion_properties_level_properties_take_precedence_over_top_level() {
660        let json = json!({
661            "context": {
662                "properties": {
663                    "duplicated property": "lower"
664                }
665            },
666            "extraProperties": {
667                "duplicated property": "upper"
668            }
669        });
670
671        let parsed_context: Context = serde_json::from_value(json).unwrap();
672        assert_eq!(
673            parsed_context.properties,
674            Some(HashMap::from([(
675                "duplicated property".into(),
676                "lower".into()
677            ),]))
678        );
679    }
680
681    #[test]
682    fn context_conversion_if_there_are_no_extra_properties_the_properties_hash_map_is_none() {
683        let json = json!({
684            "context": {
685                "userId": "7",
686            }
687        });
688
689        let parsed_context: Context = serde_json::from_value(json).unwrap();
690
691        assert_eq!(parsed_context.properties, None);
692    }
693
694    #[test]
695    fn completely_flat_json_parses_to_a_context() {
696        let json = json!(
697            {
698                "userId": "7",
699                "flat": "endsUpInProps",
700                "invalidProperty": "alsoEndsUpInProps"
701            }
702        );
703
704        let parsed_context: Context = serde_json::from_value(json).unwrap();
705
706        assert_eq!(parsed_context.user_id, Some("7".into()));
707        assert_eq!(
708            parsed_context.properties,
709            Some(HashMap::from([
710                ("flat".into(), "endsUpInProps".into()),
711                ("invalidProperty".into(), "alsoEndsUpInProps".into())
712            ]))
713        );
714    }
715
716    #[test]
717    fn post_context_root_level_properties_are_ignored_if_context_property_is_set() {
718        let json = json!(
719            {
720                "context": {
721                    "userId":"7",
722                },
723                "invalidProperty": "thisNeverGoesAnywhere",
724                "anotherInvalidProperty": "alsoGoesNoWhere"
725            }
726        );
727
728        let parsed_context: Context = serde_json::from_value(json).unwrap();
729
730        assert_eq!(parsed_context.properties, None);
731        assert_eq!(parsed_context.user_id, Some("7".into()));
732    }
733
734    #[test]
735    fn post_context_properties_are_taken_from_nested_context_object_but_root_levels_are_ignored() {
736        let json = json!(
737            {
738                "context": {
739                    "userId":"7",
740                    "properties": {
741                        "nested": "nestedValue"
742                    }
743                },
744                "invalidProperty": "thisNeverGoesAnywhere"
745            }
746        );
747
748        let parsed_context: Context = serde_json::from_value(json).unwrap();
749        assert_eq!(
750            parsed_context.properties,
751            Some(HashMap::from([("nested".into(), "nestedValue".into()),]))
752        );
753
754        assert_eq!(parsed_context.user_id, Some("7".into()));
755    }
756}