Skip to main content

systemprompt_users/repository/
federated_identity.rs

1//! Repository for `federated_identities` — the `{issuer, external_sub} ->
2//! users.id` mapping used by RFC 8693 token-exchange first-touch.
3
4use chrono::Utc;
5use sqlx::Acquire;
6use systemprompt_identifiers::UserId;
7use systemprompt_traits::FederatedIdentityClaims;
8
9use crate::error::Result;
10use crate::models::{User, UserRole, UserStatus};
11use crate::repository::UserRepository;
12
13impl UserRepository {
14    /// Look up the local `UserId` for an external `(issuer, external_sub)`
15    /// without side effects. Returns `Ok(None)` if no mapping exists yet.
16    pub async fn find_federated(&self, issuer: &str, external_sub: &str) -> Result<Option<UserId>> {
17        let row = sqlx::query!(
18            "SELECT user_id FROM federated_identities WHERE issuer = $1 AND external_sub = $2",
19            issuer,
20            external_sub
21        )
22        .fetch_optional(&*self.pool)
23        .await?;
24
25        Ok(row.map(|r| UserId::new(r.user_id)))
26    }
27
28    /// Resolve a federated identity to a local `User`, creating both the
29    /// `users` row and the `federated_identities` mapping on first touch.
30    ///
31    /// All writes happen in a single transaction so a race between two
32    /// concurrent first-touch requests for the same `(issuer, external_sub)`
33    /// cannot produce two local users — the second loser observes the
34    /// primary-key conflict and re-reads the mapping.
35    pub async fn find_or_create_federated(
36        &self,
37        issuer: &str,
38        external_sub: &str,
39        claims: &FederatedIdentityClaims,
40    ) -> Result<User> {
41        let mut conn = self.write_pool.acquire().await?;
42        let mut tx = conn.begin().await?;
43
44        if let Some(existing) = sqlx::query!(
45            "UPDATE federated_identities SET last_seen_at = CURRENT_TIMESTAMP WHERE issuer = $1 \
46             AND external_sub = $2 RETURNING user_id",
47            issuer,
48            external_sub
49        )
50        .fetch_optional(&mut *tx)
51        .await?
52        {
53            let user = sqlx::query_as!(
54                User,
55                r#"
56                SELECT id, name, email, full_name, display_name, status,
57                       email_verified, roles, avatar_url, is_bot, is_scanner,
58                       created_at, updated_at
59                FROM users WHERE id = $1
60                "#,
61                existing.user_id
62            )
63            .fetch_one(&mut *tx)
64            .await?;
65            tx.commit().await?;
66            return Ok(user);
67        }
68
69        let fields = NewFederatedUser::derive(issuer, external_sub, claims);
70
71        let user = sqlx::query_as!(
72            User,
73            r#"
74            INSERT INTO users (
75                id, name, email, full_name, display_name,
76                status, email_verified, roles, is_bot,
77                created_at, updated_at
78            )
79            VALUES ($1, $2, $3, $4, $5, $6, false, $7::TEXT[], false, $8, $8)
80            RETURNING id, name, email, full_name, display_name, status, email_verified,
81                      roles, avatar_url, is_bot, is_scanner, created_at, updated_at
82            "#,
83            fields.id.as_str(),
84            fields.name,
85            fields.email,
86            fields.display_name.as_deref(),
87            fields.display_name.as_deref(),
88            fields.status,
89            &fields.roles,
90            fields.now,
91        )
92        .fetch_one(&mut *tx)
93        .await?;
94
95        sqlx::query!(
96            "INSERT INTO federated_identities (issuer, external_sub, user_id) VALUES ($1, $2, $3)",
97            issuer,
98            external_sub,
99            user.id.as_str()
100        )
101        .execute(&mut *tx)
102        .await?;
103
104        tx.commit().await?;
105        Ok(user)
106    }
107}
108
109struct NewFederatedUser {
110    id: UserId,
111    name: String,
112    email: String,
113    display_name: Option<String>,
114    status: &'static str,
115    roles: Vec<String>,
116    now: chrono::DateTime<Utc>,
117}
118
119impl NewFederatedUser {
120    fn derive(issuer: &str, external_sub: &str, claims: &FederatedIdentityClaims) -> Self {
121        let name = claims
122            .preferred_username
123            .clone()
124            .or_else(|| claims.name.clone())
125            .unwrap_or_else(|| format!("fed_{}_{}", short_hash(issuer), short_hash(external_sub)));
126        let synthetic_email = || {
127            format!(
128                "{}@{}.federated.local",
129                short_hash(external_sub),
130                short_host(issuer)
131            )
132        };
133        let email = match (claims.email.as_deref(), claims.email_verified) {
134            (Some(addr), true) => addr.to_owned(),
135            (Some(addr), false) => {
136                tracing::warn!(
137                    issuer,
138                    external_sub,
139                    upstream_email = addr,
140                    "upstream IdP did not assert email_verified; using synthetic local email to \
141                     prevent account-claim attacks"
142                );
143                synthetic_email()
144            },
145            (None, _) => synthetic_email(),
146        };
147
148        Self {
149            id: UserId::new(uuid::Uuid::new_v4().to_string()),
150            name,
151            email,
152            display_name: claims.name.clone(),
153            status: UserStatus::Active.as_str(),
154            roles: normalised_roles(&claims.roles),
155            now: Utc::now(),
156        }
157    }
158}
159
160fn normalised_roles(claim_roles: &[String]) -> Vec<String> {
161    if claim_roles.is_empty() {
162        vec![UserRole::User.as_str().to_owned()]
163    } else {
164        claim_roles.to_vec()
165    }
166}
167
168fn short_hash(s: &str) -> String {
169    use sha2::{Digest, Sha256};
170    let digest = Sha256::digest(s.as_bytes());
171    hex::encode(&digest[..6])
172}
173
174fn short_host(issuer: &str) -> String {
175    issuer
176        .trim_start_matches("https://")
177        .trim_start_matches("http://")
178        .split('/')
179        .next()
180        .unwrap_or("issuer")
181        .replace(['.', ':'], "-")
182}