Skip to main content

systemprompt_security/authz/repository/
rules.rs

1use std::collections::HashMap;
2use std::str::FromStr;
3
4use systemprompt_identifiers::RuleId;
5
6use super::{AccessControlRepository, ExportRuleRow, UpsertRuleParams};
7use crate::authz::error::AuthzResult;
8use crate::authz::types::{Access, AccessRule, EntityKind, RuleType};
9
10impl AccessControlRepository {
11    pub async fn list_role_rules_for_export(&self) -> AuthzResult<Vec<ExportRuleRow>> {
12        let rows = sqlx::query_as!(
13            ExportRuleRow,
14            r#"
15            SELECT entity_type, entity_id, rule_type, rule_value, access, justification
16            FROM access_control_rules
17            WHERE rule_type = 'role'
18            ORDER BY entity_type, entity_id, access, rule_type, rule_value
19            "#,
20        )
21        .fetch_all(&*self.pool)
22        .await?;
23        Ok(rows)
24    }
25
26    pub async fn list_rules_for_entity(
27        &self,
28        entity_type: EntityKind,
29        entity_id: &str,
30    ) -> AuthzResult<Vec<AccessRule>> {
31        let rows = sqlx::query!(
32            r#"
33            SELECT id, rule_type, rule_value, access, justification
34            FROM access_control_rules
35            WHERE entity_type = $1 AND entity_id = $2
36            ORDER BY rule_type, rule_value
37            "#,
38            entity_type.as_str(),
39            entity_id,
40        )
41        .fetch_all(&*self.pool)
42        .await?;
43
44        let mut out = Vec::with_capacity(rows.len());
45        for row in rows {
46            out.push(AccessRule {
47                id: RuleId::new(row.id),
48                rule_type: RuleType::from_str(&row.rule_type)?,
49                rule_value: row.rule_value,
50                access: Access::from_str(&row.access)?,
51                justification: row.justification,
52            });
53        }
54        Ok(out)
55    }
56
57    pub async fn list_rules_bulk(
58        &self,
59        entity_type: EntityKind,
60        entity_ids: &[String],
61    ) -> AuthzResult<HashMap<String, Vec<AccessRule>>> {
62        let mut out: HashMap<String, Vec<AccessRule>> = HashMap::with_capacity(entity_ids.len());
63        for id in entity_ids {
64            out.entry(id.clone()).or_default();
65        }
66        if entity_ids.is_empty() {
67            return Ok(out);
68        }
69
70        let rows = sqlx::query!(
71            r#"
72            SELECT entity_id, id, rule_type, rule_value, access, justification
73            FROM access_control_rules
74            WHERE entity_type = $1 AND entity_id = ANY($2)
75            ORDER BY entity_id, rule_type, rule_value
76            "#,
77            entity_type.as_str(),
78            entity_ids,
79        )
80        .fetch_all(&*self.pool)
81        .await?;
82
83        for row in rows {
84            let rule = AccessRule {
85                id: RuleId::new(row.id),
86                rule_type: RuleType::from_str(&row.rule_type)?,
87                rule_value: row.rule_value,
88                access: Access::from_str(&row.access)?,
89                justification: row.justification,
90            };
91            out.entry(row.entity_id).or_default().push(rule);
92        }
93        Ok(out)
94    }
95
96    /// Fails with a foreign-key violation if no entity catalog row exists for
97    /// `(entity_type, entity_id)` — register the entity via
98    /// [`Self::upsert_entity`] first.
99    pub async fn upsert_rule(&self, params: UpsertRuleParams<'_>) -> AuthzResult<AccessRule> {
100        let id = RuleId::generate();
101        let rule_type_str = params.rule_type.to_string();
102        let access_str = params.access.to_string();
103        let row = sqlx::query!(
104            r#"
105            INSERT INTO access_control_rules
106                (id, entity_type, entity_id, rule_type, rule_value, access, justification)
107            VALUES ($1, $2, $3, $4, $5, $6, $7)
108            ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
109            DO UPDATE SET
110                access = EXCLUDED.access,
111                justification = COALESCE(EXCLUDED.justification, access_control_rules.justification),
112                updated_at = NOW()
113            RETURNING id, rule_type, rule_value, access, justification
114            "#,
115            id.as_str(),
116            params.entity_type.as_str(),
117            params.entity_id,
118            rule_type_str,
119            params.rule_value,
120            access_str,
121            params.justification,
122        )
123        .fetch_one(&*self.write_pool)
124        .await?;
125
126        Ok(AccessRule {
127            id: RuleId::new(row.id),
128            rule_type: RuleType::from_str(&row.rule_type)?,
129            rule_value: row.rule_value,
130            access: Access::from_str(&row.access)?,
131            justification: row.justification,
132        })
133    }
134
135    /// `None` clears the operator note.
136    pub async fn set_justification(
137        &self,
138        rule_id: &RuleId,
139        justification: Option<&str>,
140    ) -> AuthzResult<bool> {
141        let result = sqlx::query!(
142            r#"UPDATE access_control_rules SET justification = $2, updated_at = NOW() WHERE id = $1"#,
143            rule_id.as_str(),
144            justification,
145        )
146        .execute(&*self.write_pool)
147        .await?;
148        Ok(result.rows_affected() > 0)
149    }
150
151    pub async fn delete_rule(&self, rule_id: &RuleId) -> AuthzResult<bool> {
152        let result = sqlx::query!(
153            r#"DELETE FROM access_control_rules WHERE id = $1"#,
154            rule_id.as_str(),
155        )
156        .execute(&*self.write_pool)
157        .await?;
158        Ok(result.rows_affected() > 0)
159    }
160}