unleash_api_client/
http.rs

1// Copyright 2020, 2022 Cognite AS
2//! The HTTP Layer
3
4#[cfg(feature = "reqwest")]
5mod reqwest;
6#[cfg(feature = "reqwest-11")]
7mod reqwest_11;
8mod shim;
9
10pub struct HTTP<C: HttpClient> {
11    authorization_header: C::HeaderName,
12    app_name_header: C::HeaderName,
13    unleash_app_name_header: C::HeaderName,
14    unleash_sdk_header: C::HeaderName,
15    unleash_connection_id_header: C::HeaderName,
16    instance_id_header: C::HeaderName,
17    app_name: String,
18    sdk_version: &'static str,
19    instance_id: String,
20    // The connection_id represents a logical connection from the SDK to Unleash.
21    // It's assigned internally by the SDK and lives as long as the Unleash client instance.
22    // We can't reuse instance_id since some SDKs allow to override it while
23    // connection_id has to be uniquely defined by the SDK.
24    connection_id: String,
25    authorization: Option<String>,
26    client: C,
27}
28
29use crate::version::get_sdk_version;
30use serde::{de::DeserializeOwned, Serialize};
31#[doc(inline)]
32pub use shim::HttpClient;
33
34impl<C> HTTP<C>
35where
36    C: HttpClient + Default,
37{
38    /// The error type on this will change in future.
39    pub fn new(
40        app_name: String,
41        instance_id: String,
42        connection_id: String,
43        authorization: Option<String>,
44    ) -> Result<Self, C::Error> {
45        Ok(HTTP {
46            client: C::default(),
47            app_name,
48            sdk_version: get_sdk_version(),
49            connection_id,
50            instance_id,
51            authorization,
52            authorization_header: C::build_header("authorization")?,
53            app_name_header: C::build_header("appname")?,
54            unleash_app_name_header: C::build_header("unleash-appname")?,
55            unleash_sdk_header: C::build_header("unleash-sdk")?,
56            unleash_connection_id_header: C::build_header("unleash-connection-id")?,
57            instance_id_header: C::build_header("instance_id")?,
58        })
59    }
60
61    /// Perform a GET. Returns errors per HttpClient::get.
62    pub fn get(&self, uri: &str) -> C::RequestBuilder {
63        let request = self.client.get(uri);
64        self.attach_headers(request)
65    }
66
67    /// Make a get request and parse into JSON
68    pub async fn get_json<T: DeserializeOwned>(
69        &self,
70        endpoint: &str,
71        interval: Option<u64>,
72    ) -> Result<T, C::Error> {
73        let mut request = self.get(endpoint);
74        if let Some(interval) = interval {
75            request = C::header(
76                request,
77                &C::build_header("unleash-interval")?,
78                &interval.to_string(),
79            );
80        }
81        C::get_json(request).await
82    }
83
84    /// Perform a POST. Returns errors per HttpClient::post.
85    pub fn post(&self, uri: &str) -> C::RequestBuilder {
86        let request = self.client.post(uri);
87        self.attach_headers(request)
88    }
89
90    /// Encode content into JSON and post to an endpoint. Returns the statuscode
91    /// is_success() value.
92    pub async fn post_json<T: Serialize + Sync>(
93        &self,
94        endpoint: &str,
95        content: T,
96        interval: Option<u64>,
97    ) -> Result<bool, C::Error> {
98        let mut request = self.post(endpoint);
99        if let Some(interval) = interval {
100            request = C::header(
101                request,
102                &C::build_header("unleash-interval")?,
103                &interval.to_string(),
104            );
105        }
106        C::post_json(request, &content).await
107    }
108
109    fn attach_headers(&self, request: C::RequestBuilder) -> C::RequestBuilder {
110        let request = C::header(request, &self.app_name_header, self.app_name.as_str());
111        let request = C::header(
112            request,
113            &self.unleash_app_name_header,
114            self.app_name.as_str(),
115        );
116        let request = C::header(request, &self.unleash_sdk_header, self.sdk_version);
117        let request = C::header(
118            request,
119            &self.unleash_connection_id_header,
120            self.connection_id.as_str(),
121        );
122        let request = C::header(request, &self.instance_id_header, self.instance_id.as_str());
123        if let Some(auth) = &self.authorization {
124            C::header(request, &self.authorization_header.clone(), auth.as_str())
125        } else {
126            request
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use async_trait::async_trait;
135    use regex::Regex;
136    use serde_json::json;
137    use std::collections::HashMap;
138    use std::sync::{Arc, RwLock};
139    use uuid::Uuid;
140
141    #[derive(Clone, Default)]
142    struct MockHttpClient {
143        headers: Arc<RwLock<HashMap<String, String>>>,
144    }
145
146    #[async_trait]
147    impl HttpClient for MockHttpClient {
148        type Error = std::io::Error;
149        type HeaderName = String;
150        type RequestBuilder = Self;
151
152        fn build_header(name: &'static str) -> Result<Self::HeaderName, Self::Error> {
153            Ok(name.to_string())
154        }
155
156        fn header(builder: Self, key: &Self::HeaderName, value: &str) -> Self::RequestBuilder {
157            if let Ok(mut headers) = builder.headers.write() {
158                headers.insert(key.clone(), value.to_string());
159            }
160            builder
161        }
162
163        fn get(&self, _uri: &str) -> Self::RequestBuilder {
164            self.clone()
165        }
166
167        fn post(&self, _uri: &str) -> Self::RequestBuilder {
168            self.clone()
169        }
170
171        async fn get_json<T: DeserializeOwned>(
172            _req: Self::RequestBuilder,
173        ) -> Result<T, Self::Error> {
174            Ok(serde_json::from_value(json!({})).unwrap())
175        }
176
177        async fn post_json<T: Serialize + Sync>(
178            _req: Self::RequestBuilder,
179            _content: &T,
180        ) -> Result<bool, Self::Error> {
181            Ok(true)
182        }
183    }
184
185    #[tokio::test]
186    async fn test_specific_headers() {
187        let http_client = HTTP::<MockHttpClient>::new(
188            "my_app".to_string(),
189            "my_instance_id".to_string(),
190            "d512f8ec-d972-40a5-9a30-a0a6e85d93ac".to_string(),
191            Some("auth_token".to_string()),
192        )
193        .unwrap();
194
195        let _ = http_client
196            .get_json::<serde_json::Value>("http://example.com", Some(15))
197            .await;
198        let headers = &http_client.client.headers.read().unwrap();
199
200        assert_eq!(headers.get("unleash-appname").unwrap(), "my_app");
201        assert_eq!(headers.get("instance_id").unwrap(), "my_instance_id");
202        assert_eq!(
203            headers.get("unleash-connection-id").unwrap(),
204            "d512f8ec-d972-40a5-9a30-a0a6e85d93ac"
205        );
206        assert_eq!(headers.get("unleash-interval").unwrap(), "15");
207        assert_eq!(headers.get("authorization").unwrap(), "auth_token");
208
209        let version_regex = Regex::new(r"^unleash-client-rust:\d+\.\d+\.\d+$").unwrap();
210        let sdk_version = headers.get("unleash-sdk").unwrap();
211        assert!(
212            version_regex.is_match(sdk_version),
213            "Version output did not match expected format: {sdk_version}"
214        );
215
216        let connection_id = headers.get("unleash-connection-id").unwrap();
217        assert!(
218            Uuid::parse_str(connection_id).is_ok(),
219            "Connection ID is not a valid UUID"
220        );
221    }
222}