systemprompt_security/authz/repository/
rules.rs1use std::collections::HashMap;
2use std::str::FromStr;
3
4use systemprompt_identifiers::RuleId;
5
6use super::{AccessControlRepository, ExportRuleRow, UpsertRuleParams};
7use crate::authz::error::AuthzResult;
8use crate::authz::types::{Access, AccessRule, EntityKind, RuleType};
9
10impl AccessControlRepository {
11 pub async fn list_role_rules_for_export(&self) -> AuthzResult<Vec<ExportRuleRow>> {
12 let rows = sqlx::query_as!(
13 ExportRuleRow,
14 r#"
15 SELECT entity_type, entity_id, rule_type, rule_value, access, justification
16 FROM access_control_rules
17 WHERE rule_type = 'role'
18 ORDER BY entity_type, entity_id, access, rule_type, rule_value
19 "#,
20 )
21 .fetch_all(&*self.pool)
22 .await?;
23 Ok(rows)
24 }
25
26 pub async fn list_rules_for_entity(
27 &self,
28 entity_type: EntityKind,
29 entity_id: &str,
30 ) -> AuthzResult<Vec<AccessRule>> {
31 let rows = sqlx::query!(
32 r#"
33 SELECT id, rule_type, rule_value, access, justification
34 FROM access_control_rules
35 WHERE entity_type = $1 AND entity_id = $2
36 ORDER BY rule_type, rule_value
37 "#,
38 entity_type.as_str(),
39 entity_id,
40 )
41 .fetch_all(&*self.pool)
42 .await?;
43
44 let mut out = Vec::with_capacity(rows.len());
45 for row in rows {
46 out.push(AccessRule {
47 id: RuleId::new(row.id),
48 rule_type: RuleType::from_str(&row.rule_type)?,
49 rule_value: row.rule_value,
50 access: Access::from_str(&row.access)?,
51 justification: row.justification,
52 });
53 }
54 Ok(out)
55 }
56
57 pub async fn list_rules_bulk(
58 &self,
59 entity_type: EntityKind,
60 entity_ids: &[String],
61 ) -> AuthzResult<HashMap<String, Vec<AccessRule>>> {
62 let mut out: HashMap<String, Vec<AccessRule>> = HashMap::with_capacity(entity_ids.len());
63 for id in entity_ids {
64 out.entry(id.clone()).or_default();
65 }
66 if entity_ids.is_empty() {
67 return Ok(out);
68 }
69
70 let rows = sqlx::query!(
71 r#"
72 SELECT entity_id, id, rule_type, rule_value, access, justification
73 FROM access_control_rules
74 WHERE entity_type = $1 AND entity_id = ANY($2)
75 ORDER BY entity_id, rule_type, rule_value
76 "#,
77 entity_type.as_str(),
78 entity_ids,
79 )
80 .fetch_all(&*self.pool)
81 .await?;
82
83 for row in rows {
84 let rule = AccessRule {
85 id: RuleId::new(row.id),
86 rule_type: RuleType::from_str(&row.rule_type)?,
87 rule_value: row.rule_value,
88 access: Access::from_str(&row.access)?,
89 justification: row.justification,
90 };
91 out.entry(row.entity_id).or_default().push(rule);
92 }
93 Ok(out)
94 }
95
96 pub async fn upsert_rule(&self, params: UpsertRuleParams<'_>) -> AuthzResult<AccessRule> {
100 let id = RuleId::generate();
101 let rule_type_str = params.rule_type.to_string();
102 let access_str = params.access.to_string();
103 let row = sqlx::query!(
104 r#"
105 INSERT INTO access_control_rules
106 (id, entity_type, entity_id, rule_type, rule_value, access, justification)
107 VALUES ($1, $2, $3, $4, $5, $6, $7)
108 ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
109 DO UPDATE SET
110 access = EXCLUDED.access,
111 justification = COALESCE(EXCLUDED.justification, access_control_rules.justification),
112 updated_at = NOW()
113 RETURNING id, rule_type, rule_value, access, justification
114 "#,
115 id.as_str(),
116 params.entity_type.as_str(),
117 params.entity_id,
118 rule_type_str,
119 params.rule_value,
120 access_str,
121 params.justification,
122 )
123 .fetch_one(&*self.write_pool)
124 .await?;
125
126 Ok(AccessRule {
127 id: RuleId::new(row.id),
128 rule_type: RuleType::from_str(&row.rule_type)?,
129 rule_value: row.rule_value,
130 access: Access::from_str(&row.access)?,
131 justification: row.justification,
132 })
133 }
134
135 pub async fn set_justification(
137 &self,
138 rule_id: &RuleId,
139 justification: Option<&str>,
140 ) -> AuthzResult<bool> {
141 let result = sqlx::query!(
142 r#"UPDATE access_control_rules SET justification = $2, updated_at = NOW() WHERE id = $1"#,
143 rule_id.as_str(),
144 justification,
145 )
146 .execute(&*self.write_pool)
147 .await?;
148 Ok(result.rows_affected() > 0)
149 }
150
151 pub async fn delete_rule(&self, rule_id: &RuleId) -> AuthzResult<bool> {
152 let result = sqlx::query!(
153 r#"DELETE FROM access_control_rules WHERE id = $1"#,
154 rule_id.as_str(),
155 )
156 .execute(&*self.write_pool)
157 .await?;
158 Ok(result.rows_affected() > 0)
159 }
160}