walker_common/sender/provider/
openid.rs

1use super::{Credentials, Expires, TokenProvider};
2use crate::sender::Error;
3use core::fmt::{self, Debug, Formatter};
4use std::{ops::Deref, sync::Arc};
5use tokio::sync::RwLock;
6
7#[cfg(feature = "clap")]
8use {anyhow::Context, url::Url};
9
10#[cfg(feature = "clap")]
11#[derive(Clone, Debug, PartialEq, Eq, clap::Args)]
12#[cfg_attr(feature = "clap", command(next_help_heading = "OIDC"))]
13pub struct OpenIdTokenProviderConfigArguments {
14    /// The client ID for using Open ID connect
15    #[arg(
16        id = "oidc_client_id",
17        long = "oidc-client-id",
18        env = "OIDC_CLIENT_ID",
19        requires("OpenIdTokenProviderConfigArguments")
20    )]
21    pub client_id: Option<String>,
22    /// The client secret for using Open ID connect
23    #[arg(
24        id = "oidc_client_secret",
25        long = "oidc-client-secret",
26        env = "OIDC_CLIENT_SECRET",
27        requires("OpenIdTokenProviderConfigArguments")
28    )]
29    pub client_secret: Option<String>,
30    /// The issuer URL for using Open ID connect
31    #[arg(
32        id = "oidc_issuer_url",
33        long = "oidc-issuer-url",
34        env = "OIDC_ISSUER_URL",
35        requires("OpenIdTokenProviderConfigArguments")
36    )]
37    pub issuer_url: Option<String>,
38    /// The time a token must be valid before refreshing it
39    #[arg(
40        id = "oidc_refresh_before",
41        long = "oidc-refresh-before",
42        env = "OIDC_REFRESH_BEFORE",
43        default_value = "30s"
44    )]
45    pub refresh_before: humantime::Duration,
46    /// Allows using TLS in an insecure mode when connecting the OIDC issuer (DANGER!)
47    #[arg(
48        id = "oidc_tls_insecure",
49        long = "oidc-tls-insecure",
50        default_value = "false"
51    )]
52    pub tls_insecure: bool,
53    /// Allows adding additional trust anchors
54    #[arg(
55        id = "oidc_tls_ca_certificates",
56        long = "oidc-tls-ca-certificate",
57        action = clap::ArgAction::Append,
58    )]
59    pub tls_ca_certificates: Vec<std::path::PathBuf>,
60}
61
62#[cfg(feature = "clap")]
63impl OpenIdTokenProviderConfigArguments {
64    pub async fn into_provider(self) -> anyhow::Result<Arc<dyn TokenProvider>> {
65        OpenIdTokenProviderConfig::new_provider(OpenIdTokenProviderConfig::from_args(self)).await
66    }
67}
68
69#[cfg(feature = "clap")]
70#[derive(Clone, Debug, PartialEq, Eq, clap::Args)]
71pub struct OpenIdTokenProviderConfig {
72    pub client_id: String,
73    pub client_secret: String,
74    pub issuer_url: String,
75    pub refresh_before: humantime::Duration,
76    pub tls_insecure: bool,
77    pub tls_ca_certificates: Vec<std::path::PathBuf>,
78}
79
80#[cfg(feature = "clap")]
81impl OpenIdTokenProviderConfig {
82    pub async fn new_provider(config: Option<Self>) -> anyhow::Result<Arc<dyn TokenProvider>> {
83        Ok(match config {
84            Some(config) => Arc::new(OpenIdTokenProvider::with_config(config).await?),
85            None => Arc::new(()),
86        })
87    }
88
89    pub fn from_args(arguments: OpenIdTokenProviderConfigArguments) -> Option<Self> {
90        match (
91            arguments.client_id,
92            arguments.client_secret,
93            arguments.issuer_url,
94        ) {
95            (Some(client_id), Some(client_secret), Some(issuer_url)) => {
96                Some(OpenIdTokenProviderConfig {
97                    client_id,
98                    client_secret,
99                    issuer_url,
100                    refresh_before: arguments.refresh_before,
101                    tls_insecure: arguments.tls_insecure,
102                    tls_ca_certificates: arguments.tls_ca_certificates,
103                })
104            }
105            _ => None,
106        }
107    }
108}
109
110#[cfg(feature = "clap")]
111impl From<OpenIdTokenProviderConfigArguments> for Option<OpenIdTokenProviderConfig> {
112    fn from(value: OpenIdTokenProviderConfigArguments) -> Self {
113        OpenIdTokenProviderConfig::from_args(value)
114    }
115}
116
117/// A provider which provides access tokens for clients.
118#[derive(Clone)]
119pub struct OpenIdTokenProvider {
120    client: Arc<openid::Client>,
121    current_token: Arc<RwLock<Option<openid::TemporalBearerGuard>>>,
122    refresh_before: time::Duration,
123}
124
125impl Debug for OpenIdTokenProvider {
126    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
127        f.debug_struct("TokenProvider")
128            .field(
129                "client",
130                &format!("{} / {:?}", self.client.client_id, self.client.http_client),
131            )
132            .field("current_token", &"...")
133            .finish()
134    }
135}
136
137impl OpenIdTokenProvider {
138    /// Create a new provider using the provided client.
139    pub fn new(client: openid::Client, refresh_before: time::Duration) -> Self {
140        Self {
141            client: Arc::new(client),
142            current_token: Arc::new(RwLock::new(None)),
143            refresh_before,
144        }
145    }
146
147    #[cfg(feature = "clap")]
148    pub async fn with_config(config: OpenIdTokenProviderConfig) -> anyhow::Result<Self> {
149        let issuer = Url::parse(&config.issuer_url).context("Parse issuer URL")?;
150
151        let mut client = reqwest::ClientBuilder::new();
152
153        if config.tls_insecure {
154            log::warn!("Using insecure TLS when communicating with the OIDC issuer");
155            client = client
156                .danger_accept_invalid_hostnames(true)
157                .danger_accept_invalid_certs(true);
158        }
159
160        for cert in config.tls_ca_certificates {
161            client = crate::utils::pem::add_cert(client, &cert)
162                .with_context(|| format!("adding trust anchor: {}", cert.display()))?;
163        }
164
165        let client = openid::Client::discover_with_client(
166            client.build()?,
167            config.client_id,
168            config.client_secret,
169            None,
170            issuer,
171        )
172        .await
173        .context("Discover OIDC client")?;
174
175        Ok(Self::new(
176            client,
177            time::Duration::try_from(<_ as Into<std::time::Duration>>::into(
178                config.refresh_before,
179            ))?,
180        ))
181    }
182
183    /// return a fresh token, this may be an existing (non-expired) token
184    /// a newly refreshed token.
185    pub async fn provide_token(&self) -> Result<openid::Bearer, openid::error::Error> {
186        match self.current_token.read().await.deref() {
187            Some(token) if !token.expires_before(self.refresh_before) => {
188                log::debug!("Token still valid");
189                return Ok(token.as_ref().clone());
190            }
191            _ => {}
192        }
193
194        // fetch fresh token after releasing the read lock
195
196        self.fetch_fresh_token().await
197    }
198
199    async fn fetch_fresh_token(&self) -> Result<openid::Bearer, openid::error::Error> {
200        log::debug!("Fetching fresh token...");
201
202        let mut lock = self.current_token.write().await;
203
204        match lock.deref() {
205            // check if someone else refreshed the token in the meantime
206            Some(token) if !token.expires_before(self.refresh_before) => {
207                log::debug!("Token already got refreshed");
208                return Ok(token.as_ref().clone());
209            }
210            _ => {}
211        }
212
213        // we hold the write-lock now, and can perform the refresh operation
214
215        let next_token = match lock.take() {
216            // if we don't have any token, fetch an initial one
217            None => {
218                log::debug!("Fetching initial token... ");
219                self.initial_token().await?
220            }
221            // if we have an expired one, refresh it
222            Some(current_token) => {
223                log::debug!("Refreshing token ... ");
224                match current_token.as_ref().refresh_token.is_some() {
225                    true => self.client.refresh_token(current_token, None).await?.into(),
226                    false => self.initial_token().await?,
227                }
228            }
229        };
230
231        log::debug!("Next token: {:?}", next_token.as_ref());
232
233        let result = next_token.as_ref().clone();
234        lock.replace(next_token);
235
236        // done
237
238        Ok(result)
239    }
240
241    async fn initial_token(&self) -> Result<openid::TemporalBearerGuard, openid::error::Error> {
242        Ok(self
243            .client
244            .request_token_using_client_credentials(None)
245            .await?
246            .into())
247    }
248}
249
250#[async_trait::async_trait]
251impl TokenProvider for OpenIdTokenProvider {
252    async fn provide_access_token(&self) -> Result<Option<Credentials>, Error> {
253        Ok(self
254            .provide_token()
255            .await
256            .map(|token| Some(Credentials::Bearer(token.access_token)))?)
257    }
258}