Skip to main content

stack_auth/
device_client.rs

1//! Post-login device client provisioning.
2//!
3//! After a device-code login, the caller must create a client in ZeroKMS and
4//! persist the resulting secret key to disk. This module provides the
5//! orchestration logic so that any consumer (not just the CLI) can perform
6//! this step.
7
8use stack_profile::{DeviceIdentity, ProfileStore};
9use uuid::Uuid;
10use zerokms_protocol::{CreateClientRequest, CreateClientResponse, ViturKeyMaterial, ViturRequest};
11
12use crate::{ensure_trailing_slash, http_client, ServiceToken, Token};
13
14fn user_agent() -> String {
15    format!(
16        "stack-auth/{} ({} {})",
17        env!("CARGO_PKG_VERSION"),
18        std::env::consts::OS,
19        std::env::consts::ARCH,
20    )
21}
22
23// ---------------------------------------------------------------------------
24// Secret key file (output)
25// ---------------------------------------------------------------------------
26
27const SECRET_KEY_FILENAME: &str = "secretkey.json";
28const SECRET_KEY_MODE: u32 = 0o600;
29
30/// The on-disk shape of `secretkey.json`.
31///
32/// Must stay in sync with `cipherstash_client::zerokms::SecretKey` which
33/// deserializes this file. If that type moves to a shared crate, replace
34/// this with a re-export.
35#[derive(serde::Serialize)]
36struct SecretKeyFile {
37    client_id: Uuid,
38    client_key: ViturKeyMaterial,
39}
40
41// ---------------------------------------------------------------------------
42// Error type
43// ---------------------------------------------------------------------------
44
45/// Errors that can occur during device client provisioning.
46#[derive(Debug, thiserror::Error)]
47pub enum DeviceClientError {
48    /// The profile store could not load or create required data.
49    #[error("Profile error: {0}")]
50    Profile(#[from] stack_profile::ProfileError),
51
52    /// Authentication token could not be loaded or decoded.
53    #[error("Auth error: {0}")]
54    Auth(#[from] crate::AuthError),
55
56    /// The HTTP request to ZeroKMS failed.
57    #[error("ZeroKMS request failed: {0}")]
58    Request(#[from] reqwest::Error),
59
60    /// ZeroKMS returned a non-success, non-conflict status.
61    #[error("ZeroKMS returned {status}: {body}")]
62    Server { status: u16, body: String },
63
64    /// Failed to construct the ZeroKMS endpoint URL.
65    #[error("Invalid ZeroKMS URL: {0}")]
66    InvalidUrl(#[from] url::ParseError),
67}
68
69// ---------------------------------------------------------------------------
70// Public API
71// ---------------------------------------------------------------------------
72
73/// Provision a device client after login.
74///
75/// Loads the auth token and device identity from disk, creates a client in
76/// ZeroKMS (on the workspace's default keyset), and persists the resulting
77/// secret key to the profile store.
78///
79/// If the secret key already exists on disk, or the server returns 409
80/// (conflict), this is a no-op.
81pub async fn bind_client_device(store: &ProfileStore) -> Result<(), DeviceClientError> {
82    if store.exists(SECRET_KEY_FILENAME) {
83        tracing::debug!("secret key already exists, skipping provisioning");
84        return Ok(());
85    }
86
87    let token: Token = store.load_profile()?;
88    let service_token = ServiceToken::new(token.access_token().clone());
89    let zerokms_url = ensure_trailing_slash(service_token.zerokms_url()?);
90
91    let identity = DeviceIdentity::load_or_create(store)?;
92
93    let request = CreateClientRequest {
94        keyset_id: None,
95        name: (&identity.device_name).into(),
96        description: (&identity.device_name).into(),
97    };
98
99    let url = zerokms_url.join(CreateClientRequest::ENDPOINT)?;
100
101    let response = http_client()
102        .post(url)
103        .header(reqwest::header::USER_AGENT, user_agent())
104        .bearer_auth(service_token.as_str())
105        .json(&request)
106        .send()
107        .await?;
108
109    let status = response.status();
110
111    if status == reqwest::StatusCode::CONFLICT {
112        // Another client was already provisioned server-side.
113        tracing::debug!("device client already exists, skipping");
114        return Ok(());
115    }
116
117    if !status.is_success() {
118        let body = response.text().await.unwrap_or_default();
119        return Err(DeviceClientError::Server {
120            status: status.as_u16(),
121            body,
122        });
123    }
124
125    let created: CreateClientResponse = response.json().await?;
126
127    let secret_key = SecretKeyFile {
128        client_id: created.id,
129        client_key: created.client_key,
130    };
131
132    store.save_with_mode(SECRET_KEY_FILENAME, &secret_key, SECRET_KEY_MODE)?;
133
134    Ok(())
135}
136
137// ---------------------------------------------------------------------------
138// Tests
139// ---------------------------------------------------------------------------
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use crate::SecretToken;
145    use mocktail::prelude::*;
146    use tempfile::TempDir;
147
148    fn make_test_jwt(zerokms_url: impl std::fmt::Display) -> String {
149        use jsonwebtoken::{encode, EncodingKey, Header};
150        use std::time::{SystemTime, UNIX_EPOCH};
151
152        let zerokms_url = zerokms_url.to_string();
153        let now = SystemTime::now()
154            .duration_since(UNIX_EPOCH)
155            .unwrap()
156            .as_secs();
157
158        let claims = serde_json::json!({
159            "iss": "https://cts.example.com/",
160            "sub": "CS|test-user",
161            "aud": "legacy-aud-value",
162            "iat": now,
163            "exp": now + 3600,
164            "workspace": "ZVATKW3VHMFG27DY",
165            "scope": "",
166            "services": {
167                "zerokms": zerokms_url,
168            },
169        });
170
171        encode(
172            &Header::default(),
173            &claims,
174            &EncodingKey::from_secret(b"test-secret"),
175        )
176        .unwrap()
177    }
178
179    fn save_test_token(store: &ProfileStore, access_token: &str) {
180        use std::time::{SystemTime, UNIX_EPOCH};
181
182        let now = SystemTime::now()
183            .duration_since(UNIX_EPOCH)
184            .unwrap()
185            .as_secs();
186
187        let token = Token {
188            access_token: SecretToken::new(access_token),
189            refresh_token: None,
190            token_type: "Bearer".into(),
191            expires_at: now + 3600,
192            region: None,
193            client_id: None,
194            device_instance_id: None,
195        };
196        store.save_profile(&token).unwrap();
197    }
198
199    fn client_response_json() -> serde_json::Value {
200        serde_json::json!({
201            "id": "00000000-0000-0000-0000-000000000001",
202            "dataset_id": "00000000-0000-0000-0000-000000000099",
203            "name": "test-device",
204            "description": "test-device",
205            "client_key": "dGVzdC1rZXktbWF0ZXJpYWw="
206        })
207    }
208
209    async fn start_server(mocks: MockSet) -> MockServer {
210        let server = MockServer::new_http("device-client-test").with_mocks(mocks);
211        server.start().await.unwrap();
212        server
213    }
214
215    #[tokio::test]
216    async fn provisions_and_saves_secret_key() {
217        let dir = TempDir::new().unwrap();
218        let store = ProfileStore::new(dir.path());
219
220        let mut mocks = MockSet::new();
221        mocks.mock(|when, then| {
222            when.post().path("/create-client");
223            then.json(client_response_json());
224        });
225        let server = start_server(mocks).await;
226
227        let jwt = make_test_jwt(server.url("/"));
228        save_test_token(&store, &jwt);
229
230        bind_client_device(&store).await.unwrap();
231
232        let saved: serde_json::Value = store.load(SECRET_KEY_FILENAME).unwrap();
233        assert_eq!(saved["client_id"], "00000000-0000-0000-0000-000000000001");
234        assert_eq!(saved["client_key"], "dGVzdC1rZXktbWF0ZXJpYWw=");
235    }
236
237    #[tokio::test]
238    async fn skips_when_secret_key_exists() {
239        let dir = TempDir::new().unwrap();
240        let store = ProfileStore::new(dir.path());
241
242        // Pre-populate secretkey.json
243        store
244            .save_with_mode(
245                SECRET_KEY_FILENAME,
246                &serde_json::json!({"client_id": "old", "client_key": "old"}),
247                SECRET_KEY_MODE,
248            )
249            .unwrap();
250
251        // No mock server needed — the HTTP call should never happen.
252        bind_client_device(&store).await.unwrap();
253
254        let saved: serde_json::Value = store.load(SECRET_KEY_FILENAME).unwrap();
255        assert_eq!(
256            saved["client_id"], "old",
257            "should not overwrite existing key"
258        );
259    }
260
261    #[tokio::test]
262    async fn no_op_on_conflict() {
263        let dir = TempDir::new().unwrap();
264        let store = ProfileStore::new(dir.path());
265
266        let mut mocks = MockSet::new();
267        mocks.mock(|when, then| {
268            when.post().path("/create-client");
269            then.status(reqwest::StatusCode::CONFLICT)
270                .json(serde_json::json!({"error": "conflict"}));
271        });
272        let server = start_server(mocks).await;
273
274        let jwt = make_test_jwt(server.url("/"));
275        save_test_token(&store, &jwt);
276
277        bind_client_device(&store).await.unwrap();
278
279        assert!(
280            !store.exists(SECRET_KEY_FILENAME),
281            "should not write secret key on conflict"
282        );
283    }
284
285    #[tokio::test]
286    async fn returns_error_on_server_failure() {
287        let dir = TempDir::new().unwrap();
288        let store = ProfileStore::new(dir.path());
289
290        let mut mocks = MockSet::new();
291        mocks.mock(|when, then| {
292            when.post().path("/create-client");
293            then.status(reqwest::StatusCode::INTERNAL_SERVER_ERROR)
294                .json(serde_json::json!({"error": "internal error"}));
295        });
296        let server = start_server(mocks).await;
297
298        let jwt = make_test_jwt(server.url("/"));
299        save_test_token(&store, &jwt);
300
301        let err = bind_client_device(&store).await.unwrap_err();
302        assert!(
303            matches!(err, DeviceClientError::Server { status: 500, .. }),
304            "expected Server error, got: {err:?}"
305        );
306    }
307}