Skip to main content

systemprompt_security/authz/repository/
entities.rs

1use std::str::FromStr;
2
3use super::AccessControlRepository;
4use crate::authz::error::AuthzResult;
5use crate::authz::types::{EntityKind, EntityRow};
6
7impl AccessControlRepository {
8    /// `Ok(None)` means the entity has no catalog row at all (publish-pipeline
9    /// bootstrap gap) — the resolver turns this into
10    /// [`crate::authz::DenyReason::UnknownEntity`].
11    pub async fn get_entity(
12        &self,
13        entity_type: EntityKind,
14        entity_id: &str,
15    ) -> AuthzResult<Option<EntityRow>> {
16        let row = sqlx::query!(
17            r#"
18            SELECT entity_type, entity_id, default_included, source
19            FROM access_control_entities
20            WHERE entity_type = $1 AND entity_id = $2
21            "#,
22            entity_type.as_str(),
23            entity_id,
24        )
25        .fetch_optional(&*self.pool)
26        .await?;
27
28        let Some(row) = row else {
29            return Ok(None);
30        };
31        Ok(Some(EntityRow {
32            kind: EntityKind::from_str(&row.entity_type)?,
33            id: row.entity_id,
34            default_included: row.default_included,
35            source: row.source,
36        }))
37    }
38
39    /// Overwrites `default_included` and `source` on conflict so the most
40    /// recent bootstrap pass wins — the publish pipeline is the source of
41    /// truth and runs ahead of YAML grant ingestion.
42    pub async fn upsert_entity(
43        &self,
44        entity_type: EntityKind,
45        entity_id: &str,
46        default_included: bool,
47        source: &str,
48    ) -> AuthzResult<()> {
49        sqlx::query!(
50            r#"
51            INSERT INTO access_control_entities (entity_type, entity_id, default_included, source)
52            VALUES ($1, $2, $3, $4)
53            ON CONFLICT (entity_type, entity_id) DO UPDATE
54            SET default_included = EXCLUDED.default_included,
55                source = EXCLUDED.source,
56                updated_at = NOW()
57            "#,
58            entity_type.as_str(),
59            entity_id,
60            default_included,
61            source,
62        )
63        .execute(&*self.write_pool)
64        .await?;
65        Ok(())
66    }
67
68    /// One statement for the whole batch, instead of `ids.len()` awaits of
69    /// [`Self::upsert_entity`]; all rows share one `default_included` and
70    /// `source`.
71    pub async fn upsert_entities(
72        &self,
73        entity_type: EntityKind,
74        ids: &[&str],
75        default_included: bool,
76        source: &str,
77    ) -> AuthzResult<()> {
78        if ids.is_empty() {
79            return Ok(());
80        }
81        let ids_owned: Vec<String> = ids.iter().map(|id| (*id).to_owned()).collect();
82        sqlx::query!(
83            r#"
84            INSERT INTO access_control_entities (entity_type, entity_id, default_included, source)
85            SELECT $1, id, $3, $4
86            FROM UNNEST($2::text[]) AS id
87            ON CONFLICT (entity_type, entity_id) DO UPDATE
88            SET default_included = EXCLUDED.default_included,
89                source = EXCLUDED.source,
90                updated_at = NOW()
91            "#,
92            entity_type.as_str(),
93            &ids_owned,
94            default_included,
95            source,
96        )
97        .execute(&*self.write_pool)
98        .await?;
99        Ok(())
100    }
101
102    pub async fn list_entities(&self, entity_type: EntityKind) -> AuthzResult<Vec<EntityRow>> {
103        let rows = sqlx::query!(
104            r#"
105            SELECT entity_type, entity_id, default_included, source
106            FROM access_control_entities
107            WHERE entity_type = $1
108            ORDER BY entity_id
109            "#,
110            entity_type.as_str(),
111        )
112        .fetch_all(&*self.pool)
113        .await?;
114
115        let mut out = Vec::with_capacity(rows.len());
116        for row in rows {
117            out.push(EntityRow {
118                kind: EntityKind::from_str(&row.entity_type)?,
119                id: row.entity_id,
120                default_included: row.default_included,
121                source: row.source,
122            });
123        }
124        Ok(out)
125    }
126}