systemprompt_security/authz/
repository.rs1use std::collections::HashMap;
12use std::str::FromStr;
13use std::sync::Arc;
14
15use sqlx::PgPool;
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::RuleId;
18
19use super::error::{AuthzError, AuthzResult};
20use super::types::{Access, AccessRule, EntityKind, RuleType};
21
22const DEFAULT_SENTINEL_VALUE: &str = "__default__";
23
24#[derive(Debug, Clone, Copy)]
25pub struct UpsertRuleParams<'a> {
26 pub entity_type: EntityKind,
27 pub entity_id: &'a str,
28 pub rule_type: RuleType,
29 pub rule_value: &'a str,
30 pub access: Access,
31 pub justification: Option<&'a str>,
36}
37
38#[derive(Clone, Debug)]
39pub struct AccessControlRepository {
40 pool: Arc<PgPool>,
41 write_pool: Arc<PgPool>,
42}
43
44impl AccessControlRepository {
45 pub fn new(db: &DbPool) -> AuthzResult<Self> {
46 let pool = db
47 .pool_arc()
48 .map_err(|err| AuthzError::Validation(err.to_string()))?;
49 let write_pool = db
50 .write_pool_arc()
51 .map_err(|err| AuthzError::Validation(err.to_string()))?;
52 Ok(Self { pool, write_pool })
53 }
54
55 pub fn from_pool(pool: Arc<PgPool>) -> Self {
56 let write_pool = Arc::clone(&pool);
57 Self { pool, write_pool }
58 }
59
60 pub async fn list_rules_for_entity(
61 &self,
62 entity_type: EntityKind,
63 entity_id: &str,
64 ) -> AuthzResult<Vec<AccessRule>> {
65 let rows = sqlx::query!(
66 r#"
67 SELECT id, rule_type, rule_value, access, default_included, justification
68 FROM access_control_rules
69 WHERE entity_type = $1 AND entity_id = $2
70 ORDER BY rule_type, rule_value
71 "#,
72 entity_type.as_str(),
73 entity_id,
74 )
75 .fetch_all(&*self.pool)
76 .await?;
77
78 let mut out = Vec::with_capacity(rows.len());
79 for row in rows {
80 if is_sentinel(&row.rule_type, &row.rule_value) {
81 continue;
82 }
83 out.push(AccessRule {
84 id: RuleId::new(row.id),
85 rule_type: RuleType::from_str(&row.rule_type)?,
86 rule_value: row.rule_value,
87 access: Access::from_str(&row.access)?,
88 default_included: row.default_included,
89 justification: row.justification,
90 });
91 }
92 Ok(out)
93 }
94
95 pub async fn list_rules_bulk(
96 &self,
97 entity_type: EntityKind,
98 entity_ids: &[String],
99 ) -> AuthzResult<HashMap<String, Vec<AccessRule>>> {
100 let mut out: HashMap<String, Vec<AccessRule>> = HashMap::with_capacity(entity_ids.len());
101 for id in entity_ids {
102 out.entry(id.clone()).or_default();
103 }
104 if entity_ids.is_empty() {
105 return Ok(out);
106 }
107
108 let rows = sqlx::query!(
109 r#"
110 SELECT entity_id, id, rule_type, rule_value, access, default_included, justification
111 FROM access_control_rules
112 WHERE entity_type = $1 AND entity_id = ANY($2)
113 ORDER BY entity_id, rule_type, rule_value
114 "#,
115 entity_type.as_str(),
116 entity_ids,
117 )
118 .fetch_all(&*self.pool)
119 .await?;
120
121 for row in rows {
122 if is_sentinel(&row.rule_type, &row.rule_value) {
123 continue;
124 }
125 let rule = AccessRule {
126 id: RuleId::new(row.id),
127 rule_type: RuleType::from_str(&row.rule_type)?,
128 rule_value: row.rule_value,
129 access: Access::from_str(&row.access)?,
130 default_included: row.default_included,
131 justification: row.justification,
132 };
133 out.entry(row.entity_id).or_default().push(rule);
134 }
135 Ok(out)
136 }
137
138 pub async fn upsert_rule(&self, params: UpsertRuleParams<'_>) -> AuthzResult<AccessRule> {
139 let id = RuleId::generate();
140 let rule_type_str = params.rule_type.to_string();
141 let access_str = params.access.to_string();
142 let row = sqlx::query!(
143 r#"
144 INSERT INTO access_control_rules
145 (id, entity_type, entity_id, rule_type, rule_value, access, default_included, justification)
146 VALUES ($1, $2, $3, $4, $5, $6, false, $7)
147 ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
148 DO UPDATE SET
149 access = EXCLUDED.access,
150 justification = COALESCE(EXCLUDED.justification, access_control_rules.justification),
151 updated_at = NOW()
152 RETURNING id, rule_type, rule_value, access, default_included, justification
153 "#,
154 id.as_str(),
155 params.entity_type.as_str(),
156 params.entity_id,
157 rule_type_str,
158 params.rule_value,
159 access_str,
160 params.justification,
161 )
162 .fetch_one(&*self.write_pool)
163 .await?;
164
165 Ok(AccessRule {
166 id: RuleId::new(row.id),
167 rule_type: RuleType::from_str(&row.rule_type)?,
168 rule_value: row.rule_value,
169 access: Access::from_str(&row.access)?,
170 default_included: row.default_included,
171 justification: row.justification,
172 })
173 }
174
175 pub async fn set_justification(
178 &self,
179 rule_id: &RuleId,
180 justification: Option<&str>,
181 ) -> AuthzResult<bool> {
182 let result = sqlx::query!(
183 r#"UPDATE access_control_rules SET justification = $2, updated_at = NOW() WHERE id = $1"#,
184 rule_id.as_str(),
185 justification,
186 )
187 .execute(&*self.write_pool)
188 .await?;
189 Ok(result.rows_affected() > 0)
190 }
191
192 pub async fn delete_rule(&self, rule_id: &RuleId) -> AuthzResult<bool> {
193 let result = sqlx::query!(
194 r#"DELETE FROM access_control_rules WHERE id = $1"#,
195 rule_id.as_str(),
196 )
197 .execute(&*self.write_pool)
198 .await?;
199 Ok(result.rows_affected() > 0)
200 }
201
202 pub async fn set_default_included(
203 &self,
204 entity_type: EntityKind,
205 entity_id: &str,
206 value: bool,
207 ) -> AuthzResult<()> {
208 if value {
209 let id = RuleId::generate();
210 sqlx::query!(
211 r#"
212 INSERT INTO access_control_rules
213 (id, entity_type, entity_id, rule_type, rule_value, access, default_included)
214 VALUES ($1, $2, $3, 'role', $4, 'allow', true)
215 ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
216 DO UPDATE SET default_included = true, updated_at = NOW()
217 "#,
218 id.as_str(),
219 entity_type.as_str(),
220 entity_id,
221 DEFAULT_SENTINEL_VALUE,
222 )
223 .execute(&*self.write_pool)
224 .await?;
225 } else {
226 sqlx::query!(
227 r#"
228 DELETE FROM access_control_rules
229 WHERE entity_type = $1
230 AND entity_id = $2
231 AND rule_type = 'role'
232 AND rule_value = $3
233 "#,
234 entity_type.as_str(),
235 entity_id,
236 DEFAULT_SENTINEL_VALUE,
237 )
238 .execute(&*self.write_pool)
239 .await?;
240 }
241 Ok(())
242 }
243
244 pub async fn get_default_included(
245 &self,
246 entity_type: EntityKind,
247 entity_id: &str,
248 ) -> AuthzResult<bool> {
249 let row = sqlx::query!(
250 r#"
251 SELECT default_included FROM access_control_rules
252 WHERE entity_type = $1
253 AND entity_id = $2
254 AND rule_type = 'role'
255 AND rule_value = $3
256 "#,
257 entity_type.as_str(),
258 entity_id,
259 DEFAULT_SENTINEL_VALUE,
260 )
261 .fetch_optional(&*self.pool)
262 .await?;
263 Ok(row.is_some_and(|r| r.default_included))
264 }
265}
266
267fn is_sentinel(rule_type: &str, rule_value: &str) -> bool {
268 rule_type == "role" && rule_value == DEFAULT_SENTINEL_VALUE
269}