systemprompt_security/authz/
repository.rs1use std::collections::HashMap;
12use std::str::FromStr;
13use std::sync::Arc;
14
15use sqlx::PgPool;
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::RuleId;
18
19use super::error::{AuthzError, AuthzResult};
20use super::types::{Access, AccessRule, EntityKind, EntityRow, RuleType};
21
22#[derive(Debug, Clone)]
23pub struct ExportRuleRow {
24 pub entity_type: String,
25 pub entity_id: String,
26 pub rule_type: String,
27 pub rule_value: String,
28 pub access: String,
29 pub justification: Option<String>,
30}
31
32#[derive(Debug, Clone, Copy)]
33pub struct UpsertRuleParams<'a> {
34 pub entity_type: EntityKind,
35 pub entity_id: &'a str,
36 pub rule_type: RuleType,
37 pub rule_value: &'a str,
38 pub access: Access,
39 pub justification: Option<&'a str>,
43}
44
45#[derive(Clone, Debug)]
46pub struct AccessControlRepository {
47 pool: Arc<PgPool>,
48 write_pool: Arc<PgPool>,
49}
50
51impl AccessControlRepository {
52 pub fn new(db: &DbPool) -> AuthzResult<Self> {
53 let pool = db
54 .pool_arc()
55 .map_err(|err| AuthzError::Validation(err.to_string()))?;
56 let write_pool = db
57 .write_pool_arc()
58 .map_err(|err| AuthzError::Validation(err.to_string()))?;
59 Ok(Self { pool, write_pool })
60 }
61
62 pub fn from_pool(pool: Arc<PgPool>) -> Self {
63 let write_pool = Arc::clone(&pool);
64 Self { pool, write_pool }
65 }
66
67 pub async fn get_entity(
71 &self,
72 entity_type: EntityKind,
73 entity_id: &str,
74 ) -> AuthzResult<Option<EntityRow>> {
75 let row = sqlx::query!(
76 r#"
77 SELECT entity_type, entity_id, default_included, source
78 FROM access_control_entities
79 WHERE entity_type = $1 AND entity_id = $2
80 "#,
81 entity_type.as_str(),
82 entity_id,
83 )
84 .fetch_optional(&*self.pool)
85 .await?;
86
87 let Some(row) = row else {
88 return Ok(None);
89 };
90 Ok(Some(EntityRow {
91 kind: EntityKind::from_str(&row.entity_type)?,
92 id: row.entity_id,
93 default_included: row.default_included,
94 source: row.source,
95 }))
96 }
97
98 pub async fn upsert_entity(
102 &self,
103 entity_type: EntityKind,
104 entity_id: &str,
105 default_included: bool,
106 source: &str,
107 ) -> AuthzResult<()> {
108 sqlx::query!(
109 r#"
110 INSERT INTO access_control_entities (entity_type, entity_id, default_included, source)
111 VALUES ($1, $2, $3, $4)
112 ON CONFLICT (entity_type, entity_id) DO UPDATE
113 SET default_included = EXCLUDED.default_included,
114 source = EXCLUDED.source,
115 updated_at = NOW()
116 "#,
117 entity_type.as_str(),
118 entity_id,
119 default_included,
120 source,
121 )
122 .execute(&*self.write_pool)
123 .await?;
124 Ok(())
125 }
126
127 pub async fn list_entities(&self, entity_type: EntityKind) -> AuthzResult<Vec<EntityRow>> {
131 let rows = sqlx::query!(
132 r#"
133 SELECT entity_type, entity_id, default_included, source
134 FROM access_control_entities
135 WHERE entity_type = $1
136 ORDER BY entity_id
137 "#,
138 entity_type.as_str(),
139 )
140 .fetch_all(&*self.pool)
141 .await?;
142
143 let mut out = Vec::with_capacity(rows.len());
144 for row in rows {
145 out.push(EntityRow {
146 kind: EntityKind::from_str(&row.entity_type)?,
147 id: row.entity_id,
148 default_included: row.default_included,
149 source: row.source,
150 });
151 }
152 Ok(out)
153 }
154
155 pub async fn list_role_department_rules_for_export(&self) -> AuthzResult<Vec<ExportRuleRow>> {
156 let rows = sqlx::query_as!(
157 ExportRuleRow,
158 r#"
159 SELECT entity_type, entity_id, rule_type, rule_value, access, justification
160 FROM access_control_rules
161 WHERE rule_type IN ('role', 'department')
162 ORDER BY entity_type, entity_id, access, rule_type, rule_value
163 "#,
164 )
165 .fetch_all(&*self.pool)
166 .await?;
167 Ok(rows)
168 }
169
170 pub async fn list_rules_for_entity(
171 &self,
172 entity_type: EntityKind,
173 entity_id: &str,
174 ) -> AuthzResult<Vec<AccessRule>> {
175 let rows = sqlx::query!(
176 r#"
177 SELECT id, rule_type, rule_value, access, justification
178 FROM access_control_rules
179 WHERE entity_type = $1 AND entity_id = $2
180 ORDER BY rule_type, rule_value
181 "#,
182 entity_type.as_str(),
183 entity_id,
184 )
185 .fetch_all(&*self.pool)
186 .await?;
187
188 let mut out = Vec::with_capacity(rows.len());
189 for row in rows {
190 out.push(AccessRule {
191 id: RuleId::new(row.id),
192 rule_type: RuleType::from_str(&row.rule_type)?,
193 rule_value: row.rule_value,
194 access: Access::from_str(&row.access)?,
195 justification: row.justification,
196 });
197 }
198 Ok(out)
199 }
200
201 pub async fn list_rules_bulk(
202 &self,
203 entity_type: EntityKind,
204 entity_ids: &[String],
205 ) -> AuthzResult<HashMap<String, Vec<AccessRule>>> {
206 let mut out: HashMap<String, Vec<AccessRule>> = HashMap::with_capacity(entity_ids.len());
207 for id in entity_ids {
208 out.entry(id.clone()).or_default();
209 }
210 if entity_ids.is_empty() {
211 return Ok(out);
212 }
213
214 let rows = sqlx::query!(
215 r#"
216 SELECT entity_id, id, rule_type, rule_value, access, justification
217 FROM access_control_rules
218 WHERE entity_type = $1 AND entity_id = ANY($2)
219 ORDER BY entity_id, rule_type, rule_value
220 "#,
221 entity_type.as_str(),
222 entity_ids,
223 )
224 .fetch_all(&*self.pool)
225 .await?;
226
227 for row in rows {
228 let rule = AccessRule {
229 id: RuleId::new(row.id),
230 rule_type: RuleType::from_str(&row.rule_type)?,
231 rule_value: row.rule_value,
232 access: Access::from_str(&row.access)?,
233 justification: row.justification,
234 };
235 out.entry(row.entity_id).or_default().push(rule);
236 }
237 Ok(out)
238 }
239
240 pub async fn upsert_rule(&self, params: UpsertRuleParams<'_>) -> AuthzResult<AccessRule> {
244 let id = RuleId::generate();
245 let rule_type_str = params.rule_type.to_string();
246 let access_str = params.access.to_string();
247 let row = sqlx::query!(
248 r#"
249 INSERT INTO access_control_rules
250 (id, entity_type, entity_id, rule_type, rule_value, access, justification)
251 VALUES ($1, $2, $3, $4, $5, $6, $7)
252 ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
253 DO UPDATE SET
254 access = EXCLUDED.access,
255 justification = COALESCE(EXCLUDED.justification, access_control_rules.justification),
256 updated_at = NOW()
257 RETURNING id, rule_type, rule_value, access, justification
258 "#,
259 id.as_str(),
260 params.entity_type.as_str(),
261 params.entity_id,
262 rule_type_str,
263 params.rule_value,
264 access_str,
265 params.justification,
266 )
267 .fetch_one(&*self.write_pool)
268 .await?;
269
270 Ok(AccessRule {
271 id: RuleId::new(row.id),
272 rule_type: RuleType::from_str(&row.rule_type)?,
273 rule_value: row.rule_value,
274 access: Access::from_str(&row.access)?,
275 justification: row.justification,
276 })
277 }
278
279 pub async fn set_justification(
281 &self,
282 rule_id: &RuleId,
283 justification: Option<&str>,
284 ) -> AuthzResult<bool> {
285 let result = sqlx::query!(
286 r#"UPDATE access_control_rules SET justification = $2, updated_at = NOW() WHERE id = $1"#,
287 rule_id.as_str(),
288 justification,
289 )
290 .execute(&*self.write_pool)
291 .await?;
292 Ok(result.rows_affected() > 0)
293 }
294
295 pub async fn delete_rule(&self, rule_id: &RuleId) -> AuthzResult<bool> {
296 let result = sqlx::query!(
297 r#"DELETE FROM access_control_rules WHERE id = $1"#,
298 rule_id.as_str(),
299 )
300 .execute(&*self.write_pool)
301 .await?;
302 Ok(result.rows_affected() > 0)
303 }
304}