1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use chrono::{NaiveDate, Utc};
12use reqwest::{Client, StatusCode};
13use serde::Deserialize;
14use tokio::time::sleep;
15
16use crate::{DataRecord, DataSource, FrameworkError, Relationship, Result};
17
18const DEFAULT_RATE_LIMIT_DELAY_MS: u64 = 100;
20const MAX_RETRIES: u32 = 3;
21const RETRY_DELAY_MS: u64 = 1000;
22
23pub struct SimpleEmbedder {
29 dimension: usize,
30}
31
32impl SimpleEmbedder {
33 pub fn new(dimension: usize) -> Self {
35 Self { dimension }
36 }
37
38 pub fn embed_text(&self, text: &str) -> Vec<f32> {
40 let lowercase_text = text.to_lowercase();
41 let words: Vec<&str> = lowercase_text
42 .split_whitespace()
43 .filter(|w| w.len() > 2)
44 .collect();
45
46 let mut embedding = vec![0.0f32; self.dimension];
47
48 for word in words {
50 let hash = self.hash_word(word);
51 let idx = (hash % self.dimension as u64) as usize;
52 embedding[idx] += 1.0;
53 }
54
55 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
57 if magnitude > 0.0 {
58 for val in &mut embedding {
59 *val /= magnitude;
60 }
61 }
62
63 embedding
64 }
65
66 fn hash_word(&self, word: &str) -> u64 {
68 let mut hash = 5381u64;
69 for byte in word.bytes() {
70 hash = hash.wrapping_mul(33).wrapping_add(byte as u64);
71 }
72 hash
73 }
74
75 pub fn embed_json(&self, value: &serde_json::Value) -> Vec<f32> {
77 let text = self.extract_text_from_json(value);
78 self.embed_text(&text)
79 }
80
81 fn extract_text_from_json(&self, value: &serde_json::Value) -> String {
83 match value {
84 serde_json::Value::String(s) => s.clone(),
85 serde_json::Value::Object(map) => {
86 let mut text = String::new();
87 for (key, val) in map {
88 text.push_str(key);
89 text.push(' ');
90 text.push_str(&self.extract_text_from_json(val));
91 text.push(' ');
92 }
93 text
94 }
95 serde_json::Value::Array(arr) => {
96 arr.iter()
97 .map(|v| self.extract_text_from_json(v))
98 .collect::<Vec<_>>()
99 .join(" ")
100 }
101 serde_json::Value::Number(n) => n.to_string(),
102 serde_json::Value::Bool(b) => b.to_string(),
103 serde_json::Value::Null => String::new(),
104 }
105 }
106}
107
108#[cfg(feature = "onnx-embeddings")]
115pub struct OnnxEmbedder {
116 embedder: std::sync::RwLock<ruvector_onnx_embeddings::Embedder>,
117}
118
119#[cfg(feature = "onnx-embeddings")]
120impl OnnxEmbedder {
121 pub async fn new() -> std::result::Result<Self, Box<dyn std::error::Error + Send + Sync>> {
123 let embedder = ruvector_onnx_embeddings::Embedder::default_model().await?;
124 Ok(Self {
125 embedder: std::sync::RwLock::new(embedder),
126 })
127 }
128
129 pub async fn with_model(
131 model: ruvector_onnx_embeddings::PretrainedModel,
132 ) -> std::result::Result<Self, Box<dyn std::error::Error + Send + Sync>> {
133 let embedder = ruvector_onnx_embeddings::Embedder::pretrained(model).await?;
134 Ok(Self {
135 embedder: std::sync::RwLock::new(embedder),
136 })
137 }
138
139 pub fn embed_text(&self, text: &str) -> Vec<f32> {
141 let mut embedder = self.embedder.write().unwrap();
142 embedder.embed_one(text).unwrap_or_else(|_| vec![0.0; 384])
143 }
144
145 pub fn embed_batch(&self, texts: &[&str]) -> Vec<Vec<f32>> {
147 let mut embedder = self.embedder.write().unwrap();
148 match embedder.embed(texts) {
149 Ok(output) => (0..texts.len())
150 .map(|i| output.get(i).unwrap_or(&vec![0.0; 384]).clone())
151 .collect(),
152 Err(_) => texts.iter().map(|_| vec![0.0; 384]).collect(),
153 }
154 }
155
156 pub fn embed_batch_chunked(&self, texts: &[&str], batch_size: usize) -> Vec<Vec<f32>> {
170 let batch_size = batch_size.max(1);
171 let dim = self.dimension();
172 let mut all_embeddings = Vec::with_capacity(texts.len());
173
174 for chunk in texts.chunks(batch_size) {
175 let chunk_embeddings = self.embed_batch(chunk);
176 all_embeddings.extend(chunk_embeddings);
177 }
178
179 while all_embeddings.len() < texts.len() {
181 all_embeddings.push(vec![0.0; dim]);
182 }
183
184 all_embeddings
185 }
186
187 pub fn embed_batch_with_progress<F>(
194 &self,
195 texts: &[&str],
196 batch_size: usize,
197 mut progress_fn: F,
198 ) -> Vec<Vec<f32>>
199 where
200 F: FnMut(usize, usize),
201 {
202 let batch_size = batch_size.max(1);
203 let total = texts.len();
204 let dim = self.dimension();
205 let mut all_embeddings = Vec::with_capacity(total);
206 let mut processed = 0;
207
208 for chunk in texts.chunks(batch_size) {
209 let chunk_embeddings = self.embed_batch(chunk);
210 all_embeddings.extend(chunk_embeddings);
211 processed += chunk.len();
212 progress_fn(processed, total);
213 }
214
215 while all_embeddings.len() < total {
217 all_embeddings.push(vec![0.0; dim]);
218 }
219
220 all_embeddings
221 }
222
223 pub fn dimension(&self) -> usize {
225 let embedder = self.embedder.read().unwrap();
226 embedder.dimension()
227 }
228
229 pub fn similarity(&self, text1: &str, text2: &str) -> f32 {
231 let mut embedder = self.embedder.write().unwrap();
232 embedder.similarity(text1, text2).unwrap_or(0.0)
233 }
234
235 pub fn embed_json(&self, value: &serde_json::Value) -> Vec<f32> {
237 let text = extract_text_from_json(value);
238 self.embed_text(&text)
239 }
240}
241
242fn extract_text_from_json(value: &serde_json::Value) -> String {
244 match value {
245 serde_json::Value::String(s) => s.clone(),
246 serde_json::Value::Object(map) => {
247 let mut text = String::new();
248 for (key, val) in map {
249 text.push_str(key);
250 text.push(' ');
251 text.push_str(&extract_text_from_json(val));
252 text.push(' ');
253 }
254 text
255 }
256 serde_json::Value::Array(arr) => arr
257 .iter()
258 .map(|v| extract_text_from_json(v))
259 .collect::<Vec<_>>()
260 .join(" "),
261 serde_json::Value::Number(n) => n.to_string(),
262 serde_json::Value::Bool(b) => b.to_string(),
263 serde_json::Value::Null => String::new(),
264 }
265}
266
267pub trait Embedder: Send + Sync {
269 fn embed(&self, text: &str) -> Vec<f32>;
271 fn dim(&self) -> usize;
273}
274
275impl Embedder for SimpleEmbedder {
276 fn embed(&self, text: &str) -> Vec<f32> {
277 self.embed_text(text)
278 }
279 fn dim(&self) -> usize {
280 self.dimension
281 }
282}
283
284#[cfg(feature = "onnx-embeddings")]
285impl Embedder for OnnxEmbedder {
286 fn embed(&self, text: &str) -> Vec<f32> {
287 self.embed_text(text)
288 }
289 fn dim(&self) -> usize {
290 self.dimension()
291 }
292}
293
294#[derive(Debug, Deserialize)]
300struct OpenAlexWorksResponse {
301 results: Vec<OpenAlexWork>,
302 meta: OpenAlexMeta,
303}
304
305#[derive(Debug, Deserialize)]
306struct OpenAlexWork {
307 id: String,
308 title: Option<String>,
309 #[serde(rename = "display_name")]
310 display_name: Option<String>,
311 publication_date: Option<String>,
312 #[serde(rename = "authorships")]
313 authorships: Option<Vec<OpenAlexAuthorship>>,
314 #[serde(rename = "cited_by_count")]
315 cited_by_count: Option<i64>,
316 #[serde(rename = "concepts")]
317 concepts: Option<Vec<OpenAlexConcept>>,
318 #[serde(rename = "abstract_inverted_index")]
319 abstract_inverted_index: Option<HashMap<String, Vec<i32>>>,
320}
321
322#[derive(Debug, Deserialize)]
323struct OpenAlexAuthorship {
324 author: Option<OpenAlexAuthor>,
325}
326
327#[derive(Debug, Deserialize)]
328struct OpenAlexAuthor {
329 id: String,
330 #[serde(rename = "display_name")]
331 display_name: Option<String>,
332}
333
334#[derive(Debug, Deserialize)]
335struct OpenAlexConcept {
336 id: String,
337 #[serde(rename = "display_name")]
338 display_name: Option<String>,
339 score: Option<f64>,
340}
341
342#[derive(Debug, Deserialize)]
343struct OpenAlexMeta {
344 count: i64,
345}
346
347#[derive(Debug, Deserialize)]
349struct OpenAlexTopicsResponse {
350 results: Vec<OpenAlexTopic>,
351}
352
353#[derive(Debug, Deserialize)]
354struct OpenAlexTopic {
355 id: String,
356 #[serde(rename = "display_name")]
357 display_name: String,
358 description: Option<String>,
359 #[serde(rename = "works_count")]
360 works_count: Option<i64>,
361}
362
363pub struct OpenAlexClient {
365 client: Client,
366 base_url: String,
367 rate_limit_delay: Duration,
368 embedder: Arc<SimpleEmbedder>,
369 user_email: Option<String>,
370}
371
372impl OpenAlexClient {
373 pub fn new(user_email: Option<String>) -> Result<Self> {
378 let client = Client::builder()
379 .timeout(Duration::from_secs(30))
380 .build()
381 .map_err(|e| FrameworkError::Network(e))?;
382
383 Ok(Self {
384 client,
385 base_url: "https://api.openalex.org".to_string(),
386 rate_limit_delay: Duration::from_millis(DEFAULT_RATE_LIMIT_DELAY_MS),
387 embedder: Arc::new(SimpleEmbedder::new(128)),
388 user_email,
389 })
390 }
391
392 pub async fn fetch_works(&self, query: &str, limit: usize) -> Result<Vec<DataRecord>> {
398 let mut url = format!("{}/works?search={}", self.base_url, urlencoding::encode(query));
399 url.push_str(&format!("&per-page={}", limit.min(200)));
400
401 if let Some(email) = &self.user_email {
402 url.push_str(&format!("&mailto={}", email));
403 }
404
405 let response = self.fetch_with_retry(&url).await?;
406 let works_response: OpenAlexWorksResponse = response.json().await?;
407
408 let mut records = Vec::new();
409 for work in works_response.results {
410 let record = self.work_to_record(work)?;
411 records.push(record);
412 sleep(self.rate_limit_delay).await;
413 }
414
415 Ok(records)
416 }
417
418 pub async fn fetch_topics(&self, domain: &str) -> Result<Vec<DataRecord>> {
420 let mut url = format!(
421 "{}/topics?search={}",
422 self.base_url,
423 urlencoding::encode(domain)
424 );
425 url.push_str("&per-page=50");
426
427 if let Some(email) = &self.user_email {
428 url.push_str(&format!("&mailto={}", email));
429 }
430
431 let response = self.fetch_with_retry(&url).await?;
432 let topics_response: OpenAlexTopicsResponse = response.json().await?;
433
434 let mut records = Vec::new();
435 for topic in topics_response.results {
436 let record = self.topic_to_record(topic)?;
437 records.push(record);
438 sleep(self.rate_limit_delay).await;
439 }
440
441 Ok(records)
442 }
443
444 fn work_to_record(&self, work: OpenAlexWork) -> Result<DataRecord> {
446 let title = work
447 .display_name
448 .or(work.title)
449 .unwrap_or_else(|| "Untitled".to_string());
450
451 let abstract_text = work
453 .abstract_inverted_index
454 .as_ref()
455 .map(|index| self.reconstruct_abstract(index))
456 .unwrap_or_default();
457
458 let text = format!("{} {}", title, abstract_text);
460 let embedding = self.embedder.embed_text(&text);
461
462 let timestamp = work
464 .publication_date
465 .as_ref()
466 .and_then(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok())
467 .map(|d| d.and_hms_opt(0, 0, 0).unwrap().and_utc())
468 .unwrap_or_else(Utc::now);
469
470 let mut relationships = Vec::new();
472
473 if let Some(authorships) = work.authorships {
475 for authorship in authorships {
476 if let Some(author) = authorship.author {
477 relationships.push(Relationship {
478 target_id: author.id,
479 rel_type: "authored_by".to_string(),
480 weight: 1.0,
481 properties: {
482 let mut props = HashMap::new();
483 if let Some(name) = author.display_name {
484 props.insert("author_name".to_string(), serde_json::json!(name));
485 }
486 props
487 },
488 });
489 }
490 }
491 }
492
493 if let Some(concepts) = work.concepts {
495 for concept in concepts {
496 relationships.push(Relationship {
497 target_id: concept.id,
498 rel_type: "has_concept".to_string(),
499 weight: concept.score.unwrap_or(0.0),
500 properties: {
501 let mut props = HashMap::new();
502 if let Some(name) = concept.display_name {
503 props.insert("concept_name".to_string(), serde_json::json!(name));
504 }
505 props
506 },
507 });
508 }
509 }
510
511 let mut data_map = serde_json::Map::new();
512 data_map.insert("title".to_string(), serde_json::json!(title));
513 data_map.insert("abstract".to_string(), serde_json::json!(abstract_text));
514 if let Some(citations) = work.cited_by_count {
515 data_map.insert("citations".to_string(), serde_json::json!(citations));
516 }
517
518 Ok(DataRecord {
519 id: work.id,
520 source: "openalex".to_string(),
521 record_type: "work".to_string(),
522 timestamp,
523 data: serde_json::Value::Object(data_map),
524 embedding: Some(embedding),
525 relationships,
526 })
527 }
528
529 fn reconstruct_abstract(&self, inverted_index: &HashMap<String, Vec<i32>>) -> String {
531 let mut positions: Vec<(i32, String)> = Vec::new();
532 for (word, indices) in inverted_index {
533 for &pos in indices {
534 positions.push((pos, word.clone()));
535 }
536 }
537 positions.sort_by_key(|&(pos, _)| pos);
538 positions
539 .into_iter()
540 .map(|(_, word)| word)
541 .collect::<Vec<_>>()
542 .join(" ")
543 }
544
545 fn topic_to_record(&self, topic: OpenAlexTopic) -> Result<DataRecord> {
547 let text = format!(
548 "{} {}",
549 topic.display_name,
550 topic.description.as_deref().unwrap_or("")
551 );
552 let embedding = self.embedder.embed_text(&text);
553
554 let mut data_map = serde_json::Map::new();
555 data_map.insert(
556 "display_name".to_string(),
557 serde_json::json!(topic.display_name),
558 );
559 if let Some(desc) = topic.description {
560 data_map.insert("description".to_string(), serde_json::json!(desc));
561 }
562 if let Some(count) = topic.works_count {
563 data_map.insert("works_count".to_string(), serde_json::json!(count));
564 }
565
566 Ok(DataRecord {
567 id: topic.id,
568 source: "openalex".to_string(),
569 record_type: "topic".to_string(),
570 timestamp: Utc::now(),
571 data: serde_json::Value::Object(data_map),
572 embedding: Some(embedding),
573 relationships: Vec::new(),
574 })
575 }
576
577 async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
579 let mut retries = 0;
580 loop {
581 match self.client.get(url).send().await {
582 Ok(response) => {
583 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
584 {
585 retries += 1;
586 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
587 continue;
588 }
589 return Ok(response);
590 }
591 Err(_) if retries < MAX_RETRIES => {
592 retries += 1;
593 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
594 }
595 Err(e) => return Err(FrameworkError::Network(e)),
596 }
597 }
598 }
599}
600
601#[async_trait]
602impl DataSource for OpenAlexClient {
603 fn source_id(&self) -> &str {
604 "openalex"
605 }
606
607 async fn fetch_batch(
608 &self,
609 cursor: Option<String>,
610 batch_size: usize,
611 ) -> Result<(Vec<DataRecord>, Option<String>)> {
612 let query = cursor.as_deref().unwrap_or("machine learning");
614 let records = self.fetch_works(query, batch_size).await?;
615 Ok((records, None)) }
617
618 async fn total_count(&self) -> Result<Option<u64>> {
619 Ok(None)
620 }
621
622 async fn health_check(&self) -> Result<bool> {
623 let response = self.client.get(&self.base_url).send().await?;
624 Ok(response.status().is_success())
625 }
626}
627
628#[derive(Debug, Deserialize)]
634struct NoaaResponse {
635 results: Vec<NoaaObservation>,
636}
637
638#[derive(Debug, Deserialize)]
639struct NoaaObservation {
640 station: String,
641 date: String,
642 datatype: String,
643 value: f64,
644 #[serde(default)]
645 attributes: String,
646}
647
648pub struct NoaaClient {
650 client: Client,
651 base_url: String,
652 api_token: Option<String>,
653 rate_limit_delay: Duration,
654 embedder: Arc<SimpleEmbedder>,
655}
656
657impl NoaaClient {
658 pub fn new(api_token: Option<String>) -> Result<Self> {
663 let client = Client::builder()
664 .timeout(Duration::from_secs(30))
665 .build()
666 .map_err(|e| FrameworkError::Network(e))?;
667
668 Ok(Self {
669 client,
670 base_url: "https://www.ncei.noaa.gov/cdo-web/api/v2".to_string(),
671 api_token,
672 rate_limit_delay: Duration::from_millis(200), embedder: Arc::new(SimpleEmbedder::new(128)),
674 })
675 }
676
677 pub async fn fetch_climate_data(
684 &self,
685 station_id: &str,
686 start_date: &str,
687 end_date: &str,
688 ) -> Result<Vec<DataRecord>> {
689 if self.api_token.is_none() {
690 return Ok(self.generate_synthetic_climate_data(station_id, start_date, end_date)?);
692 }
693
694 let url = format!(
695 "{}/data?datasetid=GHCND&stationid={}&startdate={}&enddate={}&limit=1000",
696 self.base_url, station_id, start_date, end_date
697 );
698
699 let mut request = self.client.get(&url);
700 if let Some(token) = &self.api_token {
701 request = request.header("token", token);
702 }
703
704 let response = self.fetch_with_retry(request).await?;
705 let noaa_response: NoaaResponse = response.json().await?;
706
707 let mut records = Vec::new();
708 for observation in noaa_response.results {
709 let record = self.observation_to_record(observation)?;
710 records.push(record);
711 }
712
713 Ok(records)
714 }
715
716 fn generate_synthetic_climate_data(
718 &self,
719 station_id: &str,
720 start_date: &str,
721 _end_date: &str,
722 ) -> Result<Vec<DataRecord>> {
723 let mut records = Vec::new();
724 let datatypes = vec!["TMAX", "TMIN", "PRCP"];
725
726 for (i, datatype) in datatypes.iter().enumerate() {
728 let value = match *datatype {
729 "TMAX" => 250.0 + (i as f64 * 10.0),
730 "TMIN" => 150.0 + (i as f64 * 10.0),
731 "PRCP" => 5.0 + (i as f64),
732 _ => 0.0,
733 };
734
735 let text = format!("{} {} {}", station_id, datatype, value);
736 let embedding = self.embedder.embed_text(&text);
737
738 let mut data_map = serde_json::Map::new();
739 data_map.insert("station".to_string(), serde_json::json!(station_id));
740 data_map.insert("datatype".to_string(), serde_json::json!(datatype));
741 data_map.insert("value".to_string(), serde_json::json!(value));
742 data_map.insert("unit".to_string(), serde_json::json!("tenths"));
743
744 records.push(DataRecord {
745 id: format!("{}_{}_{}_{}", station_id, datatype, start_date, i),
746 source: "noaa".to_string(),
747 record_type: "observation".to_string(),
748 timestamp: Utc::now(),
749 data: serde_json::Value::Object(data_map),
750 embedding: Some(embedding),
751 relationships: Vec::new(),
752 });
753 }
754
755 Ok(records)
756 }
757
758 fn observation_to_record(&self, obs: NoaaObservation) -> Result<DataRecord> {
760 let text = format!("{} {} {}", obs.station, obs.datatype, obs.value);
761 let embedding = self.embedder.embed_text(&text);
762
763 let timestamp = NaiveDate::parse_from_str(&obs.date, "%Y-%m-%dT%H:%M:%S")
765 .or_else(|_| NaiveDate::parse_from_str(&obs.date, "%Y-%m-%d"))
766 .ok()
767 .and_then(|d| d.and_hms_opt(0, 0, 0))
768 .map(|dt| dt.and_utc())
769 .unwrap_or_else(Utc::now);
770
771 let mut data_map = serde_json::Map::new();
772 data_map.insert("station".to_string(), serde_json::json!(obs.station));
773 data_map.insert("datatype".to_string(), serde_json::json!(obs.datatype));
774 data_map.insert("value".to_string(), serde_json::json!(obs.value));
775 data_map.insert("attributes".to_string(), serde_json::json!(obs.attributes));
776
777 Ok(DataRecord {
778 id: format!("{}_{}", obs.station, obs.date),
779 source: "noaa".to_string(),
780 record_type: "observation".to_string(),
781 timestamp,
782 data: serde_json::Value::Object(data_map),
783 embedding: Some(embedding),
784 relationships: Vec::new(),
785 })
786 }
787
788 async fn fetch_with_retry(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
790 let mut retries = 0;
791 loop {
792 let req = request
793 .try_clone()
794 .ok_or_else(|| FrameworkError::Config("Failed to clone request".to_string()))?;
795
796 match req.send().await {
797 Ok(response) => {
798 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
799 {
800 retries += 1;
801 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
802 continue;
803 }
804 return Ok(response);
805 }
806 Err(_) if retries < MAX_RETRIES => {
807 retries += 1;
808 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
809 }
810 Err(e) => return Err(FrameworkError::Network(e)),
811 }
812 }
813 }
814}
815
816#[async_trait]
817impl DataSource for NoaaClient {
818 fn source_id(&self) -> &str {
819 "noaa"
820 }
821
822 async fn fetch_batch(
823 &self,
824 _cursor: Option<String>,
825 _batch_size: usize,
826 ) -> Result<(Vec<DataRecord>, Option<String>)> {
827 let records = self
829 .fetch_climate_data("GHCND:USW00094728", "2024-01-01", "2024-01-31")
830 .await?;
831 Ok((records, None))
832 }
833
834 async fn total_count(&self) -> Result<Option<u64>> {
835 Ok(None)
836 }
837
838 async fn health_check(&self) -> Result<bool> {
839 Ok(true) }
841}
842
843#[derive(Debug, Deserialize)]
849struct EdgarFilingData {
850 #[serde(default)]
851 filings: EdgarFilings,
852}
853
854#[derive(Debug, Default, Deserialize)]
855struct EdgarFilings {
856 #[serde(default)]
857 recent: EdgarRecent,
858}
859
860#[derive(Debug, Default, Deserialize)]
861struct EdgarRecent {
862 #[serde(rename = "accessionNumber", default)]
863 accession_number: Vec<String>,
864 #[serde(rename = "filingDate", default)]
865 filing_date: Vec<String>,
866 #[serde(rename = "reportDate", default)]
867 report_date: Vec<String>,
868 #[serde(default)]
869 form: Vec<String>,
870 #[serde(rename = "primaryDocument", default)]
871 primary_document: Vec<String>,
872}
873
874pub struct EdgarClient {
876 client: Client,
877 base_url: String,
878 rate_limit_delay: Duration,
879 embedder: Arc<SimpleEmbedder>,
880 user_agent: String,
881}
882
883impl EdgarClient {
884 pub fn new(user_agent: String) -> Result<Self> {
889 let client = Client::builder()
890 .timeout(Duration::from_secs(30))
891 .user_agent(&user_agent)
892 .build()
893 .map_err(|e| FrameworkError::Network(e))?;
894
895 Ok(Self {
896 client,
897 base_url: "https://data.sec.gov".to_string(),
898 rate_limit_delay: Duration::from_millis(100), embedder: Arc::new(SimpleEmbedder::new(128)),
900 user_agent,
901 })
902 }
903
904 pub async fn fetch_filings(
910 &self,
911 cik: &str,
912 form_type: Option<&str>,
913 ) -> Result<Vec<DataRecord>> {
914 let padded_cik = format!("{:0>10}", cik);
916
917 let url = format!(
918 "{}/submissions/CIK{}.json",
919 self.base_url, padded_cik
920 );
921
922 let response = self.fetch_with_retry(&url).await?;
923 let filing_data: EdgarFilingData = response.json().await?;
924
925 let mut records = Vec::new();
926 let recent = filing_data.filings.recent;
927
928 let count = recent.accession_number.len();
929 for i in 0..count.min(50) {
930 if let Some(filter_form) = form_type {
933 if i < recent.form.len() && recent.form[i] != filter_form {
934 continue;
935 }
936 }
937
938 let filing = EdgarFiling {
939 cik: padded_cik.clone(),
940 accession_number: recent.accession_number.get(i).cloned().unwrap_or_default(),
941 filing_date: recent.filing_date.get(i).cloned().unwrap_or_default(),
942 report_date: recent.report_date.get(i).cloned().unwrap_or_default(),
943 form: recent.form.get(i).cloned().unwrap_or_default(),
944 primary_document: recent.primary_document.get(i).cloned().unwrap_or_default(),
945 };
946
947 let record = self.filing_to_record(filing)?;
948 records.push(record);
949 sleep(self.rate_limit_delay).await;
950 }
951
952 Ok(records)
953 }
954
955 fn filing_to_record(&self, filing: EdgarFiling) -> Result<DataRecord> {
957 let text = format!(
958 "CIK {} Form {} filed on {} report date {}",
959 filing.cik, filing.form, filing.filing_date, filing.report_date
960 );
961 let embedding = self.embedder.embed_text(&text);
962
963 let timestamp = NaiveDate::parse_from_str(&filing.filing_date, "%Y-%m-%d")
965 .ok()
966 .and_then(|d| d.and_hms_opt(0, 0, 0))
967 .map(|dt| dt.and_utc())
968 .unwrap_or_else(Utc::now);
969
970 let mut data_map = serde_json::Map::new();
971 data_map.insert("cik".to_string(), serde_json::json!(filing.cik));
972 data_map.insert(
973 "accession_number".to_string(),
974 serde_json::json!(filing.accession_number),
975 );
976 data_map.insert(
977 "filing_date".to_string(),
978 serde_json::json!(filing.filing_date),
979 );
980 data_map.insert(
981 "report_date".to_string(),
982 serde_json::json!(filing.report_date),
983 );
984 data_map.insert("form".to_string(), serde_json::json!(filing.form));
985 data_map.insert(
986 "primary_document".to_string(),
987 serde_json::json!(filing.primary_document),
988 );
989
990 let filing_url = format!(
992 "https://www.sec.gov/cgi-bin/viewer?action=view&cik={}&accession_number={}&xbrl_type=v",
993 filing.cik, filing.accession_number
994 );
995 data_map.insert("filing_url".to_string(), serde_json::json!(filing_url));
996
997 Ok(DataRecord {
998 id: format!("{}_{}", filing.cik, filing.accession_number),
999 source: "edgar".to_string(),
1000 record_type: filing.form,
1001 timestamp,
1002 data: serde_json::Value::Object(data_map),
1003 embedding: Some(embedding),
1004 relationships: Vec::new(),
1005 })
1006 }
1007
1008 async fn fetch_with_retry(&self, url: &str) -> Result<reqwest::Response> {
1010 let mut retries = 0;
1011 loop {
1012 match self.client.get(url).send().await {
1013 Ok(response) => {
1014 if response.status() == StatusCode::TOO_MANY_REQUESTS && retries < MAX_RETRIES
1015 {
1016 retries += 1;
1017 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1018 continue;
1019 }
1020 return Ok(response);
1021 }
1022 Err(_) if retries < MAX_RETRIES => {
1023 retries += 1;
1024 sleep(Duration::from_millis(RETRY_DELAY_MS * retries as u64)).await;
1025 }
1026 Err(e) => return Err(FrameworkError::Network(e)),
1027 }
1028 }
1029 }
1030}
1031
1032struct EdgarFiling {
1034 cik: String,
1035 accession_number: String,
1036 filing_date: String,
1037 report_date: String,
1038 form: String,
1039 primary_document: String,
1040}
1041
1042#[async_trait]
1043impl DataSource for EdgarClient {
1044 fn source_id(&self) -> &str {
1045 "edgar"
1046 }
1047
1048 async fn fetch_batch(
1049 &self,
1050 cursor: Option<String>,
1051 _batch_size: usize,
1052 ) -> Result<(Vec<DataRecord>, Option<String>)> {
1053 let cik = cursor.as_deref().unwrap_or("320193");
1055 let records = self.fetch_filings(cik, None).await?;
1056 Ok((records, None))
1057 }
1058
1059 async fn total_count(&self) -> Result<Option<u64>> {
1060 Ok(None)
1061 }
1062
1063 async fn health_check(&self) -> Result<bool> {
1064 Ok(true)
1065 }
1066}
1067
1068#[cfg(test)]
1073mod tests {
1074 use super::*;
1075
1076 #[test]
1077 fn test_simple_embedder() {
1078 let embedder = SimpleEmbedder::new(128);
1079 let embedding = embedder.embed_text("machine learning artificial intelligence");
1080
1081 assert_eq!(embedding.len(), 128);
1082
1083 let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
1085 assert!((magnitude - 1.0).abs() < 0.01);
1086 }
1087
1088 #[test]
1089 fn test_embedder_json() {
1090 let embedder = SimpleEmbedder::new(64);
1091 let json = serde_json::json!({
1092 "title": "Test Document",
1093 "content": "Some interesting content here"
1094 });
1095
1096 let embedding = embedder.embed_json(&json);
1097 assert_eq!(embedding.len(), 64);
1098 }
1099
1100 #[tokio::test]
1101 async fn test_openalex_client_creation() {
1102 let client = OpenAlexClient::new(Some("test@example.com".to_string()));
1103 assert!(client.is_ok());
1104 }
1105
1106 #[tokio::test]
1107 async fn test_noaa_client_creation() {
1108 let client = NoaaClient::new(None);
1109 assert!(client.is_ok());
1110 }
1111
1112 #[tokio::test]
1113 async fn test_noaa_synthetic_data() {
1114 let client = NoaaClient::new(None).unwrap();
1115 let records = client
1116 .fetch_climate_data("GHCND:TEST", "2024-01-01", "2024-01-31")
1117 .await
1118 .unwrap();
1119
1120 assert!(!records.is_empty());
1121 assert_eq!(records[0].source, "noaa");
1122 assert!(records[0].embedding.is_some());
1123 }
1124
1125 #[tokio::test]
1126 async fn test_edgar_client_creation() {
1127 let client = EdgarClient::new("test-agent test@example.com".to_string());
1128 assert!(client.is_ok());
1129 }
1130}