Skip to main content

shadow_core/
python.rs

1//! PyO3 bindings — the `shadow._core` Python extension module.
2//!
3//! Exposed surface:
4//! - `parse_agentlog(bytes) -> list[dict]`
5//! - `write_agentlog(list[dict]) -> bytes`
6//! - `canonical_bytes(payload_dict) -> bytes`  (SPEC §5)
7//! - `content_id(payload_dict) -> str`         (SPEC §6)
8//! - `compute_diff_report(baseline, candidate, pricing, seed) -> dict`
9//!
10//! Everything is dict-oriented on the Python side; `pythonize` handles
11//! the serde_json::Value ↔ PyObject conversion. Type hints for Python
12//! callers live in `python/src/shadow/_core.pyi`.
13//
14// clippy::useless_conversion fires on the `?` operator in PyResult chains
15// where `?` does an identity PyErr→PyErr conversion via `From`. That's a
16// standard PyO3 pattern (every PyO3 API returns PyResult); suppressing here
17// keeps PyO3-idiomatic code readable without sprinkling allows everywhere.
18#![allow(clippy::useless_conversion)]
19
20use std::io::Cursor;
21
22use pyo3::exceptions::{PyIOError, PyValueError};
23use pyo3::prelude::*;
24use pyo3::types::{PyBytes, PyDict, PyList};
25use pythonize::{depythonize, pythonize};
26
27use crate::agentlog::{hash, parser, writer, Record};
28use crate::diff::{
29    compute_report,
30    cost::{ModelPricing, Pricing},
31    embedder::{BoxedEmbedder, Embedder},
32    semantic::compute_with_embedder,
33};
34
35/// Parse a `.agentlog` byte blob into a list of record dicts.
36#[pyfunction]
37fn parse_agentlog<'py>(
38    py: Python<'py>,
39    data: &Bound<'py, PyBytes>,
40) -> PyResult<Bound<'py, PyList>> {
41    let bytes = data.as_bytes();
42    let records =
43        parser::parse_all(Cursor::new(bytes)).map_err(|e| PyValueError::new_err(e.to_string()))?;
44    let out = PyList::empty_bound(py);
45    for r in records {
46        let v = serde_json::to_value(&r).map_err(|e| PyValueError::new_err(e.to_string()))?;
47        let obj = pythonize(py, &v).map_err(|e| PyValueError::new_err(e.to_string()))?;
48        out.append(obj)?;
49    }
50    Ok(out)
51}
52
53/// Serialize a list of record dicts into `.agentlog` bytes.
54#[pyfunction]
55fn write_agentlog<'py>(
56    py: Python<'py>,
57    records: &Bound<'py, PyList>,
58) -> PyResult<Bound<'py, PyBytes>> {
59    let mut parsed: Vec<Record> = Vec::with_capacity(records.len());
60    for item in records.iter() {
61        let v: serde_json::Value =
62            depythonize(&item).map_err(|e| PyValueError::new_err(e.to_string()))?;
63        let r: Record =
64            serde_json::from_value(v).map_err(|e| PyValueError::new_err(e.to_string()))?;
65        parsed.push(r);
66    }
67    let mut buf = Vec::new();
68    writer::write_all(&mut buf, &parsed).map_err(|e| PyIOError::new_err(e.to_string()))?;
69    Ok(PyBytes::new_bound(py, &buf))
70}
71
72/// Canonical-JSON byte sequence for a payload (SPEC §5).
73#[pyfunction]
74fn canonical_bytes<'py>(
75    py: Python<'py>,
76    payload: &Bound<'py, PyAny>,
77) -> PyResult<Bound<'py, PyBytes>> {
78    let v: serde_json::Value =
79        depythonize(payload).map_err(|e| PyValueError::new_err(e.to_string()))?;
80    let bytes = crate::agentlog::canonical::to_bytes(&v);
81    Ok(PyBytes::new_bound(py, &bytes))
82}
83
84/// Content id for a payload dict (SPEC §6).
85#[pyfunction]
86fn content_id(payload: &Bound<'_, PyAny>) -> PyResult<String> {
87    let v: serde_json::Value =
88        depythonize(payload).map_err(|e| PyValueError::new_err(e.to_string()))?;
89    Ok(hash::content_id(&v))
90}
91
92/// Compute a nine-axis diff between two traces.
93///
94/// `baseline` and `candidate` are lists of record dicts (as produced by
95/// [`parse_agentlog`]). `pricing` is a `dict[str, tuple[float, float]]`
96/// mapping model name → (price_per_input_token, price_per_output_token).
97/// `seed` is an optional RNG seed for reproducible bootstrap CIs.
98#[pyfunction]
99#[pyo3(signature = (baseline, candidate, pricing=None, seed=None))]
100fn compute_diff_report<'py>(
101    py: Python<'py>,
102    baseline: &Bound<'py, PyList>,
103    candidate: &Bound<'py, PyList>,
104    pricing: Option<&Bound<'py, PyDict>>,
105    seed: Option<u64>,
106) -> PyResult<Bound<'py, PyAny>> {
107    let baseline_records = pylist_to_records(baseline)?;
108    let candidate_records = pylist_to_records(candidate)?;
109
110    let mut price_map = Pricing::new();
111    if let Some(dict) = pricing {
112        for (k, v) in dict.iter() {
113            let key: String = k
114                .extract()
115                .map_err(|e| PyValueError::new_err(format!("pricing key: {e}")))?;
116            // Accept either (input, output) tuple (legacy) or a dict
117            // {input, output, cached_input?, reasoning?, batch_discount?}.
118            let mp = if let Ok(pair) = v.extract::<(f64, f64)>() {
119                ModelPricing::simple(pair.0, pair.1)
120            } else {
121                let v_json: serde_json::Value = depythonize(&v)
122                    .map_err(|e| PyValueError::new_err(format!("pricing value: {e}")))?;
123                serde_json::from_value(v_json)
124                    .map_err(|e| PyValueError::new_err(format!("pricing value: {e}")))?
125            };
126            price_map.insert(key, mp);
127        }
128    }
129
130    let report = compute_report(&baseline_records, &candidate_records, &price_map, seed);
131    let v = serde_json::to_value(&report).map_err(|e| PyValueError::new_err(e.to_string()))?;
132    pythonize(py, &v).map_err(|e| PyValueError::new_err(e.to_string()))
133}
134
135/// Compute the semantic axis using a Python-supplied embedder callable.
136///
137/// The callable receives a `list[str]` of texts and must return a
138/// `list[list[float]]` of equal-length dense vectors. Typical wiring:
139///
140/// ```python
141/// from sentence_transformers import SentenceTransformer
142/// model = SentenceTransformer("all-MiniLM-L6-v2")
143/// def embed(texts: list[str]) -> list[list[float]]:
144///     return model.encode(texts).tolist()
145/// shadow._core.compute_semantic_axis_with_embedder(baseline, candidate, embed, seed=42)
146/// ```
147///
148/// Returns the AxisStat as a dict (same shape the diff report uses).
149#[pyfunction]
150#[pyo3(signature = (baseline, candidate, embedder, seed=None))]
151fn compute_semantic_axis_with_embedder<'py>(
152    py: Python<'py>,
153    baseline: &Bound<'py, PyList>,
154    candidate: &Bound<'py, PyList>,
155    embedder: &Bound<'py, PyAny>,
156    seed: Option<u64>,
157) -> PyResult<Bound<'py, PyAny>> {
158    let baseline_records = pylist_to_records(baseline)?;
159    let candidate_records = pylist_to_records(candidate)?;
160
161    if !embedder.is_callable() {
162        return Err(PyValueError::new_err(
163            "embedder must be callable: fn(list[str]) -> list[list[float]]",
164        ));
165    }
166
167    // Pair up baseline ↔ candidate response records the same way
168    // `compute_report` does, then route through `compute_with_embedder`.
169    let pairs = pair_responses(&baseline_records, &candidate_records);
170
171    // Wrap the Python callable into a Rust closure that the Embedder
172    // trait can call. We hold the GIL for the duration via Python::with_gil
173    // inside the closure, since BoxedEmbedder is Sync but the actual call
174    // path is single-threaded here.
175    let embedder_obj: Py<PyAny> = embedder.clone().unbind();
176    let py_embedder = BoxedEmbedder::named(
177        move |texts: &[&str]| -> Vec<Vec<f32>> {
178            Python::with_gil(|py| {
179                let owned: Vec<String> = texts.iter().map(|s| (*s).to_string()).collect();
180                let py_list = PyList::new_bound(py, &owned);
181                let result = embedder_obj.call1(py, (py_list,));
182                let any = match result {
183                    Ok(v) => v,
184                    Err(_) => return Vec::new(),
185                };
186                let bound = any.bind(py);
187                bound.extract::<Vec<Vec<f32>>>().unwrap_or_default()
188            })
189        },
190        "py-callback",
191    );
192
193    // SAFETY: compute_with_embedder takes &dyn Embedder; BoxedEmbedder is
194    // Send + Sync. We construct on this thread and use it on this thread,
195    // so no concurrency concerns.
196    let pair_refs: Vec<(&Record, &Record)> = pairs.iter().map(|(a, b)| (*a, *b)).collect();
197    let stat = compute_with_embedder(&pair_refs, &py_embedder as &dyn Embedder, seed);
198
199    let v = serde_json::to_value(&stat).map_err(|e| PyValueError::new_err(e.to_string()))?;
200    pythonize(py, &v).map_err(|e| PyValueError::new_err(e.to_string()))
201}
202
203/// Helper: pair baseline and candidate `chat_response` records by index,
204/// matching the existing diff pipeline's pairing semantics.
205fn pair_responses<'a>(
206    baseline: &'a [Record],
207    candidate: &'a [Record],
208) -> Vec<(&'a Record, &'a Record)> {
209    use crate::agentlog::Kind;
210    let b_resps: Vec<&Record> = baseline
211        .iter()
212        .filter(|r| r.kind == Kind::ChatResponse)
213        .collect();
214    let c_resps: Vec<&Record> = candidate
215        .iter()
216        .filter(|r| r.kind == Kind::ChatResponse)
217        .collect();
218    b_resps.into_iter().zip(c_resps).collect()
219}
220
221fn pylist_to_records(list: &Bound<'_, PyList>) -> PyResult<Vec<Record>> {
222    let mut out = Vec::with_capacity(list.len());
223    for item in list.iter() {
224        let v: serde_json::Value =
225            depythonize(&item).map_err(|e| PyValueError::new_err(e.to_string()))?;
226        let r: Record =
227            serde_json::from_value(v).map_err(|e| PyValueError::new_err(e.to_string()))?;
228        out.push(r);
229    }
230    Ok(out)
231}
232
233/// The entry point the Python interpreter calls when `shadow._core` is
234/// imported. Registers every function above.
235#[pymodule]
236fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
237    m.add("__version__", crate::VERSION)?;
238    m.add("SPEC_VERSION", crate::agentlog::CURRENT_VERSION)?;
239    m.add_function(wrap_pyfunction!(parse_agentlog, m)?)?;
240    m.add_function(wrap_pyfunction!(write_agentlog, m)?)?;
241    m.add_function(wrap_pyfunction!(canonical_bytes, m)?)?;
242    m.add_function(wrap_pyfunction!(content_id, m)?)?;
243    m.add_function(wrap_pyfunction!(compute_diff_report, m)?)?;
244    m.add_function(wrap_pyfunction!(compute_semantic_axis_with_embedder, m)?)?;
245    Ok(())
246}