1use std::collections::HashMap;
12use std::str::FromStr;
13use std::sync::Arc;
14
15use sqlx::PgPool;
16use systemprompt_database::DbPool;
17use systemprompt_identifiers::RuleId;
18
19use super::error::{AuthzError, AuthzResult};
20use super::types::{Access, AccessRule, EntityKind, RuleType};
21
22const DEFAULT_SENTINEL_VALUE: &str = "__default__";
23
24#[derive(Debug, Clone)]
25pub struct ExportRuleRow {
26 pub entity_type: String,
27 pub entity_id: String,
28 pub rule_type: String,
29 pub rule_value: String,
30 pub access: String,
31 pub justification: Option<String>,
32}
33
34#[derive(Debug, Clone, Copy)]
35pub struct UpsertRuleParams<'a> {
36 pub entity_type: EntityKind,
37 pub entity_id: &'a str,
38 pub rule_type: RuleType,
39 pub rule_value: &'a str,
40 pub access: Access,
41 pub justification: Option<&'a str>,
46}
47
48#[derive(Clone, Debug)]
49pub struct AccessControlRepository {
50 pool: Arc<PgPool>,
51 write_pool: Arc<PgPool>,
52}
53
54impl AccessControlRepository {
55 pub fn new(db: &DbPool) -> AuthzResult<Self> {
56 let pool = db
57 .pool_arc()
58 .map_err(|err| AuthzError::Validation(err.to_string()))?;
59 let write_pool = db
60 .write_pool_arc()
61 .map_err(|err| AuthzError::Validation(err.to_string()))?;
62 Ok(Self { pool, write_pool })
63 }
64
65 pub fn from_pool(pool: Arc<PgPool>) -> Self {
66 let write_pool = Arc::clone(&pool);
67 Self { pool, write_pool }
68 }
69
70 pub async fn list_role_department_rules_for_export(&self) -> AuthzResult<Vec<ExportRuleRow>> {
71 let rows = sqlx::query_as!(
72 ExportRuleRow,
73 r#"
74 SELECT entity_type, entity_id, rule_type, rule_value, access, justification
75 FROM access_control_rules
76 WHERE rule_type IN ('role', 'department')
77 AND rule_value <> '__default__'
78 ORDER BY entity_type, entity_id, access, rule_type, rule_value
79 "#,
80 )
81 .fetch_all(&*self.pool)
82 .await?;
83 Ok(rows)
84 }
85
86 pub async fn list_rules_for_entity(
87 &self,
88 entity_type: EntityKind,
89 entity_id: &str,
90 ) -> AuthzResult<Vec<AccessRule>> {
91 let rows = sqlx::query!(
92 r#"
93 SELECT id, rule_type, rule_value, access, default_included, justification
94 FROM access_control_rules
95 WHERE entity_type = $1 AND entity_id = $2
96 ORDER BY rule_type, rule_value
97 "#,
98 entity_type.as_str(),
99 entity_id,
100 )
101 .fetch_all(&*self.pool)
102 .await?;
103
104 let mut out = Vec::with_capacity(rows.len());
105 for row in rows {
106 if is_sentinel(&row.rule_type, &row.rule_value) {
107 continue;
108 }
109 out.push(AccessRule {
110 id: RuleId::new(row.id),
111 rule_type: RuleType::from_str(&row.rule_type)?,
112 rule_value: row.rule_value,
113 access: Access::from_str(&row.access)?,
114 default_included: row.default_included,
115 justification: row.justification,
116 });
117 }
118 Ok(out)
119 }
120
121 pub async fn list_rules_bulk(
122 &self,
123 entity_type: EntityKind,
124 entity_ids: &[String],
125 ) -> AuthzResult<HashMap<String, Vec<AccessRule>>> {
126 let mut out: HashMap<String, Vec<AccessRule>> = HashMap::with_capacity(entity_ids.len());
127 for id in entity_ids {
128 out.entry(id.clone()).or_default();
129 }
130 if entity_ids.is_empty() {
131 return Ok(out);
132 }
133
134 let rows = sqlx::query!(
135 r#"
136 SELECT entity_id, id, rule_type, rule_value, access, default_included, justification
137 FROM access_control_rules
138 WHERE entity_type = $1 AND entity_id = ANY($2)
139 ORDER BY entity_id, rule_type, rule_value
140 "#,
141 entity_type.as_str(),
142 entity_ids,
143 )
144 .fetch_all(&*self.pool)
145 .await?;
146
147 for row in rows {
148 if is_sentinel(&row.rule_type, &row.rule_value) {
149 continue;
150 }
151 let rule = AccessRule {
152 id: RuleId::new(row.id),
153 rule_type: RuleType::from_str(&row.rule_type)?,
154 rule_value: row.rule_value,
155 access: Access::from_str(&row.access)?,
156 default_included: row.default_included,
157 justification: row.justification,
158 };
159 out.entry(row.entity_id).or_default().push(rule);
160 }
161 Ok(out)
162 }
163
164 pub async fn upsert_rule(&self, params: UpsertRuleParams<'_>) -> AuthzResult<AccessRule> {
165 let id = RuleId::generate();
166 let rule_type_str = params.rule_type.to_string();
167 let access_str = params.access.to_string();
168 let row = sqlx::query!(
169 r#"
170 INSERT INTO access_control_rules
171 (id, entity_type, entity_id, rule_type, rule_value, access, default_included, justification)
172 VALUES ($1, $2, $3, $4, $5, $6, false, $7)
173 ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
174 DO UPDATE SET
175 access = EXCLUDED.access,
176 justification = COALESCE(EXCLUDED.justification, access_control_rules.justification),
177 updated_at = NOW()
178 RETURNING id, rule_type, rule_value, access, default_included, justification
179 "#,
180 id.as_str(),
181 params.entity_type.as_str(),
182 params.entity_id,
183 rule_type_str,
184 params.rule_value,
185 access_str,
186 params.justification,
187 )
188 .fetch_one(&*self.write_pool)
189 .await?;
190
191 Ok(AccessRule {
192 id: RuleId::new(row.id),
193 rule_type: RuleType::from_str(&row.rule_type)?,
194 rule_value: row.rule_value,
195 access: Access::from_str(&row.access)?,
196 default_included: row.default_included,
197 justification: row.justification,
198 })
199 }
200
201 pub async fn set_justification(
203 &self,
204 rule_id: &RuleId,
205 justification: Option<&str>,
206 ) -> AuthzResult<bool> {
207 let result = sqlx::query!(
208 r#"UPDATE access_control_rules SET justification = $2, updated_at = NOW() WHERE id = $1"#,
209 rule_id.as_str(),
210 justification,
211 )
212 .execute(&*self.write_pool)
213 .await?;
214 Ok(result.rows_affected() > 0)
215 }
216
217 pub async fn delete_rule(&self, rule_id: &RuleId) -> AuthzResult<bool> {
218 let result = sqlx::query!(
219 r#"DELETE FROM access_control_rules WHERE id = $1"#,
220 rule_id.as_str(),
221 )
222 .execute(&*self.write_pool)
223 .await?;
224 Ok(result.rows_affected() > 0)
225 }
226
227 pub async fn set_default_included(
228 &self,
229 entity_type: EntityKind,
230 entity_id: &str,
231 value: bool,
232 ) -> AuthzResult<()> {
233 if value {
234 let id = RuleId::generate();
235 sqlx::query!(
236 r#"
237 INSERT INTO access_control_rules
238 (id, entity_type, entity_id, rule_type, rule_value, access, default_included)
239 VALUES ($1, $2, $3, 'role', $4, 'allow', true)
240 ON CONFLICT (entity_type, entity_id, rule_type, rule_value)
241 DO UPDATE SET default_included = true, updated_at = NOW()
242 "#,
243 id.as_str(),
244 entity_type.as_str(),
245 entity_id,
246 DEFAULT_SENTINEL_VALUE,
247 )
248 .execute(&*self.write_pool)
249 .await?;
250 } else {
251 sqlx::query!(
252 r#"
253 DELETE FROM access_control_rules
254 WHERE entity_type = $1
255 AND entity_id = $2
256 AND rule_type = 'role'
257 AND rule_value = $3
258 "#,
259 entity_type.as_str(),
260 entity_id,
261 DEFAULT_SENTINEL_VALUE,
262 )
263 .execute(&*self.write_pool)
264 .await?;
265 }
266 Ok(())
267 }
268
269 pub async fn get_default_included(
270 &self,
271 entity_type: EntityKind,
272 entity_id: &str,
273 ) -> AuthzResult<bool> {
274 let row = sqlx::query!(
275 r#"
276 SELECT default_included FROM access_control_rules
277 WHERE entity_type = $1
278 AND entity_id = $2
279 AND rule_type = 'role'
280 AND rule_value = $3
281 "#,
282 entity_type.as_str(),
283 entity_id,
284 DEFAULT_SENTINEL_VALUE,
285 )
286 .fetch_optional(&*self.pool)
287 .await?;
288 Ok(row.is_some_and(|r| r.default_included))
289 }
290}
291
292fn is_sentinel(rule_type: &str, rule_value: &str) -> bool {
293 rule_type == "role" && rule_value == DEFAULT_SENTINEL_VALUE
294}