1use std::cmp::min;
2use std::fmt;
3use std::fmt::{Debug, Display, Formatter};
4use std::net::IpAddr;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::{
8 collections::HashMap,
9 hash::{Hash, Hasher},
10 str::FromStr,
11};
12
13use axum::Json;
14use axum::http::HeaderValue;
15use axum::response::{IntoResponse, Response};
16use chrono::{DateTime, Duration, Utc};
17use dashmap::DashMap;
18use etag::EntityTag;
19use http::StatusCode;
20use serde::{Deserialize, Deserializer, Serialize, Serializer};
21use shadow_rs::shadow;
22use unleash_types::client_features::Context;
23use unleash_types::client_features::{ClientFeatures, ClientFeaturesDelta};
24use unleash_types::client_metrics::{ClientApplication, ClientMetricsEnv, ImpactMetric};
25use unleash_yggdrasil::EngineState;
26use utoipa::{IntoParams, ToSchema};
27
28pub mod enterprise;
29pub mod errors;
30pub mod filters;
31pub mod headers;
32pub mod metrics;
33pub mod tokens;
34pub mod urls;
35use crate::enterprise::LicenseState;
36use crate::errors::EdgeError;
37use crate::tokens::EdgeToken;
38
39pub type EdgeJsonResult<T> = Result<Json<T>, EdgeError>;
40pub type EdgeResult<T> = Result<T, EdgeError>;
41pub type EdgeAcceptedJsonResult<T> = Result<AcceptedJson<T>, EdgeError>;
42pub struct AcceptedJson<T>
43where
44 T: Serialize,
45{
46 pub body: T,
47}
48impl<T> IntoResponse for AcceptedJson<T>
49where
50 T: Serialize,
51{
52 fn into_response(self) -> Response {
53 (StatusCode::ACCEPTED, Json(self.body)).into_response()
54 }
55}
56
57pub type TokenCache = DashMap<String, EdgeToken>;
58pub type EngineCache = DashMap<String, EngineState>;
59pub fn entity_tag_to_header_value(etag: EntityTag) -> HeaderValue {
60 HeaderValue::from_str(&etag.to_string()).expect("Failed to convert ETag to HeaderValue")
61}
62
63pub type BackgroundTask = Pin<Box<dyn Future<Output = ()> + Send>>;
64
65#[derive(Debug, Clone, PartialEq, Eq, Copy)]
66pub enum RefreshState {
67 Running,
68 Paused,
69}
70
71impl From<LicenseState> for RefreshState {
72 fn from(val: LicenseState) -> Self {
73 match val {
74 LicenseState::Valid => RefreshState::Running,
75 LicenseState::Invalid => RefreshState::Paused,
76 LicenseState::Expired => RefreshState::Running,
77 }
78 }
79}
80
81#[derive(Deserialize, Serialize, Debug, Clone)]
82#[serde(rename_all = "camelCase")]
83pub struct IncomingContext {
84 #[serde(flatten)]
85 pub context: Context,
86
87 #[serde(flatten)]
88 pub extra_properties: HashMap<String, String>,
89}
90
91impl From<IncomingContext> for Context {
92 fn from(input: IncomingContext) -> Self {
93 let properties = if input.extra_properties.is_empty() {
94 input.context.properties
95 } else {
96 let mut input_properties = input.extra_properties;
97 input_properties.extend(input.context.properties.unwrap_or_default());
98 Some(input_properties)
99 };
100 Context {
101 properties,
102 ..input.context
103 }
104 }
105}
106
107#[derive(Deserialize, Serialize, Debug, Clone)]
108#[serde(rename_all = "camelCase")]
109pub struct PostContext {
110 pub context: Option<Context>,
111 #[serde(flatten)]
112 pub flattened_context: Option<Context>,
113 #[serde(flatten)]
114 pub extra_properties: HashMap<String, String>,
115}
116
117impl From<PostContext> for Context {
118 fn from(input: PostContext) -> Self {
119 if let Some(context) = input.context {
120 context
121 } else {
122 IncomingContext {
123 context: input.flattened_context.unwrap_or_default(),
124 extra_properties: input.extra_properties,
125 }
126 .into()
127 }
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, utoipa::ToSchema)]
132#[serde(rename_all = "lowercase")]
133pub enum TokenType {
134 #[serde(alias = "FRONTEND")]
135 Frontend,
136 #[serde(alias = "CLIENT", alias = "client", alias = "BACKEND")]
137 Backend,
138 #[serde(alias = "ADMIN")]
139 Admin,
140 Invalid,
141}
142
143#[derive(Clone, Debug)]
144#[allow(clippy::large_enum_variant)]
145pub enum ClientFeaturesResponse {
146 NoUpdate(EntityTag),
147 Updated(ClientFeatures, Option<EntityTag>),
148}
149
150#[derive(Clone, Debug)]
151pub enum ClientFeaturesDeltaResponse {
152 NoUpdate(EntityTag),
153 Updated(ClientFeaturesDelta, Option<EntityTag>),
154}
155
156#[derive(Clone, Debug, PartialEq, Eq, Serialize, Default, Deserialize, utoipa::ToSchema)]
157pub enum TokenValidationStatus {
158 Invalid,
159 #[default]
160 Unknown,
161 Trusted,
162 Validated,
163}
164
165impl TokenValidationStatus {
166 pub fn is_valid(&self) -> bool {
167 matches!(
168 self,
169 &TokenValidationStatus::Trusted | &TokenValidationStatus::Validated
170 )
171 }
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
175#[serde(rename_all = "UPPERCASE")]
176pub enum Status {
177 Ok,
178 NotOk,
179 NotReady,
180 Ready,
181}
182#[derive(Clone, Debug)]
183pub struct ClientFeaturesRequest {
184 pub api_key: String,
185 pub etag: Option<EntityTag>,
186 pub interval: Option<i64>,
187}
188
189#[derive(Clone, Debug, Serialize, Deserialize)]
190pub struct ValidateTokensRequest {
191 pub tokens: Vec<String>,
192}
193
194#[derive(Debug, Clone, PartialEq, Eq)]
195pub struct ClientIp {
196 pub ip: IpAddr,
197}
198
199impl Display for ClientIp {
200 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
201 write!(f, "{}", self.ip)
202 }
203}
204
205#[derive(Clone, Deserialize, Serialize)]
206pub struct TokenRefresh {
207 pub token: EdgeToken,
208 #[serde(
209 deserialize_with = "deserialize_entity_tag",
210 serialize_with = "serialize_entity_tag"
211 )]
212 pub etag: Option<EntityTag>,
213 pub next_refresh: Option<DateTime<Utc>>,
214 pub last_refreshed: Option<DateTime<Utc>>,
215 pub last_feature_count: Option<usize>,
216 pub last_check: Option<DateTime<Utc>>,
217 pub failure_count: u32,
218}
219
220impl PartialEq for TokenRefresh {
221 fn eq(&self, other: &Self) -> bool {
222 self.token == other.token
223 && self.etag == other.etag
224 && self.last_refreshed == other.last_refreshed
225 && self.last_check == other.last_check
226 }
227}
228
229#[derive(Clone, Deserialize, Serialize, Debug)]
230pub struct UnleashValidationDetail {
231 pub path: Option<String>,
232 pub description: Option<String>,
233 pub message: Option<String>,
234}
235
236#[derive(Clone, Deserialize, Serialize, Debug)]
237pub struct UnleashBadRequest {
238 pub id: Option<String>,
239 pub name: Option<String>,
240 pub message: Option<String>,
241 pub details: Option<Vec<UnleashValidationDetail>>,
242}
243
244impl TokenRefresh {
245 pub fn new(token: EdgeToken, etag: Option<EntityTag>) -> Self {
246 Self {
247 token,
248 etag,
249 last_refreshed: None,
250 last_check: None,
251 next_refresh: None,
252 failure_count: 0,
253 last_feature_count: None,
254 }
255 }
256
257 pub fn backoff(&self, refresh_interval: &Duration) -> Self {
259 let failure_count: u32 = min(self.failure_count + 1, 10);
260 let now = Utc::now();
261 let next_refresh = calculate_next_refresh(now, *refresh_interval, failure_count as u64);
262 Self {
263 failure_count,
264 next_refresh: Some(next_refresh),
265 last_check: Some(now),
266 ..self.clone()
267 }
268 }
269 pub fn successful_check(&self, refresh_interval: &Duration) -> Self {
271 let failure_count = if self.failure_count > 0 {
272 self.failure_count - 1
273 } else {
274 0
275 };
276 let now = Utc::now();
277 let next_refresh = calculate_next_refresh(now, *refresh_interval, failure_count as u64);
278 Self {
279 failure_count,
280 next_refresh: Some(next_refresh),
281 last_check: Some(now),
282 ..self.clone()
283 }
284 }
285 pub fn successful_refresh(
287 &self,
288 refresh_interval: &Duration,
289 etag: Option<EntityTag>,
290 feature_count: usize,
291 ) -> Self {
292 let failure_count = if self.failure_count > 0 {
293 self.failure_count - 1
294 } else {
295 0
296 };
297 let now = Utc::now();
298 let next_refresh = calculate_next_refresh(now, *refresh_interval, failure_count as u64);
299 Self {
300 failure_count,
301 next_refresh: Some(next_refresh),
302 last_refreshed: Some(now),
303 last_check: Some(now),
304 last_feature_count: Some(feature_count),
305 etag,
306 ..self.clone()
307 }
308 }
309}
310
311fn calculate_next_refresh(
312 now: DateTime<Utc>,
313 refresh_interval: Duration,
314 failure_count: u64,
315) -> DateTime<Utc> {
316 if failure_count == 0 {
317 now + refresh_interval
318 } else {
319 now + refresh_interval + (refresh_interval * (failure_count.try_into().unwrap_or(0)))
320 }
321}
322
323impl fmt::Debug for TokenRefresh {
324 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
325 f.debug_struct("FeatureRefresh")
326 .field("token", &"***")
327 .field("etag", &self.etag)
328 .field("last_refreshed", &self.last_refreshed)
329 .field("last_check", &self.last_check)
330 .finish()
331 }
332}
333
334#[derive(Clone, Default)]
335pub struct CacheHolder {
336 pub token_cache: Arc<DashMap<String, EdgeToken>>,
337 pub features_cache: Arc<DashMap<String, ClientFeatures>>,
338 pub engine_cache: Arc<DashMap<String, EngineState>>,
339}
340
341fn deserialize_entity_tag<'de, D>(deserializer: D) -> Result<Option<EntityTag>, D::Error>
342where
343 D: Deserializer<'de>,
344{
345 let s: Option<String> = Option::deserialize(deserializer)?;
346
347 s.map(|s| EntityTag::from_str(&s).map_err(serde::de::Error::custom))
348 .transpose()
349}
350
351fn serialize_entity_tag<S>(etag: &Option<EntityTag>, serializer: S) -> Result<S::Ok, S::Error>
352where
353 S: Serializer,
354{
355 let s = etag.as_ref().map(|e| e.to_string());
356 serializer.serialize_some(&s)
357}
358
359pub fn into_entity_tag(client_features: ClientFeatures) -> Option<EntityTag> {
360 client_features
361 .xx3_hash()
362 .ok()
363 .map(|hash| EntityTag::new(true, &hash))
364}
365
366#[derive(Clone, Debug, Serialize, Deserialize)]
367pub struct BatchMetricsRequest {
368 pub api_key: String,
369 pub body: BatchMetricsRequestBody,
370}
371
372#[derive(Clone, Debug, Serialize, Deserialize, utoipa::ToSchema)]
373pub struct BatchMetricsRequestBody {
374 pub applications: Vec<ClientApplication>,
375 pub metrics: Vec<ClientMetricsEnv>,
376 #[serde(
377 default,
378 skip_serializing_if = "Option::is_none",
379 rename = "impactMetrics"
380 )]
381 pub impact_metrics: Option<Vec<ImpactMetric>>,
382}
383
384#[derive(Debug, Serialize, Deserialize, Clone)]
385#[serde(rename_all = "camelCase")]
386pub struct ClientTokenRequest {
387 pub token_name: String,
388 #[serde(rename = "type")]
389 pub token_type: TokenType,
390 pub projects: Vec<String>,
391 pub environment: String,
392 pub expires_at: DateTime<Utc>,
393}
394
395#[derive(Debug, Serialize, Deserialize, Clone)]
396pub struct BuildInfo {
397 pub package_version: String,
398 pub app_name: String,
399 pub package_major: String,
400 pub package_minor: String,
401 pub package_patch: String,
402 pub package_version_pre: Option<String>,
403 pub branch: String,
404 pub tag: String,
405 pub rust_version: String,
406 pub rust_channel: String,
407 pub short_commit_hash: String,
408 pub full_commit_hash: String,
409 pub build_os: String,
410 pub build_target: String,
411}
412shadow!(build); pub const EDGE_VERSION: &str = build::PKG_VERSION;
414impl Default for BuildInfo {
415 fn default() -> Self {
416 BuildInfo {
417 package_version: build::PKG_VERSION.into(),
418 app_name: build::PROJECT_NAME.into(),
419 package_major: build::PKG_VERSION_MAJOR.into(),
420 package_minor: build::PKG_VERSION_MINOR.into(),
421 package_patch: build::PKG_VERSION_PATCH.into(),
422 #[allow(clippy::const_is_empty)]
423 package_version_pre: if build::PKG_VERSION_PRE.is_empty() {
424 None
425 } else {
426 Some(build::PKG_VERSION_PRE.into())
427 },
428 branch: build::BRANCH.into(),
429 tag: build::TAG.into(),
430 rust_version: build::RUST_VERSION.into(),
431 rust_channel: build::RUST_CHANNEL.into(),
432 short_commit_hash: build::SHORT_COMMIT.into(),
433 full_commit_hash: build::COMMIT_HASH.into(),
434 build_os: build::BUILD_OS.into(),
435 build_target: build::BUILD_TARGET.into(),
436 }
437 }
438}
439
440#[derive(Clone, Debug, Serialize, Deserialize, IntoParams)]
441#[serde(rename_all = "camelCase")]
442pub struct FeatureFilters {
443 pub name_prefix: Option<String>,
444}
445
446#[derive(Serialize, Deserialize, Debug, Clone)]
447#[serde(rename_all = "camelCase")]
448pub struct TokenInfo {
449 pub token_refreshes: Vec<TokenRefresh>,
450 pub token_validation_status: Vec<EdgeToken>,
451 pub invalid_token_count: usize,
452}
453
454#[derive(Debug, Clone, Eq, Deserialize, Serialize, ToSchema)]
455pub struct MetricsKey {
456 pub app_name: String,
457 pub feature_name: String,
458 pub environment: String,
459 pub timestamp: DateTime<Utc>,
460}
461
462impl Hash for MetricsKey {
463 fn hash<H: Hasher>(&self, state: &mut H) {
464 self.app_name.hash(state);
465 self.feature_name.hash(state);
466 self.environment.hash(state);
467 to_time_key(&self.timestamp).hash(state);
468 }
469}
470
471fn to_time_key(timestamp: &DateTime<Utc>) -> String {
472 format!("{}", timestamp.format("%Y-%m-%d %H"))
473}
474
475impl PartialEq for MetricsKey {
476 fn eq(&self, other: &Self) -> bool {
477 let other_hour_bin = to_time_key(&other.timestamp);
478 let self_hour_bin = to_time_key(&self.timestamp);
479
480 self.app_name == other.app_name
481 && self.feature_name == other.feature_name
482 && self.environment == other.environment
483 && self_hour_bin == other_hour_bin
484 }
485}
486
487#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
488pub struct ClientMetric {
489 pub key: MetricsKey,
490 pub bucket: ClientMetricsEnv,
491}
492#[derive(Serialize, Deserialize, Debug, Clone, ToSchema)]
493pub struct MetricsInfo {
494 pub applications: Vec<ClientApplication>,
495 pub metrics: Vec<ClientMetric>,
496}
497
498#[derive(Clone, Debug, Serialize, Deserialize)]
499pub struct EdgeTokens {
500 pub tokens: Vec<EdgeToken>,
501}
502
503#[cfg(test)]
504mod tests {
505 use std::collections::HashMap;
506 use std::str::FromStr;
507
508 use serde_json::json;
509 use test_case::test_case;
510 use tracing::warn;
511 use unleash_types::client_features::Context;
512
513 use crate::errors::EdgeError::EdgeTokenParseError;
514
515 use super::{EdgeResult, EdgeToken, EdgeTokens};
516
517 fn test_str(token: &str) -> EdgeToken {
518 EdgeToken::from_str(
519 &(token.to_owned() + ".614a75cf68bef8703aa1bd8304938a81ec871f86ea40c975468eabd6"),
520 )
521 .unwrap()
522 }
523
524 fn test_token(env: Option<&str>, projects: Vec<&str>) -> EdgeToken {
525 EdgeToken {
526 environment: env.map(|env| env.into()),
527 projects: projects.into_iter().map(|p| p.into()).collect(),
528 ..EdgeToken::default()
529 }
530 }
531
532 #[test_case("demo-app:production.614a75cf68bef8703aa1bd8304938a81ec871f86ea40c975468eabd6"; "demo token with project and environment")]
533 #[test_case("*:default.5fa5ac2580c7094abf0d87c68b1eeb54bdc485014aef40f9fcb0673b"; "demo token with access to all projects and default environment")]
534 fn edge_token_from_string(token: &str) {
535 let parsed_token = EdgeToken::from_str(token);
536 match parsed_token {
537 Ok(t) => {
538 assert_eq!(t.token, token);
539 }
540 Err(e) => {
541 warn!("{}", e);
542 panic!("Could not parse token");
543 }
544 }
545 }
546
547 #[test_case("943ca9171e2c884c545c5d82417a655fb77cec970cc3b78a8ff87f4406b495d0"; "old java client token")]
548 #[test_case("secret-123"; "old example proxy token")]
549 fn offline_token_from_string(token: &str) {
550 let offline_token = EdgeToken::offline_token(token);
551 assert_eq!(offline_token.environment, None);
552 assert!(offline_token.projects.is_empty());
553 }
554
555 #[test_case(
556 "demo-app:production",
557 "demo-app:production"
558 => true
559 ; "idempotency")]
560 #[test_case(
561 "aproject:production",
562 "another:production"
563 => false
564 ; "project mismatch")]
565 #[test_case(
566 "demo-app:development",
567 "demo-app:production"
568 => false
569 ; "environment mismatch")]
570 #[test_case(
571 "*:production",
572 "demo-app:production"
573 => true
574 ; "* subsumes a project token")]
575 fn edge_token_subsumes_edge_token(token1: &str, token2: &str) -> bool {
576 let t1 = test_str(token1);
577 let t2 = test_str(token2);
578 t1.subsumes(&t2)
579 }
580
581 #[test]
582 fn edge_token_unrelated_by_subsume() {
583 let t1 = test_str("demo-app:production");
584 let t2 = test_str("another:production");
585 assert!(!t1.subsumes(&t2));
586 assert!(!t2.subsumes(&t1));
587 }
588
589 #[test]
590 fn edge_token_does_not_subsume_if_projects_is_subset_of_other_tokens_project() {
591 let token1 = test_token(None, vec!["p1", "p2"]);
592
593 let token2 = test_token(None, vec!["p1"]);
594
595 assert!(token1.subsumes(&token2));
596 assert!(!token2.subsumes(&token1));
597 }
598
599 #[test]
600 fn token_type_should_be_case_insensitive() {
601 let json = r#"{ "tokens": [{
602 "token": "chriswk-test:development.notusedsecret",
603 "type": "CLIENT",
604 "projects": [
605 "chriswk-test"
606 ]
607 },
608 {
609 "token": "demo-app:production.notusedsecret",
610 "type": "client",
611 "projects": [
612 "demo-app"
613 ]
614 }] }"#;
615 let tokens: EdgeResult<EdgeTokens> =
616 serde_json::from_str(json).map_err(|_| EdgeTokenParseError);
617 assert!(tokens.is_ok());
618 assert_eq!(tokens.unwrap().tokens.len(), 2);
619 }
620
621 #[test]
622 fn context_conversion_works() {
623 let json = json!({
624 "context": {
625 "userId": "user",
626 "sessionId": "session",
627 "environment": "env",
628 "appName": "app",
629 "currentTime": "2024-03-12T11:42:46+01:00",
630 "remoteAddress": "127.0.0.1",
631 "properties": {
632 "normal property": "normal",
633 },
634 "top-level property": "top"
635 }
636 });
637
638 let converted: Context = serde_json::from_value(json).unwrap();
639
640 assert_eq!(converted.user_id, Some("user".into()));
641 assert_eq!(converted.session_id, Some("session".into()));
642 assert_eq!(converted.environment, Some("env".into()));
643 assert_eq!(converted.app_name, Some("app".into()));
644 assert_eq!(
645 converted.current_time,
646 Some("2024-03-12T11:42:46+01:00".into())
647 );
648 assert_eq!(converted.remote_address, Some("127.0.0.1".into()));
649 assert_eq!(
650 converted.properties,
651 Some(HashMap::from([
652 ("normal property".into(), "normal".into()),
653 ("top-level property".into(), "top".into())
654 ]))
655 );
656 }
657
658 #[test]
659 fn context_conversion_properties_level_properties_take_precedence_over_top_level() {
660 let json = json!({
661 "context": {
662 "properties": {
663 "duplicated property": "lower"
664 }
665 },
666 "extraProperties": {
667 "duplicated property": "upper"
668 }
669 });
670
671 let parsed_context: Context = serde_json::from_value(json).unwrap();
672 assert_eq!(
673 parsed_context.properties,
674 Some(HashMap::from([(
675 "duplicated property".into(),
676 "lower".into()
677 ),]))
678 );
679 }
680
681 #[test]
682 fn context_conversion_if_there_are_no_extra_properties_the_properties_hash_map_is_none() {
683 let json = json!({
684 "context": {
685 "userId": "7",
686 }
687 });
688
689 let parsed_context: Context = serde_json::from_value(json).unwrap();
690
691 assert_eq!(parsed_context.properties, None);
692 }
693
694 #[test]
695 fn completely_flat_json_parses_to_a_context() {
696 let json = json!(
697 {
698 "userId": "7",
699 "flat": "endsUpInProps",
700 "invalidProperty": "alsoEndsUpInProps"
701 }
702 );
703
704 let parsed_context: Context = serde_json::from_value(json).unwrap();
705
706 assert_eq!(parsed_context.user_id, Some("7".into()));
707 assert_eq!(
708 parsed_context.properties,
709 Some(HashMap::from([
710 ("flat".into(), "endsUpInProps".into()),
711 ("invalidProperty".into(), "alsoEndsUpInProps".into())
712 ]))
713 );
714 }
715
716 #[test]
717 fn post_context_root_level_properties_are_ignored_if_context_property_is_set() {
718 let json = json!(
719 {
720 "context": {
721 "userId":"7",
722 },
723 "invalidProperty": "thisNeverGoesAnywhere",
724 "anotherInvalidProperty": "alsoGoesNoWhere"
725 }
726 );
727
728 let parsed_context: Context = serde_json::from_value(json).unwrap();
729
730 assert_eq!(parsed_context.properties, None);
731 assert_eq!(parsed_context.user_id, Some("7".into()));
732 }
733
734 #[test]
735 fn post_context_properties_are_taken_from_nested_context_object_but_root_levels_are_ignored() {
736 let json = json!(
737 {
738 "context": {
739 "userId":"7",
740 "properties": {
741 "nested": "nestedValue"
742 }
743 },
744 "invalidProperty": "thisNeverGoesAnywhere"
745 }
746 );
747
748 let parsed_context: Context = serde_json::from_value(json).unwrap();
749 assert_eq!(
750 parsed_context.properties,
751 Some(HashMap::from([("nested".into(), "nestedValue".into()),]))
752 );
753
754 assert_eq!(parsed_context.user_id, Some("7".into()));
755 }
756}