Skip to main content

vorpal_sdk/
context.rs

1use crate::{
2    api::{
3        agent::{agent_service_client::AgentServiceClient, PrepareArtifactRequest},
4        artifact::{
5            artifact_service_client::ArtifactServiceClient, Artifact, ArtifactRequest,
6            ArtifactSystem, ArtifactsRequest, ArtifactsResponse, GetArtifactAliasRequest,
7        },
8        context::context_service_server::{ContextService, ContextServiceServer},
9    },
10    artifact::system::get_system,
11    cli::{Cli, Command},
12};
13use anyhow::{anyhow, bail, Context, Result};
14use clap::Parser;
15use http::uri::{InvalidUri, Uri};
16use oauth2::{basic::BasicClient, AuthUrl, ClientId, RefreshToken, TokenResponse, TokenUrl};
17use serde::{Deserialize, Serialize};
18use sha256::digest;
19use std::{
20    collections::{BTreeMap, HashMap},
21    path::{Path, PathBuf},
22};
23use tokio::{
24    fs::{read, OpenOptions},
25    io::AsyncWriteExt,
26};
27use tonic::{
28    metadata::{Ascii, MetadataValue},
29    transport::{Certificate, Channel, ClientTlsConfig, Server},
30    Code::NotFound,
31    Request, Response, Status,
32};
33use tracing::info;
34
35#[derive(Clone)]
36pub struct ConfigContextStore {
37    artifact: HashMap<String, Artifact>,
38    artifact_input_cache: HashMap<String, String>,
39    variable: HashMap<String, String>,
40}
41
42#[derive(Clone)]
43pub struct ConfigContext {
44    artifact: String,
45    artifact_context: PathBuf,
46    artifact_namespace: String,
47    artifact_system: ArtifactSystem,
48    artifact_unlock: bool,
49    client_agent: AgentServiceClient<Channel>,
50    client_artifact: ArtifactServiceClient<Channel>,
51    port: u16,
52    registry: String,
53    store: ConfigContextStore,
54}
55
56#[derive(Clone)]
57pub struct ConfigServer {
58    pub store: ConfigContextStore,
59}
60
61#[derive(Debug, Deserialize, Serialize)]
62pub struct VorpalCredentialsContent {
63    pub access_token: String,
64    #[serde(default, skip_serializing_if = "Option::is_none")]
65    pub audience: Option<String>,
66    pub client_id: String,
67    pub expires_in: u64,
68    pub issued_at: u64,
69    pub refresh_token: String,
70    pub scopes: Vec<String>,
71}
72
73#[derive(Debug, Deserialize, Serialize)]
74pub struct VorpalCredentials {
75    pub issuer: BTreeMap<String, VorpalCredentialsContent>,
76    pub registry: BTreeMap<String, String>,
77}
78
79/// Default namespace when none is specified in an artifact alias.
80pub const DEFAULT_NAMESPACE: &str = "library";
81
82/// Default tag when none is specified in an artifact alias.
83pub const DEFAULT_TAG: &str = "latest";
84
85/// Parsed components of an artifact alias.
86///
87/// Alias format: `[<namespace>/]<name>[:<tag>]`
88/// - namespace defaults to [`DEFAULT_NAMESPACE`] when omitted
89/// - tag defaults to [`DEFAULT_TAG`] when omitted
90#[derive(Clone, Debug, PartialEq)]
91pub struct ArtifactAlias {
92    pub name: String,
93    pub namespace: String,
94    pub tag: String,
95}
96
97/// Returns `true` if `s` is non-empty and every character is in the allowed set
98/// for alias components: alphanumeric (`a-z`, `A-Z`, `0-9`), hyphens (`-`),
99/// dots (`.`), underscores (`_`), and plus signs (`+`).
100fn is_valid_component(s: &str) -> bool {
101    !s.is_empty()
102        && s.chars()
103            .all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '.' | '_' | '+'))
104}
105
106/// Parses an artifact alias string into its components.
107///
108/// Format: `[<namespace>/]<name>[:<tag>]`
109/// - namespace is optional (defaults to [`DEFAULT_NAMESPACE`])
110/// - tag is optional (defaults to [`DEFAULT_TAG`])
111/// - name is required
112///
113/// Each component (name, namespace, tag) may only contain alphanumeric
114/// characters, hyphens, dots, underscores, and plus signs.
115///
116/// This mirrors the Go implementation in `sdk/go/pkg/config/context.go`.
117pub fn parse_artifact_alias(alias: &str) -> Result<ArtifactAlias> {
118    if alias.is_empty() {
119        bail!("alias cannot be empty");
120    }
121
122    if alias.len() > 255 {
123        bail!("alias too long (max 255 characters)");
124    }
125
126    // Step 1: Extract tag (split on rightmost ':')
127    let (base, tag) = match alias.rsplit_once(':') {
128        Some((_, "")) => bail!("tag cannot be empty"),
129        Some((b, t)) => (b, t.to_string()),
130        None => (alias, String::new()),
131    };
132
133    // Step 2: Extract namespace/name (split on '/')
134    let (namespace, name) = match base.split_once('/') {
135        Some(("", _)) => bail!("namespace cannot be empty"),
136        Some((_ns, rest)) if rest.contains('/') => {
137            bail!("invalid format: too many path separators")
138        }
139        Some((ns, name)) => (ns.to_string(), name.to_string()),
140        None => (String::new(), base.to_string()),
141    };
142
143    if name.is_empty() {
144        bail!("name is required");
145    }
146
147    // Step 3: Validate component characters
148    if !is_valid_component(&name) {
149        bail!("name contains invalid characters (allowed: alphanumeric, hyphens, dots, underscores, plus signs)");
150    }
151
152    if !namespace.is_empty() && !is_valid_component(&namespace) {
153        bail!("namespace contains invalid characters (allowed: alphanumeric, hyphens, dots, underscores, plus signs)");
154    }
155
156    if !tag.is_empty() && !is_valid_component(&tag) {
157        bail!("tag contains invalid characters (allowed: alphanumeric, hyphens, dots, underscores, plus signs)");
158    }
159
160    // Step 4: Apply defaults
161    let tag = if tag.is_empty() {
162        DEFAULT_TAG.to_string()
163    } else {
164        tag
165    };
166
167    let namespace = if namespace.is_empty() {
168        DEFAULT_NAMESPACE.to_string()
169    } else {
170        namespace
171    };
172
173    Ok(ArtifactAlias {
174        name,
175        namespace,
176        tag,
177    })
178}
179
180impl ConfigServer {
181    pub fn new(store: ConfigContextStore) -> Self {
182        Self { store }
183    }
184}
185
186#[tonic::async_trait]
187impl ContextService for ConfigServer {
188    async fn get_artifact(
189        &self,
190        request: Request<ArtifactRequest>,
191    ) -> Result<Response<Artifact>, Status> {
192        let request = request.into_inner();
193
194        if request.digest.is_empty() {
195            return Err(tonic::Status::invalid_argument("'digest' is required"));
196        }
197
198        let artifact = self.store.artifact.get(request.digest.as_str());
199
200        if artifact.is_none() {
201            return Err(tonic::Status::not_found("artifact not found"));
202        }
203
204        Ok(Response::new(artifact.unwrap().clone()))
205    }
206
207    async fn get_artifacts(
208        &self,
209        _: tonic::Request<ArtifactsRequest>,
210    ) -> Result<tonic::Response<ArtifactsResponse>, tonic::Status> {
211        let mut digests: Vec<String> = self.store.artifact.keys().cloned().collect();
212        digests.sort();
213
214        let response = ArtifactsResponse { digests };
215
216        Ok(Response::new(response))
217    }
218}
219
220pub async fn get_context() -> Result<ConfigContext> {
221    let args = Cli::parse();
222
223    match args.command {
224        Command::Start {
225            agent,
226            artifact,
227            artifact_context,
228            artifact_namespace,
229            artifact_system,
230            artifact_unlock,
231            artifact_variable,
232            port,
233            registry,
234        } => {
235            let client_agent_channel = build_channel(&agent).await?;
236            let client_registry_channel = build_channel(&registry).await?;
237
238            let client_agent = AgentServiceClient::new(client_agent_channel);
239            let client_artifact = ArtifactServiceClient::new(client_registry_channel);
240
241            Ok(ConfigContext::new(
242                artifact,
243                PathBuf::from(artifact_context),
244                artifact_namespace,
245                artifact_system,
246                artifact_unlock,
247                artifact_variable,
248                client_agent,
249                client_artifact,
250                port,
251                registry,
252            )?)
253        }
254    }
255}
256
257impl ConfigContext {
258    #[allow(clippy::too_many_arguments)]
259    pub fn new(
260        artifact: String,
261        artifact_context: PathBuf,
262        artifact_namespace: String,
263        artifact_system: String,
264        artifact_unlock: bool,
265        artifact_variable: Vec<String>,
266        client_agent: AgentServiceClient<Channel>,
267        client_artifact: ArtifactServiceClient<Channel>,
268        port: u16,
269        registry: String,
270    ) -> Result<Self> {
271        Ok(Self {
272            artifact,
273            artifact_context,
274            client_agent,
275            client_artifact,
276            artifact_namespace,
277            port,
278            registry,
279            store: ConfigContextStore {
280                artifact: HashMap::new(),
281                artifact_input_cache: HashMap::new(),
282                variable: artifact_variable
283                    .iter()
284                    .map(|v| {
285                        let mut parts = v.split('=');
286                        let name = parts.next().unwrap_or_default();
287                        let value = parts.next().unwrap_or_default();
288                        (name.to_string(), value.to_string())
289                    })
290                    .collect(),
291            },
292            artifact_system: get_system(&artifact_system)?,
293            artifact_unlock,
294        })
295    }
296
297    pub async fn add_artifact(&mut self, artifact: &Artifact) -> Result<String> {
298        if artifact.name.is_empty() {
299            bail!("name cannot be empty");
300        }
301
302        if artifact.steps.is_empty() {
303            bail!("steps cannot be empty");
304        }
305
306        if artifact.systems.is_empty() {
307            bail!("systems cannot be empty");
308        }
309
310        // Validate target is in systems list
311        if !artifact.systems.contains(&artifact.target) {
312            bail!(
313                "artifact '{}' does not support system '{:?}' (supported: {:?})",
314                artifact.name,
315                ArtifactSystem::try_from(artifact.target).unwrap_or(ArtifactSystem::UnknownSystem),
316                artifact
317                    .systems
318                    .iter()
319                    .filter_map(|&s| ArtifactSystem::try_from(s).ok())
320                    .collect::<Vec<_>>()
321            );
322        }
323
324        // Send raw sources to agent - agent will handle all lockfile operations
325        let artifact_json =
326            serde_json::to_vec(&artifact).expect("failed to serialize artifact to JSON");
327
328        let input_digest = digest(artifact_json.clone());
329
330        if self.store.artifact.contains_key(&input_digest) {
331            return Ok(input_digest);
332        }
333
334        if let Some(output_digest) = self.store.artifact_input_cache.get(&input_digest) {
335            if self.store.artifact.contains_key(output_digest) {
336                return Ok(output_digest.clone());
337            }
338        }
339
340        // TODO: make this run in parallel
341
342        let request = PrepareArtifactRequest {
343            artifact: Some(artifact.clone()),
344            artifact_context: self.artifact_context.display().to_string(),
345            artifact_namespace: self.artifact_namespace.clone(),
346            artifact_unlock: self.artifact_unlock,
347            registry: self.registry.clone(),
348        };
349
350        let mut request = Request::new(request);
351        let request_auth = client_auth_header(&self.registry).await?;
352
353        if let Some(header) = request_auth {
354            request.metadata_mut().insert("authorization", header);
355        }
356
357        let response = self
358            .client_agent
359            .prepare_artifact(request)
360            .await
361            .expect("failed to prepare artifact");
362
363        let mut response = response.into_inner();
364        let mut response_artifact = None;
365        let mut response_artifact_digest = None;
366
367        loop {
368            match response.message().await {
369                Ok(Some(message)) => {
370                    if let Some(artifact_output) = message.artifact_output {
371                        if self.port == 0 {
372                            info!("{} |> {}", artifact.name, artifact_output);
373                        } else {
374                            println!("{} |> {}", artifact.name, artifact_output);
375                        }
376                    }
377
378                    response_artifact = message.artifact;
379                    response_artifact_digest = message.artifact_digest;
380                }
381                Ok(None) => break,
382                Err(status) => {
383                    if status.code() != NotFound {
384                        bail!("{}", status.message());
385                    }
386
387                    break;
388                }
389            }
390        }
391
392        if response_artifact.is_none() {
393            bail!("artifact not returned from agent service");
394        }
395
396        if response_artifact_digest.is_none() {
397            bail!("artifact digest not returned from agent service");
398        }
399
400        let artifact = response_artifact.unwrap();
401        let artifact_digest = response_artifact_digest.unwrap();
402
403        self.store
404            .artifact
405            .insert(artifact_digest.clone(), artifact.clone());
406
407        self.store
408            .artifact_input_cache
409            .insert(input_digest, artifact_digest.clone());
410
411        Ok(artifact_digest)
412    }
413
414    pub async fn fetch_artifact(&mut self, digest: &str) -> Result<String> {
415        self.fetch_artifact_in_namespace(digest, &self.artifact_namespace.clone())
416            .await
417    }
418
419    async fn fetch_artifact_in_namespace(
420        &mut self,
421        digest: &str,
422        namespace: &str,
423    ) -> Result<String> {
424        if self.store.artifact.contains_key(digest) {
425            return Ok(digest.to_string());
426        }
427
428        // TODO: look in lockfile for artifact version
429
430        let request = ArtifactRequest {
431            digest: digest.to_string(),
432            namespace: namespace.to_string(),
433        };
434
435        let mut request = Request::new(request.clone());
436        let request_auth = client_auth_header(&self.registry).await?;
437
438        if let Some(header) = request_auth {
439            request.metadata_mut().insert("authorization", header);
440        }
441
442        match self.client_artifact.get_artifact(request).await {
443            Err(status) => {
444                if status.code() != NotFound {
445                    bail!("artifact service error: {:?}", status);
446                }
447
448                bail!("artifact not found: {}", digest);
449            }
450
451            Ok(response) => {
452                let artifact = response.into_inner();
453
454                self.store
455                    .artifact
456                    .insert(digest.to_string(), artifact.clone());
457
458                for step in artifact.steps.iter() {
459                    for dep in step.artifacts.iter() {
460                        Box::pin(self.fetch_artifact_in_namespace(dep, namespace)).await?;
461                    }
462                }
463
464                Ok(digest.to_string())
465            }
466        }
467    }
468
469    pub async fn fetch_artifact_alias(&mut self, alias: &str) -> Result<String> {
470        let alias_parsed = parse_artifact_alias(alias)?;
471
472        let request = GetArtifactAliasRequest {
473            system: self.artifact_system.into(),
474            name: alias_parsed.name,
475            namespace: alias_parsed.namespace.clone(),
476            tag: alias_parsed.tag,
477        };
478
479        let mut request = Request::new(request);
480        let request_auth = client_auth_header(&self.registry).await?;
481
482        if let Some(header) = request_auth {
483            request.metadata_mut().insert("authorization", header);
484        }
485
486        let response = self
487            .client_artifact
488            .get_artifact_alias(request)
489            .await
490            .map_err(|status| {
491                if status.code() == NotFound {
492                    anyhow!("alias not found in registry: {}", alias)
493                } else {
494                    anyhow!("registry error: {:?}", status)
495                }
496            })?;
497
498        let digest = response.into_inner().digest;
499
500        if digest.is_empty() {
501            bail!("registry returned empty digest for alias: {}", alias);
502        }
503
504        if self.store.artifact.contains_key(&digest) {
505            return Ok(digest);
506        }
507
508        self.fetch_artifact_in_namespace(&digest, &alias_parsed.namespace)
509            .await?;
510
511        Ok(digest)
512    }
513
514    pub fn get_artifact_store(&self) -> HashMap<String, Artifact> {
515        self.store.artifact.clone()
516    }
517
518    pub fn get_artifact(&self, digest: &str) -> Option<Artifact> {
519        self.store.artifact.get(digest).cloned()
520    }
521
522    pub fn get_artifact_context_path(&self) -> &PathBuf {
523        &self.artifact_context
524    }
525
526    pub fn get_artifact_name(&self) -> &str {
527        self.artifact.as_str()
528    }
529
530    pub fn get_artifact_namespace(&self) -> &str {
531        self.artifact_namespace.as_str()
532    }
533
534    pub fn get_system(&self) -> ArtifactSystem {
535        self.artifact_system
536    }
537
538    pub fn get_variable(&self, name: &str) -> Option<String> {
539        self.store.variable.get(name).cloned()
540    }
541
542    pub async fn run(&self) -> Result<()> {
543        let service = ContextServiceServer::new(ConfigServer::new(self.store.clone()));
544
545        let service_addr_str = format!("[::]:{}", self.port);
546        let service_addr = service_addr_str.parse().expect("failed to parse address");
547
548        println!("context service: {service_addr_str}");
549
550        Server::builder()
551            .add_service(service)
552            .serve(service_addr)
553            .await
554            .map_err(|e| anyhow::anyhow!("failed to serve: {}", e))
555    }
556}
557
558pub fn get_root_dir_path() -> PathBuf {
559    Path::new("/var/lib/vorpal").to_path_buf()
560}
561
562pub fn get_root_key_dir_path() -> PathBuf {
563    get_root_dir_path().join("key")
564}
565
566pub fn get_key_ca_path() -> PathBuf {
567    get_root_key_dir_path().join("ca").with_extension("pem")
568}
569
570pub fn get_key_credentials_path() -> PathBuf {
571    get_root_key_dir_path()
572        .join("credentials")
573        .with_extension("json")
574}
575
576async fn get_client_tls_config(uri: &str) -> Result<Option<ClientTlsConfig>> {
577    if uri.starts_with("http://") || uri.starts_with("unix://") {
578        return Ok(None);
579    }
580
581    let ca_pem_path = get_key_ca_path();
582
583    let mut client_tls_config = ClientTlsConfig::new().with_native_roots();
584
585    if ca_pem_path.exists() {
586        let ca_pem = read(&ca_pem_path)
587            .await
588            .with_context(|| format!("failed to read CA certificate: {}", ca_pem_path.display()))?;
589
590        client_tls_config = client_tls_config.ca_certificate(Certificate::from_pem(ca_pem));
591    }
592
593    Ok(Some(client_tls_config))
594}
595
596pub async fn build_channel(uri: &str) -> Result<Channel> {
597    // Handle Unix domain socket connections
598    if uri.starts_with("unix://") {
599        let socket_path = uri.strip_prefix("unix://").unwrap().to_string();
600
601        // Dummy URI required by tonic's channel builder; ignored when using a custom connector.
602        // Uses connect_with_connector_lazy so the channel is created immediately and the
603        // actual connection is deferred until the first RPC call, avoiding startup races
604        // when the client is created before the server socket is ready.
605        let channel = Channel::from_static("http://[::]:50051").connect_with_connector_lazy(
606            tower::service_fn(move |_: tonic::transport::Uri| {
607                let path = socket_path.clone();
608                async move {
609                    Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new(
610                        tokio::net::UnixStream::connect(path).await?,
611                    ))
612                }
613            }),
614        );
615
616        return Ok(channel);
617    }
618
619    if !uri.starts_with("http://") && !uri.starts_with("https://") {
620        bail!("URI must start with http://, https://, or unix://: {}", uri);
621    }
622
623    let parsed_uri = uri
624        .parse::<Uri>()
625        .map_err(|e: InvalidUri| anyhow!("invalid URI: {}", e))?;
626
627    let tls_config = get_client_tls_config(uri).await?;
628
629    let mut endpoint = Channel::builder(parsed_uri);
630
631    if let Some(tls) = tls_config {
632        endpoint = endpoint.tls_config(tls)?;
633    }
634
635    endpoint
636        .connect()
637        .await
638        .with_context(|| format!("failed to connect to {}", uri))
639}
640
641/// Refreshes an expired access token using the refresh token.
642///
643/// Returns `(access_token, expires_in, issued_at, rotated_refresh_token)`.
644/// `rotated_refresh_token` is `Some(new)` when the IdP rotated the refresh
645/// token (Zitadel default), `None` when it did not (caller should keep the
646/// existing refresh token).
647async fn refresh_access_token(
648    audience: Option<&str>,
649    client_id: &str,
650    issuer: &str,
651    refresh_token: &str,
652) -> Result<(String, u64, u64, Option<String>)> {
653    // Discover token endpoint
654    let discovery_url = format!("{}/.well-known/openid-configuration", issuer);
655    let doc: serde_json::Value = reqwest::get(&discovery_url).await?.json().await?;
656
657    let token_endpoint = doc
658        .get("token_endpoint")
659        .and_then(|v| v.as_str())
660        .ok_or_else(|| anyhow!("missing token_endpoint in OIDC discovery"))?;
661
662    // Create OAuth2 client
663    let client = BasicClient::new(ClientId::new(client_id.to_string()))
664        .set_auth_uri(AuthUrl::new(issuer.to_string())?)
665        .set_token_uri(TokenUrl::new(token_endpoint.to_string())?);
666
667    // Exchange refresh token
668    let http_client = reqwest::Client::new();
669    let refresh_token_obj = RefreshToken::new(refresh_token.to_string());
670    let mut request = client.exchange_refresh_token(&refresh_token_obj);
671
672    // Only add audience if provided (Auth0 requires it, others may not)
673    if let Some(aud) = audience {
674        request = request.add_extra_param("audience", aud);
675    }
676
677    let token_result = request.request_async(&http_client).await?;
678
679    let new_access_token = token_result.access_token().secret().to_string();
680    let new_expires_in = token_result
681        .expires_in()
682        .map(|d| d.as_secs())
683        .unwrap_or(3600);
684    let new_refresh_token = normalize_rotated_refresh_token(
685        token_result.refresh_token().map(|t| t.secret().to_string()),
686    );
687
688    let issued_at = std::time::SystemTime::now()
689        .duration_since(std::time::UNIX_EPOCH)?
690        .as_secs();
691
692    Ok((
693        new_access_token,
694        new_expires_in,
695        issued_at,
696        new_refresh_token,
697    ))
698}
699
700/// Normalizes the `refresh_token` field from an OIDC token-refresh response.
701///
702/// Some IdPs send `"refresh_token": ""` in the response body, which the
703/// `oauth2` crate may surface as `Some(RefreshToken(""))`. Treat that as
704/// "not rotated" so callers do not overwrite the stored refresh token with
705/// an empty string. Mirrors the Go and TypeScript SDK behavior.
706fn normalize_rotated_refresh_token(raw: Option<String>) -> Option<String> {
707    raw.filter(|s| !s.is_empty())
708}
709
710/// Applies the result of a token-refresh response to an existing credentials
711/// record. The rotated refresh token, when present, replaces the stored one;
712/// when absent the existing refresh token is left untouched. Other fields
713/// (`audience`, `client_id`, `scopes`) are never modified here.
714fn apply_token_refresh(
715    creds: &mut VorpalCredentialsContent,
716    access_token: String,
717    expires_in: u64,
718    issued_at: u64,
719    rotated_refresh_token: Option<String>,
720) {
721    creds.access_token = access_token;
722    creds.expires_in = expires_in;
723    creds.issued_at = issued_at;
724    if let Some(new) = rotated_refresh_token {
725        creds.refresh_token = new;
726    }
727}
728
729/// Writes credential bytes to `path` enforcing mode 0o600 on file create.
730///
731/// `OpenOptions::mode()` only takes effect when the file is created — if the
732/// file already exists, the existing mode is preserved. Both Rust call sites
733/// for `credentials.json` (login at `cli/src/command.rs` and refresh here)
734/// must use this pattern so the file is born 0o600 and not 0o644 (umask 022).
735async fn write_credentials_secure(path: &Path, bytes: &[u8]) -> Result<()> {
736    let mut file = OpenOptions::new()
737        .write(true)
738        .create(true)
739        .truncate(true)
740        .mode(0o600)
741        .open(path)
742        .await?;
743    file.write_all(bytes).await?;
744    file.flush().await?;
745    Ok(())
746}
747
748pub async fn client_auth_header(registry: &str) -> Result<Option<MetadataValue<Ascii>>> {
749    let credentials_path = get_key_credentials_path();
750
751    if !credentials_path.exists() {
752        return Ok(None);
753    }
754
755    let credentials_data = read(&credentials_path).await?;
756    let mut credentials: VorpalCredentials = serde_json::from_slice(&credentials_data)?;
757
758    let registry_issuer = match credentials.registry.get(registry) {
759        Some(issuer) => issuer.clone(),
760        None => return Ok(None),
761    };
762
763    // Check if token needs refresh
764    let now = std::time::SystemTime::now()
765        .duration_since(std::time::UNIX_EPOCH)?
766        .as_secs();
767
768    let needs_refresh = {
769        let issuer_creds = credentials
770            .issuer
771            .get(&registry_issuer)
772            .ok_or_else(|| anyhow!("no credentials for issuer: {}", registry_issuer))?;
773
774        let token_age = now - issuer_creds.issued_at;
775        let expires_in = issuer_creds.expires_in;
776
777        // Refresh if token has less than 5 minutes left
778        token_age + 300 >= expires_in
779    };
780
781    if needs_refresh {
782        // Clone values needed for refresh
783        let (audience, client_id, refresh_token) = {
784            let issuer_creds = credentials
785                .issuer
786                .get(&registry_issuer)
787                .ok_or_else(|| anyhow!("no credentials for issuer: {}", registry_issuer))?;
788            (
789                issuer_creds.audience.clone(),
790                issuer_creds.client_id.clone(),
791                issuer_creds.refresh_token.clone(),
792            )
793        };
794
795        // Skip refresh if no refresh token available (user must re-login)
796        if refresh_token.is_empty() {
797            return Err(anyhow!(
798                "Access token expired and no refresh token available. Please run: vorpal login --issuer {}",
799                registry_issuer
800            ));
801        }
802
803        let (new_token, new_expires, new_issued_at, rotated_refresh) = refresh_access_token(
804            audience.as_deref(),
805            &client_id,
806            &registry_issuer,
807            &refresh_token,
808        )
809        .await?;
810
811        // Now update the credentials
812        let issuer_creds = credentials
813            .issuer
814            .get_mut(&registry_issuer)
815            .ok_or_else(|| anyhow!("no credentials for issuer: {}", registry_issuer))?;
816
817        apply_token_refresh(
818            issuer_creds,
819            new_token,
820            new_expires,
821            new_issued_at,
822            rotated_refresh,
823        );
824
825        // Save updated credentials with mode 0o600 enforced on create.
826        let credentials_json = serde_json::to_string_pretty(&credentials)?;
827        write_credentials_secure(&credentials_path, credentials_json.as_bytes()).await?;
828    }
829
830    // Get the access token
831    let access_token = credentials
832        .issuer
833        .get(&registry_issuer)
834        .ok_or_else(|| anyhow!("no credentials for issuer: {}", registry_issuer))?
835        .access_token
836        .clone();
837
838    let header = format!("Bearer {}", access_token)
839        .parse()
840        .map_err(|e| anyhow!("failed to parse Bearer token: {}", e))?;
841
842    Ok(Some(header))
843}
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848
849    fn sample_creds() -> VorpalCredentialsContent {
850        VorpalCredentialsContent {
851            access_token: "old-access".to_string(),
852            audience: Some("aud-1".to_string()),
853            client_id: "client-1".to_string(),
854            expires_in: 3600,
855            issued_at: 1_700_000_000,
856            refresh_token: "old-refresh".to_string(),
857            scopes: vec!["openid".to_string(), "offline_access".to_string()],
858        }
859    }
860
861    #[test]
862    fn apply_token_refresh_replaces_refresh_when_rotated() {
863        let mut creds = sample_creds();
864
865        apply_token_refresh(
866            &mut creds,
867            "new-access".to_string(),
868            7200,
869            1_700_000_500,
870            Some("rotated-refresh".to_string()),
871        );
872
873        assert_eq!(creds.access_token, "new-access");
874        assert_eq!(creds.expires_in, 7200);
875        assert_eq!(creds.issued_at, 1_700_000_500);
876        assert_eq!(creds.refresh_token, "rotated-refresh");
877        assert_eq!(creds.audience.as_deref(), Some("aud-1"));
878        assert_eq!(creds.client_id, "client-1");
879        assert_eq!(creds.scopes, vec!["openid", "offline_access"]);
880    }
881
882    #[test]
883    fn apply_token_refresh_keeps_refresh_when_not_rotated() {
884        let mut creds = sample_creds();
885
886        apply_token_refresh(
887            &mut creds,
888            "new-access".to_string(),
889            7200,
890            1_700_000_500,
891            None,
892        );
893
894        assert_eq!(creds.access_token, "new-access");
895        assert_eq!(creds.expires_in, 7200);
896        assert_eq!(creds.issued_at, 1_700_000_500);
897        assert_eq!(creds.refresh_token, "old-refresh");
898        assert_eq!(creds.audience.as_deref(), Some("aud-1"));
899        assert_eq!(creds.client_id, "client-1");
900        assert_eq!(creds.scopes, vec!["openid", "offline_access"]);
901    }
902
903    #[test]
904    fn normalize_rotated_refresh_token_some_nonempty_passes_through() {
905        assert_eq!(
906            normalize_rotated_refresh_token(Some("rotated-refresh".to_string())),
907            Some("rotated-refresh".to_string())
908        );
909    }
910
911    #[test]
912    fn normalize_rotated_refresh_token_some_empty_becomes_none() {
913        assert_eq!(
914            normalize_rotated_refresh_token(Some(String::new())),
915            None,
916            "empty-string refresh_token must be treated as not-rotated for parity with Go/TS"
917        );
918    }
919
920    #[test]
921    fn normalize_rotated_refresh_token_none_passes_through() {
922        assert_eq!(normalize_rotated_refresh_token(None), None);
923    }
924
925    #[test]
926    fn write_credentials_secure_creates_file_with_mode_0o600() {
927        use std::os::unix::fs::PermissionsExt;
928
929        let dir = std::env::temp_dir().join(format!(
930            "vorpal-creds-mode-test-{}-{}",
931            std::process::id(),
932            std::time::SystemTime::now()
933                .duration_since(std::time::UNIX_EPOCH)
934                .unwrap()
935                .as_nanos()
936        ));
937        std::fs::create_dir_all(&dir).expect("create temp dir");
938        let path = dir.join("credentials.json");
939        // Sanity: the path must not pre-exist — we are testing file birth, not
940        // an inherited mode from a pre-created 0o600 file.
941        assert!(!path.exists(), "test path must be previously-nonexistent");
942
943        let runtime = tokio::runtime::Builder::new_current_thread()
944            .enable_all()
945            .build()
946            .expect("build runtime");
947        runtime
948            .block_on(write_credentials_secure(&path, b"{\"hello\":\"world\"}"))
949            .expect("write credentials");
950
951        let mode = std::fs::metadata(&path)
952            .expect("stat credentials")
953            .permissions()
954            .mode();
955        assert_eq!(
956            mode & 0o777,
957            0o600,
958            "credentials file must be born 0o600, got {:o}",
959            mode & 0o777
960        );
961
962        let _ = std::fs::remove_file(&path);
963        let _ = std::fs::remove_dir(&dir);
964    }
965
966    #[test]
967    fn apply_token_refresh_persists_through_serde_roundtrip() {
968        let mut creds = sample_creds();
969
970        apply_token_refresh(
971            &mut creds,
972            "new-access".to_string(),
973            7200,
974            1_700_000_500,
975            Some("rotated-refresh".to_string()),
976        );
977
978        let json = serde_json::to_string(&creds).expect("serialize");
979        let parsed: VorpalCredentialsContent = serde_json::from_str(&json).expect("deserialize");
980
981        assert_eq!(parsed.access_token, "new-access");
982        assert_eq!(parsed.refresh_token, "rotated-refresh");
983        assert_eq!(parsed.expires_in, 7200);
984        assert_eq!(parsed.issued_at, 1_700_000_500);
985        assert_eq!(parsed.audience.as_deref(), Some("aud-1"));
986        assert_eq!(parsed.client_id, "client-1");
987        assert_eq!(parsed.scopes, vec!["openid", "offline_access"]);
988    }
989}