1use std::collections::HashMap;
18use std::net::IpAddr;
19use std::sync::Arc;
20use std::time::Duration;
21
22use regex::Regex;
23use tracing::info;
24
25use spvirit_types::{NtScalar, NtScalarArray, ScalarArrayValue, ScalarValue};
26
27use crate::db::{load_db, parse_db};
28use crate::handler::PvListMode;
29use crate::monitor::MonitorRegistry;
30use crate::server::{run_pva_server_with_registry, PvaServerConfig};
31use crate::simple_store::{LinkDef, OnPutCallback, ScanCallback, SimplePvStore};
32use crate::types::{
33 DbCommonState, OutputMode, RecordData, RecordInstance, RecordType,
34};
35
36pub struct PvaServerBuilder {
49 records: HashMap<String, RecordInstance>,
50 on_put: HashMap<String, OnPutCallback>,
51 scans: Vec<(String, Duration, ScanCallback)>,
52 links: Vec<LinkDef>,
53 tcp_port: u16,
54 udp_port: u16,
55 listen_ip: Option<IpAddr>,
56 advertise_ip: Option<IpAddr>,
57 compute_alarms: bool,
58 beacon_period_secs: u64,
59 conn_timeout: Duration,
60 pvlist_mode: PvListMode,
61 pvlist_max: usize,
62 pvlist_allow_pattern: Option<Regex>,
63}
64
65impl PvaServerBuilder {
66 fn new() -> Self {
67 Self {
68 records: HashMap::new(),
69 on_put: HashMap::new(),
70 scans: Vec::new(),
71 links: Vec::new(),
72 tcp_port: 5075,
73 udp_port: 5076,
74 listen_ip: None,
75 advertise_ip: None,
76 compute_alarms: false,
77 beacon_period_secs: 15,
78 conn_timeout: Duration::from_secs(64000),
79 pvlist_mode: PvListMode::List,
80 pvlist_max: 1024,
81 pvlist_allow_pattern: None,
82 }
83 }
84
85 pub fn ai(mut self, name: impl Into<String>, initial: f64) -> Self {
89 let name = name.into();
90 self.records.insert(
91 name.clone(),
92 make_scalar_record(&name, RecordType::Ai, ScalarValue::F64(initial)),
93 );
94 self
95 }
96
97 pub fn ao(mut self, name: impl Into<String>, initial: f64) -> Self {
99 let name = name.into();
100 self.records.insert(
101 name.clone(),
102 make_output_record(&name, RecordType::Ao, ScalarValue::F64(initial)),
103 );
104 self
105 }
106
107 pub fn bi(mut self, name: impl Into<String>, initial: bool) -> Self {
109 let name = name.into();
110 self.records.insert(
111 name.clone(),
112 make_scalar_record(&name, RecordType::Bi, ScalarValue::Bool(initial)),
113 );
114 self
115 }
116
117 pub fn bo(mut self, name: impl Into<String>, initial: bool) -> Self {
119 let name = name.into();
120 self.records.insert(
121 name.clone(),
122 make_output_record(&name, RecordType::Bo, ScalarValue::Bool(initial)),
123 );
124 self
125 }
126
127 pub fn string_in(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
129 let name = name.into();
130 self.records.insert(
131 name.clone(),
132 make_scalar_record(&name, RecordType::StringIn, ScalarValue::Str(initial.into())),
133 );
134 self
135 }
136
137 pub fn string_out(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
139 let name = name.into();
140 self.records.insert(
141 name.clone(),
142 make_output_record(&name, RecordType::StringOut, ScalarValue::Str(initial.into())),
143 );
144 self
145 }
146
147 pub fn waveform(
149 mut self,
150 name: impl Into<String>,
151 data: ScalarArrayValue,
152 ) -> Self {
153 let name = name.into();
154 let ftvl = data.type_label().trim_end_matches("[]").to_string();
155 let nelm = data.len();
156 self.records.insert(
157 name.clone(),
158 RecordInstance {
159 name: name.clone(),
160 record_type: RecordType::Waveform,
161 common: DbCommonState::default(),
162 data: RecordData::Waveform {
163 nt: NtScalarArray::from_value(data),
164 inp: None,
165 ftvl,
166 nelm,
167 nord: nelm,
168 },
169 raw_fields: HashMap::new(),
170 },
171 );
172 self
173 }
174
175 pub fn db_file(mut self, path: impl AsRef<str>) -> Self {
179 match load_db(path.as_ref()) {
180 Ok(records) => {
181 self.records.extend(records);
182 }
183 Err(e) => {
184 tracing::error!("Failed to load db file '{}': {}", path.as_ref(), e);
185 }
186 }
187 self
188 }
189
190 pub fn db_string(mut self, content: &str) -> Self {
192 match parse_db(content) {
193 Ok(records) => {
194 self.records.extend(records);
195 }
196 Err(e) => {
197 tracing::error!("Failed to parse db string: {}", e);
198 }
199 }
200 self
201 }
202
203 pub fn on_put<F>(mut self, name: impl Into<String>, callback: F) -> Self
207 where
208 F: Fn(&str, &spvirit_codec::spvd_decode::DecodedValue) + Send + Sync + 'static,
209 {
210 self.on_put.insert(name.into(), Arc::new(callback));
211 self
212 }
213
214 pub fn scan<F>(
216 mut self,
217 name: impl Into<String>,
218 period: Duration,
219 callback: F,
220 ) -> Self
221 where
222 F: Fn(&str) -> ScalarValue + Send + Sync + 'static,
223 {
224 self.scans
225 .push((name.into(), period, Arc::new(callback)));
226 self
227 }
228
229 pub fn link<F>(
244 mut self,
245 output: impl Into<String>,
246 inputs: &[&str],
247 compute: F,
248 ) -> Self
249 where
250 F: Fn(&[ScalarValue]) -> ScalarValue + Send + Sync + 'static,
251 {
252 self.links.push(LinkDef {
253 output: output.into(),
254 inputs: inputs.iter().map(|s| s.to_string()).collect(),
255 compute: Arc::new(compute),
256 });
257 self
258 }
259
260 pub fn port(mut self, port: u16) -> Self {
264 self.tcp_port = port;
265 self
266 }
267
268 pub fn udp_port(mut self, port: u16) -> Self {
270 self.udp_port = port;
271 self
272 }
273
274 pub fn listen_ip(mut self, ip: IpAddr) -> Self {
276 self.listen_ip = Some(ip);
277 self
278 }
279
280 pub fn advertise_ip(mut self, ip: IpAddr) -> Self {
282 self.advertise_ip = Some(ip);
283 self
284 }
285
286 pub fn compute_alarms(mut self, enabled: bool) -> Self {
288 self.compute_alarms = enabled;
289 self
290 }
291
292 pub fn beacon_period(mut self, secs: u64) -> Self {
294 self.beacon_period_secs = secs;
295 self
296 }
297
298 pub fn conn_timeout(mut self, timeout: Duration) -> Self {
300 self.conn_timeout = timeout;
301 self
302 }
303
304 pub fn pvlist_mode(mut self, mode: PvListMode) -> Self {
306 self.pvlist_mode = mode;
307 self
308 }
309
310 pub fn pvlist_max(mut self, max: usize) -> Self {
312 self.pvlist_max = max;
313 self
314 }
315
316 pub fn pvlist_allow_pattern(mut self, pattern: Regex) -> Self {
318 self.pvlist_allow_pattern = Some(pattern);
319 self
320 }
321
322 pub fn build(self) -> PvaServer {
324 let store = Arc::new(SimplePvStore::new(
325 self.records,
326 self.on_put,
327 self.links,
328 self.compute_alarms,
329 ));
330
331 let mut config = PvaServerConfig::default();
332 config.tcp_port = self.tcp_port;
333 config.udp_port = self.udp_port;
334 config.compute_alarms = self.compute_alarms;
335 if let Some(ip) = self.listen_ip {
336 config.listen_ip = ip;
337 }
338 config.advertise_ip = self.advertise_ip;
339 config.beacon_period_secs = self.beacon_period_secs;
340 config.conn_timeout = self.conn_timeout;
341 config.pvlist_mode = self.pvlist_mode;
342 config.pvlist_max = self.pvlist_max;
343 config.pvlist_allow_pattern = self.pvlist_allow_pattern;
344
345 PvaServer {
346 store,
347 config,
348 scans: self.scans,
349 }
350 }
351}
352
353pub struct PvaServer {
374 store: Arc<SimplePvStore>,
375 config: PvaServerConfig,
376 scans: Vec<(String, Duration, ScanCallback)>,
377}
378
379impl PvaServer {
380 pub fn builder() -> PvaServerBuilder {
382 PvaServerBuilder::new()
383 }
384
385 pub fn store(&self) -> &Arc<SimplePvStore> {
387 &self.store
388 }
389
390 pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
394 let registry = Arc::new(MonitorRegistry::new());
397 self.store.set_registry(registry.clone()).await;
398
399 for (name, period, callback) in &self.scans {
401 let store = self.store.clone();
402 let name = name.clone();
403 let period = *period;
404 let callback = callback.clone();
405 tokio::spawn(async move {
406 let mut interval = tokio::time::interval(period);
407 loop {
408 interval.tick().await;
409 let new_val = callback(&name);
410 store.set_value(&name, new_val).await;
411 }
412 });
413 }
414
415 let pv_count = self.store.pv_names().await.len();
416 info!(
417 "PvaServer starting: {} PVs on port {}",
418 pv_count, self.config.tcp_port
419 );
420
421 run_pva_server_with_registry(self.store, self.config, registry).await
422 }
423}
424
425fn make_scalar_record(
428 name: &str,
429 record_type: RecordType,
430 value: ScalarValue,
431) -> RecordInstance {
432 let nt = NtScalar::from_value(value);
433 let data = match record_type {
434 RecordType::Ai => RecordData::Ai {
435 nt,
436 inp: None,
437 siml: None,
438 siol: None,
439 simm: false,
440 },
441 RecordType::Bi => RecordData::Bi {
442 nt,
443 inp: None,
444 znam: "Off".to_string(),
445 onam: "On".to_string(),
446 siml: None,
447 siol: None,
448 simm: false,
449 },
450 RecordType::StringIn => RecordData::StringIn {
451 nt,
452 inp: None,
453 siml: None,
454 siol: None,
455 simm: false,
456 },
457 _ => panic!("make_scalar_record: unsupported type {record_type:?}"),
458 };
459 RecordInstance {
460 name: name.to_string(),
461 record_type,
462 common: DbCommonState::default(),
463 data,
464 raw_fields: HashMap::new(),
465 }
466}
467
468fn make_output_record(
469 name: &str,
470 record_type: RecordType,
471 value: ScalarValue,
472) -> RecordInstance {
473 let nt = NtScalar::from_value(value);
474 let data = match record_type {
475 RecordType::Ao => RecordData::Ao {
476 nt,
477 out: None,
478 dol: None,
479 omsl: OutputMode::Supervisory,
480 drvl: None,
481 drvh: None,
482 oroc: None,
483 siml: None,
484 siol: None,
485 simm: false,
486 },
487 RecordType::Bo => RecordData::Bo {
488 nt,
489 out: None,
490 dol: None,
491 omsl: OutputMode::Supervisory,
492 znam: "Off".to_string(),
493 onam: "On".to_string(),
494 siml: None,
495 siol: None,
496 simm: false,
497 },
498 RecordType::StringOut => RecordData::StringOut {
499 nt,
500 out: None,
501 dol: None,
502 omsl: OutputMode::Supervisory,
503 siml: None,
504 siol: None,
505 simm: false,
506 },
507 _ => panic!("make_output_record: unsupported type {record_type:?}"),
508 };
509 RecordInstance {
510 name: name.to_string(),
511 record_type,
512 common: DbCommonState::default(),
513 data,
514 raw_fields: HashMap::new(),
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn builder_creates_records() {
524 let server = PvaServer::builder()
525 .ai("T:AI", 1.0)
526 .ao("T:AO", 2.0)
527 .bi("T:BI", true)
528 .bo("T:BO", false)
529 .string_in("T:SI", "hello")
530 .string_out("T:SO", "world")
531 .build();
532
533 let rt = tokio::runtime::Builder::new_current_thread()
534 .enable_all()
535 .build()
536 .unwrap();
537 let names = rt.block_on(server.store.pv_names());
538 assert_eq!(names.len(), 6);
539 }
540
541 #[test]
542 fn builder_defaults() {
543 let server = PvaServer::builder().build();
544 assert_eq!(server.config.tcp_port, 5075);
545 assert_eq!(server.config.udp_port, 5076);
546 assert!(!server.config.compute_alarms);
547 }
548
549 #[test]
550 fn builder_port_override() {
551 let server = PvaServer::builder().port(9075).udp_port(9076).build();
552 assert_eq!(server.config.tcp_port, 9075);
553 assert_eq!(server.config.udp_port, 9076);
554 }
555
556 #[test]
557 fn builder_db_string() {
558 let db = r#"
559 record(ai, "TEST:VAL") {
560 field(VAL, "3.14")
561 }
562 "#;
563 let server = PvaServer::builder().db_string(db).build();
564 let rt = tokio::runtime::Builder::new_current_thread()
565 .enable_all()
566 .build()
567 .unwrap();
568 assert!(rt.block_on(server.store.get_value("TEST:VAL")).is_some());
569 }
570
571 #[test]
572 fn builder_waveform() {
573 let data = ScalarArrayValue::F64(vec![1.0, 2.0, 3.0]);
574 let server = PvaServer::builder()
575 .waveform("T:WF", data)
576 .build();
577 let rt = tokio::runtime::Builder::new_current_thread()
578 .enable_all()
579 .build()
580 .unwrap();
581 let names = rt.block_on(server.store.pv_names());
582 assert!(names.contains(&"T:WF".to_string()));
583 }
584
585 #[test]
586 fn builder_scan_callback() {
587 let server = PvaServer::builder()
588 .ai("SCAN:V", 0.0)
589 .scan("SCAN:V", Duration::from_secs(1), |_name| {
590 ScalarValue::F64(42.0)
591 })
592 .build();
593 assert_eq!(server.scans.len(), 1);
594 }
595
596 #[test]
597 fn builder_on_put_callback() {
598 let server = PvaServer::builder()
599 .ao("PUT:V", 0.0)
600 .on_put("PUT:V", |_name, _val| {})
601 .build();
602 let rt = tokio::runtime::Builder::new_current_thread()
605 .enable_all()
606 .build()
607 .unwrap();
608 assert!(rt.block_on(server.store.get_value("PUT:V")).is_some());
609 }
610
611 #[test]
612 fn store_runtime_get_set() {
613 let server = PvaServer::builder()
614 .ao("RT:V", 0.0)
615 .build();
616 let rt = tokio::runtime::Builder::new_current_thread()
617 .enable_all()
618 .build()
619 .unwrap();
620 let store = server.store().clone();
621 rt.block_on(async {
622 assert_eq!(
623 store.get_value("RT:V").await,
624 Some(ScalarValue::F64(0.0))
625 );
626 store.set_value("RT:V", ScalarValue::F64(99.0)).await;
627 assert_eq!(
628 store.get_value("RT:V").await,
629 Some(ScalarValue::F64(99.0))
630 );
631 });
632 }
633
634 #[test]
635 fn link_propagates_on_set_value() {
636 let server = PvaServer::builder()
637 .ao("INPUT:A", 1.0)
638 .ao("INPUT:B", 2.0)
639 .ai("CALC:SUM", 0.0)
640 .link("CALC:SUM", &["INPUT:A", "INPUT:B"], |values| {
641 let a = match &values[0] { ScalarValue::F64(v) => *v, _ => 0.0 };
642 let b = match &values[1] { ScalarValue::F64(v) => *v, _ => 0.0 };
643 ScalarValue::F64(a + b)
644 })
645 .build();
646
647 let rt = tokio::runtime::Builder::new_current_thread()
648 .enable_all()
649 .build()
650 .unwrap();
651 let store = server.store().clone();
652 rt.block_on(async {
653 store.set_value("INPUT:A", ScalarValue::F64(10.0)).await;
655 assert_eq!(
656 store.get_value("CALC:SUM").await,
657 Some(ScalarValue::F64(12.0))
658 );
659
660 store.set_value("INPUT:B", ScalarValue::F64(5.0)).await;
662 assert_eq!(
663 store.get_value("CALC:SUM").await,
664 Some(ScalarValue::F64(15.0))
665 );
666 });
667 }
668}