sqlite_graphrag/
entity_type.rs1use crate::errors::AppError;
9
10#[derive(
15 Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, clap::ValueEnum,
16)]
17#[serde(rename_all = "snake_case")]
18#[clap(rename_all = "snake_case")]
19pub enum EntityType {
20 Concept,
21 Date,
22 Dashboard,
23 Decision,
24 File,
25 Incident,
26 IssueTracker,
27 Location,
28 Memory,
29 Organization,
30 Person,
31 Project,
32 Tool,
33}
34
35impl EntityType {
36 pub fn as_str(self) -> &'static str {
38 match self {
39 EntityType::Concept => "concept",
40 EntityType::Date => "date",
41 EntityType::Dashboard => "dashboard",
42 EntityType::Decision => "decision",
43 EntityType::File => "file",
44 EntityType::Incident => "incident",
45 EntityType::IssueTracker => "issue_tracker",
46 EntityType::Location => "location",
47 EntityType::Memory => "memory",
48 EntityType::Organization => "organization",
49 EntityType::Person => "person",
50 EntityType::Project => "project",
51 EntityType::Tool => "tool",
52 }
53 }
54}
55
56impl std::fmt::Display for EntityType {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.write_str(self.as_str())
59 }
60}
61
62impl std::str::FromStr for EntityType {
63 type Err = AppError;
64
65 fn from_str(s: &str) -> Result<Self, Self::Err> {
66 match s.to_lowercase().as_str() {
67 "concept" => Ok(EntityType::Concept),
68 "date" => Ok(EntityType::Date),
69 "dashboard" => Ok(EntityType::Dashboard),
70 "decision" => Ok(EntityType::Decision),
71 "file" => Ok(EntityType::File),
72 "incident" => Ok(EntityType::Incident),
73 "issue_tracker" => Ok(EntityType::IssueTracker),
74 "location" => Ok(EntityType::Location),
75 "memory" => Ok(EntityType::Memory),
76 "organization" => Ok(EntityType::Organization),
77 "person" => Ok(EntityType::Person),
78 "project" => Ok(EntityType::Project),
79 "tool" => Ok(EntityType::Tool),
80 other => Err(AppError::Validation(format!(
81 "invalid entity type: {other}; expected one of: concept, date, dashboard, decision, file, incident, issue_tracker, location, memory, organization, person, project, tool"
82 ))),
83 }
84 }
85}
86
87impl rusqlite::types::FromSql for EntityType {
88 fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
89 let s = String::column_result(value)?;
90 s.parse::<EntityType>().map_err(|e| {
91 rusqlite::types::FromSqlError::Other(Box::new(std::io::Error::other(e.to_string())))
92 })
93 }
94}
95
96impl rusqlite::types::ToSql for EntityType {
97 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
98 Ok(rusqlite::types::ToSqlOutput::from(self.as_str()))
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn from_str_lowercase_roundtrip() {
108 assert_eq!("person".parse::<EntityType>().unwrap(), EntityType::Person);
109 assert_eq!(
110 "organization".parse::<EntityType>().unwrap(),
111 EntityType::Organization
112 );
113 assert_eq!(
114 "issue_tracker".parse::<EntityType>().unwrap(),
115 EntityType::IssueTracker
116 );
117 }
118
119 #[test]
120 fn from_str_uppercase_is_case_insensitive() {
121 assert_eq!("PERSON".parse::<EntityType>().unwrap(), EntityType::Person);
122 assert_eq!(
123 "Organization".parse::<EntityType>().unwrap(),
124 EntityType::Organization
125 );
126 }
127
128 #[test]
129 fn from_str_invalid_returns_err() {
130 let result = "invalid".parse::<EntityType>();
131 assert!(result.is_err());
132 let msg = result.unwrap_err().to_string();
133 assert!(msg.contains("invalid entity type"));
134 }
135
136 #[test]
137 fn as_str_returns_canonical_lowercase() {
138 assert_eq!(EntityType::Person.as_str(), "person");
139 assert_eq!(EntityType::IssueTracker.as_str(), "issue_tracker");
140 }
141
142 #[test]
143 fn serde_json_serializes_as_lowercase_string() {
144 let json = serde_json::to_string(&EntityType::Person).unwrap();
145 assert_eq!(json, "\"person\"");
146 let json = serde_json::to_string(&EntityType::IssueTracker).unwrap();
147 assert_eq!(json, "\"issue_tracker\"");
148 }
149
150 #[test]
151 fn serde_json_deserializes_from_lowercase_string() {
152 let et: EntityType = serde_json::from_str("\"person\"").unwrap();
153 assert_eq!(et, EntityType::Person);
154 }
155}