Skip to main content

semantic_memory_forge/
estimator.rs

1//! Estimator and sidecar execution metadata.
2//!
3//! Records full methodological metadata for estimators and refuters,
4//! including Python sidecar discipline when applicable.
5
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9/// Kind of estimator used.
10#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
11#[serde(rename_all = "snake_case")]
12pub enum EstimatorKind {
13    /// Difference-in-differences.
14    DiffInDiff,
15    /// Propensity score matching.
16    PropensityScore,
17    /// Instrumental variables.
18    InstrumentalVariables,
19    /// Ordinary least squares.
20    OLS,
21    /// Bayesian estimation.
22    Bayesian,
23    /// Simple before/after comparison.
24    BeforeAfter,
25    /// Custom estimator.
26    Custom(String),
27}
28
29impl EstimatorKind {
30    /// Returns the stable wire-format label for the estimator kind.
31    pub fn as_str(&self) -> &str {
32        match self {
33            Self::DiffInDiff => "diff_in_diff",
34            Self::PropensityScore => "propensity_score",
35            Self::InstrumentalVariables => "instrumental_variables",
36            Self::OLS => "ols",
37            Self::Bayesian => "bayesian",
38            Self::BeforeAfter => "before_after",
39            Self::Custom(s) => s,
40        }
41    }
42}
43
44/// Metadata about an estimator or refuter invocation.
45///
46/// Captures everything needed to reproduce or audit the estimation.
47#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
48pub struct EstimatorMeta {
49    /// The kind of estimator.
50    pub kind: EstimatorKind,
51    /// Version of the estimator (semver or commit hash).
52    pub version: String,
53    /// Parameters passed to the estimator.
54    pub parameters: serde_json::Value,
55    /// Random seed, if applicable for reproducibility.
56    pub random_seed: Option<u64>,
57    /// Environment fingerprint for the execution.
58    pub environment: Option<EnvironmentFingerprint>,
59    /// Timeout applied to the execution.
60    pub timeout_secs: Option<u64>,
61    /// How the estimator failed, if it did.
62    pub failure_mode: Option<String>,
63    /// Versioned request schema identifier.
64    pub request_schema_version: Option<String>,
65    /// Versioned response schema identifier.
66    pub response_schema_version: Option<String>,
67}
68
69/// Fingerprint of the execution environment.
70///
71/// Used to detect environment drift that could affect reproducibility.
72#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
73pub struct EnvironmentFingerprint {
74    /// Python version (if sidecar).
75    pub python_version: Option<String>,
76    /// Key package versions (e.g., {"dowhy": "0.11", "numpy": "1.26"}).
77    pub package_versions: serde_json::Value,
78    /// OS / platform identifier.
79    pub platform: Option<String>,
80    /// Hash of the full environment specification (e.g., pip freeze hash).
81    pub env_hash: Option<String>,
82}
83
84/// Record of a sidecar execution (e.g., Python estimation/refutation).
85///
86/// Preserves the full request/response chain for audit and replay.
87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
88pub struct SidecarExecution {
89    /// Estimator metadata.
90    pub estimator: EstimatorMeta,
91    /// The request payload sent to the sidecar.
92    pub request: serde_json::Value,
93    /// The response payload received from the sidecar.
94    pub response: Option<serde_json::Value>,
95    /// Duration of the execution in milliseconds.
96    pub duration_ms: Option<u64>,
97    /// Whether the execution succeeded.
98    pub success: bool,
99    /// Error message if the execution failed.
100    pub error: Option<String>,
101    /// When the execution started.
102    pub started_at: String,
103    /// When the execution completed.
104    pub completed_at: Option<String>,
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn estimator_meta_serde() {
113        let meta = EstimatorMeta {
114            kind: EstimatorKind::DiffInDiff,
115            version: "1.0.0".into(),
116            parameters: serde_json::json!({"method": "linear"}),
117            random_seed: Some(42),
118            environment: Some(EnvironmentFingerprint {
119                python_version: Some("3.11".into()),
120                package_versions: serde_json::json!({"dowhy": "0.11"}),
121                platform: Some("linux-x86_64".into()),
122                env_hash: None,
123            }),
124            timeout_secs: Some(300),
125            failure_mode: None,
126            request_schema_version: Some("v1".into()),
127            response_schema_version: Some("v1".into()),
128        };
129
130        let json = serde_json::to_string(&meta).unwrap();
131        let back: EstimatorMeta = serde_json::from_str(&json).unwrap();
132        assert_eq!(back.version, "1.0.0");
133        assert_eq!(back.random_seed, Some(42));
134    }
135
136    #[test]
137    fn sidecar_execution_serde() {
138        let exec = SidecarExecution {
139            estimator: EstimatorMeta {
140                kind: EstimatorKind::PropensityScore,
141                version: "2.0.0".into(),
142                parameters: serde_json::json!({}),
143                random_seed: None,
144                environment: None,
145                timeout_secs: Some(60),
146                failure_mode: None,
147                request_schema_version: None,
148                response_schema_version: None,
149            },
150            request: serde_json::json!({"data": [1, 2, 3]}),
151            response: Some(serde_json::json!({"estimate": 0.5})),
152            duration_ms: Some(1500),
153            success: true,
154            error: None,
155            started_at: "2024-01-01T00:00:00Z".into(),
156            completed_at: Some("2024-01-01T00:00:01Z".into()),
157        };
158
159        let json = serde_json::to_string(&exec).unwrap();
160        let back: SidecarExecution = serde_json::from_str(&json).unwrap();
161        assert!(back.success);
162        assert_eq!(back.duration_ms, Some(1500));
163    }
164}