systemprompt_security/authz/repository/
entities.rs1use std::str::FromStr;
2
3use super::AccessControlRepository;
4use crate::authz::error::AuthzResult;
5use crate::authz::types::{EntityKind, EntityRow};
6
7impl AccessControlRepository {
8 pub async fn get_entity(
12 &self,
13 entity_type: EntityKind,
14 entity_id: &str,
15 ) -> AuthzResult<Option<EntityRow>> {
16 let row = sqlx::query!(
17 r#"
18 SELECT entity_type, entity_id, default_included, source
19 FROM access_control_entities
20 WHERE entity_type = $1 AND entity_id = $2
21 "#,
22 entity_type.as_str(),
23 entity_id,
24 )
25 .fetch_optional(&*self.pool)
26 .await?;
27
28 let Some(row) = row else {
29 return Ok(None);
30 };
31 Ok(Some(EntityRow {
32 kind: EntityKind::from_str(&row.entity_type)?,
33 id: row.entity_id,
34 default_included: row.default_included,
35 source: row.source,
36 }))
37 }
38
39 pub async fn upsert_entity(
43 &self,
44 entity_type: EntityKind,
45 entity_id: &str,
46 default_included: bool,
47 source: &str,
48 ) -> AuthzResult<()> {
49 sqlx::query!(
50 r#"
51 INSERT INTO access_control_entities (entity_type, entity_id, default_included, source)
52 VALUES ($1, $2, $3, $4)
53 ON CONFLICT (entity_type, entity_id) DO UPDATE
54 SET default_included = EXCLUDED.default_included,
55 source = EXCLUDED.source,
56 updated_at = NOW()
57 "#,
58 entity_type.as_str(),
59 entity_id,
60 default_included,
61 source,
62 )
63 .execute(&*self.write_pool)
64 .await?;
65 Ok(())
66 }
67
68 pub async fn upsert_entities(
73 &self,
74 entity_type: EntityKind,
75 ids: &[&str],
76 default_included: bool,
77 source: &str,
78 ) -> AuthzResult<()> {
79 if ids.is_empty() {
80 return Ok(());
81 }
82 let ids_owned: Vec<String> = ids.iter().map(|id| (*id).to_owned()).collect();
83 sqlx::query!(
84 r#"
85 INSERT INTO access_control_entities (entity_type, entity_id, default_included, source)
86 SELECT $1, id, $3, $4
87 FROM UNNEST($2::text[]) AS id
88 ON CONFLICT (entity_type, entity_id) DO UPDATE
89 SET default_included = EXCLUDED.default_included,
90 source = EXCLUDED.source,
91 updated_at = NOW()
92 "#,
93 entity_type.as_str(),
94 &ids_owned,
95 default_included,
96 source,
97 )
98 .execute(&*self.write_pool)
99 .await?;
100 Ok(())
101 }
102
103 pub async fn list_entities(&self, entity_type: EntityKind) -> AuthzResult<Vec<EntityRow>> {
107 let rows = sqlx::query!(
108 r#"
109 SELECT entity_type, entity_id, default_included, source
110 FROM access_control_entities
111 WHERE entity_type = $1
112 ORDER BY entity_id
113 "#,
114 entity_type.as_str(),
115 )
116 .fetch_all(&*self.pool)
117 .await?;
118
119 let mut out = Vec::with_capacity(rows.len());
120 for row in rows {
121 out.push(EntityRow {
122 kind: EntityKind::from_str(&row.entity_type)?,
123 id: row.entity_id,
124 default_included: row.default_included,
125 source: row.source,
126 });
127 }
128 Ok(out)
129 }
130}