1use std::collections::HashMap;
18use std::net::IpAddr;
19use std::sync::Arc;
20use std::time::Duration;
21
22use regex::Regex;
23use tracing::info;
24
25use spvirit_types::{NtScalar, NtScalarArray, ScalarArrayValue, ScalarValue};
26
27use crate::db::{load_db, parse_db};
28use crate::handler::PvListMode;
29use crate::monitor::MonitorRegistry;
30use crate::server::{PvaServerConfig, run_pva_server_with_registry};
31use crate::simple_store::{LinkDef, OnPutCallback, ScanCallback, SimplePvStore};
32use crate::types::{DbCommonState, OutputMode, RecordData, RecordInstance, RecordType};
33
34pub struct PvaServerBuilder {
47 records: HashMap<String, RecordInstance>,
48 on_put: HashMap<String, OnPutCallback>,
49 scans: Vec<(String, Duration, ScanCallback)>,
50 links: Vec<LinkDef>,
51 tcp_port: u16,
52 udp_port: u16,
53 listen_ip: Option<IpAddr>,
54 advertise_ip: Option<IpAddr>,
55 compute_alarms: bool,
56 beacon_period_secs: u64,
57 conn_timeout: Duration,
58 pvlist_mode: PvListMode,
59 pvlist_max: usize,
60 pvlist_allow_pattern: Option<Regex>,
61}
62
63impl PvaServerBuilder {
64 fn new() -> Self {
65 Self {
66 records: HashMap::new(),
67 on_put: HashMap::new(),
68 scans: Vec::new(),
69 links: Vec::new(),
70 tcp_port: 5075,
71 udp_port: 5076,
72 listen_ip: None,
73 advertise_ip: None,
74 compute_alarms: false,
75 beacon_period_secs: 15,
76 conn_timeout: Duration::from_secs(64000),
77 pvlist_mode: PvListMode::List,
78 pvlist_max: 1024,
79 pvlist_allow_pattern: None,
80 }
81 }
82
83 pub fn ai(mut self, name: impl Into<String>, initial: f64) -> Self {
87 let name = name.into();
88 self.records.insert(
89 name.clone(),
90 make_scalar_record(&name, RecordType::Ai, ScalarValue::F64(initial)),
91 );
92 self
93 }
94
95 pub fn ao(mut self, name: impl Into<String>, initial: f64) -> Self {
97 let name = name.into();
98 self.records.insert(
99 name.clone(),
100 make_output_record(&name, RecordType::Ao, ScalarValue::F64(initial)),
101 );
102 self
103 }
104
105 pub fn bi(mut self, name: impl Into<String>, initial: bool) -> Self {
107 let name = name.into();
108 self.records.insert(
109 name.clone(),
110 make_scalar_record(&name, RecordType::Bi, ScalarValue::Bool(initial)),
111 );
112 self
113 }
114
115 pub fn bo(mut self, name: impl Into<String>, initial: bool) -> Self {
117 let name = name.into();
118 self.records.insert(
119 name.clone(),
120 make_output_record(&name, RecordType::Bo, ScalarValue::Bool(initial)),
121 );
122 self
123 }
124
125 pub fn string_in(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
127 let name = name.into();
128 self.records.insert(
129 name.clone(),
130 make_scalar_record(
131 &name,
132 RecordType::StringIn,
133 ScalarValue::Str(initial.into()),
134 ),
135 );
136 self
137 }
138
139 pub fn string_out(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
141 let name = name.into();
142 self.records.insert(
143 name.clone(),
144 make_output_record(
145 &name,
146 RecordType::StringOut,
147 ScalarValue::Str(initial.into()),
148 ),
149 );
150 self
151 }
152
153 pub fn waveform(mut self, name: impl Into<String>, data: ScalarArrayValue) -> Self {
155 let name = name.into();
156 let ftvl = data.type_label().trim_end_matches("[]").to_string();
157 let nelm = data.len();
158 self.records.insert(
159 name.clone(),
160 RecordInstance {
161 name: name.clone(),
162 record_type: RecordType::Waveform,
163 common: DbCommonState::default(),
164 data: RecordData::Waveform {
165 nt: NtScalarArray::from_value(data),
166 inp: None,
167 ftvl,
168 nelm,
169 nord: nelm,
170 },
171 raw_fields: HashMap::new(),
172 },
173 );
174 self
175 }
176
177 pub fn db_file(mut self, path: impl AsRef<str>) -> Self {
181 match load_db(path.as_ref()) {
182 Ok(records) => {
183 self.records.extend(records);
184 }
185 Err(e) => {
186 tracing::error!("Failed to load db file '{}': {}", path.as_ref(), e);
187 }
188 }
189 self
190 }
191
192 pub fn db_string(mut self, content: &str) -> Self {
194 match parse_db(content) {
195 Ok(records) => {
196 self.records.extend(records);
197 }
198 Err(e) => {
199 tracing::error!("Failed to parse db string: {}", e);
200 }
201 }
202 self
203 }
204
205 pub fn on_put<F>(mut self, name: impl Into<String>, callback: F) -> Self
209 where
210 F: Fn(&str, &spvirit_codec::spvd_decode::DecodedValue) + Send + Sync + 'static,
211 {
212 self.on_put.insert(name.into(), Arc::new(callback));
213 self
214 }
215
216 pub fn scan<F>(mut self, name: impl Into<String>, period: Duration, callback: F) -> Self
218 where
219 F: Fn(&str) -> ScalarValue + Send + Sync + 'static,
220 {
221 self.scans.push((name.into(), period, Arc::new(callback)));
222 self
223 }
224
225 pub fn link<F>(mut self, output: impl Into<String>, inputs: &[&str], compute: F) -> Self
240 where
241 F: Fn(&[ScalarValue]) -> ScalarValue + Send + Sync + 'static,
242 {
243 self.links.push(LinkDef {
244 output: output.into(),
245 inputs: inputs.iter().map(|s| s.to_string()).collect(),
246 compute: Arc::new(compute),
247 });
248 self
249 }
250
251 pub fn port(mut self, port: u16) -> Self {
255 self.tcp_port = port;
256 self
257 }
258
259 pub fn udp_port(mut self, port: u16) -> Self {
261 self.udp_port = port;
262 self
263 }
264
265 pub fn listen_ip(mut self, ip: IpAddr) -> Self {
267 self.listen_ip = Some(ip);
268 self
269 }
270
271 pub fn advertise_ip(mut self, ip: IpAddr) -> Self {
273 self.advertise_ip = Some(ip);
274 self
275 }
276
277 pub fn compute_alarms(mut self, enabled: bool) -> Self {
279 self.compute_alarms = enabled;
280 self
281 }
282
283 pub fn beacon_period(mut self, secs: u64) -> Self {
285 self.beacon_period_secs = secs;
286 self
287 }
288
289 pub fn conn_timeout(mut self, timeout: Duration) -> Self {
291 self.conn_timeout = timeout;
292 self
293 }
294
295 pub fn pvlist_mode(mut self, mode: PvListMode) -> Self {
297 self.pvlist_mode = mode;
298 self
299 }
300
301 pub fn pvlist_max(mut self, max: usize) -> Self {
303 self.pvlist_max = max;
304 self
305 }
306
307 pub fn pvlist_allow_pattern(mut self, pattern: Regex) -> Self {
309 self.pvlist_allow_pattern = Some(pattern);
310 self
311 }
312
313 pub fn build(self) -> PvaServer {
315 let store = Arc::new(SimplePvStore::new(
316 self.records,
317 self.on_put,
318 self.links,
319 self.compute_alarms,
320 ));
321
322 let mut config = PvaServerConfig::default();
323 config.tcp_port = self.tcp_port;
324 config.udp_port = self.udp_port;
325 config.compute_alarms = self.compute_alarms;
326 if let Some(ip) = self.listen_ip {
327 config.listen_ip = ip;
328 }
329 config.advertise_ip = self.advertise_ip;
330 config.beacon_period_secs = self.beacon_period_secs;
331 config.conn_timeout = self.conn_timeout;
332 config.pvlist_mode = self.pvlist_mode;
333 config.pvlist_max = self.pvlist_max;
334 config.pvlist_allow_pattern = self.pvlist_allow_pattern;
335
336 PvaServer {
337 store,
338 config,
339 scans: self.scans,
340 }
341 }
342}
343
344pub struct PvaServer {
365 store: Arc<SimplePvStore>,
366 config: PvaServerConfig,
367 scans: Vec<(String, Duration, ScanCallback)>,
368}
369
370impl PvaServer {
371 pub fn builder() -> PvaServerBuilder {
373 PvaServerBuilder::new()
374 }
375
376 pub fn store(&self) -> &Arc<SimplePvStore> {
378 &self.store
379 }
380
381 pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
385 let registry = Arc::new(MonitorRegistry::new());
388 self.store.set_registry(registry.clone()).await;
389
390 for (name, period, callback) in &self.scans {
392 let store = self.store.clone();
393 let name = name.clone();
394 let period = *period;
395 let callback = callback.clone();
396 tokio::spawn(async move {
397 let mut interval = tokio::time::interval(period);
398 loop {
399 interval.tick().await;
400 let new_val = callback(&name);
401 store.set_value(&name, new_val).await;
402 }
403 });
404 }
405
406 let pv_count = self.store.pv_names().await.len();
407 info!(
408 "PvaServer starting: {} PVs on port {}",
409 pv_count, self.config.tcp_port
410 );
411
412 run_pva_server_with_registry(self.store, self.config, registry).await
413 }
414}
415
416fn make_scalar_record(name: &str, record_type: RecordType, value: ScalarValue) -> RecordInstance {
419 let nt = NtScalar::from_value(value);
420 let data = match record_type {
421 RecordType::Ai => RecordData::Ai {
422 nt,
423 inp: None,
424 siml: None,
425 siol: None,
426 simm: false,
427 },
428 RecordType::Bi => RecordData::Bi {
429 nt,
430 inp: None,
431 znam: "Off".to_string(),
432 onam: "On".to_string(),
433 siml: None,
434 siol: None,
435 simm: false,
436 },
437 RecordType::StringIn => RecordData::StringIn {
438 nt,
439 inp: None,
440 siml: None,
441 siol: None,
442 simm: false,
443 },
444 _ => panic!("make_scalar_record: unsupported type {record_type:?}"),
445 };
446 RecordInstance {
447 name: name.to_string(),
448 record_type,
449 common: DbCommonState::default(),
450 data,
451 raw_fields: HashMap::new(),
452 }
453}
454
455fn make_output_record(name: &str, record_type: RecordType, value: ScalarValue) -> RecordInstance {
456 let nt = NtScalar::from_value(value);
457 let data = match record_type {
458 RecordType::Ao => RecordData::Ao {
459 nt,
460 out: None,
461 dol: None,
462 omsl: OutputMode::Supervisory,
463 drvl: None,
464 drvh: None,
465 oroc: None,
466 siml: None,
467 siol: None,
468 simm: false,
469 },
470 RecordType::Bo => RecordData::Bo {
471 nt,
472 out: None,
473 dol: None,
474 omsl: OutputMode::Supervisory,
475 znam: "Off".to_string(),
476 onam: "On".to_string(),
477 siml: None,
478 siol: None,
479 simm: false,
480 },
481 RecordType::StringOut => RecordData::StringOut {
482 nt,
483 out: None,
484 dol: None,
485 omsl: OutputMode::Supervisory,
486 siml: None,
487 siol: None,
488 simm: false,
489 },
490 _ => panic!("make_output_record: unsupported type {record_type:?}"),
491 };
492 RecordInstance {
493 name: name.to_string(),
494 record_type,
495 common: DbCommonState::default(),
496 data,
497 raw_fields: HashMap::new(),
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn builder_creates_records() {
507 let server = PvaServer::builder()
508 .ai("T:AI", 1.0)
509 .ao("T:AO", 2.0)
510 .bi("T:BI", true)
511 .bo("T:BO", false)
512 .string_in("T:SI", "hello")
513 .string_out("T:SO", "world")
514 .build();
515
516 let rt = tokio::runtime::Builder::new_current_thread()
517 .enable_all()
518 .build()
519 .unwrap();
520 let names = rt.block_on(server.store.pv_names());
521 assert_eq!(names.len(), 6);
522 }
523
524 #[test]
525 fn builder_defaults() {
526 let server = PvaServer::builder().build();
527 assert_eq!(server.config.tcp_port, 5075);
528 assert_eq!(server.config.udp_port, 5076);
529 assert!(!server.config.compute_alarms);
530 }
531
532 #[test]
533 fn builder_port_override() {
534 let server = PvaServer::builder().port(9075).udp_port(9076).build();
535 assert_eq!(server.config.tcp_port, 9075);
536 assert_eq!(server.config.udp_port, 9076);
537 }
538
539 #[test]
540 fn builder_db_string() {
541 let db = r#"
542 record(ai, "TEST:VAL") {
543 field(VAL, "3.14")
544 }
545 "#;
546 let server = PvaServer::builder().db_string(db).build();
547 let rt = tokio::runtime::Builder::new_current_thread()
548 .enable_all()
549 .build()
550 .unwrap();
551 assert!(rt.block_on(server.store.get_value("TEST:VAL")).is_some());
552 }
553
554 #[test]
555 fn builder_waveform() {
556 let data = ScalarArrayValue::F64(vec![1.0, 2.0, 3.0]);
557 let server = PvaServer::builder().waveform("T:WF", data).build();
558 let rt = tokio::runtime::Builder::new_current_thread()
559 .enable_all()
560 .build()
561 .unwrap();
562 let names = rt.block_on(server.store.pv_names());
563 assert!(names.contains(&"T:WF".to_string()));
564 }
565
566 #[test]
567 fn builder_scan_callback() {
568 let server = PvaServer::builder()
569 .ai("SCAN:V", 0.0)
570 .scan("SCAN:V", Duration::from_secs(1), |_name| {
571 ScalarValue::F64(42.0)
572 })
573 .build();
574 assert_eq!(server.scans.len(), 1);
575 }
576
577 #[test]
578 fn builder_on_put_callback() {
579 let server = PvaServer::builder()
580 .ao("PUT:V", 0.0)
581 .on_put("PUT:V", |_name, _val| {})
582 .build();
583 let rt = tokio::runtime::Builder::new_current_thread()
586 .enable_all()
587 .build()
588 .unwrap();
589 assert!(rt.block_on(server.store.get_value("PUT:V")).is_some());
590 }
591
592 #[test]
593 fn store_runtime_get_set() {
594 let server = PvaServer::builder().ao("RT:V", 0.0).build();
595 let rt = tokio::runtime::Builder::new_current_thread()
596 .enable_all()
597 .build()
598 .unwrap();
599 let store = server.store().clone();
600 rt.block_on(async {
601 assert_eq!(store.get_value("RT:V").await, Some(ScalarValue::F64(0.0)));
602 store.set_value("RT:V", ScalarValue::F64(99.0)).await;
603 assert_eq!(store.get_value("RT:V").await, Some(ScalarValue::F64(99.0)));
604 });
605 }
606
607 #[test]
608 fn link_propagates_on_set_value() {
609 let server = PvaServer::builder()
610 .ao("INPUT:A", 1.0)
611 .ao("INPUT:B", 2.0)
612 .ai("CALC:SUM", 0.0)
613 .link("CALC:SUM", &["INPUT:A", "INPUT:B"], |values| {
614 let a = match &values[0] {
615 ScalarValue::F64(v) => *v,
616 _ => 0.0,
617 };
618 let b = match &values[1] {
619 ScalarValue::F64(v) => *v,
620 _ => 0.0,
621 };
622 ScalarValue::F64(a + b)
623 })
624 .build();
625
626 let rt = tokio::runtime::Builder::new_current_thread()
627 .enable_all()
628 .build()
629 .unwrap();
630 let store = server.store().clone();
631 rt.block_on(async {
632 store.set_value("INPUT:A", ScalarValue::F64(10.0)).await;
634 assert_eq!(
635 store.get_value("CALC:SUM").await,
636 Some(ScalarValue::F64(12.0))
637 );
638
639 store.set_value("INPUT:B", ScalarValue::F64(5.0)).await;
641 assert_eq!(
642 store.get_value("CALC:SUM").await,
643 Some(ScalarValue::F64(15.0))
644 );
645 });
646 }
647}