1#![allow(clippy::useless_conversion)]
19
20use std::io::Cursor;
21
22use pyo3::exceptions::{PyIOError, PyValueError};
23use pyo3::prelude::*;
24use pyo3::types::{PyBytes, PyDict, PyList};
25use pythonize::{depythonize, pythonize};
26
27use crate::agentlog::{hash, parser, writer, Record};
28use crate::diff::{
29 compute_report,
30 cost::{ModelPricing, Pricing},
31 embedder::{BoxedEmbedder, Embedder},
32 semantic::compute_with_embedder,
33};
34
35#[pyfunction]
37fn parse_agentlog<'py>(
38 py: Python<'py>,
39 data: &Bound<'py, PyBytes>,
40) -> PyResult<Bound<'py, PyList>> {
41 let bytes = data.as_bytes();
42 let records =
43 parser::parse_all(Cursor::new(bytes)).map_err(|e| PyValueError::new_err(e.to_string()))?;
44 let out = PyList::empty_bound(py);
45 for r in records {
46 let v = serde_json::to_value(&r).map_err(|e| PyValueError::new_err(e.to_string()))?;
47 let obj = pythonize(py, &v).map_err(|e| PyValueError::new_err(e.to_string()))?;
48 out.append(obj)?;
49 }
50 Ok(out)
51}
52
53#[pyfunction]
55fn write_agentlog<'py>(
56 py: Python<'py>,
57 records: &Bound<'py, PyList>,
58) -> PyResult<Bound<'py, PyBytes>> {
59 let mut parsed: Vec<Record> = Vec::with_capacity(records.len());
60 for item in records.iter() {
61 let v: serde_json::Value =
62 depythonize(&item).map_err(|e| PyValueError::new_err(e.to_string()))?;
63 let r: Record =
64 serde_json::from_value(v).map_err(|e| PyValueError::new_err(e.to_string()))?;
65 parsed.push(r);
66 }
67 let mut buf = Vec::new();
68 writer::write_all(&mut buf, &parsed).map_err(|e| PyIOError::new_err(e.to_string()))?;
69 Ok(PyBytes::new_bound(py, &buf))
70}
71
72#[pyfunction]
74fn canonical_bytes<'py>(
75 py: Python<'py>,
76 payload: &Bound<'py, PyAny>,
77) -> PyResult<Bound<'py, PyBytes>> {
78 let v: serde_json::Value =
79 depythonize(payload).map_err(|e| PyValueError::new_err(e.to_string()))?;
80 let bytes = crate::agentlog::canonical::to_bytes(&v);
81 Ok(PyBytes::new_bound(py, &bytes))
82}
83
84#[pyfunction]
86fn content_id(payload: &Bound<'_, PyAny>) -> PyResult<String> {
87 let v: serde_json::Value =
88 depythonize(payload).map_err(|e| PyValueError::new_err(e.to_string()))?;
89 Ok(hash::content_id(&v))
90}
91
92#[pyfunction]
99#[pyo3(signature = (baseline, candidate, pricing=None, seed=None))]
100fn compute_diff_report<'py>(
101 py: Python<'py>,
102 baseline: &Bound<'py, PyList>,
103 candidate: &Bound<'py, PyList>,
104 pricing: Option<&Bound<'py, PyDict>>,
105 seed: Option<u64>,
106) -> PyResult<Bound<'py, PyAny>> {
107 let baseline_records = pylist_to_records(baseline)?;
108 let candidate_records = pylist_to_records(candidate)?;
109
110 let mut price_map = Pricing::new();
111 if let Some(dict) = pricing {
112 for (k, v) in dict.iter() {
113 let key: String = k
114 .extract()
115 .map_err(|e| PyValueError::new_err(format!("pricing key: {e}")))?;
116 let mp = if let Ok(pair) = v.extract::<(f64, f64)>() {
119 ModelPricing::simple(pair.0, pair.1)
120 } else {
121 let v_json: serde_json::Value = depythonize(&v)
122 .map_err(|e| PyValueError::new_err(format!("pricing value: {e}")))?;
123 serde_json::from_value(v_json)
124 .map_err(|e| PyValueError::new_err(format!("pricing value: {e}")))?
125 };
126 price_map.insert(key, mp);
127 }
128 }
129
130 let report = compute_report(&baseline_records, &candidate_records, &price_map, seed);
131 let v = serde_json::to_value(&report).map_err(|e| PyValueError::new_err(e.to_string()))?;
132 pythonize(py, &v).map_err(|e| PyValueError::new_err(e.to_string()))
133}
134
135#[pyfunction]
150#[pyo3(signature = (baseline, candidate, embedder, seed=None))]
151fn compute_semantic_axis_with_embedder<'py>(
152 py: Python<'py>,
153 baseline: &Bound<'py, PyList>,
154 candidate: &Bound<'py, PyList>,
155 embedder: &Bound<'py, PyAny>,
156 seed: Option<u64>,
157) -> PyResult<Bound<'py, PyAny>> {
158 let baseline_records = pylist_to_records(baseline)?;
159 let candidate_records = pylist_to_records(candidate)?;
160
161 if !embedder.is_callable() {
162 return Err(PyValueError::new_err(
163 "embedder must be callable: fn(list[str]) -> list[list[float]]",
164 ));
165 }
166
167 let pairs = pair_responses(&baseline_records, &candidate_records);
170
171 let embedder_obj: Py<PyAny> = embedder.clone().unbind();
176 let py_embedder = BoxedEmbedder::named(
177 move |texts: &[&str]| -> Vec<Vec<f32>> {
178 Python::with_gil(|py| {
179 let owned: Vec<String> = texts.iter().map(|s| (*s).to_string()).collect();
180 let py_list = PyList::new_bound(py, &owned);
181 let result = embedder_obj.call1(py, (py_list,));
182 let any = match result {
183 Ok(v) => v,
184 Err(_) => return Vec::new(),
185 };
186 let bound = any.bind(py);
187 bound.extract::<Vec<Vec<f32>>>().unwrap_or_default()
188 })
189 },
190 "py-callback",
191 );
192
193 let pair_refs: Vec<(&Record, &Record)> = pairs.iter().map(|(a, b)| (*a, *b)).collect();
197 let stat = compute_with_embedder(&pair_refs, &py_embedder as &dyn Embedder, seed);
198
199 let v = serde_json::to_value(&stat).map_err(|e| PyValueError::new_err(e.to_string()))?;
200 pythonize(py, &v).map_err(|e| PyValueError::new_err(e.to_string()))
201}
202
203fn pair_responses<'a>(
206 baseline: &'a [Record],
207 candidate: &'a [Record],
208) -> Vec<(&'a Record, &'a Record)> {
209 use crate::agentlog::Kind;
210 let b_resps: Vec<&Record> = baseline
211 .iter()
212 .filter(|r| r.kind == Kind::ChatResponse)
213 .collect();
214 let c_resps: Vec<&Record> = candidate
215 .iter()
216 .filter(|r| r.kind == Kind::ChatResponse)
217 .collect();
218 b_resps.into_iter().zip(c_resps).collect()
219}
220
221fn pylist_to_records(list: &Bound<'_, PyList>) -> PyResult<Vec<Record>> {
222 let mut out = Vec::with_capacity(list.len());
223 for item in list.iter() {
224 let v: serde_json::Value =
225 depythonize(&item).map_err(|e| PyValueError::new_err(e.to_string()))?;
226 let r: Record =
227 serde_json::from_value(v).map_err(|e| PyValueError::new_err(e.to_string()))?;
228 out.push(r);
229 }
230 Ok(out)
231}
232
233#[pymodule]
236fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
237 m.add("__version__", crate::VERSION)?;
238 m.add("SPEC_VERSION", crate::agentlog::CURRENT_VERSION)?;
239 m.add_function(wrap_pyfunction!(parse_agentlog, m)?)?;
240 m.add_function(wrap_pyfunction!(write_agentlog, m)?)?;
241 m.add_function(wrap_pyfunction!(canonical_bytes, m)?)?;
242 m.add_function(wrap_pyfunction!(content_id, m)?)?;
243 m.add_function(wrap_pyfunction!(compute_diff_report, m)?)?;
244 m.add_function(wrap_pyfunction!(compute_semantic_axis_with_embedder, m)?)?;
245 Ok(())
246}