1use std::collections::HashMap;
12use std::str::FromStr;
13use std::sync::Arc;
14
15use sqlx::PgPool;
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::RuleId;
18
19use super::error::{AuthzError, AuthzResult};
20use super::types::{Access, AccessRule, EntityKind, RuleType};
21
22const DEFAULT_SENTINEL_VALUE: &str = "__default__";
23
24#[derive(Debug, Clone)]
25pub struct ExportRuleRow {
26 pub entity_type: String,
27 pub entity_id: String,
28 pub rule_type: String,
29 pub rule_value: String,
30 pub access: String,
31 pub justification: Option<String>,
32}
33
34#[derive(Debug, Clone, Copy)]
35pub struct UpsertRuleParams<'a> {
36 pub entity_type: EntityKind,
37 pub entity_id: &'a str,
38 pub rule_type: RuleType,
39 pub rule_value: &'a str,
40 pub access: Access,
41 pub justification: Option<&'a str>,
46}
47
48#[derive(Clone, Debug)]
49pub struct AccessControlRepository {
50 pool: Arc<PgPool>,
51 write_pool: Arc<PgPool>,
52}
53
54impl AccessControlRepository {
55 pub fn new(db: &DbPool) -> AuthzResult<Self> {
56 let pool = db
57 .pool_arc()
58 .map_err(|err| AuthzError::Validation(err.to_string()))?;
59 let write_pool = db
60 .write_pool_arc()
61 .map_err(|err| AuthzError::Validation(err.to_string()))?;
62 Ok(Self { pool, write_pool })
63 }
64
65 pub fn from_pool(pool: Arc<PgPool>) -> Self {
66 let write_pool = Arc::clone(&pool);
67 Self { pool, write_pool }
68 }
69
70 pub async fn list_role_department_rules_for_export(&self) -> AuthzResult<Vec<ExportRuleRow>> {
71 let rows = sqlx::query_as!(
72 ExportRuleRow,
73 r#"
74 SELECT entity_type, entity_id, rule_type, rule_value, access, justification
75 FROM access_control_rules
76 WHERE rule_type IN ('role', 'department')
77 AND rule_value <> '__default__'
78 ORDER BY entity_type, entity_id, access, rule_type, rule_value
79 "#,
80 )
81 .fetch_all(&*self.pool)
82 .await?;
83 Ok(rows)
84 }
85
86 pub async fn list_rules_for_entity(
87 &self,
88 entity_type: EntityKind,
89 entity_id: &str,
90 ) -> AuthzResult<Vec<AccessRule>> {
91 let rows = sqlx::query!(
92 r#"
93 SELECT id, rule_type, rule_value, access, default_included, justification
94 FROM access_control_rules
95 WHERE entity_type = $1 AND entity_id = $2
96 ORDER BY rule_type, rule_value
97 "#,
98 entity_type.as_str(),
99 entity_id,
100 )
101 .fetch_all(&*self.pool)
102 .await?;
103
104 let mut out = Vec::with_capacity(rows.len());
105 for row in rows {
106 if is_sentinel(&row.rule_type, &row.rule_value) {
107 continue;
108 }
109 out.push(AccessRule {
110 id: RuleId::new(row.id),
111 rule_type: RuleType::from_str(&row.rule_type)?,
112 rule_value: row.rule_value,
113 access: Access::from_str(&row.access)?,
114 default_included: row.default_included,
115 justification: row.justification,
116 });
117 }
118 Ok(out)
119 }
120
121 pub async fn list_rules_bulk(
122 &self,
123 entity_type: EntityKind,
124 entity_ids: &[String],
125 ) -> AuthzResult<HashMap<String, Vec<AccessRule>>> {
126 let mut out: HashMap<String, Vec<AccessRule>> = HashMap::with_capacity(entity_ids.len());
127 for id in entity_ids {
128 out.entry(id.clone()).or_default();
129 }
130 if entity_ids.is_empty() {
131 return Ok(out);
132 }
133
134 let rows = sqlx::query!(
135 r#"
136 SELECT entity_id, id, rule_type, rule_value, access, default_included, justification
137 FROM access_control_rules
138 WHERE entity_type = $1 AND entity_id = ANY($2)
139 ORDER BY entity_id, rule_type, rule_value
140 "#,
141 entity_type.as_str(),
142 entity_ids,
143 )
144 .fetch_all(&*self.pool)
145 .await?;
146
147 for row in rows {
148 if is_sentinel(&row.rule_type, &row.rule_value) {
149 continue;
150 }
151 let rule = AccessRule {
152 id: RuleId::new(row.id),
153 rule_type: RuleType::from_str(&row.rule_type)?,
154 rule_value: row.rule_value,
155 access: Access::from_str(&row.access)?,
156 default_included: row.default_included,
157 justification: row.justification,
158 };
159 out.entry(row.entity_id).or_default().push(rule);
160 }
161 Ok(out)
162 }
163
164 pub async fn upsert_rule(&self, params: UpsertRuleParams<'_>) -> AuthzResult<AccessRule> {
165 let id = RuleId::generate();
166 let rule_type_str = params.rule_type.to_string();
167 let access_str = params.access.to_string();
168 let row = sqlx::query!(
169 r#"
170 INSERT INTO access_control_rules
171 (id, entity_type, entity_id, rule_type, rule_value, access, default_included, justification)
172 VALUES ($1, $2, $3, $4, $5, $6, false, $7)
173 ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
174 DO UPDATE SET
175 access = EXCLUDED.access,
176 justification = COALESCE(EXCLUDED.justification, access_control_rules.justification),
177 updated_at = NOW()
178 RETURNING id, rule_type, rule_value, access, default_included, justification
179 "#,
180 id.as_str(),
181 params.entity_type.as_str(),
182 params.entity_id,
183 rule_type_str,
184 params.rule_value,
185 access_str,
186 params.justification,
187 )
188 .fetch_one(&*self.write_pool)
189 .await?;
190
191 Ok(AccessRule {
192 id: RuleId::new(row.id),
193 rule_type: RuleType::from_str(&row.rule_type)?,
194 rule_value: row.rule_value,
195 access: Access::from_str(&row.access)?,
196 default_included: row.default_included,
197 justification: row.justification,
198 })
199 }
200
201 pub async fn set_justification(
204 &self,
205 rule_id: &RuleId,
206 justification: Option<&str>,
207 ) -> AuthzResult<bool> {
208 let result = sqlx::query!(
209 r#"UPDATE access_control_rules SET justification = $2, updated_at = NOW() WHERE id = $1"#,
210 rule_id.as_str(),
211 justification,
212 )
213 .execute(&*self.write_pool)
214 .await?;
215 Ok(result.rows_affected() > 0)
216 }
217
218 pub async fn delete_rule(&self, rule_id: &RuleId) -> AuthzResult<bool> {
219 let result = sqlx::query!(
220 r#"DELETE FROM access_control_rules WHERE id = $1"#,
221 rule_id.as_str(),
222 )
223 .execute(&*self.write_pool)
224 .await?;
225 Ok(result.rows_affected() > 0)
226 }
227
228 pub async fn set_default_included(
229 &self,
230 entity_type: EntityKind,
231 entity_id: &str,
232 value: bool,
233 ) -> AuthzResult<()> {
234 if value {
235 let id = RuleId::generate();
236 sqlx::query!(
237 r#"
238 INSERT INTO access_control_rules
239 (id, entity_type, entity_id, rule_type, rule_value, access, default_included)
240 VALUES ($1, $2, $3, 'role', $4, 'allow', true)
241 ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
242 DO UPDATE SET default_included = true, updated_at = NOW()
243 "#,
244 id.as_str(),
245 entity_type.as_str(),
246 entity_id,
247 DEFAULT_SENTINEL_VALUE,
248 )
249 .execute(&*self.write_pool)
250 .await?;
251 } else {
252 sqlx::query!(
253 r#"
254 DELETE FROM access_control_rules
255 WHERE entity_type = $1
256 AND entity_id = $2
257 AND rule_type = 'role'
258 AND rule_value = $3
259 "#,
260 entity_type.as_str(),
261 entity_id,
262 DEFAULT_SENTINEL_VALUE,
263 )
264 .execute(&*self.write_pool)
265 .await?;
266 }
267 Ok(())
268 }
269
270 pub async fn get_default_included(
271 &self,
272 entity_type: EntityKind,
273 entity_id: &str,
274 ) -> AuthzResult<bool> {
275 let row = sqlx::query!(
276 r#"
277 SELECT default_included FROM access_control_rules
278 WHERE entity_type = $1
279 AND entity_id = $2
280 AND rule_type = 'role'
281 AND rule_value = $3
282 "#,
283 entity_type.as_str(),
284 entity_id,
285 DEFAULT_SENTINEL_VALUE,
286 )
287 .fetch_optional(&*self.pool)
288 .await?;
289 Ok(row.is_some_and(|r| r.default_included))
290 }
291}
292
293fn is_sentinel(rule_type: &str, rule_value: &str) -> bool {
294 rule_type == "role" && rule_value == DEFAULT_SENTINEL_VALUE
295}