systemprompt_users/repository/
federated_identity.rs1use chrono::Utc;
5use sqlx::Acquire;
6use systemprompt_identifiers::UserId;
7use systemprompt_traits::FederatedIdentityClaims;
8
9use crate::error::Result;
10use crate::models::{User, UserRole, UserStatus};
11use crate::repository::UserRepository;
12
13impl UserRepository {
14 pub async fn find_federated(&self, issuer: &str, external_sub: &str) -> Result<Option<UserId>> {
15 let row = sqlx::query!(
16 "SELECT user_id FROM federated_identities WHERE issuer = $1 AND external_sub = $2",
17 issuer,
18 external_sub
19 )
20 .fetch_optional(&*self.pool)
21 .await?;
22
23 Ok(row.map(|r| UserId::new(r.user_id)))
24 }
25
26 pub async fn find_or_create_federated(
27 &self,
28 issuer: &str,
29 external_sub: &str,
30 claims: &FederatedIdentityClaims,
31 ) -> Result<User> {
32 let mut conn = self.write_pool.acquire().await?;
33 let mut tx = conn.begin().await?;
34
35 if let Some(existing) = sqlx::query!(
36 "UPDATE federated_identities SET last_seen_at = CURRENT_TIMESTAMP WHERE issuer = $1 \
37 AND external_sub = $2 RETURNING user_id",
38 issuer,
39 external_sub
40 )
41 .fetch_optional(&mut *tx)
42 .await?
43 {
44 let user = sqlx::query_as!(
45 User,
46 r#"
47 SELECT id, name, email, full_name, display_name, status,
48 email_verified, roles, avatar_url, is_bot, is_scanner,
49 created_at, updated_at
50 FROM users WHERE id = $1
51 "#,
52 existing.user_id
53 )
54 .fetch_one(&mut *tx)
55 .await?;
56 tx.commit().await?;
57 return Ok(user);
58 }
59
60 let fields = NewFederatedUser::derive(issuer, external_sub, claims);
61
62 let user = sqlx::query_as!(
63 User,
64 r#"
65 INSERT INTO users (
66 id, name, email, full_name, display_name,
67 status, email_verified, roles, is_bot,
68 created_at, updated_at
69 )
70 VALUES ($1, $2, $3, $4, $5, $6, false, $7::TEXT[], false, $8, $8)
71 RETURNING id, name, email, full_name, display_name, status, email_verified,
72 roles, avatar_url, is_bot, is_scanner, created_at, updated_at
73 "#,
74 fields.id.as_str(),
75 fields.name,
76 fields.email,
77 fields.display_name.as_deref(),
78 fields.display_name.as_deref(),
79 fields.status,
80 &fields.roles,
81 fields.now,
82 )
83 .fetch_one(&mut *tx)
84 .await?;
85
86 sqlx::query!(
87 "INSERT INTO federated_identities (issuer, external_sub, user_id) VALUES ($1, $2, $3)",
88 issuer,
89 external_sub,
90 user.id.as_str()
91 )
92 .execute(&mut *tx)
93 .await?;
94
95 tx.commit().await?;
96 Ok(user)
97 }
98}
99
100struct NewFederatedUser {
101 id: UserId,
102 name: String,
103 email: String,
104 display_name: Option<String>,
105 status: &'static str,
106 roles: Vec<String>,
107 now: chrono::DateTime<Utc>,
108}
109
110impl NewFederatedUser {
111 fn derive(issuer: &str, external_sub: &str, claims: &FederatedIdentityClaims) -> Self {
112 let name = claims
113 .preferred_username
114 .clone()
115 .or_else(|| claims.name.clone())
116 .unwrap_or_else(|| format!("fed_{}_{}", short_hash(issuer), short_hash(external_sub)));
117 let synthetic_email = || {
118 format!(
119 "{}@{}.federated.local",
120 short_hash(external_sub),
121 short_host(issuer)
122 )
123 };
124 let email = match (claims.email.as_deref(), claims.email_verified) {
125 (Some(addr), true) => addr.to_owned(),
126 (Some(addr), false) => {
127 tracing::warn!(
128 issuer,
129 external_sub,
130 upstream_email = addr,
131 "upstream IdP did not assert email_verified; using synthetic local email to \
132 prevent account-claim attacks"
133 );
134 synthetic_email()
135 },
136 (None, _) => synthetic_email(),
137 };
138
139 Self {
140 id: UserId::new(uuid::Uuid::new_v4().to_string()),
141 name,
142 email,
143 display_name: claims.name.clone(),
144 status: UserStatus::Active.as_str(),
145 roles: normalised_roles(&claims.roles),
146 now: Utc::now(),
147 }
148 }
149}
150
151fn normalised_roles(claim_roles: &[String]) -> Vec<String> {
152 if claim_roles.is_empty() {
153 vec![UserRole::User.as_str().to_owned()]
154 } else {
155 claim_roles.to_vec()
156 }
157}
158
159fn short_hash(s: &str) -> String {
160 use sha2::{Digest, Sha256};
161 let digest = Sha256::digest(s.as_bytes());
162 hex::encode(&digest[..6])
163}
164
165fn short_host(issuer: &str) -> String {
166 issuer
167 .trim_start_matches("https://")
168 .trim_start_matches("http://")
169 .split('/')
170 .next()
171 .unwrap_or("issuer")
172 .replace(['.', ':'], "-")
173}