Skip to main content

systemprompt_security/authz/
repository.rs

1//! `AccessControlRepository` — sqlx-backed access to `access_control_rules`.
2//!
3//! Generic over `entity_type` so the same repository serves the gateway
4//! (`gateway_route`), MCP (`mcp_server`), and any future enforcement site.
5//! The `default_included` per-entity flag is encoded as a sentinel row
6//! (`rule_type='role'`, `rule_value='__default__'`) inside the same table;
7//! [`AccessControlRepository::list_rules_for_entity`] and
8//! [`AccessControlRepository::list_rules_bulk`] filter that sentinel out so
9//! callers only see real assignments.
10
11use std::collections::HashMap;
12use std::str::FromStr;
13use std::sync::Arc;
14
15use sqlx::PgPool;
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::RuleId;
18
19use super::error::{AuthzError, AuthzResult};
20use super::types::{Access, AccessRule, EntityKind, RuleType};
21
22const DEFAULT_SENTINEL_VALUE: &str = "__default__";
23
24#[derive(Debug, Clone, Copy)]
25pub struct UpsertRuleParams<'a> {
26    pub entity_type: EntityKind,
27    pub entity_id: &'a str,
28    pub rule_type: RuleType,
29    pub rule_value: &'a str,
30    pub access: Access,
31    /// Operator-supplied note explaining *why* this rule exists.
32    /// Surfaced in the matrix tooltip and in the audit row's
33    /// `evaluated_rules` JSON when the rule decides. `None` means
34    /// the operator declined to give a reason.
35    pub justification: Option<&'a str>,
36}
37
38#[derive(Clone, Debug)]
39pub struct AccessControlRepository {
40    pool: Arc<PgPool>,
41    write_pool: Arc<PgPool>,
42}
43
44impl AccessControlRepository {
45    pub fn new(db: &DbPool) -> AuthzResult<Self> {
46        let pool = db
47            .pool_arc()
48            .map_err(|err| AuthzError::Validation(err.to_string()))?;
49        let write_pool = db
50            .write_pool_arc()
51            .map_err(|err| AuthzError::Validation(err.to_string()))?;
52        Ok(Self { pool, write_pool })
53    }
54
55    pub fn from_pool(pool: Arc<PgPool>) -> Self {
56        let write_pool = Arc::clone(&pool);
57        Self { pool, write_pool }
58    }
59
60    pub async fn list_rules_for_entity(
61        &self,
62        entity_type: EntityKind,
63        entity_id: &str,
64    ) -> AuthzResult<Vec<AccessRule>> {
65        let rows = sqlx::query!(
66            r#"
67            SELECT id, rule_type, rule_value, access, default_included, justification
68            FROM access_control_rules
69            WHERE entity_type = $1 AND entity_id = $2
70            ORDER BY rule_type, rule_value
71            "#,
72            entity_type.as_str(),
73            entity_id,
74        )
75        .fetch_all(&*self.pool)
76        .await?;
77
78        let mut out = Vec::with_capacity(rows.len());
79        for row in rows {
80            if is_sentinel(&row.rule_type, &row.rule_value) {
81                continue;
82            }
83            out.push(AccessRule {
84                id: RuleId::new(row.id),
85                rule_type: RuleType::from_str(&row.rule_type)?,
86                rule_value: row.rule_value,
87                access: Access::from_str(&row.access)?,
88                default_included: row.default_included,
89                justification: row.justification,
90            });
91        }
92        Ok(out)
93    }
94
95    pub async fn list_rules_bulk(
96        &self,
97        entity_type: EntityKind,
98        entity_ids: &[String],
99    ) -> AuthzResult<HashMap<String, Vec<AccessRule>>> {
100        let mut out: HashMap<String, Vec<AccessRule>> = HashMap::with_capacity(entity_ids.len());
101        for id in entity_ids {
102            out.entry(id.clone()).or_default();
103        }
104        if entity_ids.is_empty() {
105            return Ok(out);
106        }
107
108        let rows = sqlx::query!(
109            r#"
110            SELECT entity_id, id, rule_type, rule_value, access, default_included, justification
111            FROM access_control_rules
112            WHERE entity_type = $1 AND entity_id = ANY($2)
113            ORDER BY entity_id, rule_type, rule_value
114            "#,
115            entity_type.as_str(),
116            entity_ids,
117        )
118        .fetch_all(&*self.pool)
119        .await?;
120
121        for row in rows {
122            if is_sentinel(&row.rule_type, &row.rule_value) {
123                continue;
124            }
125            let rule = AccessRule {
126                id: RuleId::new(row.id),
127                rule_type: RuleType::from_str(&row.rule_type)?,
128                rule_value: row.rule_value,
129                access: Access::from_str(&row.access)?,
130                default_included: row.default_included,
131                justification: row.justification,
132            };
133            out.entry(row.entity_id).or_default().push(rule);
134        }
135        Ok(out)
136    }
137
138    pub async fn upsert_rule(&self, params: UpsertRuleParams<'_>) -> AuthzResult<AccessRule> {
139        let id = RuleId::generate();
140        let rule_type_str = params.rule_type.to_string();
141        let access_str = params.access.to_string();
142        let row = sqlx::query!(
143            r#"
144            INSERT INTO access_control_rules
145                (id, entity_type, entity_id, rule_type, rule_value, access, default_included, justification)
146            VALUES ($1, $2, $3, $4, $5, $6, false, $7)
147            ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
148            DO UPDATE SET
149                access = EXCLUDED.access,
150                justification = COALESCE(EXCLUDED.justification, access_control_rules.justification),
151                updated_at = NOW()
152            RETURNING id, rule_type, rule_value, access, default_included, justification
153            "#,
154            id.as_str(),
155            params.entity_type.as_str(),
156            params.entity_id,
157            rule_type_str,
158            params.rule_value,
159            access_str,
160            params.justification,
161        )
162        .fetch_one(&*self.write_pool)
163        .await?;
164
165        Ok(AccessRule {
166            id: RuleId::new(row.id),
167            rule_type: RuleType::from_str(&row.rule_type)?,
168            rule_value: row.rule_value,
169            access: Access::from_str(&row.access)?,
170            default_included: row.default_included,
171            justification: row.justification,
172        })
173    }
174
175    /// Update only the justification on an existing rule. Pass `None` to
176    /// clear the operator note.
177    pub async fn set_justification(
178        &self,
179        rule_id: &RuleId,
180        justification: Option<&str>,
181    ) -> AuthzResult<bool> {
182        let result = sqlx::query!(
183            r#"UPDATE access_control_rules SET justification = $2, updated_at = NOW() WHERE id = $1"#,
184            rule_id.as_str(),
185            justification,
186        )
187        .execute(&*self.write_pool)
188        .await?;
189        Ok(result.rows_affected() > 0)
190    }
191
192    pub async fn delete_rule(&self, rule_id: &RuleId) -> AuthzResult<bool> {
193        let result = sqlx::query!(
194            r#"DELETE FROM access_control_rules WHERE id = $1"#,
195            rule_id.as_str(),
196        )
197        .execute(&*self.write_pool)
198        .await?;
199        Ok(result.rows_affected() > 0)
200    }
201
202    pub async fn set_default_included(
203        &self,
204        entity_type: EntityKind,
205        entity_id: &str,
206        value: bool,
207    ) -> AuthzResult<()> {
208        if value {
209            let id = RuleId::generate();
210            sqlx::query!(
211                r#"
212                INSERT INTO access_control_rules
213                    (id, entity_type, entity_id, rule_type, rule_value, access, default_included)
214                VALUES ($1, $2, $3, 'role', $4, 'allow', true)
215                ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
216                DO UPDATE SET default_included = true, updated_at = NOW()
217                "#,
218                id.as_str(),
219                entity_type.as_str(),
220                entity_id,
221                DEFAULT_SENTINEL_VALUE,
222            )
223            .execute(&*self.write_pool)
224            .await?;
225        } else {
226            sqlx::query!(
227                r#"
228                DELETE FROM access_control_rules
229                WHERE entity_type = $1
230                  AND entity_id = $2
231                  AND rule_type = 'role'
232                  AND rule_value = $3
233                "#,
234                entity_type.as_str(),
235                entity_id,
236                DEFAULT_SENTINEL_VALUE,
237            )
238            .execute(&*self.write_pool)
239            .await?;
240        }
241        Ok(())
242    }
243
244    pub async fn get_default_included(
245        &self,
246        entity_type: EntityKind,
247        entity_id: &str,
248    ) -> AuthzResult<bool> {
249        let row = sqlx::query!(
250            r#"
251            SELECT default_included FROM access_control_rules
252            WHERE entity_type = $1
253              AND entity_id = $2
254              AND rule_type = 'role'
255              AND rule_value = $3
256            "#,
257            entity_type.as_str(),
258            entity_id,
259            DEFAULT_SENTINEL_VALUE,
260        )
261        .fetch_optional(&*self.pool)
262        .await?;
263        Ok(row.is_some_and(|r| r.default_included))
264    }
265}
266
267fn is_sentinel(rule_type: &str, rule_value: &str) -> bool {
268    rule_type == "role" && rule_value == DEFAULT_SENTINEL_VALUE
269}