Skip to main content

sts_cat/
exchange.rs

1use crate::error::Error;
2
3#[derive(serde::Deserialize)]
4pub struct ExchangeRequest {
5    pub scope: String,
6    pub identity: String,
7}
8
9#[derive(serde::Serialize)]
10pub struct ExchangeResponse {
11    pub token: GitHubToken,
12}
13
14/// Inner type for GitHub installation access tokens.
15/// Implements SerializableSecret intentionally — sts-cat is a token vending service.
16#[derive(
17    Clone, Debug, serde::Serialize, serde::Deserialize, zeroize::Zeroize, zeroize::ZeroizeOnDrop,
18)]
19pub struct GitHubTokenInner(String);
20impl secrecy::SerializableSecret for GitHubTokenInner {}
21impl secrecy::CloneableSecret for GitHubTokenInner {}
22
23impl GitHubTokenInner {
24    pub fn as_str(&self) -> &str {
25        &self.0
26    }
27}
28
29impl From<String> for GitHubTokenInner {
30    fn from(s: String) -> Self {
31        Self(s)
32    }
33}
34
35pub type GitHubToken = secrecy::SecretBox<GitHubTokenInner>;
36
37pub struct AppState {
38    pub config: crate::config::Config,
39    pub github: crate::github::GitHubClient,
40    pub oidc: crate::oidc::OidcVerifier,
41    org_repos: std::collections::HashMap<String, String>,
42    policy_cache: moka::future::Cache<
43        (String, String, String),
44        std::sync::Arc<crate::trust_policy::CompiledTrustPolicy>,
45    >,
46    installation_cache: moka::future::Cache<String, u64>,
47}
48
49impl AppState {
50    pub async fn build(
51        config: crate::config::Config,
52    ) -> Result<std::sync::Arc<Self>, anyhow::Error> {
53        let signer = config.build_signer().await?;
54        let github =
55            crate::github::GitHubClient::new(&config.github_api_url, &config.github_app_id, signer);
56        let oidc = crate::oidc::OidcVerifier::new(config.allowed_issuer_urls.clone());
57        let org_repos = config.parse_org_repos()?;
58
59        let policy_cache = moka::future::Cache::builder()
60            .max_capacity(200)
61            .time_to_live(std::time::Duration::from_secs(300))
62            .build();
63
64        let installation_cache = moka::future::Cache::builder()
65            .max_capacity(200)
66            .time_to_live(std::time::Duration::from_secs(3600))
67            .build();
68
69        Ok(std::sync::Arc::new(Self {
70            config,
71            github,
72            oidc,
73            org_repos,
74            policy_cache,
75            installation_cache,
76        }))
77    }
78}
79
80pub async fn handle_exchange(
81    axum::extract::State(state): axum::extract::State<std::sync::Arc<AppState>>,
82    bearer: Result<
83        axum_extra::TypedHeader<headers::Authorization<headers::authorization::Bearer>>,
84        axum_extra::typed_header::TypedHeaderRejection,
85    >,
86    axum::Json(req): axum::Json<ExchangeRequest>,
87) -> Result<axum::Json<ExchangeResponse>, Error> {
88    let axum_extra::TypedHeader(authorization) = bearer
89        .map_err(|_| Error::Unauthenticated("missing or invalid Authorization header".into()))?;
90    let bearer_token = authorization.token();
91
92    if req.scope.is_empty() {
93        return Err(Error::BadRequest("scope must not be empty".into()));
94    }
95    if req.identity.is_empty() {
96        return Err(Error::BadRequest("identity must not be empty".into()));
97    }
98    if !is_valid_name(&req.identity) {
99        return Err(Error::BadRequest("invalid identity format".into()));
100    }
101
102    let (owner, mut repo, is_org_level) = parse_scope(&req.scope)?;
103    let owner = owner.to_ascii_lowercase();
104    if is_org_level && let Some(override_repo) = state.org_repos.get(&owner) {
105        repo = override_repo.clone();
106    }
107    let claims = state.oidc.verify(bearer_token).await?;
108
109    let installation_id = if let Some(id) = state.installation_cache.get(&owner).await {
110        id
111    } else {
112        let id = state.github.get_installation_id(&owner).await?;
113        state.installation_cache.insert(owner.clone(), id).await;
114        id
115    };
116
117    let policy_path = format!(
118        "{}/{}{}",
119        state.config.policy_path_prefix, req.identity, state.config.policy_file_extension
120    );
121
122    let cache_key = (owner.clone(), repo.clone(), req.identity.clone());
123    let compiled = if let Some(cached) = state.policy_cache.get(&cache_key).await {
124        cached
125    } else {
126        let content = state
127            .github
128            .get_trust_policy_content(installation_id, &owner, &repo, &policy_path)
129            .await?;
130        let policy = crate::trust_policy::TrustPolicy::parse(&content)?;
131        let compiled = std::sync::Arc::new(policy.compile(is_org_level)?);
132        state.policy_cache.insert(cache_key, compiled.clone()).await;
133        compiled
134    };
135
136    let actor = match compiled.check_token(&claims, &state.config.identifier) {
137        Ok(actor) => actor,
138        Err(e) => {
139            tracing::warn!(
140                event = "exchange_denied",
141                scope = %req.scope,
142                identity = %req.identity,
143                issuer = %claims.iss,
144                subject = %claims.sub,
145                reason = %e,
146            );
147            return Err(e);
148        }
149    };
150
151    tracing::info!(
152        event = "exchange_authorized",
153        scope = %req.scope,
154        identity = %req.identity,
155        issuer = %actor.issuer,
156        subject = %actor.subject,
157        installation_id = installation_id,
158        policy_path = %policy_path,
159    );
160
161    let repositories = if is_org_level {
162        compiled.repositories.clone().unwrap_or_default()
163    } else {
164        vec![repo.clone()]
165    };
166
167    let token = state
168        .github
169        .create_installation_token(installation_id, &compiled.permissions, &repositories)
170        .await?;
171
172    use secrecy::ExposeSecret as _;
173    use sha2::Digest as _;
174    let token_hash = hex::encode(sha2::Sha256::digest(
175        token.expose_secret().as_str().as_bytes(),
176    ));
177    tracing::info!(
178        event = "exchange_success",
179        scope = %req.scope,
180        identity = %req.identity,
181        issuer = %actor.issuer,
182        subject = %actor.subject,
183        installation_id = installation_id,
184        token_sha256 = %token_hash,
185    );
186
187    Ok(axum::Json(ExchangeResponse { token }))
188}
189
190#[derive(serde::Serialize)]
191pub(crate) struct HealthResponse {
192    ok: bool,
193}
194
195pub(crate) async fn handle_healthz() -> axum::Json<HealthResponse> {
196    axum::Json(HealthResponse { ok: true })
197}
198
199fn is_valid_name(s: &str) -> bool {
200    static RE: std::sync::LazyLock<regex::Regex> =
201        std::sync::LazyLock::new(|| regex::Regex::new(r"^[a-zA-Z0-9._-]+$").unwrap());
202    RE.is_match(s)
203}
204
205fn parse_scope(scope: &str) -> Result<(String, String, bool), Error> {
206    if let Some((owner, repo)) = scope.split_once('/') {
207        if owner.is_empty() || repo.is_empty() {
208            return Err(Error::BadRequest("invalid scope format".into()));
209        }
210        if !is_valid_name(owner) || !is_valid_name(repo) {
211            return Err(Error::BadRequest("invalid scope format".into()));
212        }
213        let is_org_level = repo == ".github";
214        Ok((owner.to_owned(), repo.to_owned(), is_org_level))
215    } else {
216        // Org-level scope: "org" → reads from ".github" repo
217        if scope.is_empty() {
218            return Err(Error::BadRequest("invalid scope format".into()));
219        }
220        if !is_valid_name(scope) {
221            return Err(Error::BadRequest("invalid scope format".into()));
222        }
223        Ok((scope.to_owned(), ".github".to_owned(), true))
224    }
225}
226
227pub fn build_router(state: std::sync::Arc<AppState>) -> axum::Router {
228    let server_header = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
229    axum::Router::new()
230        .route("/token", axum::routing::post(handle_exchange))
231        .route("/healthz", axum::routing::get(handle_healthz))
232        .layer(axum::middleware::from_fn(
233            move |req, next: axum::middleware::Next| {
234                let val = server_header.clone();
235                async move {
236                    let mut resp = next.run(req).await;
237                    resp.headers_mut().insert(
238                        axum::http::header::SERVER,
239                        axum::http::HeaderValue::from_str(&val).unwrap(),
240                    );
241                    resp
242                }
243            },
244        ))
245        .with_state(state)
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_parse_scope_repo() {
254        let (owner, repo, is_org) = parse_scope("myorg/myrepo").unwrap();
255        assert_eq!(owner, "myorg");
256        assert_eq!(repo, "myrepo");
257        assert!(!is_org);
258    }
259
260    #[test]
261    fn test_parse_scope_org() {
262        let (owner, repo, is_org) = parse_scope("myorg").unwrap();
263        assert_eq!(owner, "myorg");
264        assert_eq!(repo, ".github");
265        assert!(is_org);
266    }
267
268    #[test]
269    fn test_parse_scope_org_dotgithub() {
270        let (owner, repo, is_org) = parse_scope("myorg/.github").unwrap();
271        assert_eq!(owner, "myorg");
272        assert_eq!(repo, ".github");
273        assert!(is_org);
274    }
275
276    #[test]
277    fn test_parse_scope_empty() {
278        assert!(parse_scope("").is_err());
279        assert!(parse_scope("/repo").is_err());
280        assert!(parse_scope("owner/").is_err());
281    }
282
283    #[test]
284    fn test_parse_scope_rejects_invalid_chars() {
285        assert!(parse_scope("org/../evil").is_err());
286        assert!(parse_scope("org/repo name").is_err());
287        assert!(parse_scope("org/<script>").is_err());
288        assert!(parse_scope("org\0evil").is_err());
289    }
290
291    #[test]
292    fn test_is_valid_name() {
293        assert!(is_valid_name("my-repo"));
294        assert!(is_valid_name("my.repo"));
295        assert!(is_valid_name("my_repo"));
296        assert!(is_valid_name(".github"));
297        assert!(!is_valid_name("../etc/passwd"));
298        assert!(!is_valid_name("repo name"));
299        assert!(!is_valid_name("repo/name"));
300        assert!(!is_valid_name(""));
301    }
302}