Skip to main content

sqlite_graphrag/
entity_type.rs

1//! Canonical entity type taxonomy used across extraction, storage and CLI.
2//!
3//! `EntityType` is the single source of truth for the 13 graph entity kinds.
4//! It derives `clap::ValueEnum` so CLI flags can use it directly, and derives
5//! `serde::{Serialize, Deserialize}` with `rename_all = "lowercase"` so JSON
6//! round-trips remain backward-compatible with the pre-enum string format.
7
8use crate::errors::AppError;
9
10/// The 13 canonical graph entity classifications.
11///
12/// Values are serialized as lowercase strings (`"person"`, `"organization"`,
13/// etc.) matching the pre-enum wire format and the SQLite `type` column.
14#[derive(
15    Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, clap::ValueEnum,
16)]
17#[serde(rename_all = "snake_case")]
18#[clap(rename_all = "snake_case")]
19pub enum EntityType {
20    Concept,
21    Date,
22    Dashboard,
23    Decision,
24    File,
25    Incident,
26    IssueTracker,
27    Location,
28    Memory,
29    Organization,
30    Person,
31    Project,
32    Tool,
33}
34
35impl EntityType {
36    /// Returns the canonical lowercase string representation stored in SQLite.
37    pub fn as_str(self) -> &'static str {
38        match self {
39            EntityType::Concept => "concept",
40            EntityType::Date => "date",
41            EntityType::Dashboard => "dashboard",
42            EntityType::Decision => "decision",
43            EntityType::File => "file",
44            EntityType::Incident => "incident",
45            EntityType::IssueTracker => "issue_tracker",
46            EntityType::Location => "location",
47            EntityType::Memory => "memory",
48            EntityType::Organization => "organization",
49            EntityType::Person => "person",
50            EntityType::Project => "project",
51            EntityType::Tool => "tool",
52        }
53    }
54}
55
56impl std::fmt::Display for EntityType {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.write_str(self.as_str())
59    }
60}
61
62impl std::str::FromStr for EntityType {
63    type Err = AppError;
64
65    fn from_str(s: &str) -> Result<Self, Self::Err> {
66        match s.to_lowercase().as_str() {
67            "concept" => Ok(EntityType::Concept),
68            "date" => Ok(EntityType::Date),
69            "dashboard" => Ok(EntityType::Dashboard),
70            "decision" => Ok(EntityType::Decision),
71            "file" => Ok(EntityType::File),
72            "incident" => Ok(EntityType::Incident),
73            "issue_tracker" => Ok(EntityType::IssueTracker),
74            "location" => Ok(EntityType::Location),
75            "memory" => Ok(EntityType::Memory),
76            "organization" => Ok(EntityType::Organization),
77            "person" => Ok(EntityType::Person),
78            "project" => Ok(EntityType::Project),
79            "tool" => Ok(EntityType::Tool),
80            other => Err(AppError::Validation(format!(
81                "invalid entity type: {other}; expected one of: concept, date, dashboard, decision, file, incident, issue_tracker, location, memory, organization, person, project, tool"
82            ))),
83        }
84    }
85}
86
87impl rusqlite::types::FromSql for EntityType {
88    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
89        let s = String::column_result(value)?;
90        s.parse::<EntityType>().map_err(|e| {
91            rusqlite::types::FromSqlError::Other(Box::new(std::io::Error::other(e.to_string())))
92        })
93    }
94}
95
96impl rusqlite::types::ToSql for EntityType {
97    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
98        Ok(rusqlite::types::ToSqlOutput::from(self.as_str()))
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn from_str_lowercase_roundtrip() {
108        assert_eq!("person".parse::<EntityType>().unwrap(), EntityType::Person);
109        assert_eq!(
110            "organization".parse::<EntityType>().unwrap(),
111            EntityType::Organization
112        );
113        assert_eq!(
114            "issue_tracker".parse::<EntityType>().unwrap(),
115            EntityType::IssueTracker
116        );
117    }
118
119    #[test]
120    fn from_str_uppercase_is_case_insensitive() {
121        assert_eq!("PERSON".parse::<EntityType>().unwrap(), EntityType::Person);
122        assert_eq!(
123            "Organization".parse::<EntityType>().unwrap(),
124            EntityType::Organization
125        );
126    }
127
128    #[test]
129    fn from_str_invalid_returns_err() {
130        let result = "invalid".parse::<EntityType>();
131        assert!(result.is_err());
132        let msg = result.unwrap_err().to_string();
133        assert!(msg.contains("invalid entity type"));
134    }
135
136    #[test]
137    fn as_str_returns_canonical_lowercase() {
138        assert_eq!(EntityType::Person.as_str(), "person");
139        assert_eq!(EntityType::IssueTracker.as_str(), "issue_tracker");
140    }
141
142    #[test]
143    fn serde_json_serializes_as_lowercase_string() {
144        let json = serde_json::to_string(&EntityType::Person).unwrap();
145        assert_eq!(json, "\"person\"");
146        let json = serde_json::to_string(&EntityType::IssueTracker).unwrap();
147        assert_eq!(json, "\"issue_tracker\"");
148    }
149
150    #[test]
151    fn serde_json_deserializes_from_lowercase_string() {
152        let et: EntityType = serde_json::from_str("\"person\"").unwrap();
153        assert_eq!(et, EntityType::Person);
154    }
155}