Skip to main content

systemprompt_security/authz/
repository.rs

1//! `AccessControlRepository` — sqlx-backed access to `access_control_rules`.
2//!
3//! Generic over `entity_type` so the same repository serves the gateway
4//! (`gateway_route`), MCP (`mcp_server`), and any future enforcement site.
5//! The `default_included` per-entity flag is encoded as a sentinel row
6//! (`rule_type='role'`, `rule_value='__default__'`) inside the same table;
7//! [`AccessControlRepository::list_rules_for_entity`] and
8//! [`AccessControlRepository::list_rules_bulk`] filter that sentinel out so
9//! callers only see real assignments.
10
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::sync::Arc;
14
15use sqlx::PgPool;
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::RuleId;
18
19use super::error::{AuthzError, AuthzResult};
20use super::types::{Access, AccessRule, EntityKind, RuleType};
21
22const DEFAULT_SENTINEL_VALUE: &str = "__default__";
23
24#[derive(Debug, Clone)]
25pub struct ExportRuleRow {
26    pub entity_type: String,
27    pub entity_id: String,
28    pub rule_type: String,
29    pub rule_value: String,
30    pub access: String,
31    pub justification: Option<String>,
32}
33
34#[derive(Debug, Clone, Copy)]
35pub struct UpsertRuleParams<'a> {
36    pub entity_type: EntityKind,
37    pub entity_id: &'a str,
38    pub rule_type: RuleType,
39    pub rule_value: &'a str,
40    pub access: Access,
41    /// Operator-supplied note explaining *why* this rule exists.
42    /// Surfaced in the matrix tooltip and in the audit row's
43    /// `evaluated_rules` JSON when the rule decides. `None` means
44    /// the operator declined to give a reason.
45    pub justification: Option<&'a str>,
46}
47
48#[derive(Clone, Debug)]
49pub struct AccessControlRepository {
50    pool: Arc<PgPool>,
51    write_pool: Arc<PgPool>,
52}
53
54impl AccessControlRepository {
55    pub fn new(db: &DbPool) -> AuthzResult<Self> {
56        let pool = db
57            .pool_arc()
58            .map_err(|err| AuthzError::Validation(err.to_string()))?;
59        let write_pool = db
60            .write_pool_arc()
61            .map_err(|err| AuthzError::Validation(err.to_string()))?;
62        Ok(Self { pool, write_pool })
63    }
64
65    pub fn from_pool(pool: Arc<PgPool>) -> Self {
66        let write_pool = Arc::clone(&pool);
67        Self { pool, write_pool }
68    }
69
70    pub async fn list_role_department_rules_for_export(&self) -> AuthzResult<Vec<ExportRuleRow>> {
71        let rows = sqlx::query_as!(
72            ExportRuleRow,
73            r#"
74            SELECT entity_type, entity_id, rule_type, rule_value, access, justification
75            FROM access_control_rules
76            WHERE rule_type IN ('role', 'department')
77              AND rule_value <> '__default__'
78            ORDER BY entity_type, entity_id, access, rule_type, rule_value
79            "#,
80        )
81        .fetch_all(&*self.pool)
82        .await?;
83        Ok(rows)
84    }
85
86    pub async fn list_rules_for_entity(
87        &self,
88        entity_type: EntityKind,
89        entity_id: &str,
90    ) -> AuthzResult<Vec<AccessRule>> {
91        let rows = sqlx::query!(
92            r#"
93            SELECT id, rule_type, rule_value, access, default_included, justification
94            FROM access_control_rules
95            WHERE entity_type = $1 AND entity_id = $2
96            ORDER BY rule_type, rule_value
97            "#,
98            entity_type.as_str(),
99            entity_id,
100        )
101        .fetch_all(&*self.pool)
102        .await?;
103
104        let mut out = Vec::with_capacity(rows.len());
105        for row in rows {
106            if is_sentinel(&row.rule_type, &row.rule_value) {
107                continue;
108            }
109            out.push(AccessRule {
110                id: RuleId::new(row.id),
111                rule_type: RuleType::from_str(&row.rule_type)?,
112                rule_value: row.rule_value,
113                access: Access::from_str(&row.access)?,
114                default_included: row.default_included,
115                justification: row.justification,
116            });
117        }
118        Ok(out)
119    }
120
121    pub async fn list_rules_bulk(
122        &self,
123        entity_type: EntityKind,
124        entity_ids: &[String],
125    ) -> AuthzResult<HashMap<String, Vec<AccessRule>>> {
126        let mut out: HashMap<String, Vec<AccessRule>> = HashMap::with_capacity(entity_ids.len());
127        for id in entity_ids {
128            out.entry(id.clone()).or_default();
129        }
130        if entity_ids.is_empty() {
131            return Ok(out);
132        }
133
134        let rows = sqlx::query!(
135            r#"
136            SELECT entity_id, id, rule_type, rule_value, access, default_included, justification
137            FROM access_control_rules
138            WHERE entity_type = $1 AND entity_id = ANY($2)
139            ORDER BY entity_id, rule_type, rule_value
140            "#,
141            entity_type.as_str(),
142            entity_ids,
143        )
144        .fetch_all(&*self.pool)
145        .await?;
146
147        for row in rows {
148            if is_sentinel(&row.rule_type, &row.rule_value) {
149                continue;
150            }
151            let rule = AccessRule {
152                id: RuleId::new(row.id),
153                rule_type: RuleType::from_str(&row.rule_type)?,
154                rule_value: row.rule_value,
155                access: Access::from_str(&row.access)?,
156                default_included: row.default_included,
157                justification: row.justification,
158            };
159            out.entry(row.entity_id).or_default().push(rule);
160        }
161        Ok(out)
162    }
163
164    pub async fn upsert_rule(&self, params: UpsertRuleParams<'_>) -> AuthzResult<AccessRule> {
165        let id = RuleId::generate();
166        let rule_type_str = params.rule_type.to_string();
167        let access_str = params.access.to_string();
168        let row = sqlx::query!(
169            r#"
170            INSERT INTO access_control_rules
171                (id, entity_type, entity_id, rule_type, rule_value, access, default_included, justification)
172            VALUES ($1, $2, $3, $4, $5, $6, false, $7)
173            ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
174            DO UPDATE SET
175                access = EXCLUDED.access,
176                justification = COALESCE(EXCLUDED.justification, access_control_rules.justification),
177                updated_at = NOW()
178            RETURNING id, rule_type, rule_value, access, default_included, justification
179            "#,
180            id.as_str(),
181            params.entity_type.as_str(),
182            params.entity_id,
183            rule_type_str,
184            params.rule_value,
185            access_str,
186            params.justification,
187        )
188        .fetch_one(&*self.write_pool)
189        .await?;
190
191        Ok(AccessRule {
192            id: RuleId::new(row.id),
193            rule_type: RuleType::from_str(&row.rule_type)?,
194            rule_value: row.rule_value,
195            access: Access::from_str(&row.access)?,
196            default_included: row.default_included,
197            justification: row.justification,
198        })
199    }
200
201    /// Update only the justification on an existing rule. Pass `None` to
202    /// clear the operator note.
203    pub async fn set_justification(
204        &self,
205        rule_id: &RuleId,
206        justification: Option<&str>,
207    ) -> AuthzResult<bool> {
208        let result = sqlx::query!(
209            r#"UPDATE access_control_rules SET justification = $2, updated_at = NOW() WHERE id = $1"#,
210            rule_id.as_str(),
211            justification,
212        )
213        .execute(&*self.write_pool)
214        .await?;
215        Ok(result.rows_affected() > 0)
216    }
217
218    pub async fn delete_rule(&self, rule_id: &RuleId) -> AuthzResult<bool> {
219        let result = sqlx::query!(
220            r#"DELETE FROM access_control_rules WHERE id = $1"#,
221            rule_id.as_str(),
222        )
223        .execute(&*self.write_pool)
224        .await?;
225        Ok(result.rows_affected() > 0)
226    }
227
228    pub async fn set_default_included(
229        &self,
230        entity_type: EntityKind,
231        entity_id: &str,
232        value: bool,
233    ) -> AuthzResult<()> {
234        if value {
235            let id = RuleId::generate();
236            sqlx::query!(
237                r#"
238                INSERT INTO access_control_rules
239                    (id, entity_type, entity_id, rule_type, rule_value, access, default_included)
240                VALUES ($1, $2, $3, 'role', $4, 'allow', true)
241                ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
242                DO UPDATE SET default_included = true, updated_at = NOW()
243                "#,
244                id.as_str(),
245                entity_type.as_str(),
246                entity_id,
247                DEFAULT_SENTINEL_VALUE,
248            )
249            .execute(&*self.write_pool)
250            .await?;
251        } else {
252            sqlx::query!(
253                r#"
254                DELETE FROM access_control_rules
255                WHERE entity_type = $1
256                  AND entity_id = $2
257                  AND rule_type = 'role'
258                  AND rule_value = $3
259                "#,
260                entity_type.as_str(),
261                entity_id,
262                DEFAULT_SENTINEL_VALUE,
263            )
264            .execute(&*self.write_pool)
265            .await?;
266        }
267        Ok(())
268    }
269
270    pub async fn get_default_included(
271        &self,
272        entity_type: EntityKind,
273        entity_id: &str,
274    ) -> AuthzResult<bool> {
275        let row = sqlx::query!(
276            r#"
277            SELECT default_included FROM access_control_rules
278            WHERE entity_type = $1
279              AND entity_id = $2
280              AND rule_type = 'role'
281              AND rule_value = $3
282            "#,
283            entity_type.as_str(),
284            entity_id,
285            DEFAULT_SENTINEL_VALUE,
286        )
287        .fetch_optional(&*self.pool)
288        .await?;
289        Ok(row.is_some_and(|r| r.default_included))
290    }
291}
292
293fn is_sentinel(rule_type: &str, rule_value: &str) -> bool {
294    rule_type == "role" && rule_value == DEFAULT_SENTINEL_VALUE
295}