Skip to main content

stormchaser_model/
schema_cache.rs

1use oci_distribution::client::{Client, ClientConfig, ClientProtocol};
2use oci_distribution::secrets::RegistryAuth;
3use oci_distribution::Reference;
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// In-memory cache for OCI-backed CloudEvent schemas.
10#[derive(Clone)]
11pub struct SchemaCache {
12    schemas: Arc<RwLock<HashMap<String, Value>>>,
13}
14
15impl Default for SchemaCache {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl SchemaCache {
22    /// Create a new empty schema cache.
23    pub fn new() -> Self {
24        Self {
25            schemas: Arc::new(RwLock::new(HashMap::new())),
26        }
27    }
28
29    /// Get a schema by its ID or Version.
30    pub async fn get(&self, schema_id: &str) -> Option<Value> {
31        let cache = self.schemas.read().await;
32        cache.get(schema_id).cloned()
33    }
34
35    /// Insert or update a schema in the cache.
36    pub async fn insert(&self, schema_id: String, schema: Value) {
37        let mut cache = self.schemas.write().await;
38        cache.insert(schema_id, schema);
39    }
40
41    /// Start a background sync process to fetch schemas from an OCI registry.
42    /// This is a permissive sync; it runs asynchronously and updates the cache.
43    pub fn start_background_sync(&self, oci_registry_url: String) {
44        let schemas = self.schemas.clone();
45        tokio::spawn(async move {
46            loop {
47                let fetched_schemas = Self::fetch_schemas_from_oci(&oci_registry_url).await;
48                for (id, schema) in fetched_schemas {
49                    schemas.write().await.insert(id, schema);
50                }
51                tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
52            }
53        });
54    }
55
56    /// Fetches schemas from the given OCI registry URL.
57    /// We attempt to pull image layers and parse them as JSON schemas.
58    async fn fetch_schemas_from_oci(url: &str) -> Vec<(String, Value)> {
59        let protocol = if url.starts_with("http://") {
60            ClientProtocol::Http
61        } else {
62            ClientProtocol::Https
63        };
64        let reference_str = url
65            .trim_start_matches("http://")
66            .trim_start_matches("https://");
67
68        let reference: Reference = match reference_str.parse() {
69            Ok(r) => r,
70            Err(e) => {
71                tracing::error!(
72                    "Failed to parse OCI registry URL '{}': {}",
73                    reference_str,
74                    e
75                );
76                return vec![];
77            }
78        };
79
80        let config = ClientConfig {
81            protocol,
82            ..Default::default()
83        };
84        let client = Client::new(config);
85        let auth = RegistryAuth::Anonymous;
86
87        let image_data = match client
88            .pull(
89                &reference,
90                &auth,
91                vec![
92                    "application/vnd.oci.image.layer.v1.tar+gzip",
93                    "application/vnd.oci.image.layer.v1.tar",
94                ],
95            )
96            .await
97        {
98            Ok(data) => data,
99            Err(e) => {
100                tracing::error!("Failed to pull OCI artifact from '{}': {}", url, e);
101                return vec![];
102            }
103        };
104
105        let mut results = Vec::new();
106        for layer in image_data.layers {
107            // Read layer data
108            let bytes = layer.data;
109            let media_type = layer.media_type;
110
111            // For now, assume it's just raw JSON or try to parse directly.
112            // If the layer is a tarball, we might need to extract it.
113            // A simple approach is to see if we can parse it as JSON first.
114            if let Ok(value) = serde_json::from_slice::<Value>(&bytes) {
115                if let Some(id) = value.get("$id").and_then(|v| v.as_str()) {
116                    results.push((id.to_string(), value));
117                    continue;
118                }
119            }
120
121            // Extract JSON files from tar or tar+gzip layer
122            use std::io::Read;
123            use tar::Archive;
124
125            // Detect if the layer is gzip-compressed. Other compression formats (e.g., zstd)
126            // and plain uncompressed tar both fall through to the uncompressed tar path.
127            let is_gzip = media_type == "application/vnd.oci.image.layer.v1.tar+gzip";
128            let reader: Box<dyn Read> = if is_gzip {
129                use flate2::read::GzDecoder;
130                Box::new(GzDecoder::new(std::io::Cursor::new(bytes)))
131            } else {
132                Box::new(std::io::Cursor::new(bytes))
133            };
134            let mut archive = Archive::new(reader);
135
136            if let Ok(entries) = archive.entries() {
137                for file in entries.flatten() {
138                    let path = file
139                        .path()
140                        .map(|p| p.to_string_lossy().to_string())
141                        .unwrap_or_default();
142                    if path.ends_with(".json") {
143                        let mut content = Vec::new();
144                        let mut file = file; // take ownership of mut entry
145                        if file.read_to_end(&mut content).is_ok() {
146                            if let Ok(value) = serde_json::from_slice::<Value>(&content) {
147                                if let Some(id) = value.get("$id").and_then(|v| v.as_str()) {
148                                    results.push((id.to_string(), value));
149                                }
150                            }
151                        }
152                    }
153                }
154            }
155        }
156
157        results
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[tokio::test]
166    async fn test_schema_cache_insert_and_get() {
167        let cache = SchemaCache::default();
168        let schema_id = "test_schema_id".to_string();
169        let schema_val = serde_json::json!({"type": "object"});
170
171        // Should be empty initially
172        assert_eq!(cache.get(&schema_id).await, None);
173
174        // Insert
175        cache.insert(schema_id.clone(), schema_val.clone()).await;
176
177        // Should return the inserted value
178        assert_eq!(cache.get(&schema_id).await, Some(schema_val));
179    }
180
181    #[tokio::test]
182    async fn test_fetch_schemas_from_oci_mock() {
183        use wiremock::matchers::{method, path};
184        use wiremock::{Mock, MockServer, ResponseTemplate};
185
186        let mock_server = MockServer::start().await;
187
188        // Mock /v2/ auth check
189        Mock::given(method("GET"))
190            .and(path("/v2/"))
191            .respond_with(ResponseTemplate::new(200))
192            .mount(&mock_server)
193            .await;
194
195        // Mock manifest
196        let manifest = serde_json::json!({
197           "schemaVersion": 2,
198           "mediaType": "application/vnd.oci.image.manifest.v1+json",
199           "config": {
200              "mediaType": "application/vnd.oci.image.config.v1+json",
201              "digest": "sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a",
202              "size": 2
203           },
204           "layers": [
205              {
206                 "mediaType": "application/vnd.oci.image.layer.v1.tar+gzip",
207                 "digest": "sha256:416c9c6f24b11975d1f224ea33076dc692289fda308f8ed61ca49f097186e6a1",
208                 "size": 40
209              }
210           ]
211        });
212        Mock::given(method("GET"))
213            .and(path("/v2/mock-repo/manifests/latest"))
214            .respond_with(
215                ResponseTemplate::new(200)
216                    .set_body_json(&manifest)
217                    .insert_header("Docker-Content-Digest", "sha256:some-digest")
218                    .insert_header("Content-Type", "application/vnd.oci.image.manifest.v1+json"),
219            )
220            .mount(&mock_server)
221            .await;
222
223        // Mock config blob
224        let config_json = "{}";
225        Mock::given(method("GET"))
226            .and(path("/v2/mock-repo/blobs/sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a"))
227            .respond_with(ResponseTemplate::new(200).set_body_string(config_json))
228            .mount(&mock_server)
229            .await;
230
231        // Mock blob
232        let schema_json = r#"{"$id": "mock_schema", "type": "object"}"#;
233        Mock::given(method("GET"))
234            .and(path("/v2/mock-repo/blobs/sha256:416c9c6f24b11975d1f224ea33076dc692289fda308f8ed61ca49f097186e6a1"))
235            .respond_with(ResponseTemplate::new(200).set_body_string(schema_json))
236            .mount(&mock_server)
237            .await;
238
239        let url = format!("{}/mock-repo:latest", mock_server.uri());
240        let results = SchemaCache::fetch_schemas_from_oci(&url).await;
241
242        assert_eq!(results.len(), 1);
243        assert_eq!(results[0].0, "mock_schema");
244        assert_eq!(
245            results[0].1.get("type").unwrap().as_str().unwrap(),
246            "object"
247        );
248    }
249}