weaviate_community/
oidc.rs

1/// https://weaviate.io/developers/weaviate/api/rest/well-known
2use reqwest::Url;
3use std::error::Error;
4use std::sync::Arc;
5
6use crate::collections::error::NotConfiguredError;
7use crate::collections::oidc::OidcResponse;
8
9#[derive(Debug)]
10pub struct Oidc {
11    endpoint: Url,
12    client: Arc<reqwest::Client>,
13}
14
15impl Oidc {
16    pub(super) fn new(url: &Url, client: Arc<reqwest::Client>) -> Result<Self, Box<dyn Error>> {
17        let endpoint = url.join("/v1/.well-known")?;
18        Ok(Oidc { endpoint, client })
19    }
20
21    /// Get OIDC information if OpenID Connect (OIDC) authentication is enabled. The endpoint
22    /// redirects to the token issued if one is configured.
23    ///
24    /// The redirect will return the following fields:
25    /// - href      => The reference to the client
26    /// - cliendID  => The ID of the client
27    ///
28    /// # Examples
29    ///
30    /// GET /v1/.well-known/openid-configuration
31    /// ```
32    /// ```
33    pub async fn get_open_id_configuration(&self) -> Result<OidcResponse, Box<dyn Error>> {
34        let endpoint = self.endpoint.join("/openid-configuration")?;
35        let resp = self.client.get(endpoint).send().await?;
36        match resp.status() {
37            reqwest::StatusCode::OK => {
38                let parsed: OidcResponse = resp.json::<OidcResponse>().await?;
39                Ok(parsed)
40            }
41            _ => Err(Box::new(NotConfiguredError(
42                "OIDC is not configured or is unavailable".into(),
43            ))),
44        }
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use crate::{collections::oidc::OidcResponse, WeaviateClient};
51
52    fn test_oidc_response() -> OidcResponse {
53        let response: OidcResponse = serde_json::from_value(
54            serde_json::json!({
55                "clientId": "wcs",
56                "href": "https://auth.wcs.api.weaviate.io/auth/realms/SeMI/.well-known/openid-configuration"
57            })
58        ).unwrap();
59        response
60    }
61
62    fn get_test_harness() -> (mockito::ServerGuard, WeaviateClient) {
63        let mock_server = mockito::Server::new();
64        let mut host = "http://".to_string();
65        host.push_str(&mock_server.host_with_port());
66        let client = WeaviateClient::builder(&host).build().unwrap();
67        (mock_server, client)
68    }
69
70    fn mock_get(
71        server: &mut mockito::ServerGuard,
72        endpoint: &str,
73        status_code: usize,
74        body: &str,
75    ) -> mockito::Mock {
76        server
77            .mock("GET", endpoint)
78            .with_status(status_code)
79            .with_header("content-type", "application/json")
80            .with_body(body)
81            .create()
82    }
83
84    #[tokio::test]
85    async fn test_get_open_id_configuration_ok() {
86        let resp = test_oidc_response();
87        let resp_str = serde_json::to_string(&resp).unwrap();
88        let (mut mock_server, client) = get_test_harness();
89        let mock = mock_get(&mut mock_server, "/openid-configuration", 200, &resp_str);
90        let res = client.oidc.get_open_id_configuration().await;
91        mock.assert();
92        assert!(res.is_ok());
93        assert_eq!(resp.client_id, res.unwrap().client_id);
94    }
95
96    #[tokio::test]
97    async fn test_get_open_id_configuration_err() {
98        let (mut mock_server, client) = get_test_harness();
99        let mock = mock_get(&mut mock_server, "/openid-configuration", 404, "");
100        let res = client.oidc.get_open_id_configuration().await;
101        mock.assert();
102        assert!(res.is_err());
103    }
104}