Skip to main content

self_agent_sdk/
registration_flow.rs

1// SPDX-FileCopyrightText: 2025-2026 Social Connect Labs, Inc.
2// SPDX-License-Identifier: BUSL-1.1
3// NOTE: Converts to Apache-2.0 on 2029-06-11 per LICENSE.
4
5//! REST-based registration and deregistration flow for AI agents.
6//!
7//! # Example
8//!
9//! ```no_run
10//! use self_agent_sdk::registration_flow::*;
11//!
12//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
13//! # tokio::runtime::Runtime::new()?.block_on(async {
14//! let session = RegistrationSession::request(RegistrationRequest {
15//!     mode: "linked".into(),
16//!     network: "mainnet".into(),
17//!     ..Default::default()
18//! }, None).await?;
19//!
20//! println!("QR: {}", session.qr_url);
21//! println!("Instructions: {:?}", session.human_instructions);
22//!
23//! let result = session.wait_for_completion(None, None).await?;
24//! println!("Agent ID: {}", result.agent_id);
25//! # Ok::<(), Box<dyn std::error::Error>>(())
26//! # })?;
27//! # Ok(())
28//! # }
29//! ```
30
31use reqwest::Client;
32use serde::Serialize;
33use std::time::{Duration, Instant};
34
35/// Default API base URL (overridden by `SELF_AGENT_API_BASE` when set).
36pub const DEFAULT_API_BASE: &str = "https://self-agent-id.vercel.app";
37
38/// Default polling timeout (30 minutes).
39pub const DEFAULT_TIMEOUT_MS: u64 = 30 * 60 * 1000;
40
41/// Default polling interval (5 seconds).
42pub const DEFAULT_POLL_INTERVAL_MS: u64 = 5000;
43
44fn resolve_api_base(api_base: Option<&str>) -> String {
45    if let Some(base) = api_base {
46        return base.to_string();
47    }
48    if let Ok(base) = std::env::var("SELF_AGENT_API_BASE") {
49        let trimmed = base.trim();
50        if !trimmed.is_empty() {
51            return trimmed.to_string();
52        }
53    }
54    DEFAULT_API_BASE.to_string()
55}
56
57/// Errors specific to the registration flow.
58#[derive(Debug, thiserror::Error)]
59pub enum RegistrationError {
60    #[error("session expired — call request_registration() again")]
61    ExpiredSession,
62    #[error("registration failed: {0}")]
63    Failed(String),
64    #[error("registration timed out")]
65    Timeout,
66    #[error("HTTP error: {0}")]
67    Http(String),
68    #[error("API error: {0}")]
69    Api(String),
70}
71
72/// Request payload for initiating a registration.
73#[derive(Debug, Clone, Serialize, Default)]
74#[serde(rename_all = "camelCase")]
75pub struct RegistrationRequest {
76    pub mode: String,
77    pub network: String,
78    #[serde(default)]
79    pub disclosures: serde_json::Value,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub human_address: Option<String>,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub agent_name: Option<String>,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub agent_description: Option<String>,
86}
87
88/// Request payload for initiating a deregistration.
89#[derive(Debug, Clone, Serialize)]
90#[serde(rename_all = "camelCase")]
91pub struct DeregistrationRequest {
92    pub network: String,
93    pub agent_address: String,
94}
95
96/// Successful registration result.
97#[derive(Debug, Clone)]
98pub struct RegistrationResult {
99    pub agent_id: u64,
100    pub agent_address: String,
101    pub credentials: Option<serde_json::Value>,
102    pub tx_hash: Option<String>,
103}
104
105/// An in-progress registration session.
106#[derive(Debug, Clone)]
107pub struct RegistrationSession {
108    pub session_token: String,
109    pub stage: String,
110    pub qr_url: String,
111    pub deep_link: String,
112    pub agent_address: String,
113    pub expires_at: String,
114    pub time_remaining_ms: u64,
115    pub human_instructions: Vec<String>,
116    api_base: String,
117    http: Client,
118}
119
120impl RegistrationSession {
121    /// Initiate a registration via the REST API.
122    pub async fn request(
123        req: RegistrationRequest,
124        api_base: Option<&str>,
125    ) -> Result<Self, RegistrationError> {
126        let base = resolve_api_base(api_base);
127        let http = Client::new();
128        let resp = http
129            .post(format!("{}/api/agent/register", base))
130            .json(&req)
131            .send()
132            .await
133            .map_err(|e| RegistrationError::Http(e.to_string()))?;
134
135        let data: serde_json::Value = resp
136            .json()
137            .await
138            .map_err(|e| RegistrationError::Http(e.to_string()))?;
139
140        if let Some(err) = data.get("error").and_then(|v| v.as_str()) {
141            return Err(RegistrationError::Api(err.to_string()));
142        }
143
144        Ok(Self {
145            session_token: json_str(&data, "sessionToken"),
146            stage: json_str(&data, "stage"),
147            qr_url: json_str(&data, "qrUrl"),
148            deep_link: json_str(&data, "deepLink"),
149            agent_address: json_str(&data, "agentAddress"),
150            expires_at: json_str(&data, "expiresAt"),
151            time_remaining_ms: data
152                .get("timeRemainingMs")
153                .and_then(|v| v.as_u64())
154                .unwrap_or(0),
155            human_instructions: json_str_array(&data, "humanInstructions"),
156            api_base: base,
157            http,
158        })
159    }
160
161    /// Poll until registration completes or times out.
162    pub async fn wait_for_completion(
163        &self,
164        timeout_ms: Option<u64>,
165        poll_interval_ms: Option<u64>,
166    ) -> Result<RegistrationResult, RegistrationError> {
167        let timeout = Duration::from_millis(timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS));
168        let interval = Duration::from_millis(poll_interval_ms.unwrap_or(DEFAULT_POLL_INTERVAL_MS));
169        let deadline = Instant::now() + timeout;
170        let mut token = self.session_token.clone();
171
172        while Instant::now() < deadline {
173            let resp = self
174                .http
175                .get(format!("{}/api/agent/register/status", self.api_base))
176                .query(&[("token", &token)])
177                .send()
178                .await
179                .map_err(|e| RegistrationError::Http(e.to_string()))?;
180
181            let data: serde_json::Value = resp
182                .json()
183                .await
184                .map_err(|e| RegistrationError::Http(e.to_string()))?;
185
186            if let Some(err) = data.get("error").and_then(|v| v.as_str()) {
187                if err.to_lowercase().contains("expired") {
188                    return Err(RegistrationError::ExpiredSession);
189                }
190            }
191
192            let stage = json_str(&data, "stage");
193            if let Some(t) = data.get("sessionToken").and_then(|v| v.as_str()) {
194                token = t.to_string();
195            }
196
197            match stage.as_str() {
198                "completed" => {
199                    return Ok(RegistrationResult {
200                        agent_id: data
201                            .get("agentId")
202                            .and_then(|v| v.as_u64())
203                            .unwrap_or(0),
204                        agent_address: json_str(&data, "agentAddress"),
205                        credentials: data.get("credentials").cloned(),
206                        tx_hash: data.get("txHash").and_then(|v| v.as_str()).map(String::from),
207                    });
208                }
209                "failed" => {
210                    let err = json_str(&data, "error");
211                    return Err(RegistrationError::Failed(
212                        if err.is_empty() { "Registration failed".into() } else { err },
213                    ));
214                }
215                "expired" => return Err(RegistrationError::ExpiredSession),
216                _ => {}
217            }
218
219            tokio::time::sleep(interval).await;
220        }
221
222        Err(RegistrationError::Timeout)
223    }
224
225    /// Export the agent private key generated during registration.
226    ///
227    /// Only available for modes that created a new keypair (e.g. linked).
228    pub async fn export_key(&self) -> Result<String, RegistrationError> {
229        let resp = self
230            .http
231            .post(format!("{}/api/agent/register/export", self.api_base))
232            .json(&serde_json::json!({ "token": self.session_token }))
233            .send()
234            .await
235            .map_err(|e| RegistrationError::Http(e.to_string()))?;
236
237        let data: serde_json::Value = resp
238            .json()
239            .await
240            .map_err(|e| RegistrationError::Http(e.to_string()))?;
241
242        if let Some(err) = data.get("error").and_then(|v| v.as_str()) {
243            return Err(RegistrationError::Api(err.to_string()));
244        }
245
246        Ok(json_str(&data, "privateKey"))
247    }
248}
249
250/// An in-progress deregistration session.
251#[derive(Debug, Clone)]
252pub struct DeregistrationSession {
253    pub session_token: String,
254    pub stage: String,
255    pub qr_url: String,
256    pub deep_link: String,
257    pub expires_at: String,
258    pub time_remaining_ms: u64,
259    pub human_instructions: Vec<String>,
260    api_base: String,
261    http: Client,
262}
263
264impl DeregistrationSession {
265    /// Initiate a deregistration via the REST API.
266    pub async fn request(
267        req: DeregistrationRequest,
268        api_base: Option<&str>,
269    ) -> Result<Self, RegistrationError> {
270        let base = resolve_api_base(api_base);
271        let http = Client::new();
272        let resp = http
273            .post(format!("{}/api/agent/deregister", base))
274            .json(&req)
275            .send()
276            .await
277            .map_err(|e| RegistrationError::Http(e.to_string()))?;
278
279        let data: serde_json::Value = resp
280            .json()
281            .await
282            .map_err(|e| RegistrationError::Http(e.to_string()))?;
283
284        if let Some(err) = data.get("error").and_then(|v| v.as_str()) {
285            return Err(RegistrationError::Api(err.to_string()));
286        }
287
288        Ok(Self {
289            session_token: json_str(&data, "sessionToken"),
290            stage: json_str(&data, "stage"),
291            qr_url: json_str(&data, "qrUrl"),
292            deep_link: json_str(&data, "deepLink"),
293            expires_at: json_str(&data, "expiresAt"),
294            time_remaining_ms: data
295                .get("timeRemainingMs")
296                .and_then(|v| v.as_u64())
297                .unwrap_or(0),
298            human_instructions: json_str_array(&data, "humanInstructions"),
299            api_base: base,
300            http,
301        })
302    }
303
304    /// Poll until deregistration completes or times out.
305    pub async fn wait_for_completion(
306        &self,
307        timeout_ms: Option<u64>,
308        poll_interval_ms: Option<u64>,
309    ) -> Result<(), RegistrationError> {
310        let timeout = Duration::from_millis(timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS));
311        let interval = Duration::from_millis(poll_interval_ms.unwrap_or(DEFAULT_POLL_INTERVAL_MS));
312        let deadline = Instant::now() + timeout;
313        let mut token = self.session_token.clone();
314
315        while Instant::now() < deadline {
316            let resp = self
317                .http
318                .get(format!("{}/api/agent/deregister/status", self.api_base))
319                .query(&[("token", &token)])
320                .send()
321                .await
322                .map_err(|e| RegistrationError::Http(e.to_string()))?;
323
324            let data: serde_json::Value = resp
325                .json()
326                .await
327                .map_err(|e| RegistrationError::Http(e.to_string()))?;
328
329            let stage = json_str(&data, "stage");
330            if let Some(t) = data.get("sessionToken").and_then(|v| v.as_str()) {
331                token = t.to_string();
332            }
333
334            match stage.as_str() {
335                "completed" => return Ok(()),
336                "failed" => {
337                    let err = json_str(&data, "error");
338                    return Err(RegistrationError::Failed(
339                        if err.is_empty() { "Deregistration failed".into() } else { err },
340                    ));
341                }
342                "expired" => {
343                    return Err(RegistrationError::Failed(
344                        "Deregistration session expired".into(),
345                    ))
346                }
347                _ => {}
348            }
349
350            tokio::time::sleep(interval).await;
351        }
352
353        Err(RegistrationError::Timeout)
354    }
355}
356
357// ---------------------------------------------------------------------------
358// Proof Refresh
359// ---------------------------------------------------------------------------
360
361/// Request payload for initiating a proof refresh.
362#[derive(Debug, Clone, Serialize)]
363#[serde(rename_all = "camelCase")]
364pub struct ProofRefreshRequest {
365    /// Agent ID (token ID) to refresh the proof for.
366    pub agent_id: u64,
367    /// Network: "mainnet" (default) or "testnet".
368    pub network: String,
369    /// Credential disclosures to request (should match original registration).
370    #[serde(default, skip_serializing_if = "Option::is_none")]
371    pub disclosures: Option<serde_json::Value>,
372}
373
374/// Successful proof refresh result.
375#[derive(Debug, Clone)]
376pub struct ProofRefreshResult {
377    /// Unix timestamp (seconds) when the new proof expires.
378    pub proof_expires_at: u64,
379}
380
381/// An in-progress proof refresh session.
382#[derive(Debug, Clone)]
383pub struct RefreshSession {
384    pub session_token: String,
385    pub stage: String,
386    pub deep_link: String,
387    pub expires_at: String,
388    pub time_remaining_ms: u64,
389    pub human_instructions: Vec<String>,
390    api_base: String,
391    http: Client,
392}
393
394impl RefreshSession {
395    /// Initiate a proof refresh via the REST API.
396    pub async fn request(
397        req: ProofRefreshRequest,
398        api_base: Option<&str>,
399    ) -> Result<Self, RegistrationError> {
400        let base = resolve_api_base(api_base);
401        let http = Client::new();
402        let resp = http
403            .post(format!("{}/api/agent/refresh", base))
404            .json(&req)
405            .send()
406            .await
407            .map_err(|e| RegistrationError::Http(e.to_string()))?;
408
409        let data: serde_json::Value = resp
410            .json()
411            .await
412            .map_err(|e| RegistrationError::Http(e.to_string()))?;
413
414        if let Some(err) = data.get("error").and_then(|v| v.as_str()) {
415            return Err(RegistrationError::Api(err.to_string()));
416        }
417
418        Ok(Self {
419            session_token: json_str(&data, "sessionToken"),
420            stage: json_str(&data, "stage"),
421            deep_link: json_str(&data, "deepLink"),
422            expires_at: json_str(&data, "expiresAt"),
423            time_remaining_ms: data
424                .get("timeRemainingMs")
425                .and_then(|v| v.as_u64())
426                .unwrap_or(0),
427            human_instructions: json_str_array(&data, "humanInstructions"),
428            api_base: base,
429            http,
430        })
431    }
432
433    /// Poll until proof refresh completes or times out.
434    pub async fn wait_for_completion(
435        &self,
436        timeout_ms: Option<u64>,
437        poll_interval_ms: Option<u64>,
438    ) -> Result<ProofRefreshResult, RegistrationError> {
439        let timeout = Duration::from_millis(timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS));
440        let interval = Duration::from_millis(poll_interval_ms.unwrap_or(DEFAULT_POLL_INTERVAL_MS));
441        let deadline = Instant::now() + timeout;
442        let mut token = self.session_token.clone();
443
444        while Instant::now() < deadline {
445            let resp = self
446                .http
447                .get(format!("{}/api/agent/refresh/status", self.api_base))
448                .query(&[("token", &token)])
449                .send()
450                .await
451                .map_err(|e| RegistrationError::Http(e.to_string()))?;
452
453            let data: serde_json::Value = resp
454                .json()
455                .await
456                .map_err(|e| RegistrationError::Http(e.to_string()))?;
457
458            if let Some(err) = data.get("error").and_then(|v| v.as_str()) {
459                if err.to_lowercase().contains("expired") {
460                    return Err(RegistrationError::ExpiredSession);
461                }
462            }
463
464            let stage = json_str(&data, "stage");
465            if let Some(t) = data.get("sessionToken").and_then(|v| v.as_str()) {
466                token = t.to_string();
467            }
468
469            match stage.as_str() {
470                "completed" => {
471                    // The status response may include proofExpiresAt as a unix
472                    // timestamp (number) or an ISO date string. Try number first,
473                    // then fall back to string-as-number, then 1 year default.
474                    let proof_expires_at = data
475                        .get("proofExpiresAt")
476                        .and_then(|v| {
477                            v.as_u64().or_else(|| {
478                                v.as_str().and_then(|s| s.parse::<u64>().ok())
479                            })
480                        })
481                        // Fallback: 1 year from now
482                        .unwrap_or_else(|| {
483                            std::time::SystemTime::now()
484                                .duration_since(std::time::UNIX_EPOCH)
485                                .expect("system clock before UNIX epoch")
486                                .as_secs()
487                                + 365 * 24 * 60 * 60
488                        });
489                    return Ok(ProofRefreshResult { proof_expires_at });
490                }
491                "failed" => {
492                    return Err(RegistrationError::Failed(
493                        "Proof refresh failed on-chain".into(),
494                    ));
495                }
496                "expired" => return Err(RegistrationError::ExpiredSession),
497                _ => {}
498            }
499
500            tokio::time::sleep(interval).await;
501        }
502
503        Err(RegistrationError::Timeout)
504    }
505}
506
507/// Initiate a proof refresh for an existing agent through the Self Agent ID REST API.
508///
509/// Returns a session object with a deep link for the human to scan in the Self app,
510/// and a polling method to wait for the new proof to be recorded on-chain.
511pub async fn request_proof_refresh(
512    req: ProofRefreshRequest,
513    api_base: Option<&str>,
514) -> Result<RefreshSession, RegistrationError> {
515    RefreshSession::request(req, api_base).await
516}
517
518/// Helper: extract a string from a JSON value, defaulting to empty string.
519fn json_str(data: &serde_json::Value, key: &str) -> String {
520    data.get(key)
521        .and_then(|v| v.as_str())
522        .unwrap_or("")
523        .to_string()
524}
525
526/// Helper: extract a string array from a JSON value.
527fn json_str_array(data: &serde_json::Value, key: &str) -> Vec<String> {
528    data.get(key)
529        .and_then(|v| v.as_array())
530        .map(|arr| {
531            arr.iter()
532                .filter_map(|v| v.as_str().map(String::from))
533                .collect()
534        })
535        .unwrap_or_default()
536}