shape_runtime/data/
provider_metadata.rs1use serde::Serialize;
7use std::collections::HashMap;
8use std::sync::OnceLock;
9
10#[derive(Debug, Clone, Serialize)]
12pub struct ProviderMetadata {
13 pub name: &'static str,
14 pub description: &'static str,
15 pub category: &'static str,
16 pub parameters: &'static [ProviderParam],
17 pub example: Option<&'static str>,
18}
19
20#[derive(Debug, Clone, Serialize)]
22pub struct ProviderParam {
23 pub name: &'static str,
24 pub param_type: &'static str,
25 pub required: bool,
26 pub description: &'static str,
27 pub default: Option<&'static str>,
28}
29
30pub struct ProviderMetadataRegistry {
32 providers: HashMap<String, &'static ProviderMetadata>,
33}
34
35impl ProviderMetadataRegistry {
36 pub fn load() -> Self {
38 let providers = HashMap::new();
39
40 Self { providers }
44 }
45
46 pub fn get(&self, name: &str) -> Option<&'static ProviderMetadata> {
48 self.providers.get(name).copied()
49 }
50
51 pub fn all(&self) -> Vec<&'static ProviderMetadata> {
53 self.providers.values().copied().collect()
54 }
55
56 pub fn has(&self, name: &str) -> bool {
58 self.providers.contains_key(name)
59 }
60
61 pub fn register(&mut self, metadata: &'static ProviderMetadata) {
63 self.providers.insert(metadata.name.to_string(), metadata);
64 }
65}
66
67static PROVIDER_METADATA_REGISTRY: OnceLock<ProviderMetadataRegistry> = OnceLock::new();
69
70pub fn provider_registry() -> &'static ProviderMetadataRegistry {
72 PROVIDER_METADATA_REGISTRY.get_or_init(ProviderMetadataRegistry::load)
73}
74
75#[cfg(test)]
76mod tests {
77 #[shape_macros::shape_provider(category = "Market Data")]
88 pub fn data_provider() {
89 }
91
92 #[test]
93 fn test_provider_metadata_generated() {
94 data_provider();
95
96 assert_eq!(PROVIDER_METADATA_DATA.name, "data");
98 assert_eq!(
99 PROVIDER_METADATA_DATA.description,
100 "Test provider for market data"
101 );
102 assert_eq!(PROVIDER_METADATA_DATA.category, "Market Data");
103 assert_eq!(PROVIDER_METADATA_DATA.parameters.len(), 2);
104
105 assert_eq!(PROVIDER_METADATA_DATA.parameters[0].name, "symbol");
107 assert_eq!(PROVIDER_METADATA_DATA.parameters[0].param_type, "String");
108 assert_eq!(PROVIDER_METADATA_DATA.parameters[0].required, true);
109 assert_eq!(
110 PROVIDER_METADATA_DATA.parameters[0].description,
111 "Stock symbol"
112 );
113
114 assert_eq!(PROVIDER_METADATA_DATA.parameters[1].name, "timeframe");
116 assert_eq!(PROVIDER_METADATA_DATA.parameters[1].param_type, "String");
117 assert_eq!(PROVIDER_METADATA_DATA.parameters[1].required, false);
118 assert_eq!(
119 PROVIDER_METADATA_DATA.parameters[1].description,
120 "Time period (optional)"
121 );
122
123 assert_eq!(
125 PROVIDER_METADATA_DATA.example,
126 Some(" data('data', {symbol: 'ES', timeframe: '1h'})")
127 );
128 }
129}