Skip to main content

sqlite_graphrag/commands/
memory_entities.rs

1//! Handler for the `memory-entities` CLI subcommand.
2
3use crate::errors::AppError;
4use crate::output;
5use crate::paths::AppPaths;
6use crate::storage::connection::open_ro;
7use rusqlite::params;
8use serde::Serialize;
9
10#[derive(clap::Args)]
11#[command(
12    about = "List entities linked to a memory, or memories linked to an entity",
13    after_long_help = "EXAMPLES:\n  \
14    # List entities connected to a memory\n  \
15    sqlite-graphrag memory-entities --name my-memory\n\n  \
16    # Reverse: list memories bound to an entity\n  \
17    sqlite-graphrag memory-entities --entity rust-lang\n\n  \
18    # With namespace\n  \
19    sqlite-graphrag memory-entities --name my-memory --namespace project"
20)]
21pub struct MemoryEntitiesArgs {
22    #[arg(value_name = "NAME", conflicts_with = "name", help = "Memory name")]
23    pub name_positional: Option<String>,
24    #[arg(long, conflicts_with_all = ["entity"])]
25    pub name: Option<String>,
26    /// Entity name — list memories bound to this entity (reverse lookup).
27    #[arg(long, conflicts_with_all = ["name", "name_positional"])]
28    pub entity: Option<String>,
29    #[arg(
30        long,
31        help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
32    )]
33    pub namespace: Option<String>,
34    #[arg(long, hide = true)]
35    pub json: bool,
36    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
37    pub db: Option<String>,
38}
39
40#[derive(Serialize)]
41struct EntityBinding {
42    entity_id: i64,
43    name: String,
44    entity_type: String,
45}
46
47#[derive(Serialize)]
48struct MemoryEntitiesResponse {
49    memory_name: String,
50    entities: Vec<EntityBinding>,
51    count: usize,
52    elapsed_ms: u64,
53}
54
55#[derive(Serialize)]
56struct MemoryBinding {
57    memory_id: i64,
58    name: String,
59    description: String,
60    memory_type: String,
61}
62
63#[derive(Serialize)]
64struct EntityMemoriesResponse {
65    entity_name: String,
66    memories: Vec<MemoryBinding>,
67    count: usize,
68    elapsed_ms: u64,
69}
70
71pub fn run(args: MemoryEntitiesArgs) -> Result<(), AppError> {
72    let start = std::time::Instant::now();
73    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
74    let paths = AppPaths::resolve(args.db.as_deref())?;
75    crate::storage::connection::ensure_db_ready(&paths)?;
76    let conn = open_ro(&paths.db)?;
77
78    if let Some(entity_name) = args.entity {
79        let entity_id = crate::storage::entities::find_entity_id(&conn, &namespace, &entity_name)?
80            .ok_or_else(|| {
81                AppError::NotFound(crate::i18n::errors_msg::entity_not_found(
82                    &entity_name,
83                    &namespace,
84                ))
85            })?;
86
87        let mut stmt = conn.prepare(
88            "SELECT m.id, m.name, m.description, m.type
89             FROM memory_entities me
90             JOIN memories m ON m.id = me.memory_id
91             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL
92             ORDER BY m.name",
93        )?;
94
95        let memories: Vec<MemoryBinding> = stmt
96            .query_map(params![entity_id], |r| {
97                Ok(MemoryBinding {
98                    memory_id: r.get(0)?,
99                    name: r.get(1)?,
100                    description: r.get(2)?,
101                    memory_type: r.get(3)?,
102                })
103            })?
104            .collect::<Result<Vec<_>, _>>()?;
105
106        let count = memories.len();
107        output::emit_json(&EntityMemoriesResponse {
108            entity_name,
109            memories,
110            count,
111            elapsed_ms: start.elapsed().as_millis() as u64,
112        })?;
113        return Ok(());
114    }
115
116    let name = args.name_positional.or(args.name).ok_or_else(|| {
117        AppError::Validation(
118            "name required: pass as positional argument, via --name, or use --entity for reverse lookup".to_string(),
119        )
120    })?;
121
122    let memory_id: i64 = conn
123        .query_row(
124            "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
125            params![namespace, name],
126            |r| r.get(0),
127        )
128        .map_err(|_| {
129            AppError::NotFound(crate::i18n::errors_msg::memory_not_found(&name, &namespace))
130        })?;
131
132    let mut stmt = conn.prepare(
133        "SELECT e.id, e.name, e.type AS entity_type
134         FROM memory_entities me
135         JOIN entities e ON e.id = me.entity_id
136         WHERE me.memory_id = ?1
137         ORDER BY e.name",
138    )?;
139
140    let entities: Vec<EntityBinding> = stmt
141        .query_map(params![memory_id], |r| {
142            Ok(EntityBinding {
143                entity_id: r.get(0)?,
144                name: r.get(1)?,
145                entity_type: r.get(2)?,
146            })
147        })?
148        .collect::<Result<Vec<_>, _>>()?;
149
150    let count = entities.len();
151
152    output::emit_json(&MemoryEntitiesResponse {
153        memory_name: name,
154        entities,
155        count,
156        elapsed_ms: start.elapsed().as_millis() as u64,
157    })?;
158
159    Ok(())
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn response_serializes_correctly() {
168        let resp = MemoryEntitiesResponse {
169            memory_name: "test-mem".to_string(),
170            entities: vec![EntityBinding {
171                entity_id: 1,
172                name: "rust".to_string(),
173                entity_type: "concept".to_string(),
174            }],
175            count: 1,
176            elapsed_ms: 5,
177        };
178        let json = serde_json::to_value(&resp).unwrap();
179        assert_eq!(json["memory_name"], "test-mem");
180        assert_eq!(json["count"], 1);
181        assert_eq!(json["entities"][0]["name"], "rust");
182    }
183
184    #[test]
185    fn entity_memories_response_serializes_correctly() {
186        let resp = EntityMemoriesResponse {
187            entity_name: "rust-lang".to_string(),
188            memories: vec![MemoryBinding {
189                memory_id: 42,
190                name: "design-auth".to_string(),
191                description: "JWT auth design".to_string(),
192                memory_type: "decision".to_string(),
193            }],
194            count: 1,
195            elapsed_ms: 3,
196        };
197        let json = serde_json::to_value(&resp).unwrap();
198        assert_eq!(json["entity_name"], "rust-lang");
199        assert_eq!(json["count"], 1);
200        assert_eq!(json["memories"][0]["name"], "design-auth");
201        assert_eq!(json["memories"][0]["memory_type"], "decision");
202        assert_eq!(json["memories"][0]["memory_id"], 42);
203    }
204
205    #[test]
206    fn entity_memories_response_empty_list() {
207        let resp = EntityMemoriesResponse {
208            entity_name: "orphan-entity".to_string(),
209            memories: vec![],
210            count: 0,
211            elapsed_ms: 1,
212        };
213        let json = serde_json::to_value(&resp).unwrap();
214        assert_eq!(json["entity_name"], "orphan-entity");
215        assert_eq!(json["count"], 0);
216        assert!(json["memories"].as_array().unwrap().is_empty());
217    }
218}