systemprompt_security/authz/repository/
entities.rs1use std::str::FromStr;
2
3use super::AccessControlRepository;
4use crate::authz::error::AuthzResult;
5use crate::authz::types::{EntityKind, EntityRow};
6
7impl AccessControlRepository {
8 pub async fn get_entity(
12 &self,
13 entity_type: EntityKind,
14 entity_id: &str,
15 ) -> AuthzResult<Option<EntityRow>> {
16 let row = sqlx::query!(
17 r#"
18 SELECT entity_type, entity_id, default_included, source
19 FROM access_control_entities
20 WHERE entity_type = $1 AND entity_id = $2
21 "#,
22 entity_type.as_str(),
23 entity_id,
24 )
25 .fetch_optional(&*self.pool)
26 .await?;
27
28 let Some(row) = row else {
29 return Ok(None);
30 };
31 Ok(Some(EntityRow {
32 kind: EntityKind::from_str(&row.entity_type)?,
33 id: row.entity_id,
34 default_included: row.default_included,
35 source: row.source,
36 }))
37 }
38
39 pub async fn upsert_entity(
43 &self,
44 entity_type: EntityKind,
45 entity_id: &str,
46 default_included: bool,
47 source: &str,
48 ) -> AuthzResult<()> {
49 sqlx::query!(
50 r#"
51 INSERT INTO access_control_entities (entity_type, entity_id, default_included, source)
52 VALUES ($1, $2, $3, $4)
53 ON CONFLICT (entity_type, entity_id) DO UPDATE
54 SET default_included = EXCLUDED.default_included,
55 source = EXCLUDED.source,
56 updated_at = NOW()
57 "#,
58 entity_type.as_str(),
59 entity_id,
60 default_included,
61 source,
62 )
63 .execute(&*self.write_pool)
64 .await?;
65 Ok(())
66 }
67
68 pub async fn upsert_entities(
72 &self,
73 entity_type: EntityKind,
74 ids: &[&str],
75 default_included: bool,
76 source: &str,
77 ) -> AuthzResult<()> {
78 if ids.is_empty() {
79 return Ok(());
80 }
81 let ids_owned: Vec<String> = ids.iter().map(|id| (*id).to_owned()).collect();
82 sqlx::query!(
83 r#"
84 INSERT INTO access_control_entities (entity_type, entity_id, default_included, source)
85 SELECT $1, id, $3, $4
86 FROM UNNEST($2::text[]) AS id
87 ON CONFLICT (entity_type, entity_id) DO UPDATE
88 SET default_included = EXCLUDED.default_included,
89 source = EXCLUDED.source,
90 updated_at = NOW()
91 "#,
92 entity_type.as_str(),
93 &ids_owned,
94 default_included,
95 source,
96 )
97 .execute(&*self.write_pool)
98 .await?;
99 Ok(())
100 }
101
102 pub async fn list_entities(&self, entity_type: EntityKind) -> AuthzResult<Vec<EntityRow>> {
103 let rows = sqlx::query!(
104 r#"
105 SELECT entity_type, entity_id, default_included, source
106 FROM access_control_entities
107 WHERE entity_type = $1
108 ORDER BY entity_id
109 "#,
110 entity_type.as_str(),
111 )
112 .fetch_all(&*self.pool)
113 .await?;
114
115 let mut out = Vec::with_capacity(rows.len());
116 for row in rows {
117 out.push(EntityRow {
118 kind: EntityKind::from_str(&row.entity_type)?,
119 id: row.entity_id,
120 default_included: row.default_included,
121 source: row.source,
122 });
123 }
124 Ok(out)
125 }
126}