use super::{Credentials, Expires, TokenProvider};
use crate::sender::Error;
use core::fmt::{self, Debug, Formatter};
use std::{ops::Deref, sync::Arc};
use tokio::sync::RwLock;
#[cfg(feature = "clap")]
use {anyhow::Context, url::Url};
#[cfg(feature = "clap")]
#[derive(Clone, Debug, PartialEq, Eq, clap::Args)]
#[cfg_attr(feature = "clap", command(next_help_heading = "OIDC"))]
pub struct OpenIdTokenProviderConfigArguments {
#[arg(
id = "oidc_client_id",
long = "oidc-client-id",
requires("OpenIdTokenProviderConfigArguments")
)]
pub client_id: Option<String>,
#[arg(
id = "oidc_client_secret",
long = "oidc-client-secret",
requires("OpenIdTokenProviderConfigArguments")
)]
pub client_secret: Option<String>,
#[arg(
id = "oidc_issuer_url",
long = "oidc-issuer-url",
requires("OpenIdTokenProviderConfigArguments")
)]
pub issuer_url: Option<String>,
#[arg(
id = "oidc_refresh_before",
long = "oidc-refresh-before",
default_value = "30s"
)]
pub refresh_before: humantime::Duration,
#[arg(
id = "oidc_tls_insecure",
long = "oidc-tls-insecure",
default_value = "false"
)]
pub tls_insecure: bool,
#[arg(
id = "oidc_tls_ca_certificates",
long = "oidc-tls-ca-certificate",
action = clap::ArgAction::Append,
)]
pub tls_ca_certificates: Vec<std::path::PathBuf>,
}
#[cfg(feature = "clap")]
impl OpenIdTokenProviderConfigArguments {
pub async fn into_provider(self) -> anyhow::Result<Arc<dyn TokenProvider>> {
OpenIdTokenProviderConfig::new_provider(OpenIdTokenProviderConfig::from_args(self)).await
}
}
#[cfg(feature = "clap")]
#[derive(Clone, Debug, PartialEq, Eq, clap::Args)]
pub struct OpenIdTokenProviderConfig {
pub client_id: String,
pub client_secret: String,
pub issuer_url: String,
pub refresh_before: humantime::Duration,
pub tls_insecure: bool,
pub tls_ca_certificates: Vec<std::path::PathBuf>,
}
#[cfg(feature = "clap")]
impl OpenIdTokenProviderConfig {
pub async fn new_provider(config: Option<Self>) -> anyhow::Result<Arc<dyn TokenProvider>> {
Ok(match config {
Some(config) => Arc::new(OpenIdTokenProvider::with_config(config).await?),
None => Arc::new(()),
})
}
pub fn from_args(arguments: OpenIdTokenProviderConfigArguments) -> Option<Self> {
match (
arguments.client_id,
arguments.client_secret,
arguments.issuer_url,
) {
(Some(client_id), Some(client_secret), Some(issuer_url)) => {
Some(OpenIdTokenProviderConfig {
client_id,
client_secret,
issuer_url,
refresh_before: arguments.refresh_before,
tls_insecure: arguments.tls_insecure,
tls_ca_certificates: arguments.tls_ca_certificates,
})
}
_ => None,
}
}
}
#[cfg(feature = "clap")]
impl From<OpenIdTokenProviderConfigArguments> for Option<OpenIdTokenProviderConfig> {
fn from(value: OpenIdTokenProviderConfigArguments) -> Self {
OpenIdTokenProviderConfig::from_args(value)
}
}
#[derive(Clone)]
pub struct OpenIdTokenProvider {
client: Arc<openid::Client>,
current_token: Arc<RwLock<Option<openid::TemporalBearerGuard>>>,
refresh_before: time::Duration,
}
impl Debug for OpenIdTokenProvider {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("TokenProvider")
.field(
"client",
&format!("{} / {:?}", self.client.client_id, self.client.http_client),
)
.field("current_token", &"...")
.finish()
}
}
impl OpenIdTokenProvider {
pub fn new(client: openid::Client, refresh_before: time::Duration) -> Self {
Self {
client: Arc::new(client),
current_token: Arc::new(RwLock::new(None)),
refresh_before,
}
}
#[cfg(feature = "clap")]
pub async fn with_config(config: OpenIdTokenProviderConfig) -> anyhow::Result<Self> {
let issuer = Url::parse(&config.issuer_url).context("Parse issuer URL")?;
let mut client = reqwest::ClientBuilder::new();
if config.tls_insecure {
log::warn!("Using insecure TLS when communicating with the OIDC issuer");
client = client
.danger_accept_invalid_hostnames(true)
.danger_accept_invalid_certs(true);
}
for cert in config.tls_ca_certificates {
client = crate::utils::pem::add_cert(client, &cert)
.with_context(|| format!("adding trust anchor: {}", cert.display()))?;
}
let client = openid::Client::discover_with_client(
client.build()?,
config.client_id,
config.client_secret,
None,
issuer,
)
.await
.context("Discover OIDC client")?;
Ok(Self::new(
client,
time::Duration::try_from(<_ as Into<std::time::Duration>>::into(
config.refresh_before,
))?,
))
}
pub async fn provide_token(&self) -> Result<openid::Bearer, openid::error::Error> {
match self.current_token.read().await.deref() {
Some(token) if !token.expires_before(self.refresh_before) => {
log::debug!("Token still valid");
return Ok(token.as_ref().clone());
}
_ => {}
}
self.fetch_fresh_token().await
}
async fn fetch_fresh_token(&self) -> Result<openid::Bearer, openid::error::Error> {
log::debug!("Fetching fresh token...");
let mut lock = self.current_token.write().await;
match lock.deref() {
Some(token) if !token.expires_before(self.refresh_before) => {
log::debug!("Token already got refreshed");
return Ok(token.as_ref().clone());
}
_ => {}
}
let next_token = match lock.take() {
None => {
log::debug!("Fetching initial token... ");
self.initial_token().await?
}
Some(current_token) => {
log::debug!("Refreshing token ... ");
match current_token.as_ref().refresh_token.is_some() {
true => self.client.refresh_token(current_token, None).await?.into(),
false => self.initial_token().await?,
}
}
};
log::debug!("Next token: {:?}", next_token.as_ref());
let result = next_token.as_ref().clone();
lock.replace(next_token);
Ok(result)
}
async fn initial_token(&self) -> Result<openid::TemporalBearerGuard, openid::error::Error> {
Ok(self
.client
.request_token_using_client_credentials(None)
.await?
.into())
}
}
#[async_trait::async_trait]
impl TokenProvider for OpenIdTokenProvider {
async fn provide_access_token(&self) -> Result<Option<Credentials>, Error> {
Ok(self
.provide_token()
.await
.map(|token| Some(Credentials::Bearer(token.access_token)))?)
}
}