Skip to main content

selene_graph/
candidate_state.rs

1//! Maintained graph-derived candidate sets.
2//!
3//! This module owns small, policy-neutral maintained node sets for graph/vector
4//! retrieval. A set can require node labels, require incoming/outgoing edge
5//! evidence, and exclude nodes that have disqualifying incoming/outgoing edges.
6//! That is enough to model active/current/unresolved memory subsets without
7//! hard-coding those application labels into the engine.
8
9use std::collections::BTreeSet;
10
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13
14use selene_core::{Change, DbString, NodeId};
15#[cfg(test)]
16use selene_core::{EdgeId, LabelSet};
17
18use crate::index_provider::{
19    IndexProvider, ProviderError, ProviderTag, SubTag, VectorCandidateStateInfo,
20};
21use crate::store::RowIndex;
22use crate::{SeleneGraph, VectorCandidateSet};
23
24#[path = "candidate_state/state.rs"]
25mod state;
26
27use state::{
28    CandidateState, CandidateStateSnapshot, TrackedEdge, canonicalize_labels, ensure_state_subtag,
29    inconsistent, insert_sorted_unique, invalid_payload, validate_unique_specs, watches_label,
30};
31
32/// Provider tag for maintained graph candidate-state sections.
33pub const CANDIDATE_STATE_PROVIDER_TAG: [u8; 4] = *b"CSET";
34
35/// Provider-owned snapshot section for maintained candidate-state data.
36pub const CANDIDATE_STATE_SUB: [u8; 4] = *b"STAT";
37
38const SNAPSHOT_VERSION: u8 = 1;
39const SUB_TAGS: &[SubTag] = &[SubTag(CANDIDATE_STATE_SUB)];
40
41/// Declarative rule for one maintained candidate set.
42#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
43pub struct CandidateStateSpec {
44    /// Stable set name used by callers to retrieve candidates.
45    pub name: DbString,
46    /// Optional node label required for membership.
47    pub required_label: Option<DbString>,
48    /// Outgoing edge labels required on the source node.
49    pub require_outgoing: Vec<DbString>,
50    /// Incoming edge labels required on the target node.
51    pub require_incoming: Vec<DbString>,
52    /// Outgoing edge labels that disqualify the source node.
53    pub exclude_outgoing: Vec<DbString>,
54    /// Incoming edge labels that disqualify the target node.
55    pub exclude_incoming: Vec<DbString>,
56}
57
58impl CandidateStateSpec {
59    /// Construct an unconstrained named candidate set.
60    #[must_use]
61    pub fn new(name: DbString) -> Self {
62        Self {
63            name,
64            required_label: None,
65            require_outgoing: Vec::new(),
66            require_incoming: Vec::new(),
67            exclude_outgoing: Vec::new(),
68            exclude_incoming: Vec::new(),
69        }
70    }
71
72    /// Require `label` for candidate membership.
73    #[must_use]
74    pub fn require_label(mut self, label: DbString) -> Self {
75        self.required_label = Some(label);
76        self
77    }
78
79    /// Require an outgoing edge carrying `label`.
80    #[must_use]
81    pub fn require_outgoing(mut self, label: DbString) -> Self {
82        insert_sorted_unique(&mut self.require_outgoing, label);
83        self
84    }
85
86    /// Require an incoming edge carrying `label`.
87    #[must_use]
88    pub fn require_incoming(mut self, label: DbString) -> Self {
89        insert_sorted_unique(&mut self.require_incoming, label);
90        self
91    }
92
93    /// Exclude nodes with an outgoing edge carrying `label`.
94    #[must_use]
95    pub fn exclude_outgoing(mut self, label: DbString) -> Self {
96        insert_sorted_unique(&mut self.exclude_outgoing, label);
97        self
98    }
99
100    /// Exclude nodes with an incoming edge carrying `label`.
101    #[must_use]
102    pub fn exclude_incoming(mut self, label: DbString) -> Self {
103        insert_sorted_unique(&mut self.exclude_incoming, label);
104        self
105    }
106}
107
108/// First-party provider maintaining named graph-derived candidate sets.
109pub struct MaintainedCandidateStateProvider {
110    specs: Vec<CandidateStateSpec>,
111    state: Mutex<CandidateState>,
112}
113
114impl MaintainedCandidateStateProvider {
115    /// Construct an empty provider for `specs`.
116    ///
117    /// # Errors
118    ///
119    /// Returns [`ProviderError`] when two specs use the same name.
120    pub fn new(specs: impl IntoIterator<Item = CandidateStateSpec>) -> Result<Self, ProviderError> {
121        let mut specs = specs.into_iter().collect::<Vec<_>>();
122        for spec in &mut specs {
123            canonicalize_labels(&mut spec.require_outgoing);
124            canonicalize_labels(&mut spec.require_incoming);
125            canonicalize_labels(&mut spec.exclude_outgoing);
126            canonicalize_labels(&mut spec.exclude_incoming);
127        }
128        validate_unique_specs(&specs)?;
129        Ok(Self {
130            state: Mutex::new(CandidateState::new(&specs)),
131            specs,
132        })
133    }
134
135    /// Construct a provider and initialize it from a graph snapshot.
136    ///
137    /// # Errors
138    ///
139    /// Returns [`ProviderError`] when specs are invalid or the graph snapshot is
140    /// internally inconsistent.
141    pub fn from_graph(
142        specs: impl IntoIterator<Item = CandidateStateSpec>,
143        graph: &SeleneGraph,
144    ) -> Result<Self, ProviderError> {
145        let provider = Self::new(specs)?;
146        provider.rebuild_from_graph(graph)?;
147        Ok(provider)
148    }
149
150    /// Rebuild all maintained state from `graph`.
151    ///
152    /// This is the safe attachment path when a provider is registered against an
153    /// already-populated graph instead of observing mutations from graph birth.
154    ///
155    /// # Errors
156    ///
157    /// Returns [`ProviderError`] if live row-to-id mappings are inconsistent.
158    pub fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
159        let mut rebuilt = CandidateState::new(&self.specs);
160        for row in graph.live_nodes() {
161            let row = RowIndex::new(row);
162            let id = graph.node_id_for_row(row).ok_or_else(|| {
163                inconsistent(format!("live node row {} has no external id", row.get()))
164            })?;
165            let labels = graph
166                .node_labels(id)
167                .ok_or_else(|| inconsistent(format!("live node {id} has no label column entry")))?;
168            rebuilt.node_labels.insert(id, labels.clone());
169        }
170        for row in graph.live_edges() {
171            let row = RowIndex::new(row);
172            let id = graph.edge_id_for_row(row).ok_or_else(|| {
173                inconsistent(format!("live edge row {} has no external id", row.get()))
174            })?;
175            let label = graph
176                .edge_label(id)
177                .ok_or_else(|| inconsistent(format!("live edge {id} has no label")))?;
178            if !watches_label(&self.specs, label) {
179                continue;
180            }
181            let (source, target) = graph
182                .edge_endpoints(id)
183                .ok_or_else(|| inconsistent(format!("live edge {id} has no endpoints")))?;
184            rebuilt.edges.insert(
185                id,
186                TrackedEdge {
187                    label: label.clone(),
188                    source,
189                    target,
190                },
191            );
192        }
193        rebuilt.rebuild_derived(&self.specs);
194        rebuilt.generation = graph.meta.generation;
195        *self.state.lock() = rebuilt;
196        Ok(())
197    }
198
199    /// Return the configured spec named `name`.
200    #[must_use]
201    pub fn spec(&self, name: &DbString) -> Option<&CandidateStateSpec> {
202        self.specs.iter().find(|spec| &spec.name == name)
203    }
204
205    /// Return the current candidate set for `name`.
206    #[must_use]
207    pub fn candidate_set(&self, name: &DbString) -> Option<VectorCandidateSet> {
208        let state = self.state.lock();
209        state.members.get(name).map(|members| {
210            VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
211        })
212    }
213
214    /// Return the provider generation watermark.
215    #[must_use]
216    pub fn generation(&self) -> u64 {
217        self.state.lock().generation
218    }
219
220    /// Return the current candidate set for `name` if it matches `generation`.
221    ///
222    /// # Errors
223    ///
224    /// Returns [`ProviderError`] when this provider has not applied every
225    /// mutation through `generation`.
226    pub fn candidate_set_at_generation(
227        &self,
228        name: &DbString,
229        generation: u64,
230    ) -> Result<Option<VectorCandidateSet>, ProviderError> {
231        let state = self.state.lock();
232        if state.generation != generation {
233            return Err(inconsistent(format!(
234                "candidate-state generation {} does not match graph generation {generation}",
235                state.generation
236            )));
237        }
238        Ok(state.members.get(name).map(|members| {
239            VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
240        }))
241    }
242
243    /// Return generation-checked metadata for every configured candidate set.
244    ///
245    /// # Errors
246    ///
247    /// Returns [`ProviderError`] when this provider has not applied every
248    /// mutation through `generation`.
249    pub fn candidate_state_infos_at_generation(
250        &self,
251        generation: u64,
252    ) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
253        let state = self.state.lock();
254        if state.generation != generation {
255            return Err(inconsistent(format!(
256                "candidate-state generation {} does not match graph generation {generation}",
257                state.generation
258            )));
259        }
260        Ok(self
261            .specs
262            .iter()
263            .map(|spec| VectorCandidateStateInfo {
264                name: spec.name.clone(),
265                generation,
266                candidate_count: state.members.get(&spec.name).map_or(0, BTreeSet::len),
267                required_label: spec.required_label.clone(),
268                require_outgoing: spec.require_outgoing.clone(),
269                require_incoming: spec.require_incoming.clone(),
270                exclude_outgoing: spec.exclude_outgoing.clone(),
271                exclude_incoming: spec.exclude_incoming.clone(),
272            })
273            .collect())
274    }
275
276    /// Return true when `node` is currently a member of the named set.
277    #[must_use]
278    pub fn contains(&self, name: &DbString, node: NodeId) -> bool {
279        self.state
280            .lock()
281            .members
282            .get(name)
283            .is_some_and(|members| members.contains(&node))
284    }
285}
286
287impl IndexProvider for MaintainedCandidateStateProvider {
288    fn provider_tag(&self) -> ProviderTag {
289        ProviderTag(CANDIDATE_STATE_PROVIDER_TAG)
290    }
291
292    fn read_section(&self, sub_tag: SubTag, bytes: &[u8]) -> Result<(), ProviderError> {
293        ensure_state_subtag(sub_tag)?;
294        let snapshot: CandidateStateSnapshot = postcard::from_bytes(bytes).map_err(|error| {
295            invalid_payload(format!("CSET/STAT postcard decode failed: {error}"))
296        })?;
297        if snapshot.version != SNAPSHOT_VERSION {
298            return Err(invalid_payload(format!(
299                "unsupported CSET/STAT version {}",
300                snapshot.version
301            )));
302        }
303        if snapshot.specs != self.specs {
304            return Err(invalid_payload(
305                "CSET/STAT specs differ from provider configuration".to_owned(),
306            ));
307        }
308        let mut state = CandidateState::new(&self.specs);
309        state.generation = snapshot.generation;
310        for (id, labels) in snapshot.node_labels {
311            if state.node_labels.insert(id, labels).is_some() {
312                return Err(invalid_payload(format!(
313                    "duplicate node id {id} in CSET/STAT"
314                )));
315            }
316        }
317        for (id, edge) in snapshot.edges {
318            if !watches_label(&self.specs, &edge.label) {
319                return Err(invalid_payload(format!(
320                    "unwatched edge label {} in CSET/STAT",
321                    edge.label.as_str()
322                )));
323            }
324            if !state.node_labels.contains_key(&edge.source)
325                || !state.node_labels.contains_key(&edge.target)
326            {
327                return Err(invalid_payload(format!(
328                    "tracked edge {id} references missing endpoint in CSET/STAT"
329                )));
330            }
331            if state.edges.insert(id, edge).is_some() {
332                return Err(invalid_payload(format!(
333                    "duplicate edge id {id} in CSET/STAT"
334                )));
335            }
336        }
337        state.rebuild_derived(&self.specs);
338        *self.state.lock() = state;
339        Ok(())
340    }
341
342    fn write_section(&self, sub_tag: SubTag) -> Result<Vec<u8>, ProviderError> {
343        ensure_state_subtag(sub_tag)?;
344        let state = self.state.lock();
345        let snapshot = CandidateStateSnapshot {
346            version: SNAPSHOT_VERSION,
347            generation: state.generation,
348            specs: self.specs.clone(),
349            node_labels: state
350                .node_labels
351                .iter()
352                .map(|(id, labels)| (*id, labels.clone()))
353                .collect(),
354            edges: state
355                .edges
356                .iter()
357                .map(|(id, edge)| (*id, edge.clone()))
358                .collect(),
359        };
360        postcard::to_stdvec(&snapshot).map_err(|error| ProviderError::SerializationFailed {
361            reason: format!("CSET/STAT postcard encode failed: {error}"),
362        })
363    }
364
365    fn on_change(&self, change: &Change) -> Result<(), ProviderError> {
366        self.state.lock().apply_change(&self.specs, change)
367    }
368
369    fn handles_change_batches(&self) -> bool {
370        true
371    }
372
373    fn on_changes(&self, changes: &[Change]) -> Result<(), ProviderError> {
374        let mut state = self.state.lock();
375        for change in changes {
376            state.apply_change(&self.specs, change)?;
377        }
378        Ok(())
379    }
380
381    fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
382        MaintainedCandidateStateProvider::rebuild_from_graph(self, graph)
383    }
384
385    fn on_commit_applied(&self, generation: u64) -> Result<(), ProviderError> {
386        self.state.lock().generation = generation;
387        Ok(())
388    }
389
390    fn vector_candidate_set(
391        &self,
392        name: &DbString,
393        generation: u64,
394    ) -> Result<Option<VectorCandidateSet>, ProviderError> {
395        self.candidate_set_at_generation(name, generation)
396    }
397
398    fn vector_candidate_state_infos(
399        &self,
400        generation: u64,
401    ) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
402        self.candidate_state_infos_at_generation(generation)
403    }
404
405    fn declared_sub_tags(&self) -> &[SubTag] {
406        SUB_TAGS
407    }
408}
409
410#[cfg(test)]
411#[path = "candidate_state/tests.rs"]
412mod tests;