Skip to main content

scion_sdk_reqwest_connect_rpc/
token_source.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! Token source trait for the connect RPC client.
15
16use async_trait::async_trait;
17use tokio::sync::watch;
18
19pub mod mock;
20pub mod refresh;
21pub mod static_token;
22
23/// The error type for token sources.
24pub type TokenSourceError = Box<dyn std::error::Error + Sync + Send>;
25/// A watch receiver for token source updates.
26pub type TokenSourceWatch = watch::Receiver<Option<Result<String, TokenSourceError>>>;
27
28/// A source for authentication tokens.
29#[async_trait]
30pub trait TokenSource: Send + Sync + 'static {
31    /// Returns a watch receiver that always holds the latest valid token.
32    ///
33    /// The receiver allows both grabbing the current value immediately
34    /// and awaiting updates.
35    fn watch(&self) -> TokenSourceWatch;
36
37    /// Gets a token, possibly refreshing it.
38    ///
39    /// If the token cannot be obtained, returns a `TokenSourceError`.
40    ///
41    /// Prefer using `watch` if a subscription to token updates is needed.
42    ///
43    /// ### Implementation Note
44    ///
45    /// The default implementation uses the watch channel to get the latest token.
46    ///``
47    /// - Should be efficient to call multiple times.
48    /// - Errors should be returned if no valid token can be obtained.
49    /// - Should try to not return errors as long as a valid token is available.
50    async fn get_token(&self) -> Result<String, TokenSourceError> {
51        let mut watch = self.watch();
52
53        // First, try to get the current value without waiting. and return immediately if available.
54        match watch.borrow_and_update().as_ref() {
55            Some(Ok(token)) => return Ok(token.clone()),
56            Some(Err(e)) => return Err(e.to_string().into()),
57            None => {}
58        }
59
60        // If there is no current value, wait for an update.
61        watch.changed().await.map_err(|_| {
62            Box::<dyn std::error::Error + Sync + Send>::from("token source watch channel closed")
63        })?;
64
65        // After being notified, get the updated value.
66        match watch.borrow().as_ref() {
67            Some(Ok(token)) => Ok(token.clone()),
68            Some(Err(e)) => Err(e.to_string().into()),
69            None => {
70                Err(Box::<dyn std::error::Error + Sync + Send>::from(
71                    "token source watch channel has no value",
72                ))
73            }
74        }
75    }
76
77    /// Formats the token for use in an `Authorization` header.
78    ///
79    /// The default implementation formats the token as a Bearer token.
80    /// Override this method if a different format is required.
81    fn format_header(&self, token: String) -> String {
82        format!("Bearer {token}")
83    }
84}