Skip to main content

systemprompt_security/authz/
repository.rs

1//! `AccessControlRepository` — sqlx-backed access to the two-table authz
2//! schema.
3//!
4//! `access_control_entities` owns one row per `(entity_type, entity_id)` and
5//! carries the `default_included` flag plus a `source` provenance string.
6//! `access_control_rules` is the per-(entity, subject) grant table, with a
7//! foreign key back to the entity catalog. Callers fetch the entity row
8//! first (a `None` result signals an entity unknown to access control), then
9//! list rules for it, and hand both to [`super::resolver::resolve`].
10
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::sync::Arc;
14
15use sqlx::PgPool;
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::RuleId;
18
19use super::error::{AuthzError, AuthzResult};
20use super::types::{Access, AccessRule, EntityKind, EntityRow, RuleType};
21
22#[derive(Debug, Clone)]
23pub struct ExportRuleRow {
24    pub entity_type: String,
25    pub entity_id: String,
26    pub rule_type: String,
27    pub rule_value: String,
28    pub access: String,
29    pub justification: Option<String>,
30}
31
32#[derive(Debug, Clone, Copy)]
33pub struct UpsertRuleParams<'a> {
34    pub entity_type: EntityKind,
35    pub entity_id: &'a str,
36    pub rule_type: RuleType,
37    pub rule_value: &'a str,
38    pub access: Access,
39    /// Operator-supplied note explaining *why* this rule exists. Surfaced in
40    /// the matrix tooltip and in the audit row's `evaluated_rules` JSON when
41    /// the rule decides. `None` means the operator declined to give a reason.
42    pub justification: Option<&'a str>,
43}
44
45#[derive(Clone, Debug)]
46pub struct AccessControlRepository {
47    pool: Arc<PgPool>,
48    write_pool: Arc<PgPool>,
49}
50
51impl AccessControlRepository {
52    pub fn new(db: &DbPool) -> AuthzResult<Self> {
53        let pool = db
54            .pool_arc()
55            .map_err(|err| AuthzError::Validation(err.to_string()))?;
56        let write_pool = db
57            .write_pool_arc()
58            .map_err(|err| AuthzError::Validation(err.to_string()))?;
59        Ok(Self { pool, write_pool })
60    }
61
62    pub fn from_pool(pool: Arc<PgPool>) -> Self {
63        let write_pool = Arc::clone(&pool);
64        Self { pool, write_pool }
65    }
66
67    /// Look up one entity catalog row. `Ok(None)` means the entity has no
68    /// catalog row at all (publish-pipeline bootstrap gap) — the resolver
69    /// turns this into [`super::DenyReason::UnknownEntity`].
70    pub async fn get_entity(
71        &self,
72        entity_type: EntityKind,
73        entity_id: &str,
74    ) -> AuthzResult<Option<EntityRow>> {
75        let row = sqlx::query!(
76            r#"
77            SELECT entity_type, entity_id, default_included, source
78            FROM access_control_entities
79            WHERE entity_type = $1 AND entity_id = $2
80            "#,
81            entity_type.as_str(),
82            entity_id,
83        )
84        .fetch_optional(&*self.pool)
85        .await?;
86
87        let Some(row) = row else {
88            return Ok(None);
89        };
90        Ok(Some(EntityRow {
91            kind: EntityKind::from_str(&row.entity_type)?,
92            id: row.entity_id,
93            default_included: row.default_included,
94            source: row.source,
95        }))
96    }
97
98    /// Upsert an entity catalog row. Always overwrites `default_included` and
99    /// `source` so the most recent bootstrap pass wins — the publish pipeline
100    /// is the source of truth and runs ahead of YAML grant ingestion.
101    pub async fn upsert_entity(
102        &self,
103        entity_type: EntityKind,
104        entity_id: &str,
105        default_included: bool,
106        source: &str,
107    ) -> AuthzResult<()> {
108        sqlx::query!(
109            r#"
110            INSERT INTO access_control_entities (entity_type, entity_id, default_included, source)
111            VALUES ($1, $2, $3, $4)
112            ON CONFLICT (entity_type, entity_id) DO UPDATE
113            SET default_included = EXCLUDED.default_included,
114                source = EXCLUDED.source,
115                updated_at = NOW()
116            "#,
117            entity_type.as_str(),
118            entity_id,
119            default_included,
120            source,
121        )
122        .execute(&*self.write_pool)
123        .await?;
124        Ok(())
125    }
126
127    /// Bulk-fetch every catalog row for a given kind. Used by the CLI lint and
128    /// the publish-pipeline validator to detect rules pointing at entities
129    /// the bootstrap pass never registered.
130    pub async fn list_entities(&self, entity_type: EntityKind) -> AuthzResult<Vec<EntityRow>> {
131        let rows = sqlx::query!(
132            r#"
133            SELECT entity_type, entity_id, default_included, source
134            FROM access_control_entities
135            WHERE entity_type = $1
136            ORDER BY entity_id
137            "#,
138            entity_type.as_str(),
139        )
140        .fetch_all(&*self.pool)
141        .await?;
142
143        let mut out = Vec::with_capacity(rows.len());
144        for row in rows {
145            out.push(EntityRow {
146                kind: EntityKind::from_str(&row.entity_type)?,
147                id: row.entity_id,
148                default_included: row.default_included,
149                source: row.source,
150            });
151        }
152        Ok(out)
153    }
154
155    pub async fn list_role_department_rules_for_export(&self) -> AuthzResult<Vec<ExportRuleRow>> {
156        let rows = sqlx::query_as!(
157            ExportRuleRow,
158            r#"
159            SELECT entity_type, entity_id, rule_type, rule_value, access, justification
160            FROM access_control_rules
161            WHERE rule_type IN ('role', 'department')
162            ORDER BY entity_type, entity_id, access, rule_type, rule_value
163            "#,
164        )
165        .fetch_all(&*self.pool)
166        .await?;
167        Ok(rows)
168    }
169
170    pub async fn list_rules_for_entity(
171        &self,
172        entity_type: EntityKind,
173        entity_id: &str,
174    ) -> AuthzResult<Vec<AccessRule>> {
175        let rows = sqlx::query!(
176            r#"
177            SELECT id, rule_type, rule_value, access, justification
178            FROM access_control_rules
179            WHERE entity_type = $1 AND entity_id = $2
180            ORDER BY rule_type, rule_value
181            "#,
182            entity_type.as_str(),
183            entity_id,
184        )
185        .fetch_all(&*self.pool)
186        .await?;
187
188        let mut out = Vec::with_capacity(rows.len());
189        for row in rows {
190            out.push(AccessRule {
191                id: RuleId::new(row.id),
192                rule_type: RuleType::from_str(&row.rule_type)?,
193                rule_value: row.rule_value,
194                access: Access::from_str(&row.access)?,
195                justification: row.justification,
196            });
197        }
198        Ok(out)
199    }
200
201    pub async fn list_rules_bulk(
202        &self,
203        entity_type: EntityKind,
204        entity_ids: &[String],
205    ) -> AuthzResult<HashMap<String, Vec<AccessRule>>> {
206        let mut out: HashMap<String, Vec<AccessRule>> = HashMap::with_capacity(entity_ids.len());
207        for id in entity_ids {
208            out.entry(id.clone()).or_default();
209        }
210        if entity_ids.is_empty() {
211            return Ok(out);
212        }
213
214        let rows = sqlx::query!(
215            r#"
216            SELECT entity_id, id, rule_type, rule_value, access, justification
217            FROM access_control_rules
218            WHERE entity_type = $1 AND entity_id = ANY($2)
219            ORDER BY entity_id, rule_type, rule_value
220            "#,
221            entity_type.as_str(),
222            entity_ids,
223        )
224        .fetch_all(&*self.pool)
225        .await?;
226
227        for row in rows {
228            let rule = AccessRule {
229                id: RuleId::new(row.id),
230                rule_type: RuleType::from_str(&row.rule_type)?,
231                rule_value: row.rule_value,
232                access: Access::from_str(&row.access)?,
233                justification: row.justification,
234            };
235            out.entry(row.entity_id).or_default().push(rule);
236        }
237        Ok(out)
238    }
239
240    /// Insert or update a grant row. Fails with a foreign-key violation if no
241    /// entity catalog row exists — register the entity via
242    /// [`Self::upsert_entity`] first.
243    pub async fn upsert_rule(&self, params: UpsertRuleParams<'_>) -> AuthzResult<AccessRule> {
244        let id = RuleId::generate();
245        let rule_type_str = params.rule_type.to_string();
246        let access_str = params.access.to_string();
247        let row = sqlx::query!(
248            r#"
249            INSERT INTO access_control_rules
250                (id, entity_type, entity_id, rule_type, rule_value, access, justification)
251            VALUES ($1, $2, $3, $4, $5, $6, $7)
252            ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
253            DO UPDATE SET
254                access = EXCLUDED.access,
255                justification = COALESCE(EXCLUDED.justification, access_control_rules.justification),
256                updated_at = NOW()
257            RETURNING id, rule_type, rule_value, access, justification
258            "#,
259            id.as_str(),
260            params.entity_type.as_str(),
261            params.entity_id,
262            rule_type_str,
263            params.rule_value,
264            access_str,
265            params.justification,
266        )
267        .fetch_one(&*self.write_pool)
268        .await?;
269
270        Ok(AccessRule {
271            id: RuleId::new(row.id),
272            rule_type: RuleType::from_str(&row.rule_type)?,
273            rule_value: row.rule_value,
274            access: Access::from_str(&row.access)?,
275            justification: row.justification,
276        })
277    }
278
279    // None clears the operator note.
280    pub async fn set_justification(
281        &self,
282        rule_id: &RuleId,
283        justification: Option<&str>,
284    ) -> AuthzResult<bool> {
285        let result = sqlx::query!(
286            r#"UPDATE access_control_rules SET justification = $2, updated_at = NOW() WHERE id = $1"#,
287            rule_id.as_str(),
288            justification,
289        )
290        .execute(&*self.write_pool)
291        .await?;
292        Ok(result.rows_affected() > 0)
293    }
294
295    pub async fn delete_rule(&self, rule_id: &RuleId) -> AuthzResult<bool> {
296        let result = sqlx::query!(
297            r#"DELETE FROM access_control_rules WHERE id = $1"#,
298            rule_id.as_str(),
299        )
300        .execute(&*self.write_pool)
301        .await?;
302        Ok(result.rows_affected() > 0)
303    }
304}