stormchaser_model/
schema_cache.rs1use oci_distribution::client::{Client, ClientConfig, ClientProtocol};
2use oci_distribution::secrets::RegistryAuth;
3use oci_distribution::Reference;
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9#[derive(Clone)]
11pub struct SchemaCache {
12 schemas: Arc<RwLock<HashMap<String, Value>>>,
13}
14
15impl Default for SchemaCache {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21impl SchemaCache {
22 pub fn new() -> Self {
24 Self {
25 schemas: Arc::new(RwLock::new(HashMap::new())),
26 }
27 }
28
29 pub async fn get(&self, schema_id: &str) -> Option<Value> {
31 let cache = self.schemas.read().await;
32 cache.get(schema_id).cloned()
33 }
34
35 pub async fn insert(&self, schema_id: String, schema: Value) {
37 let mut cache = self.schemas.write().await;
38 cache.insert(schema_id, schema);
39 }
40
41 pub fn start_background_sync(&self, oci_registry_url: String) {
44 let schemas = self.schemas.clone();
45 tokio::spawn(async move {
46 loop {
47 let fetched_schemas = Self::fetch_schemas_from_oci(&oci_registry_url).await;
48 for (id, schema) in fetched_schemas {
49 schemas.write().await.insert(id, schema);
50 }
51 tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
52 }
53 });
54 }
55
56 async fn fetch_schemas_from_oci(url: &str) -> Vec<(String, Value)> {
59 let protocol = if url.starts_with("http://") {
60 ClientProtocol::Http
61 } else {
62 ClientProtocol::Https
63 };
64 let reference_str = url
65 .trim_start_matches("http://")
66 .trim_start_matches("https://");
67
68 let reference: Reference = match reference_str.parse() {
69 Ok(r) => r,
70 Err(e) => {
71 tracing::error!(
72 "Failed to parse OCI registry URL '{}': {}",
73 reference_str,
74 e
75 );
76 return vec![];
77 }
78 };
79
80 let config = ClientConfig {
81 protocol,
82 ..Default::default()
83 };
84 let client = Client::new(config);
85 let auth = RegistryAuth::Anonymous;
86
87 let image_data = match client
88 .pull(
89 &reference,
90 &auth,
91 vec![
92 "application/vnd.oci.image.layer.v1.tar+gzip",
93 "application/vnd.oci.image.layer.v1.tar",
94 ],
95 )
96 .await
97 {
98 Ok(data) => data,
99 Err(e) => {
100 tracing::error!("Failed to pull OCI artifact from '{}': {}", url, e);
101 return vec![];
102 }
103 };
104
105 let mut results = Vec::new();
106 for layer in image_data.layers {
107 let bytes = layer.data;
109 let media_type = layer.media_type;
110
111 if let Ok(value) = serde_json::from_slice::<Value>(&bytes) {
115 if let Some(id) = value.get("$id").and_then(|v| v.as_str()) {
116 results.push((id.to_string(), value));
117 continue;
118 }
119 }
120
121 use std::io::Read;
123 use tar::Archive;
124
125 let is_gzip = media_type == "application/vnd.oci.image.layer.v1.tar+gzip";
128 let reader: Box<dyn Read> = if is_gzip {
129 use flate2::read::GzDecoder;
130 Box::new(GzDecoder::new(std::io::Cursor::new(bytes)))
131 } else {
132 Box::new(std::io::Cursor::new(bytes))
133 };
134 let mut archive = Archive::new(reader);
135
136 if let Ok(entries) = archive.entries() {
137 for file in entries.flatten() {
138 let path = file
139 .path()
140 .map(|p| p.to_string_lossy().to_string())
141 .unwrap_or_default();
142 if path.ends_with(".json") {
143 let mut content = Vec::new();
144 let mut file = file; if file.read_to_end(&mut content).is_ok() {
146 if let Ok(value) = serde_json::from_slice::<Value>(&content) {
147 if let Some(id) = value.get("$id").and_then(|v| v.as_str()) {
148 results.push((id.to_string(), value));
149 }
150 }
151 }
152 }
153 }
154 }
155 }
156
157 results
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 #[tokio::test]
166 async fn test_schema_cache_insert_and_get() {
167 let cache = SchemaCache::default();
168 let schema_id = "test_schema_id".to_string();
169 let schema_val = serde_json::json!({"type": "object"});
170
171 assert_eq!(cache.get(&schema_id).await, None);
173
174 cache.insert(schema_id.clone(), schema_val.clone()).await;
176
177 assert_eq!(cache.get(&schema_id).await, Some(schema_val));
179 }
180
181 #[tokio::test]
182 async fn test_fetch_schemas_from_oci_mock() {
183 use wiremock::matchers::{method, path};
184 use wiremock::{Mock, MockServer, ResponseTemplate};
185
186 let mock_server = MockServer::start().await;
187
188 Mock::given(method("GET"))
190 .and(path("/v2/"))
191 .respond_with(ResponseTemplate::new(200))
192 .mount(&mock_server)
193 .await;
194
195 let manifest = serde_json::json!({
197 "schemaVersion": 2,
198 "mediaType": "application/vnd.oci.image.manifest.v1+json",
199 "config": {
200 "mediaType": "application/vnd.oci.image.config.v1+json",
201 "digest": "sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a",
202 "size": 2
203 },
204 "layers": [
205 {
206 "mediaType": "application/vnd.oci.image.layer.v1.tar+gzip",
207 "digest": "sha256:416c9c6f24b11975d1f224ea33076dc692289fda308f8ed61ca49f097186e6a1",
208 "size": 40
209 }
210 ]
211 });
212 Mock::given(method("GET"))
213 .and(path("/v2/mock-repo/manifests/latest"))
214 .respond_with(
215 ResponseTemplate::new(200)
216 .set_body_json(&manifest)
217 .insert_header("Docker-Content-Digest", "sha256:some-digest")
218 .insert_header("Content-Type", "application/vnd.oci.image.manifest.v1+json"),
219 )
220 .mount(&mock_server)
221 .await;
222
223 let config_json = "{}";
225 Mock::given(method("GET"))
226 .and(path("/v2/mock-repo/blobs/sha256:44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a"))
227 .respond_with(ResponseTemplate::new(200).set_body_string(config_json))
228 .mount(&mock_server)
229 .await;
230
231 let schema_json = r#"{"$id": "mock_schema", "type": "object"}"#;
233 Mock::given(method("GET"))
234 .and(path("/v2/mock-repo/blobs/sha256:416c9c6f24b11975d1f224ea33076dc692289fda308f8ed61ca49f097186e6a1"))
235 .respond_with(ResponseTemplate::new(200).set_body_string(schema_json))
236 .mount(&mock_server)
237 .await;
238
239 let url = format!("{}/mock-repo:latest", mock_server.uri());
240 let results = SchemaCache::fetch_schemas_from_oci(&url).await;
241
242 assert_eq!(results.len(), 1);
243 assert_eq!(results[0].0, "mock_schema");
244 assert_eq!(
245 results[0].1.get("type").unwrap().as_str().unwrap(),
246 "object"
247 );
248 }
249}