Skip to main content

sqlite_graphrag/
entity_type.rs

1//! Canonical entity type taxonomy used across extraction, storage and CLI.
2//!
3//! `EntityType` is the single source of truth for the 13 graph entity kinds.
4//! It derives `clap::ValueEnum` so CLI flags can use it directly, and derives
5//! `serde::{Serialize, Deserialize}` with `rename_all = "lowercase"` so JSON
6//! round-trips remain backward-compatible with the pre-enum string format.
7
8use crate::errors::AppError;
9
10/// The 13 canonical graph entity classifications.
11///
12/// Values are serialized as lowercase strings (`"person"`, `"organization"`,
13/// etc.) matching the pre-enum wire format and the SQLite `type` column.
14#[derive(
15    Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, clap::ValueEnum,
16)]
17#[serde(rename_all = "snake_case")]
18#[clap(rename_all = "snake_case")]
19pub enum EntityType {
20    Concept,
21    Date,
22    Dashboard,
23    Decision,
24    File,
25    Incident,
26    IssueTracker,
27    Location,
28    Memory,
29    Organization,
30    Person,
31    Project,
32    Tool,
33}
34
35impl EntityType {
36    /// Returns the canonical lowercase string representation stored in SQLite.
37    pub fn as_str(self) -> &'static str {
38        match self {
39            EntityType::Concept => "concept",
40            EntityType::Date => "date",
41            EntityType::Dashboard => "dashboard",
42            EntityType::Decision => "decision",
43            EntityType::File => "file",
44            EntityType::Incident => "incident",
45            EntityType::IssueTracker => "issue_tracker",
46            EntityType::Location => "location",
47            EntityType::Memory => "memory",
48            EntityType::Organization => "organization",
49            EntityType::Person => "person",
50            EntityType::Project => "project",
51            EntityType::Tool => "tool",
52        }
53    }
54}
55
56impl std::fmt::Display for EntityType {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.write_str(self.as_str())
59    }
60}
61
62impl std::str::FromStr for EntityType {
63    type Err = AppError;
64
65    fn from_str(s: &str) -> Result<Self, Self::Err> {
66        match s.to_lowercase().as_str() {
67            "concept" => Ok(EntityType::Concept),
68            "date" => Ok(EntityType::Date),
69            "dashboard" => Ok(EntityType::Dashboard),
70            "decision" => Ok(EntityType::Decision),
71            "file" => Ok(EntityType::File),
72            "incident" => Ok(EntityType::Incident),
73            "issue_tracker" => Ok(EntityType::IssueTracker),
74            "location" => Ok(EntityType::Location),
75            "memory" => Ok(EntityType::Memory),
76            "organization" => Ok(EntityType::Organization),
77            "person" => Ok(EntityType::Person),
78            "project" => Ok(EntityType::Project),
79            "tool" => Ok(EntityType::Tool),
80            other => {
81                let hint = match other {
82                    "reference" | "skill" | "note" | "feedback" => Some("concept"),
83                    "document" => Some("file"),
84                    "user" => Some("person"),
85                    _ => None,
86                };
87                let msg = if let Some(suggested) = hint {
88                    format!(
89                        "invalid entity_type '{other}'; '{other}' is a MEMORY type, not an entity type. \
90                         Try '{suggested}' instead. Valid entity types: concept, date, dashboard, \
91                         decision, file, incident, issue_tracker, location, memory, organization, \
92                         person, project, tool"
93                    )
94                } else {
95                    format!(
96                        "invalid entity type: {other}; expected one of: concept, date, dashboard, \
97                         decision, file, incident, issue_tracker, location, memory, organization, \
98                         person, project, tool"
99                    )
100                };
101                Err(AppError::Validation(msg))
102            }
103        }
104    }
105}
106
107impl rusqlite::types::FromSql for EntityType {
108    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
109        let s = String::column_result(value)?;
110        s.parse::<EntityType>().map_err(|e| {
111            rusqlite::types::FromSqlError::Other(Box::new(std::io::Error::other(e.to_string())))
112        })
113    }
114}
115
116impl rusqlite::types::ToSql for EntityType {
117    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
118        Ok(rusqlite::types::ToSqlOutput::from(self.as_str()))
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn from_str_lowercase_roundtrip() {
128        assert_eq!("person".parse::<EntityType>().unwrap(), EntityType::Person);
129        assert_eq!(
130            "organization".parse::<EntityType>().unwrap(),
131            EntityType::Organization
132        );
133        assert_eq!(
134            "issue_tracker".parse::<EntityType>().unwrap(),
135            EntityType::IssueTracker
136        );
137    }
138
139    #[test]
140    fn from_str_uppercase_is_case_insensitive() {
141        assert_eq!("PERSON".parse::<EntityType>().unwrap(), EntityType::Person);
142        assert_eq!(
143            "Organization".parse::<EntityType>().unwrap(),
144            EntityType::Organization
145        );
146    }
147
148    #[test]
149    fn from_str_invalid_returns_err() {
150        let result = "invalid".parse::<EntityType>();
151        assert!(result.is_err());
152        let msg = result.unwrap_err().to_string();
153        assert!(msg.contains("invalid entity type"));
154    }
155
156    #[test]
157    fn as_str_returns_canonical_lowercase() {
158        assert_eq!(EntityType::Person.as_str(), "person");
159        assert_eq!(EntityType::IssueTracker.as_str(), "issue_tracker");
160    }
161
162    #[test]
163    fn serde_json_serializes_as_lowercase_string() {
164        let json = serde_json::to_string(&EntityType::Person).unwrap();
165        assert_eq!(json, "\"person\"");
166        let json = serde_json::to_string(&EntityType::IssueTracker).unwrap();
167        assert_eq!(json, "\"issue_tracker\"");
168    }
169
170    #[test]
171    fn serde_json_deserializes_from_lowercase_string() {
172        let et: EntityType = serde_json::from_str("\"person\"").unwrap();
173        assert_eq!(et, EntityType::Person);
174    }
175}