walker_common/sender/provider/
openid.rs1use super::{Credentials, Expires, TokenProvider};
2use crate::sender::Error;
3use core::fmt::{self, Debug, Formatter};
4use std::{ops::Deref, sync::Arc};
5use tokio::sync::RwLock;
6
7#[cfg(feature = "clap")]
8use {anyhow::Context, url::Url};
9
10#[cfg(feature = "clap")]
11#[derive(Clone, Debug, PartialEq, Eq, clap::Args)]
12#[cfg_attr(feature = "clap", command(next_help_heading = "OIDC"))]
13pub struct OpenIdTokenProviderConfigArguments {
14 #[arg(
16 id = "oidc_client_id",
17 long = "oidc-client-id",
18 env = "OIDC_CLIENT_ID",
19 requires("OpenIdTokenProviderConfigArguments")
20 )]
21 pub client_id: Option<String>,
22 #[arg(
24 id = "oidc_client_secret",
25 long = "oidc-client-secret",
26 env = "OIDC_CLIENT_SECRET",
27 requires("OpenIdTokenProviderConfigArguments")
28 )]
29 pub client_secret: Option<String>,
30 #[arg(
32 id = "oidc_issuer_url",
33 long = "oidc-issuer-url",
34 env = "OIDC_ISSUER_URL",
35 requires("OpenIdTokenProviderConfigArguments")
36 )]
37 pub issuer_url: Option<String>,
38 #[arg(
40 id = "oidc_refresh_before",
41 long = "oidc-refresh-before",
42 env = "OIDC_REFRESH_BEFORE",
43 default_value = "30s"
44 )]
45 pub refresh_before: humantime::Duration,
46 #[arg(
48 id = "oidc_tls_insecure",
49 long = "oidc-tls-insecure",
50 default_value = "false"
51 )]
52 pub tls_insecure: bool,
53 #[arg(
55 id = "oidc_tls_ca_certificates",
56 long = "oidc-tls-ca-certificate",
57 action = clap::ArgAction::Append,
58 )]
59 pub tls_ca_certificates: Vec<std::path::PathBuf>,
60}
61
62#[cfg(feature = "clap")]
63impl OpenIdTokenProviderConfigArguments {
64 pub async fn into_provider(self) -> anyhow::Result<Arc<dyn TokenProvider>> {
65 OpenIdTokenProviderConfig::new_provider(OpenIdTokenProviderConfig::from_args(self)).await
66 }
67}
68
69#[cfg(feature = "clap")]
70#[derive(Clone, Debug, PartialEq, Eq, clap::Args)]
71pub struct OpenIdTokenProviderConfig {
72 pub client_id: String,
73 pub client_secret: String,
74 pub issuer_url: String,
75 pub refresh_before: humantime::Duration,
76 pub tls_insecure: bool,
77 pub tls_ca_certificates: Vec<std::path::PathBuf>,
78}
79
80#[cfg(feature = "clap")]
81impl OpenIdTokenProviderConfig {
82 pub async fn new_provider(config: Option<Self>) -> anyhow::Result<Arc<dyn TokenProvider>> {
83 Ok(match config {
84 Some(config) => Arc::new(OpenIdTokenProvider::with_config(config).await?),
85 None => Arc::new(()),
86 })
87 }
88
89 pub fn from_args(arguments: OpenIdTokenProviderConfigArguments) -> Option<Self> {
90 match (
91 arguments.client_id,
92 arguments.client_secret,
93 arguments.issuer_url,
94 ) {
95 (Some(client_id), Some(client_secret), Some(issuer_url)) => {
96 Some(OpenIdTokenProviderConfig {
97 client_id,
98 client_secret,
99 issuer_url,
100 refresh_before: arguments.refresh_before,
101 tls_insecure: arguments.tls_insecure,
102 tls_ca_certificates: arguments.tls_ca_certificates,
103 })
104 }
105 _ => None,
106 }
107 }
108}
109
110#[cfg(feature = "clap")]
111impl From<OpenIdTokenProviderConfigArguments> for Option<OpenIdTokenProviderConfig> {
112 fn from(value: OpenIdTokenProviderConfigArguments) -> Self {
113 OpenIdTokenProviderConfig::from_args(value)
114 }
115}
116
117#[derive(Clone)]
119pub struct OpenIdTokenProvider {
120 client: Arc<openid::Client>,
121 current_token: Arc<RwLock<Option<openid::TemporalBearerGuard>>>,
122 refresh_before: time::Duration,
123}
124
125impl Debug for OpenIdTokenProvider {
126 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
127 f.debug_struct("TokenProvider")
128 .field(
129 "client",
130 &format!("{} / {:?}", self.client.client_id, self.client.http_client),
131 )
132 .field("current_token", &"...")
133 .finish()
134 }
135}
136
137impl OpenIdTokenProvider {
138 pub fn new(client: openid::Client, refresh_before: time::Duration) -> Self {
140 Self {
141 client: Arc::new(client),
142 current_token: Arc::new(RwLock::new(None)),
143 refresh_before,
144 }
145 }
146
147 #[cfg(feature = "clap")]
148 pub async fn with_config(config: OpenIdTokenProviderConfig) -> anyhow::Result<Self> {
149 let issuer = Url::parse(&config.issuer_url).context("Parse issuer URL")?;
150
151 let mut client = reqwest::ClientBuilder::new();
152
153 if config.tls_insecure {
154 log::warn!("Using insecure TLS when communicating with the OIDC issuer");
155 client = client
156 .danger_accept_invalid_hostnames(true)
157 .danger_accept_invalid_certs(true);
158 }
159
160 for cert in config.tls_ca_certificates {
161 client = crate::utils::pem::add_cert(client, &cert)
162 .with_context(|| format!("adding trust anchor: {}", cert.display()))?;
163 }
164
165 let client = openid::Client::discover_with_client(
166 client.build()?,
167 config.client_id,
168 config.client_secret,
169 None,
170 issuer,
171 )
172 .await
173 .context("Discover OIDC client")?;
174
175 Ok(Self::new(
176 client,
177 time::Duration::try_from(<_ as Into<std::time::Duration>>::into(
178 config.refresh_before,
179 ))?,
180 ))
181 }
182
183 pub async fn provide_token(&self) -> Result<openid::Bearer, openid::error::Error> {
186 match self.current_token.read().await.deref() {
187 Some(token) if !token.expires_before(self.refresh_before) => {
188 log::debug!("Token still valid");
189 return Ok(token.as_ref().clone());
190 }
191 _ => {}
192 }
193
194 self.fetch_fresh_token().await
197 }
198
199 async fn fetch_fresh_token(&self) -> Result<openid::Bearer, openid::error::Error> {
200 log::debug!("Fetching fresh token...");
201
202 let mut lock = self.current_token.write().await;
203
204 match lock.deref() {
205 Some(token) if !token.expires_before(self.refresh_before) => {
207 log::debug!("Token already got refreshed");
208 return Ok(token.as_ref().clone());
209 }
210 _ => {}
211 }
212
213 let next_token = match lock.take() {
216 None => {
218 log::debug!("Fetching initial token... ");
219 self.initial_token().await?
220 }
221 Some(current_token) => {
223 log::debug!("Refreshing token ... ");
224 match current_token.as_ref().refresh_token.is_some() {
225 true => self.client.refresh_token(current_token, None).await?.into(),
226 false => self.initial_token().await?,
227 }
228 }
229 };
230
231 log::debug!("Next token: {:?}", next_token.as_ref());
232
233 let result = next_token.as_ref().clone();
234 lock.replace(next_token);
235
236 Ok(result)
239 }
240
241 async fn initial_token(&self) -> Result<openid::TemporalBearerGuard, openid::error::Error> {
242 Ok(self
243 .client
244 .request_token_using_client_credentials(None)
245 .await?
246 .into())
247 }
248}
249
250#[async_trait::async_trait]
251impl TokenProvider for OpenIdTokenProvider {
252 async fn provide_access_token(&self) -> Result<Option<Credentials>, Error> {
253 Ok(self
254 .provide_token()
255 .await
256 .map(|token| Some(Credentials::Bearer(token.access_token)))?)
257 }
258}