1use std::collections::HashMap;
18use std::net::IpAddr;
19use std::sync::Arc;
20use std::time::Duration;
21
22use regex::Regex;
23use tracing::info;
24
25use spvirit_types::{NtEnum, NtScalar, NtScalarArray, PvValue, ScalarArrayValue, ScalarValue};
26
27use crate::db::{load_db, parse_db};
28use crate::handler::PvListMode;
29use crate::monitor::MonitorRegistry;
30use crate::server::{PvaServerConfig, run_pva_server_with_registry};
31use crate::simple_store::{LinkDef, OnPutCallback, ScanCallback, SimplePvStore};
32use crate::types::{DbCommonState, OutputMode, RecordData, RecordInstance, RecordType};
33
34pub struct PvaServerBuilder {
47 records: HashMap<String, RecordInstance>,
48 on_put: HashMap<String, OnPutCallback>,
49 scans: Vec<(String, Duration, ScanCallback)>,
50 links: Vec<LinkDef>,
51 tcp_port: u16,
52 udp_port: u16,
53 listen_ip: Option<IpAddr>,
54 advertise_ip: Option<IpAddr>,
55 compute_alarms: bool,
56 beacon_period_secs: u64,
57 conn_timeout: Duration,
58 pvlist_mode: PvListMode,
59 pvlist_max: usize,
60 pvlist_allow_pattern: Option<Regex>,
61}
62
63impl PvaServerBuilder {
64 fn new() -> Self {
65 Self {
66 records: HashMap::new(),
67 on_put: HashMap::new(),
68 scans: Vec::new(),
69 links: Vec::new(),
70 tcp_port: 5075,
71 udp_port: 5076,
72 listen_ip: None,
73 advertise_ip: None,
74 compute_alarms: false,
75 beacon_period_secs: 15,
76 conn_timeout: Duration::from_secs(64000),
77 pvlist_mode: PvListMode::List,
78 pvlist_max: 1024,
79 pvlist_allow_pattern: None,
80 }
81 }
82
83 pub fn ai(mut self, name: impl Into<String>, initial: f64) -> Self {
87 let name = name.into();
88 self.records.insert(
89 name.clone(),
90 make_scalar_record(&name, RecordType::Ai, ScalarValue::F64(initial)),
91 );
92 self
93 }
94
95 pub fn ao(mut self, name: impl Into<String>, initial: f64) -> Self {
97 let name = name.into();
98 self.records.insert(
99 name.clone(),
100 make_output_record(&name, RecordType::Ao, ScalarValue::F64(initial)),
101 );
102 self
103 }
104
105 pub fn bi(mut self, name: impl Into<String>, initial: bool) -> Self {
107 let name = name.into();
108 self.records.insert(
109 name.clone(),
110 make_scalar_record(&name, RecordType::Bi, ScalarValue::Bool(initial)),
111 );
112 self
113 }
114
115 pub fn bo(mut self, name: impl Into<String>, initial: bool) -> Self {
117 let name = name.into();
118 self.records.insert(
119 name.clone(),
120 make_output_record(&name, RecordType::Bo, ScalarValue::Bool(initial)),
121 );
122 self
123 }
124
125 pub fn string_in(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
127 let name = name.into();
128 self.records.insert(
129 name.clone(),
130 make_scalar_record(
131 &name,
132 RecordType::StringIn,
133 ScalarValue::Str(initial.into()),
134 ),
135 );
136 self
137 }
138
139 pub fn string_out(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
141 let name = name.into();
142 self.records.insert(
143 name.clone(),
144 make_output_record(
145 &name,
146 RecordType::StringOut,
147 ScalarValue::Str(initial.into()),
148 ),
149 );
150 self
151 }
152
153 pub fn waveform(mut self, name: impl Into<String>, data: ScalarArrayValue) -> Self {
155 let name = name.into();
156 let ftvl = data.type_label().trim_end_matches("[]").to_string();
157 let nelm = data.len();
158 self.records.insert(
159 name.clone(),
160 RecordInstance {
161 name: name.clone(),
162 record_type: RecordType::Waveform,
163 common: DbCommonState::default(),
164 data: RecordData::Waveform {
165 nt: NtScalarArray::from_value(data),
166 inp: None,
167 ftvl,
168 nelm,
169 nord: nelm,
170 },
171 raw_fields: HashMap::new(),
172 },
173 );
174 self
175 }
176
177 pub fn mbbi(
179 mut self,
180 name: impl Into<String>,
181 choices: Vec<String>,
182 initial: i32,
183 ) -> Self {
184 let name = name.into();
185 self.records.insert(
186 name.clone(),
187 RecordInstance {
188 name: name.clone(),
189 record_type: RecordType::Mbbi,
190 common: DbCommonState::default(),
191 data: RecordData::NtEnum {
192 nt: NtEnum::new(initial, choices),
193 inp: None,
194 out: None,
195 omsl: OutputMode::Supervisory,
196 },
197 raw_fields: HashMap::new(),
198 },
199 );
200 self
201 }
202
203 pub fn mbbo(
205 mut self,
206 name: impl Into<String>,
207 choices: Vec<String>,
208 initial: i32,
209 ) -> Self {
210 let name = name.into();
211 self.records.insert(
212 name.clone(),
213 RecordInstance {
214 name: name.clone(),
215 record_type: RecordType::Mbbo,
216 common: DbCommonState::default(),
217 data: RecordData::NtEnum {
218 nt: NtEnum::new(initial, choices),
219 inp: None,
220 out: None,
221 omsl: OutputMode::Supervisory,
222 },
223 raw_fields: HashMap::new(),
224 },
225 );
226 self
227 }
228
229 pub fn generic(
231 mut self,
232 name: impl Into<String>,
233 struct_id: impl Into<String>,
234 fields: Vec<(String, PvValue)>,
235 ) -> Self {
236 let name = name.into();
237 self.records.insert(
238 name.clone(),
239 RecordInstance {
240 name: name.clone(),
241 record_type: RecordType::Generic,
242 common: DbCommonState::default(),
243 data: RecordData::Generic {
244 struct_id: struct_id.into(),
245 fields,
246 inp: None,
247 out: None,
248 omsl: OutputMode::Supervisory,
249 },
250 raw_fields: HashMap::new(),
251 },
252 );
253 self
254 }
255
256 pub fn db_file(mut self, path: impl AsRef<str>) -> Self {
260 match load_db(path.as_ref()) {
261 Ok(records) => {
262 self.records.extend(records);
263 }
264 Err(e) => {
265 tracing::error!("Failed to load db file '{}': {}", path.as_ref(), e);
266 }
267 }
268 self
269 }
270
271 pub fn db_string(mut self, content: &str) -> Self {
273 match parse_db(content) {
274 Ok(records) => {
275 self.records.extend(records);
276 }
277 Err(e) => {
278 tracing::error!("Failed to parse db string: {}", e);
279 }
280 }
281 self
282 }
283
284 pub fn on_put<F>(mut self, name: impl Into<String>, callback: F) -> Self
288 where
289 F: Fn(&str, &spvirit_codec::spvd_decode::DecodedValue) + Send + Sync + 'static,
290 {
291 self.on_put.insert(name.into(), Arc::new(callback));
292 self
293 }
294
295 pub fn scan<F>(mut self, name: impl Into<String>, period: Duration, callback: F) -> Self
297 where
298 F: Fn(&str) -> ScalarValue + Send + Sync + 'static,
299 {
300 self.scans.push((name.into(), period, Arc::new(callback)));
301 self
302 }
303
304 pub fn link<F>(mut self, output: impl Into<String>, inputs: &[&str], compute: F) -> Self
319 where
320 F: Fn(&[ScalarValue]) -> ScalarValue + Send + Sync + 'static,
321 {
322 self.links.push(LinkDef {
323 output: output.into(),
324 inputs: inputs.iter().map(|s| s.to_string()).collect(),
325 compute: Arc::new(compute),
326 });
327 self
328 }
329
330 pub fn port(mut self, port: u16) -> Self {
334 self.tcp_port = port;
335 self
336 }
337
338 pub fn udp_port(mut self, port: u16) -> Self {
340 self.udp_port = port;
341 self
342 }
343
344 pub fn listen_ip(mut self, ip: IpAddr) -> Self {
346 self.listen_ip = Some(ip);
347 self
348 }
349
350 pub fn advertise_ip(mut self, ip: IpAddr) -> Self {
352 self.advertise_ip = Some(ip);
353 self
354 }
355
356 pub fn compute_alarms(mut self, enabled: bool) -> Self {
358 self.compute_alarms = enabled;
359 self
360 }
361
362 pub fn beacon_period(mut self, secs: u64) -> Self {
364 self.beacon_period_secs = secs;
365 self
366 }
367
368 pub fn conn_timeout(mut self, timeout: Duration) -> Self {
370 self.conn_timeout = timeout;
371 self
372 }
373
374 pub fn pvlist_mode(mut self, mode: PvListMode) -> Self {
376 self.pvlist_mode = mode;
377 self
378 }
379
380 pub fn pvlist_max(mut self, max: usize) -> Self {
382 self.pvlist_max = max;
383 self
384 }
385
386 pub fn pvlist_allow_pattern(mut self, pattern: Regex) -> Self {
388 self.pvlist_allow_pattern = Some(pattern);
389 self
390 }
391
392 pub fn build(self) -> PvaServer {
394 let store = Arc::new(SimplePvStore::new(
395 self.records,
396 self.on_put,
397 self.links,
398 self.compute_alarms,
399 ));
400
401 let mut config = PvaServerConfig::default();
402 config.tcp_port = self.tcp_port;
403 config.udp_port = self.udp_port;
404 config.compute_alarms = self.compute_alarms;
405 if let Some(ip) = self.listen_ip {
406 config.listen_ip = ip;
407 }
408 config.advertise_ip = self.advertise_ip;
409 config.beacon_period_secs = self.beacon_period_secs;
410 config.conn_timeout = self.conn_timeout;
411 config.pvlist_mode = self.pvlist_mode;
412 config.pvlist_max = self.pvlist_max;
413 config.pvlist_allow_pattern = self.pvlist_allow_pattern;
414
415 PvaServer {
416 store,
417 config,
418 scans: self.scans,
419 }
420 }
421}
422
423pub struct PvaServer {
444 store: Arc<SimplePvStore>,
445 config: PvaServerConfig,
446 scans: Vec<(String, Duration, ScanCallback)>,
447}
448
449impl PvaServer {
450 pub fn builder() -> PvaServerBuilder {
452 PvaServerBuilder::new()
453 }
454
455 pub fn store(&self) -> &Arc<SimplePvStore> {
457 &self.store
458 }
459
460 pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
464 let registry = Arc::new(MonitorRegistry::new());
467 self.store.set_registry(registry.clone()).await;
468
469 for (name, period, callback) in &self.scans {
471 let store = self.store.clone();
472 let name = name.clone();
473 let period = *period;
474 let callback = callback.clone();
475 tokio::spawn(async move {
476 let mut interval = tokio::time::interval(period);
477 loop {
478 interval.tick().await;
479 let new_val = callback(&name);
480 store.set_value(&name, new_val).await;
481 }
482 });
483 }
484
485 let pv_count = self.store.pv_names().await.len();
486 info!(
487 "PvaServer starting: {} PVs on port {}",
488 pv_count, self.config.tcp_port
489 );
490
491 run_pva_server_with_registry(self.store, self.config, registry).await
492 }
493}
494
495fn make_scalar_record(name: &str, record_type: RecordType, value: ScalarValue) -> RecordInstance {
498 let nt = NtScalar::from_value(value);
499 let data = match record_type {
500 RecordType::Ai => RecordData::Ai {
501 nt,
502 inp: None,
503 siml: None,
504 siol: None,
505 simm: false,
506 },
507 RecordType::Bi => RecordData::Bi {
508 nt,
509 inp: None,
510 znam: "Off".to_string(),
511 onam: "On".to_string(),
512 siml: None,
513 siol: None,
514 simm: false,
515 },
516 RecordType::StringIn => RecordData::StringIn {
517 nt,
518 inp: None,
519 siml: None,
520 siol: None,
521 simm: false,
522 },
523 _ => panic!("make_scalar_record: unsupported type {record_type:?}"),
524 };
525 RecordInstance {
526 name: name.to_string(),
527 record_type,
528 common: DbCommonState::default(),
529 data,
530 raw_fields: HashMap::new(),
531 }
532}
533
534fn make_output_record(name: &str, record_type: RecordType, value: ScalarValue) -> RecordInstance {
535 let nt = NtScalar::from_value(value);
536 let data = match record_type {
537 RecordType::Ao => RecordData::Ao {
538 nt,
539 out: None,
540 dol: None,
541 omsl: OutputMode::Supervisory,
542 drvl: None,
543 drvh: None,
544 oroc: None,
545 siml: None,
546 siol: None,
547 simm: false,
548 },
549 RecordType::Bo => RecordData::Bo {
550 nt,
551 out: None,
552 dol: None,
553 omsl: OutputMode::Supervisory,
554 znam: "Off".to_string(),
555 onam: "On".to_string(),
556 siml: None,
557 siol: None,
558 simm: false,
559 },
560 RecordType::StringOut => RecordData::StringOut {
561 nt,
562 out: None,
563 dol: None,
564 omsl: OutputMode::Supervisory,
565 siml: None,
566 siol: None,
567 simm: false,
568 },
569 _ => panic!("make_output_record: unsupported type {record_type:?}"),
570 };
571 RecordInstance {
572 name: name.to_string(),
573 record_type,
574 common: DbCommonState::default(),
575 data,
576 raw_fields: HashMap::new(),
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583
584 #[test]
585 fn builder_creates_records() {
586 let server = PvaServer::builder()
587 .ai("T:AI", 1.0)
588 .ao("T:AO", 2.0)
589 .bi("T:BI", true)
590 .bo("T:BO", false)
591 .string_in("T:SI", "hello")
592 .string_out("T:SO", "world")
593 .build();
594
595 let rt = tokio::runtime::Builder::new_current_thread()
596 .enable_all()
597 .build()
598 .unwrap();
599 let names = rt.block_on(server.store.pv_names());
600 assert_eq!(names.len(), 6);
601 }
602
603 #[test]
604 fn builder_defaults() {
605 let server = PvaServer::builder().build();
606 assert_eq!(server.config.tcp_port, 5075);
607 assert_eq!(server.config.udp_port, 5076);
608 assert!(!server.config.compute_alarms);
609 }
610
611 #[test]
612 fn builder_port_override() {
613 let server = PvaServer::builder().port(9075).udp_port(9076).build();
614 assert_eq!(server.config.tcp_port, 9075);
615 assert_eq!(server.config.udp_port, 9076);
616 }
617
618 #[test]
619 fn builder_db_string() {
620 let db = r#"
621 record(ai, "TEST:VAL") {
622 field(VAL, "3.14")
623 }
624 "#;
625 let server = PvaServer::builder().db_string(db).build();
626 let rt = tokio::runtime::Builder::new_current_thread()
627 .enable_all()
628 .build()
629 .unwrap();
630 assert!(rt.block_on(server.store.get_value("TEST:VAL")).is_some());
631 }
632
633 #[test]
634 fn builder_waveform() {
635 let data = ScalarArrayValue::F64(vec![1.0, 2.0, 3.0]);
636 let server = PvaServer::builder().waveform("T:WF", data).build();
637 let rt = tokio::runtime::Builder::new_current_thread()
638 .enable_all()
639 .build()
640 .unwrap();
641 let names = rt.block_on(server.store.pv_names());
642 assert!(names.contains(&"T:WF".to_string()));
643 }
644
645 #[test]
646 fn builder_scan_callback() {
647 let server = PvaServer::builder()
648 .ai("SCAN:V", 0.0)
649 .scan("SCAN:V", Duration::from_secs(1), |_name| {
650 ScalarValue::F64(42.0)
651 })
652 .build();
653 assert_eq!(server.scans.len(), 1);
654 }
655
656 #[test]
657 fn builder_on_put_callback() {
658 let server = PvaServer::builder()
659 .ao("PUT:V", 0.0)
660 .on_put("PUT:V", |_name, _val| {})
661 .build();
662 let rt = tokio::runtime::Builder::new_current_thread()
665 .enable_all()
666 .build()
667 .unwrap();
668 assert!(rt.block_on(server.store.get_value("PUT:V")).is_some());
669 }
670
671 #[test]
672 fn store_runtime_get_set() {
673 let server = PvaServer::builder().ao("RT:V", 0.0).build();
674 let rt = tokio::runtime::Builder::new_current_thread()
675 .enable_all()
676 .build()
677 .unwrap();
678 let store = server.store().clone();
679 rt.block_on(async {
680 assert_eq!(store.get_value("RT:V").await, Some(ScalarValue::F64(0.0)));
681 store.set_value("RT:V", ScalarValue::F64(99.0)).await;
682 assert_eq!(store.get_value("RT:V").await, Some(ScalarValue::F64(99.0)));
683 });
684 }
685
686 #[test]
687 fn link_propagates_on_set_value() {
688 let server = PvaServer::builder()
689 .ao("INPUT:A", 1.0)
690 .ao("INPUT:B", 2.0)
691 .ai("CALC:SUM", 0.0)
692 .link("CALC:SUM", &["INPUT:A", "INPUT:B"], |values| {
693 let a = match &values[0] {
694 ScalarValue::F64(v) => *v,
695 _ => 0.0,
696 };
697 let b = match &values[1] {
698 ScalarValue::F64(v) => *v,
699 _ => 0.0,
700 };
701 ScalarValue::F64(a + b)
702 })
703 .build();
704
705 let rt = tokio::runtime::Builder::new_current_thread()
706 .enable_all()
707 .build()
708 .unwrap();
709 let store = server.store().clone();
710 rt.block_on(async {
711 store.set_value("INPUT:A", ScalarValue::F64(10.0)).await;
713 assert_eq!(
714 store.get_value("CALC:SUM").await,
715 Some(ScalarValue::F64(12.0))
716 );
717
718 store.set_value("INPUT:B", ScalarValue::F64(5.0)).await;
720 assert_eq!(
721 store.get_value("CALC:SUM").await,
722 Some(ScalarValue::F64(15.0))
723 );
724 });
725 }
726}