Skip to main content

uni_query/query/df_graph/
locy_model_invoke.rs

1//! Phase B A4: `LocyModelInvokeExec` — a DataFusion `ExecutionPlan`
2//! node that runs registered neural classifiers against its input
3//! batches.
4//!
5//! This is the structural successor to the legacy
6//! `apply_model_invocations` post-projection record-batch pass that
7//! lived inside `run_fixpoint_loop` / `LocyProgramExec::run_program`.
8//! The behavior is byte-identical — the inner implementation
9//! (`super::locy_fixpoint::apply_model_invocations`) is unchanged;
10//! only the call site moved into a proper plan node so the
11//! invocation is part of the DataFusion plan tree instead of a
12//! post-execute mutation.
13//!
14//! Async-in-stream pattern: the input stream is collected into a
15//! `Vec<RecordBatch>`, the async classifier pass mutates it, and the
16//! result is yielded as a `RecordBatchStreamAdapter`. Same shape as
17//! `mutation_common::execute_mutation_inner` — no novel async
18//! machinery in the codebase.
19
20use std::any::Any;
21use std::collections::HashMap;
22use std::sync::Arc;
23
24use arrow_array::RecordBatch;
25use arrow_schema::SchemaRef;
26use datafusion::error::Result as DFResult;
27use datafusion::execution::TaskContext;
28use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
29use datafusion::physical_plan::{
30    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
31};
32use futures::TryStreamExt;
33use parking_lot::RwLock;
34use uni_algo::algo::AlgorithmRegistry;
35use uni_locy::{ClassifierRegistry, ModelInvocation, ModelInvocationCache};
36use uni_store::runtime::L0Manager;
37use uni_store::runtime::property_manager::PropertyManager;
38use uni_store::storage::manager::StorageManager;
39use uni_xervo::runtime::ModelRuntime;
40
41use super::locy_fixpoint::apply_model_invocations;
42
43/// Phase D D1 graph-structural runtime: a Clone+Debug bundle of the
44/// pieces needed to invoke `uni.algo.*` procedures directly from the
45/// FEATURE pipeline (no Cypher CALL roundtrip) and to traverse
46/// one-hop neighborhoods for `avg_neighbor` / `max_neighbor` /
47/// `sum_neighbor`.
48///
49/// Built fresh at physical-plan lowering (`df_planner.rs`) from
50/// `GraphExecutionContext`; mirrors the `XervoRuntimeHandle` pattern
51/// (logical plan is graph_ctx-agnostic).
52#[derive(Clone, Default)]
53pub struct GraphAlgoHandle {
54    pub(crate) registry: Option<Arc<AlgorithmRegistry>>,
55    pub(crate) storage: Option<Arc<StorageManager>>,
56    pub(crate) l0_manager: Option<Arc<L0Manager>>,
57    pub(crate) property_manager: Option<Arc<PropertyManager>>,
58    /// Raw L0 buffers for building a fresh `QueryContext` when the
59    /// neighbor-aggregator path calls `PropertyManager::get_vertex_prop_with_ctx`.
60    /// L0-resident vertex properties are invisible to property reads
61    /// without a `QueryContext`; topology procedures don't need this
62    /// because they consume `L0Manager` directly via `AlgoContext`.
63    pub(crate) l0_buffers: Option<L0Buffers>,
64}
65
66#[derive(Clone)]
67pub(crate) struct L0Buffers {
68    pub(crate) current: Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>,
69    pub(crate) transaction: Option<Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
70    pub(crate) pending_flush: Vec<Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
71}
72
73impl std::fmt::Debug for GraphAlgoHandle {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        match (&self.registry, &self.storage) {
76            (Some(_), Some(_)) => write!(f, "GraphAlgoHandle(<configured>)"),
77            _ => write!(f, "GraphAlgoHandle(<none>)"),
78        }
79    }
80}
81
82impl GraphAlgoHandle {
83    pub fn is_configured(&self) -> bool {
84        self.registry.is_some() && self.storage.is_some()
85    }
86}
87
88/// Phase D D2 runtime: a Clone+Debug wrapper around the optional
89/// Uni-Xervo runtime. `ModelRuntime` doesn't derive Debug (its
90/// `providers: HashMap<String, Box<dyn ModelProvider>>` field
91/// contains trait objects that aren't Debug), so we need this
92/// shim to keep `LogicalPlan` derivable.
93#[derive(Clone, Default)]
94pub struct XervoRuntimeHandle(pub Option<Arc<ModelRuntime>>);
95
96impl std::fmt::Debug for XervoRuntimeHandle {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        match &self.0 {
99            Some(_) => write!(f, "XervoRuntimeHandle(<configured>)"),
100            None => write!(f, "XervoRuntimeHandle(<none>)"),
101        }
102    }
103}
104
105impl XervoRuntimeHandle {
106    pub fn as_ref(&self) -> Option<&Arc<ModelRuntime>> {
107        self.0.as_ref()
108    }
109}
110
111/// Phase D D3 runtime: a shared handle into a source rule's derived
112/// facts. The plan builder mints these for every `path_context.source_rule`
113/// referenced by any invocation in the same `LocyModelInvoke`, so the
114/// runtime can read the rule's `Vec<RecordBatch>` (already populated by
115/// the fixpoint loop in an earlier stratum) and join by VID without
116/// consulting the registry at exec time.
117#[derive(Debug, Clone)]
118pub struct PathContextHandle {
119    pub source_rule: String,
120    pub data: Arc<RwLock<Vec<RecordBatch>>>,
121    pub schema: SchemaRef,
122}
123
124/// `ExecutionPlan` wrapper that runs `apply_model_invocations` over
125/// the batches produced by `input`.
126#[derive(Debug)]
127pub struct LocyModelInvokeExec {
128    input: Arc<dyn ExecutionPlan>,
129    invocations: Vec<ModelInvocation>,
130    registry: Arc<ClassifierRegistry>,
131    cache: Option<Arc<ModelInvocationCache>>,
132    /// Phase D D3: one handle per distinct `path_context.source_rule`
133    /// referenced by the invocations on this node, indexed by rule name.
134    path_context_handles: HashMap<String, PathContextHandle>,
135    /// Phase D D2 runtime: Uni-Xervo runtime for auto-embedding
136    /// `semantic_match(prop, 'text')` query literals once per
137    /// `apply_model_invocations` call. `None` when no xervo runtime
138    /// is configured — `semantic_match` calls then error with a
139    /// clear message at row time.
140    xervo_runtime: XervoRuntimeHandle,
141    /// Phase D D1 graph-structural runtime: registry + storage handle
142    /// for invoking topology algorithms (degree/pagerank/closeness)
143    /// and walking one-hop neighborhoods. Built from `GraphExecutionContext`
144    /// at physical lowering.
145    graph_algo: GraphAlgoHandle,
146    /// Phase C B1-B3 follow-up: per-query side-channel store for
147    /// (raw, calibrated, confidence_band) tuples. Written per row
148    /// per invocation by `apply_model_invocations`; consumed by
149    /// EXPLAIN's `collect_neural_calls_for_row` to surface
150    /// `NeuralProvenance` regardless of whether the invocation
151    /// lives in YIELD / ALONG / FOLD position.
152    provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
153    /// Output schema: the input schema, post-invocation. Since
154    /// `apply_model_invocations` overwrites a placeholder column
155    /// emitted by the compiler at the same name and forces it to
156    /// `Float64`, the output schema equals the input schema with
157    /// each invocation's `output_column` retyped to `Float64`.
158    schema: SchemaRef,
159    plan_properties: Arc<PlanProperties>,
160}
161
162impl LocyModelInvokeExec {
163    #[allow(clippy::too_many_arguments)]
164    pub fn new(
165        input: Arc<dyn ExecutionPlan>,
166        invocations: Vec<ModelInvocation>,
167        registry: Arc<ClassifierRegistry>,
168        cache: Option<Arc<ModelInvocationCache>>,
169        provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
170        path_context_handles: HashMap<String, PathContextHandle>,
171        xervo_runtime: XervoRuntimeHandle,
172        graph_algo: GraphAlgoHandle,
173    ) -> Self {
174        let schema = compute_output_schema(input.schema(), &invocations);
175        let plan_properties = compute_plan_properties(&input, schema.clone());
176        Self {
177            input,
178            invocations,
179            registry,
180            cache,
181            provenance_store,
182            path_context_handles,
183            xervo_runtime,
184            graph_algo,
185            schema,
186            plan_properties,
187        }
188    }
189}
190
191fn compute_output_schema(input_schema: SchemaRef, invocations: &[ModelInvocation]) -> SchemaRef {
192    use arrow_schema::{DataType, Field, Schema};
193    if invocations.is_empty() {
194        return input_schema;
195    }
196    let mut fields: Vec<Arc<Field>> = input_schema.fields().iter().cloned().collect();
197    for invocation in invocations {
198        if let Some((idx, _)) = input_schema
199            .fields()
200            .iter()
201            .enumerate()
202            .find(|(_, f)| f.name() == &invocation.output_column)
203        {
204            fields[idx] = Arc::new(Field::new(
205                &invocation.output_column,
206                DataType::Float64,
207                true,
208            ));
209        } else {
210            fields.push(Arc::new(Field::new(
211                &invocation.output_column,
212                DataType::Float64,
213                true,
214            )));
215        }
216    }
217    Arc::new(Schema::new(fields))
218}
219
220fn compute_plan_properties(
221    input: &Arc<dyn ExecutionPlan>,
222    schema: SchemaRef,
223) -> Arc<PlanProperties> {
224    use datafusion::physical_expr::EquivalenceProperties;
225    use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
226
227    let eq = EquivalenceProperties::new(schema);
228    Arc::new(PlanProperties::new(
229        eq,
230        input.properties().output_partitioning().clone(),
231        EmissionType::Final,
232        Boundedness::Bounded,
233    ))
234}
235
236impl DisplayAs for LocyModelInvokeExec {
237    fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
238        write!(
239            f,
240            "LocyModelInvokeExec: invocations=[{}]",
241            self.invocations
242                .iter()
243                .map(|inv| format!("{}→{}", inv.model_name, inv.output_column))
244                .collect::<Vec<_>>()
245                .join(", ")
246        )
247    }
248}
249
250impl ExecutionPlan for LocyModelInvokeExec {
251    fn name(&self) -> &str {
252        "LocyModelInvokeExec"
253    }
254
255    fn as_any(&self) -> &dyn Any {
256        self
257    }
258
259    fn properties(&self) -> &Arc<PlanProperties> {
260        &self.plan_properties
261    }
262
263    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
264        vec![&self.input]
265    }
266
267    fn with_new_children(
268        self: Arc<Self>,
269        children: Vec<Arc<dyn ExecutionPlan>>,
270    ) -> DFResult<Arc<dyn ExecutionPlan>> {
271        if children.len() != 1 {
272            return Err(datafusion::error::DataFusionError::Internal(format!(
273                "LocyModelInvokeExec expects exactly 1 child, got {}",
274                children.len()
275            )));
276        }
277        Ok(Arc::new(Self::new(
278            children.into_iter().next().unwrap(),
279            self.invocations.clone(),
280            Arc::clone(&self.registry),
281            self.cache.as_ref().map(Arc::clone),
282            self.provenance_store.as_ref().map(Arc::clone),
283            self.path_context_handles.clone(),
284            self.xervo_runtime.clone(),
285            self.graph_algo.clone(),
286        )))
287    }
288
289    fn execute(
290        &self,
291        partition: usize,
292        context: Arc<TaskContext>,
293    ) -> DFResult<SendableRecordBatchStream> {
294        let input_stream = self.input.execute(partition, context)?;
295        let invocations = self.invocations.clone();
296        let registry = Arc::clone(&self.registry);
297        let cache = self.cache.as_ref().map(Arc::clone);
298        let provenance_store = self.provenance_store.as_ref().map(Arc::clone);
299        let path_context_handles = self.path_context_handles.clone();
300        let xervo_runtime = self.xervo_runtime.clone();
301        let graph_algo = self.graph_algo.clone();
302        let schema = self.schema.clone();
303
304        let fut = async move {
305            let batches: Vec<RecordBatch> = input_stream.try_collect::<Vec<_>>().await?;
306            let out = apply_model_invocations(
307                batches,
308                &invocations,
309                &registry,
310                cache.as_ref(),
311                provenance_store.as_ref(),
312                &path_context_handles,
313                &xervo_runtime,
314                &graph_algo,
315            )
316            .await?;
317            // Wrap the Vec<RecordBatch> as a stream so try_flatten
318            // can splice it inline.
319            Ok::<_, datafusion::error::DataFusionError>(futures::stream::iter(
320                out.into_iter().map(Ok),
321            ))
322        };
323        let stream = futures::stream::once(fut).try_flatten();
324
325        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
326    }
327}