1use std::collections::BTreeSet;
10
11use parking_lot::Mutex;
12use serde::{Deserialize, Serialize};
13
14use selene_core::{Change, DbString, NodeId};
15#[cfg(test)]
16use selene_core::{EdgeId, LabelSet};
17
18use crate::index_provider::{
19 IndexProvider, ProviderError, ProviderTag, SubTag, VectorCandidateStateInfo,
20};
21use crate::store::RowIndex;
22use crate::{SeleneGraph, VectorCandidateSet};
23
24#[path = "candidate_state/state.rs"]
25mod state;
26
27use state::{
28 CandidateState, CandidateStateSnapshot, TrackedEdge, canonicalize_labels, ensure_state_subtag,
29 inconsistent, insert_sorted_unique, invalid_payload, validate_unique_specs, watches_label,
30};
31
32pub const CANDIDATE_STATE_PROVIDER_TAG: [u8; 4] = *b"CSET";
34
35pub const CANDIDATE_STATE_SUB: [u8; 4] = *b"STAT";
37
38const SNAPSHOT_VERSION: u8 = 1;
39const SUB_TAGS: &[SubTag] = &[SubTag(CANDIDATE_STATE_SUB)];
40
41#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
43pub struct CandidateStateSpec {
44 pub name: DbString,
46 pub required_label: Option<DbString>,
48 pub require_outgoing: Vec<DbString>,
50 pub require_incoming: Vec<DbString>,
52 pub exclude_outgoing: Vec<DbString>,
54 pub exclude_incoming: Vec<DbString>,
56}
57
58impl CandidateStateSpec {
59 #[must_use]
61 pub fn new(name: DbString) -> Self {
62 Self {
63 name,
64 required_label: None,
65 require_outgoing: Vec::new(),
66 require_incoming: Vec::new(),
67 exclude_outgoing: Vec::new(),
68 exclude_incoming: Vec::new(),
69 }
70 }
71
72 #[must_use]
74 pub fn require_label(mut self, label: DbString) -> Self {
75 self.required_label = Some(label);
76 self
77 }
78
79 #[must_use]
81 pub fn require_outgoing(mut self, label: DbString) -> Self {
82 insert_sorted_unique(&mut self.require_outgoing, label);
83 self
84 }
85
86 #[must_use]
88 pub fn require_incoming(mut self, label: DbString) -> Self {
89 insert_sorted_unique(&mut self.require_incoming, label);
90 self
91 }
92
93 #[must_use]
95 pub fn exclude_outgoing(mut self, label: DbString) -> Self {
96 insert_sorted_unique(&mut self.exclude_outgoing, label);
97 self
98 }
99
100 #[must_use]
102 pub fn exclude_incoming(mut self, label: DbString) -> Self {
103 insert_sorted_unique(&mut self.exclude_incoming, label);
104 self
105 }
106}
107
108pub struct MaintainedCandidateStateProvider {
110 specs: Vec<CandidateStateSpec>,
111 state: Mutex<CandidateState>,
112}
113
114impl MaintainedCandidateStateProvider {
115 pub fn new(specs: impl IntoIterator<Item = CandidateStateSpec>) -> Result<Self, ProviderError> {
121 let mut specs = specs.into_iter().collect::<Vec<_>>();
122 for spec in &mut specs {
123 canonicalize_labels(&mut spec.require_outgoing);
124 canonicalize_labels(&mut spec.require_incoming);
125 canonicalize_labels(&mut spec.exclude_outgoing);
126 canonicalize_labels(&mut spec.exclude_incoming);
127 }
128 validate_unique_specs(&specs)?;
129 Ok(Self {
130 state: Mutex::new(CandidateState::new(&specs)),
131 specs,
132 })
133 }
134
135 pub fn from_graph(
142 specs: impl IntoIterator<Item = CandidateStateSpec>,
143 graph: &SeleneGraph,
144 ) -> Result<Self, ProviderError> {
145 let provider = Self::new(specs)?;
146 provider.rebuild_from_graph(graph)?;
147 Ok(provider)
148 }
149
150 pub fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
159 let mut rebuilt = CandidateState::new(&self.specs);
160 for row in graph.live_nodes() {
161 let row = RowIndex::new(row);
162 let id = graph.node_id_for_row(row).ok_or_else(|| {
163 inconsistent(format!("live node row {} has no external id", row.get()))
164 })?;
165 let labels = graph
166 .node_labels(id)
167 .ok_or_else(|| inconsistent(format!("live node {id} has no label column entry")))?;
168 rebuilt.node_labels.insert(id, labels.clone());
169 }
170 for row in graph.live_edges() {
171 let row = RowIndex::new(row);
172 let id = graph.edge_id_for_row(row).ok_or_else(|| {
173 inconsistent(format!("live edge row {} has no external id", row.get()))
174 })?;
175 let label = graph
176 .edge_label(id)
177 .ok_or_else(|| inconsistent(format!("live edge {id} has no label")))?;
178 if !watches_label(&self.specs, label) {
179 continue;
180 }
181 let (source, target) = graph
182 .edge_endpoints(id)
183 .ok_or_else(|| inconsistent(format!("live edge {id} has no endpoints")))?;
184 rebuilt.edges.insert(
185 id,
186 TrackedEdge {
187 label: label.clone(),
188 source,
189 target,
190 },
191 );
192 }
193 rebuilt.rebuild_derived(&self.specs);
194 rebuilt.generation = graph.meta.generation;
195 *self.state.lock() = rebuilt;
196 Ok(())
197 }
198
199 #[must_use]
201 pub fn spec(&self, name: &DbString) -> Option<&CandidateStateSpec> {
202 self.specs.iter().find(|spec| &spec.name == name)
203 }
204
205 #[must_use]
207 pub fn candidate_set(&self, name: &DbString) -> Option<VectorCandidateSet> {
208 let state = self.state.lock();
209 state.members.get(name).map(|members| {
210 VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
211 })
212 }
213
214 #[must_use]
216 pub fn generation(&self) -> u64 {
217 self.state.lock().generation
218 }
219
220 pub fn candidate_set_at_generation(
227 &self,
228 name: &DbString,
229 generation: u64,
230 ) -> Result<Option<VectorCandidateSet>, ProviderError> {
231 let state = self.state.lock();
232 if state.generation != generation {
233 return Err(inconsistent(format!(
234 "candidate-state generation {} does not match graph generation {generation}",
235 state.generation
236 )));
237 }
238 Ok(state.members.get(name).map(|members| {
239 VectorCandidateSet::from_canonical_nodes(members.iter().copied().collect())
240 }))
241 }
242
243 pub fn candidate_state_infos_at_generation(
250 &self,
251 generation: u64,
252 ) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
253 let state = self.state.lock();
254 if state.generation != generation {
255 return Err(inconsistent(format!(
256 "candidate-state generation {} does not match graph generation {generation}",
257 state.generation
258 )));
259 }
260 Ok(self
261 .specs
262 .iter()
263 .map(|spec| VectorCandidateStateInfo {
264 name: spec.name.clone(),
265 generation,
266 candidate_count: state.members.get(&spec.name).map_or(0, BTreeSet::len),
267 required_label: spec.required_label.clone(),
268 require_outgoing: spec.require_outgoing.clone(),
269 require_incoming: spec.require_incoming.clone(),
270 exclude_outgoing: spec.exclude_outgoing.clone(),
271 exclude_incoming: spec.exclude_incoming.clone(),
272 })
273 .collect())
274 }
275
276 #[must_use]
278 pub fn contains(&self, name: &DbString, node: NodeId) -> bool {
279 self.state
280 .lock()
281 .members
282 .get(name)
283 .is_some_and(|members| members.contains(&node))
284 }
285}
286
287impl IndexProvider for MaintainedCandidateStateProvider {
288 fn provider_tag(&self) -> ProviderTag {
289 ProviderTag(CANDIDATE_STATE_PROVIDER_TAG)
290 }
291
292 fn read_section(&self, sub_tag: SubTag, bytes: &[u8]) -> Result<(), ProviderError> {
293 ensure_state_subtag(sub_tag)?;
294 let snapshot: CandidateStateSnapshot = postcard::from_bytes(bytes).map_err(|error| {
295 invalid_payload(format!("CSET/STAT postcard decode failed: {error}"))
296 })?;
297 if snapshot.version != SNAPSHOT_VERSION {
298 return Err(invalid_payload(format!(
299 "unsupported CSET/STAT version {}",
300 snapshot.version
301 )));
302 }
303 if snapshot.specs != self.specs {
304 return Err(invalid_payload(
305 "CSET/STAT specs differ from provider configuration".to_owned(),
306 ));
307 }
308 let mut state = CandidateState::new(&self.specs);
309 state.generation = snapshot.generation;
310 for (id, labels) in snapshot.node_labels {
311 if state.node_labels.insert(id, labels).is_some() {
312 return Err(invalid_payload(format!(
313 "duplicate node id {id} in CSET/STAT"
314 )));
315 }
316 }
317 for (id, edge) in snapshot.edges {
318 if !watches_label(&self.specs, &edge.label) {
319 return Err(invalid_payload(format!(
320 "unwatched edge label {} in CSET/STAT",
321 edge.label.as_str()
322 )));
323 }
324 if !state.node_labels.contains_key(&edge.source)
325 || !state.node_labels.contains_key(&edge.target)
326 {
327 return Err(invalid_payload(format!(
328 "tracked edge {id} references missing endpoint in CSET/STAT"
329 )));
330 }
331 if state.edges.insert(id, edge).is_some() {
332 return Err(invalid_payload(format!(
333 "duplicate edge id {id} in CSET/STAT"
334 )));
335 }
336 }
337 state.rebuild_derived(&self.specs);
338 *self.state.lock() = state;
339 Ok(())
340 }
341
342 fn write_section(&self, sub_tag: SubTag) -> Result<Vec<u8>, ProviderError> {
343 ensure_state_subtag(sub_tag)?;
344 let state = self.state.lock();
345 let snapshot = CandidateStateSnapshot {
346 version: SNAPSHOT_VERSION,
347 generation: state.generation,
348 specs: self.specs.clone(),
349 node_labels: state
350 .node_labels
351 .iter()
352 .map(|(id, labels)| (*id, labels.clone()))
353 .collect(),
354 edges: state
355 .edges
356 .iter()
357 .map(|(id, edge)| (*id, edge.clone()))
358 .collect(),
359 };
360 postcard::to_stdvec(&snapshot).map_err(|error| ProviderError::SerializationFailed {
361 reason: format!("CSET/STAT postcard encode failed: {error}"),
362 })
363 }
364
365 fn on_change(&self, change: &Change) -> Result<(), ProviderError> {
366 self.state.lock().apply_change(&self.specs, change)
367 }
368
369 fn handles_change_batches(&self) -> bool {
370 true
371 }
372
373 fn on_changes(&self, changes: &[Change]) -> Result<(), ProviderError> {
374 let mut state = self.state.lock();
375 for change in changes {
376 state.apply_change(&self.specs, change)?;
377 }
378 Ok(())
379 }
380
381 fn rebuild_from_graph(&self, graph: &SeleneGraph) -> Result<(), ProviderError> {
382 MaintainedCandidateStateProvider::rebuild_from_graph(self, graph)
383 }
384
385 fn on_commit_applied(&self, generation: u64) -> Result<(), ProviderError> {
386 self.state.lock().generation = generation;
387 Ok(())
388 }
389
390 fn vector_candidate_set(
391 &self,
392 name: &DbString,
393 generation: u64,
394 ) -> Result<Option<VectorCandidateSet>, ProviderError> {
395 self.candidate_set_at_generation(name, generation)
396 }
397
398 fn vector_candidate_state_infos(
399 &self,
400 generation: u64,
401 ) -> Result<Vec<VectorCandidateStateInfo>, ProviderError> {
402 self.candidate_state_infos_at_generation(generation)
403 }
404
405 fn declared_sub_tags(&self) -> &[SubTag] {
406 SUB_TAGS
407 }
408}
409
410#[cfg(test)]
411#[path = "candidate_state/tests.rs"]
412mod tests;