Skip to main content

sqlite_graphrag/commands/
memory_entities.rs

1//! Handler for the `memory-entities` CLI subcommand.
2
3use crate::errors::AppError;
4use crate::output;
5use crate::paths::AppPaths;
6use crate::storage::connection::open_ro;
7use rusqlite::params;
8use serde::Serialize;
9
10#[derive(clap::Args)]
11#[command(
12    about = "List entities linked to a specific memory",
13    after_long_help = "EXAMPLES:\n  \
14    # List entities connected to a memory\n  \
15    sqlite-graphrag memory-entities --name my-memory\n\n  \
16    # With namespace\n  \
17    sqlite-graphrag memory-entities --name my-memory --namespace project"
18)]
19pub struct MemoryEntitiesArgs {
20    #[arg(value_name = "NAME", conflicts_with = "name", help = "Memory name")]
21    pub name_positional: Option<String>,
22    #[arg(long)]
23    pub name: Option<String>,
24    #[arg(
25        long,
26        help = "Namespace (env: SQLITE_GRAPHRAG_NAMESPACE, default: global)"
27    )]
28    pub namespace: Option<String>,
29    #[arg(long, hide = true)]
30    pub json: bool,
31    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
32    pub db: Option<String>,
33}
34
35#[derive(Serialize)]
36struct EntityBinding {
37    entity_id: i64,
38    name: String,
39    entity_type: String,
40}
41
42#[derive(Serialize)]
43struct MemoryEntitiesResponse {
44    memory_name: String,
45    entities: Vec<EntityBinding>,
46    count: usize,
47    elapsed_ms: u64,
48}
49
50pub fn run(args: MemoryEntitiesArgs) -> Result<(), AppError> {
51    let start = std::time::Instant::now();
52    let name = args.name_positional.or(args.name).ok_or_else(|| {
53        AppError::Validation("name required: pass as positional argument or via --name".to_string())
54    })?;
55    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
56    let paths = AppPaths::resolve(args.db.as_deref())?;
57    crate::storage::connection::ensure_db_ready(&paths)?;
58    let conn = open_ro(&paths.db)?;
59
60    let memory_id: i64 = conn
61        .query_row(
62            "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
63            params![namespace, name],
64            |r| r.get(0),
65        )
66        .map_err(|_| {
67            AppError::NotFound(crate::i18n::errors_msg::memory_not_found(&name, &namespace))
68        })?;
69
70    let mut stmt = conn.prepare(
71        "SELECT e.id, e.name, e.type AS entity_type
72         FROM memory_entities me
73         JOIN entities e ON e.id = me.entity_id
74         WHERE me.memory_id = ?1
75         ORDER BY e.name",
76    )?;
77
78    let entities: Vec<EntityBinding> = stmt
79        .query_map(params![memory_id], |r| {
80            Ok(EntityBinding {
81                entity_id: r.get(0)?,
82                name: r.get(1)?,
83                entity_type: r.get(2)?,
84            })
85        })?
86        .collect::<Result<Vec<_>, _>>()?;
87
88    let count = entities.len();
89
90    output::emit_json(&MemoryEntitiesResponse {
91        memory_name: name,
92        entities,
93        count,
94        elapsed_ms: start.elapsed().as_millis() as u64,
95    })?;
96
97    Ok(())
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn response_serializes_correctly() {
106        let resp = MemoryEntitiesResponse {
107            memory_name: "test-mem".to_string(),
108            entities: vec![EntityBinding {
109                entity_id: 1,
110                name: "rust".to_string(),
111                entity_type: "concept".to_string(),
112            }],
113            count: 1,
114            elapsed_ms: 5,
115        };
116        let json = serde_json::to_value(&resp).unwrap();
117        assert_eq!(json["memory_name"], "test-mem");
118        assert_eq!(json["count"], 1);
119        assert_eq!(json["entities"][0]["name"], "rust");
120    }
121}