sqlite_graphrag/
memory_source.rs1use crate::errors::AppError;
27use serde::{Deserialize, Serialize};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum MemorySource {
40 Agent,
42 User,
44 System,
46 Import,
48 Sync,
50}
51
52impl MemorySource {
53 pub const fn as_str(self) -> &'static str {
58 match self {
59 Self::Agent => "agent",
60 Self::User => "user",
61 Self::System => "system",
62 Self::Import => "import",
63 Self::Sync => "sync",
64 }
65 }
66
67 pub const ALL: &'static [MemorySource] = &[
69 Self::Agent,
70 Self::User,
71 Self::System,
72 Self::Import,
73 Self::Sync,
74 ];
75}
76
77impl std::fmt::Display for MemorySource {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.write_str(self.as_str())
80 }
81}
82
83impl TryFrom<&str> for MemorySource {
91 type Error = AppError;
92
93 fn try_from(value: &str) -> Result<Self, Self::Error> {
94 match value {
95 "agent" => Ok(Self::Agent),
96 "user" => Ok(Self::User),
97 "system" => Ok(Self::System),
98 "import" => Ok(Self::Import),
99 "sync" => Ok(Self::Sync),
100 other => Err(AppError::Validation(format!(
101 "invalid memory source: {other:?}; expected one of {}",
102 Self::ALL
103 .iter()
104 .map(|v| v.as_str())
105 .collect::<Vec<_>>()
106 .join(", ")
107 ))),
108 }
109 }
110}
111
112impl TryFrom<String> for MemorySource {
113 type Error = AppError;
114
115 fn try_from(value: String) -> Result<Self, Self::Error> {
116 Self::try_from(value.as_str())
117 }
118}
119
120pub fn validate_source(raw: &str) -> Result<&'static str, AppError> {
133 match raw {
134 "agent" => Ok("agent"),
135 "user" => Ok("user"),
136 "system" => Ok("system"),
137 "import" => Ok("import"),
138 "sync" => Ok("sync"),
139 other => Err(AppError::Validation(format!(
140 "invalid memory source: {other:?}; expected one of {}",
141 MemorySource::ALL
142 .iter()
143 .map(|v| v.as_str())
144 .collect::<Vec<_>>()
145 .join(", ")
146 ))),
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn as_str_returns_canonical_lowercase() {
156 assert_eq!(MemorySource::Agent.as_str(), "agent");
157 assert_eq!(MemorySource::User.as_str(), "user");
158 assert_eq!(MemorySource::System.as_str(), "system");
159 assert_eq!(MemorySource::Import.as_str(), "import");
160 assert_eq!(MemorySource::Sync.as_str(), "sync");
161 }
162
163 #[test]
164 fn try_from_valid_strings_succeeds() {
165 assert_eq!(
166 MemorySource::try_from("agent").unwrap(),
167 MemorySource::Agent
168 );
169 assert_eq!(MemorySource::try_from("user").unwrap(), MemorySource::User);
170 assert_eq!(
171 MemorySource::try_from("system").unwrap(),
172 MemorySource::System
173 );
174 assert_eq!(
175 MemorySource::try_from("import").unwrap(),
176 MemorySource::Import
177 );
178 assert_eq!(MemorySource::try_from("sync").unwrap(), MemorySource::Sync);
179 }
180
181 #[test]
182 fn try_from_invalid_string_returns_err() {
183 let err = MemorySource::try_from("enrich").unwrap_err();
185 let msg = format!("{err}");
186 assert!(msg.contains("invalid memory source"), "got: {msg}");
187 assert!(msg.contains("\"enrich\""), "got: {msg}");
188 assert!(msg.contains("agent"), "must list agent as valid: {msg}");
189 }
190
191 #[test]
192 fn try_from_empty_string_returns_err() {
193 assert!(MemorySource::try_from("").is_err());
194 }
195
196 #[test]
197 fn try_from_string_owned_works() {
198 let src: MemorySource = String::from("agent").try_into().unwrap();
199 assert_eq!(src, MemorySource::Agent);
200 }
201
202 #[test]
203 fn display_matches_as_str() {
204 for v in MemorySource::ALL {
205 assert_eq!(format!("{v}"), v.as_str());
206 }
207 }
208
209 #[test]
210 fn serialize_round_trip_preserves_variant() {
211 let v = MemorySource::Import;
212 let json = serde_json::to_string(&v).unwrap();
213 assert_eq!(json, "\"import\"");
214 let back: MemorySource = serde_json::from_str(&json).unwrap();
215 assert_eq!(back, v);
216 }
217
218 #[test]
219 fn all_slice_has_exactly_five_variants() {
220 assert_eq!(MemorySource::ALL.len(), 5);
221 }
222}