Skip to main content

spvirit_server/
pva_server.rs

1//! High-level PVAccess server — builder pattern for typed records.
2//!
3//! # Example
4//!
5//! ```rust,ignore
6//! use spvirit_server::PvaServer;
7//!
8//! let server = PvaServer::builder()
9//!     .ai("SIM:TEMPERATURE", 22.5)
10//!     .ao("SIM:SETPOINT", 25.0)
11//!     .bo("SIM:ENABLE", false)
12//!     .build();
13//!
14//! server.run().await?;
15//! ```
16
17use std::collections::HashMap;
18use std::net::IpAddr;
19use std::sync::Arc;
20use std::time::Duration;
21
22use regex::Regex;
23use tracing::info;
24
25use spvirit_types::{NtEnum, NtScalar, NtScalarArray, PvValue, ScalarArrayValue, ScalarValue};
26
27use crate::db::{load_db, parse_db};
28use crate::handler::PvListMode;
29use crate::monitor::MonitorRegistry;
30use crate::server::{PvaServerConfig, run_pva_server_with_registry};
31use crate::simple_store::{LinkDef, OnPutCallback, ScanCallback, SimplePvStore};
32use crate::types::{DbCommonState, OutputMode, RecordData, RecordInstance, RecordType};
33
34// ─── PvaServerBuilder ────────────────────────────────────────────────────
35
36/// Builder for [`PvaServer`].
37///
38/// ```rust,ignore
39/// let server = PvaServer::builder()
40///     .ai("TEMP:READBACK", 22.5)
41///     .ao("TEMP:SETPOINT", 25.0)
42///     .bo("HEATER:ON", false)
43///     .port(5075)
44///     .build();
45/// ```
46pub struct PvaServerBuilder {
47    records: HashMap<String, RecordInstance>,
48    on_put: HashMap<String, OnPutCallback>,
49    scans: Vec<(String, Duration, ScanCallback)>,
50    links: Vec<LinkDef>,
51    tcp_port: u16,
52    udp_port: u16,
53    listen_ip: Option<IpAddr>,
54    advertise_ip: Option<IpAddr>,
55    compute_alarms: bool,
56    beacon_period_secs: u64,
57    conn_timeout: Duration,
58    pvlist_mode: PvListMode,
59    pvlist_max: usize,
60    pvlist_allow_pattern: Option<Regex>,
61}
62
63impl PvaServerBuilder {
64    fn new() -> Self {
65        Self {
66            records: HashMap::new(),
67            on_put: HashMap::new(),
68            scans: Vec::new(),
69            links: Vec::new(),
70            tcp_port: 5075,
71            udp_port: 5076,
72            listen_ip: None,
73            advertise_ip: None,
74            compute_alarms: false,
75            beacon_period_secs: 15,
76            conn_timeout: Duration::from_secs(64000),
77            pvlist_mode: PvListMode::List,
78            pvlist_max: 1024,
79            pvlist_allow_pattern: None,
80        }
81    }
82
83    // ─── Typed record constructors ───────────────────────────────────
84
85    /// Add an `ai` (analog input, read-only) record.
86    pub fn ai(mut self, name: impl Into<String>, initial: f64) -> Self {
87        let name = name.into();
88        self.records.insert(
89            name.clone(),
90            make_scalar_record(&name, RecordType::Ai, ScalarValue::F64(initial)),
91        );
92        self
93    }
94
95    /// Add an `ao` (analog output, writable) record.
96    pub fn ao(mut self, name: impl Into<String>, initial: f64) -> Self {
97        let name = name.into();
98        self.records.insert(
99            name.clone(),
100            make_output_record(&name, RecordType::Ao, ScalarValue::F64(initial)),
101        );
102        self
103    }
104
105    /// Add a `bi` (binary input, read-only) record.
106    pub fn bi(mut self, name: impl Into<String>, initial: bool) -> Self {
107        let name = name.into();
108        self.records.insert(
109            name.clone(),
110            make_scalar_record(&name, RecordType::Bi, ScalarValue::Bool(initial)),
111        );
112        self
113    }
114
115    /// Add a `bo` (binary output, writable) record.
116    pub fn bo(mut self, name: impl Into<String>, initial: bool) -> Self {
117        let name = name.into();
118        self.records.insert(
119            name.clone(),
120            make_output_record(&name, RecordType::Bo, ScalarValue::Bool(initial)),
121        );
122        self
123    }
124
125    /// Add a `stringin` (string input, read-only) record.
126    pub fn string_in(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
127        let name = name.into();
128        self.records.insert(
129            name.clone(),
130            make_scalar_record(
131                &name,
132                RecordType::StringIn,
133                ScalarValue::Str(initial.into()),
134            ),
135        );
136        self
137    }
138
139    /// Add a `stringout` (string output, writable) record.
140    pub fn string_out(mut self, name: impl Into<String>, initial: impl Into<String>) -> Self {
141        let name = name.into();
142        self.records.insert(
143            name.clone(),
144            make_output_record(
145                &name,
146                RecordType::StringOut,
147                ScalarValue::Str(initial.into()),
148            ),
149        );
150        self
151    }
152
153    /// Add a `waveform` record (array) with the given initial data.
154    pub fn waveform(mut self, name: impl Into<String>, data: ScalarArrayValue) -> Self {
155        let name = name.into();
156        let ftvl = data.type_label().trim_end_matches("[]").to_string();
157        let nelm = data.len();
158        self.records.insert(
159            name.clone(),
160            RecordInstance {
161                name: name.clone(),
162                record_type: RecordType::Waveform,
163                common: DbCommonState::default(),
164                data: RecordData::Waveform {
165                    nt: NtScalarArray::from_value(data),
166                    inp: None,
167                    ftvl,
168                    nelm,
169                    nord: nelm,
170                },
171                raw_fields: HashMap::new(),
172            },
173        );
174        self
175    }
176
177    /// Add an `mbbi` (multi-bit binary input, read-only) NTEnum record.
178    pub fn mbbi(
179        mut self,
180        name: impl Into<String>,
181        choices: Vec<String>,
182        initial: i32,
183    ) -> Self {
184        let name = name.into();
185        self.records.insert(
186            name.clone(),
187            RecordInstance {
188                name: name.clone(),
189                record_type: RecordType::Mbbi,
190                common: DbCommonState::default(),
191                data: RecordData::NtEnum {
192                    nt: NtEnum::new(initial, choices),
193                    inp: None,
194                    out: None,
195                    omsl: OutputMode::Supervisory,
196                },
197                raw_fields: HashMap::new(),
198            },
199        );
200        self
201    }
202
203    /// Add an `mbbo` (multi-bit binary output, writable) NTEnum record.
204    pub fn mbbo(
205        mut self,
206        name: impl Into<String>,
207        choices: Vec<String>,
208        initial: i32,
209    ) -> Self {
210        let name = name.into();
211        self.records.insert(
212            name.clone(),
213            RecordInstance {
214                name: name.clone(),
215                record_type: RecordType::Mbbo,
216                common: DbCommonState::default(),
217                data: RecordData::NtEnum {
218                    nt: NtEnum::new(initial, choices),
219                    inp: None,
220                    out: None,
221                    omsl: OutputMode::Supervisory,
222                },
223                raw_fields: HashMap::new(),
224            },
225        );
226        self
227    }
228
229    /// Add a generic structure record with a custom struct ID and fields.
230    pub fn generic(
231        mut self,
232        name: impl Into<String>,
233        struct_id: impl Into<String>,
234        fields: Vec<(String, PvValue)>,
235    ) -> Self {
236        let name = name.into();
237        self.records.insert(
238            name.clone(),
239            RecordInstance {
240                name: name.clone(),
241                record_type: RecordType::Generic,
242                common: DbCommonState::default(),
243                data: RecordData::Generic {
244                    struct_id: struct_id.into(),
245                    fields,
246                    inp: None,
247                    out: None,
248                    omsl: OutputMode::Supervisory,
249                },
250                raw_fields: HashMap::new(),
251            },
252        );
253        self
254    }
255
256    // ─── .db file loading ────────────────────────────────────────────
257
258    /// Load records from an EPICS `.db` file.
259    pub fn db_file(mut self, path: impl AsRef<str>) -> Self {
260        match load_db(path.as_ref()) {
261            Ok(records) => {
262                self.records.extend(records);
263            }
264            Err(e) => {
265                tracing::error!("Failed to load db file '{}': {}", path.as_ref(), e);
266            }
267        }
268        self
269    }
270
271    /// Parse records from an EPICS `.db` string.
272    pub fn db_string(mut self, content: &str) -> Self {
273        match parse_db(content) {
274            Ok(records) => {
275                self.records.extend(records);
276            }
277            Err(e) => {
278                tracing::error!("Failed to parse db string: {}", e);
279            }
280        }
281        self
282    }
283
284    // ─── Callbacks ───────────────────────────────────────────────────
285
286    /// Register a callback invoked when a PUT is applied to the named PV.
287    pub fn on_put<F>(mut self, name: impl Into<String>, callback: F) -> Self
288    where
289        F: Fn(&str, &spvirit_codec::spvd_decode::DecodedValue) + Send + Sync + 'static,
290    {
291        self.on_put.insert(name.into(), Arc::new(callback));
292        self
293    }
294
295    /// Register a periodic scan callback that produces a new value for a PV.
296    pub fn scan<F>(mut self, name: impl Into<String>, period: Duration, callback: F) -> Self
297    where
298        F: Fn(&str) -> ScalarValue + Send + Sync + 'static,
299    {
300        self.scans.push((name.into(), period, Arc::new(callback)));
301        self
302    }
303
304    /// Link an output PV to one or more input PVs.
305    ///
306    /// Whenever any input PV changes (via `set_value`, protocol PUT, or
307    /// another link), the `compute` callback is invoked with the current
308    /// values of **all** inputs (in order) and the result is written to
309    /// the output PV.
310    ///
311    /// ```rust,ignore
312    /// .link("CALC:SUM", &["INPUT:A", "INPUT:B"], |values| {
313    ///     let a = values[0].as_f64().unwrap_or(0.0);
314    ///     let b = values[1].as_f64().unwrap_or(0.0);
315    ///     ScalarValue::F64(a + b)
316    /// })
317    /// ```
318    pub fn link<F>(mut self, output: impl Into<String>, inputs: &[&str], compute: F) -> Self
319    where
320        F: Fn(&[ScalarValue]) -> ScalarValue + Send + Sync + 'static,
321    {
322        self.links.push(LinkDef {
323            output: output.into(),
324            inputs: inputs.iter().map(|s| s.to_string()).collect(),
325            compute: Arc::new(compute),
326        });
327        self
328    }
329
330    // ─── Configuration ───────────────────────────────────────────────
331
332    /// Set the TCP port (default 5075).
333    pub fn port(mut self, port: u16) -> Self {
334        self.tcp_port = port;
335        self
336    }
337
338    /// Set the UDP search port (default 5076).
339    pub fn udp_port(mut self, port: u16) -> Self {
340        self.udp_port = port;
341        self
342    }
343
344    /// Set the IP address to listen on.
345    pub fn listen_ip(mut self, ip: IpAddr) -> Self {
346        self.listen_ip = Some(ip);
347        self
348    }
349
350    /// Set the IP address to advertise in search responses.
351    pub fn advertise_ip(mut self, ip: IpAddr) -> Self {
352        self.advertise_ip = Some(ip);
353        self
354    }
355
356    /// Enable alarm computation from limits.
357    pub fn compute_alarms(mut self, enabled: bool) -> Self {
358        self.compute_alarms = enabled;
359        self
360    }
361
362    /// Set the beacon broadcast period in seconds (default 15).
363    pub fn beacon_period(mut self, secs: u64) -> Self {
364        self.beacon_period_secs = secs;
365        self
366    }
367
368    /// Set the idle connection timeout (default ~18 hours).
369    pub fn conn_timeout(mut self, timeout: Duration) -> Self {
370        self.conn_timeout = timeout;
371        self
372    }
373
374    /// Set the PV list mode (default [`PvListMode::List`]).
375    pub fn pvlist_mode(mut self, mode: PvListMode) -> Self {
376        self.pvlist_mode = mode;
377        self
378    }
379
380    /// Set the maximum number of PV names in pvlist responses (default 1024).
381    pub fn pvlist_max(mut self, max: usize) -> Self {
382        self.pvlist_max = max;
383        self
384    }
385
386    /// Set a regex filter for PV names exposed by pvlist.
387    pub fn pvlist_allow_pattern(mut self, pattern: Regex) -> Self {
388        self.pvlist_allow_pattern = Some(pattern);
389        self
390    }
391
392    /// Build the [`PvaServer`].
393    pub fn build(self) -> PvaServer {
394        let store = Arc::new(SimplePvStore::new(
395            self.records,
396            self.on_put,
397            self.links,
398            self.compute_alarms,
399        ));
400
401        let mut config = PvaServerConfig::default();
402        config.tcp_port = self.tcp_port;
403        config.udp_port = self.udp_port;
404        config.compute_alarms = self.compute_alarms;
405        if let Some(ip) = self.listen_ip {
406            config.listen_ip = ip;
407        }
408        config.advertise_ip = self.advertise_ip;
409        config.beacon_period_secs = self.beacon_period_secs;
410        config.conn_timeout = self.conn_timeout;
411        config.pvlist_mode = self.pvlist_mode;
412        config.pvlist_max = self.pvlist_max;
413        config.pvlist_allow_pattern = self.pvlist_allow_pattern;
414
415        PvaServer {
416            store,
417            config,
418            scans: self.scans,
419        }
420    }
421}
422
423// ─── PvaServer ───────────────────────────────────────────────────────────
424
425/// High-level PVAccess server.
426///
427/// Built via [`PvaServer::builder()`] with typed record constructors,
428/// `.db_file()` loading, `.on_put()` / `.scan()` callbacks, and a
429/// simple `.run()` to start serving.
430///
431/// ```rust,ignore
432/// let server = PvaServer::builder()
433///     .ai("SIM:TEMP", 22.5)
434///     .ao("SIM:SP", 25.0)
435///     .build();
436///
437/// // Read/write PVs from another task:
438/// let store = server.store();
439/// store.set_value("SIM:TEMP", ScalarValue::F64(23.1)).await;
440///
441/// server.run().await?;
442/// ```
443pub struct PvaServer {
444    store: Arc<SimplePvStore>,
445    config: PvaServerConfig,
446    scans: Vec<(String, Duration, ScanCallback)>,
447}
448
449impl PvaServer {
450    /// Create a builder for configuring a [`PvaServer`].
451    pub fn builder() -> PvaServerBuilder {
452        PvaServerBuilder::new()
453    }
454
455    /// Get a reference to the underlying store for runtime get/put.
456    pub fn store(&self) -> &Arc<SimplePvStore> {
457        &self.store
458    }
459
460    /// Start the PVA server (UDP search + TCP handler + beacon + scan tasks).
461    ///
462    /// This blocks until the server is shut down or an error occurs.
463    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
464        // Create the monitor registry early so scan tasks can notify
465        // PVAccess monitor clients when values change.
466        let registry = Arc::new(MonitorRegistry::new());
467        self.store.set_registry(registry.clone()).await;
468
469        // Spawn scan tasks.
470        for (name, period, callback) in &self.scans {
471            let store = self.store.clone();
472            let name = name.clone();
473            let period = *period;
474            let callback = callback.clone();
475            tokio::spawn(async move {
476                let mut interval = tokio::time::interval(period);
477                loop {
478                    interval.tick().await;
479                    let new_val = callback(&name);
480                    store.set_value(&name, new_val).await;
481                }
482            });
483        }
484
485        let pv_count = self.store.pv_names().await.len();
486        info!(
487            "PvaServer starting: {} PVs on port {}",
488            pv_count, self.config.tcp_port
489        );
490
491        run_pva_server_with_registry(self.store, self.config, registry).await
492    }
493}
494
495// ─── Record construction helpers ─────────────────────────────────────────
496
497fn make_scalar_record(name: &str, record_type: RecordType, value: ScalarValue) -> RecordInstance {
498    let nt = NtScalar::from_value(value);
499    let data = match record_type {
500        RecordType::Ai => RecordData::Ai {
501            nt,
502            inp: None,
503            siml: None,
504            siol: None,
505            simm: false,
506        },
507        RecordType::Bi => RecordData::Bi {
508            nt,
509            inp: None,
510            znam: "Off".to_string(),
511            onam: "On".to_string(),
512            siml: None,
513            siol: None,
514            simm: false,
515        },
516        RecordType::StringIn => RecordData::StringIn {
517            nt,
518            inp: None,
519            siml: None,
520            siol: None,
521            simm: false,
522        },
523        _ => panic!("make_scalar_record: unsupported type {record_type:?}"),
524    };
525    RecordInstance {
526        name: name.to_string(),
527        record_type,
528        common: DbCommonState::default(),
529        data,
530        raw_fields: HashMap::new(),
531    }
532}
533
534fn make_output_record(name: &str, record_type: RecordType, value: ScalarValue) -> RecordInstance {
535    let nt = NtScalar::from_value(value);
536    let data = match record_type {
537        RecordType::Ao => RecordData::Ao {
538            nt,
539            out: None,
540            dol: None,
541            omsl: OutputMode::Supervisory,
542            drvl: None,
543            drvh: None,
544            oroc: None,
545            siml: None,
546            siol: None,
547            simm: false,
548        },
549        RecordType::Bo => RecordData::Bo {
550            nt,
551            out: None,
552            dol: None,
553            omsl: OutputMode::Supervisory,
554            znam: "Off".to_string(),
555            onam: "On".to_string(),
556            siml: None,
557            siol: None,
558            simm: false,
559        },
560        RecordType::StringOut => RecordData::StringOut {
561            nt,
562            out: None,
563            dol: None,
564            omsl: OutputMode::Supervisory,
565            siml: None,
566            siol: None,
567            simm: false,
568        },
569        _ => panic!("make_output_record: unsupported type {record_type:?}"),
570    };
571    RecordInstance {
572        name: name.to_string(),
573        record_type,
574        common: DbCommonState::default(),
575        data,
576        raw_fields: HashMap::new(),
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583
584    #[test]
585    fn builder_creates_records() {
586        let server = PvaServer::builder()
587            .ai("T:AI", 1.0)
588            .ao("T:AO", 2.0)
589            .bi("T:BI", true)
590            .bo("T:BO", false)
591            .string_in("T:SI", "hello")
592            .string_out("T:SO", "world")
593            .build();
594
595        let rt = tokio::runtime::Builder::new_current_thread()
596            .enable_all()
597            .build()
598            .unwrap();
599        let names = rt.block_on(server.store.pv_names());
600        assert_eq!(names.len(), 6);
601    }
602
603    #[test]
604    fn builder_defaults() {
605        let server = PvaServer::builder().build();
606        assert_eq!(server.config.tcp_port, 5075);
607        assert_eq!(server.config.udp_port, 5076);
608        assert!(!server.config.compute_alarms);
609    }
610
611    #[test]
612    fn builder_port_override() {
613        let server = PvaServer::builder().port(9075).udp_port(9076).build();
614        assert_eq!(server.config.tcp_port, 9075);
615        assert_eq!(server.config.udp_port, 9076);
616    }
617
618    #[test]
619    fn builder_db_string() {
620        let db = r#"
621            record(ai, "TEST:VAL") {
622                field(VAL, "3.14")
623            }
624        "#;
625        let server = PvaServer::builder().db_string(db).build();
626        let rt = tokio::runtime::Builder::new_current_thread()
627            .enable_all()
628            .build()
629            .unwrap();
630        assert!(rt.block_on(server.store.get_value("TEST:VAL")).is_some());
631    }
632
633    #[test]
634    fn builder_waveform() {
635        let data = ScalarArrayValue::F64(vec![1.0, 2.0, 3.0]);
636        let server = PvaServer::builder().waveform("T:WF", data).build();
637        let rt = tokio::runtime::Builder::new_current_thread()
638            .enable_all()
639            .build()
640            .unwrap();
641        let names = rt.block_on(server.store.pv_names());
642        assert!(names.contains(&"T:WF".to_string()));
643    }
644
645    #[test]
646    fn builder_scan_callback() {
647        let server = PvaServer::builder()
648            .ai("SCAN:V", 0.0)
649            .scan("SCAN:V", Duration::from_secs(1), |_name| {
650                ScalarValue::F64(42.0)
651            })
652            .build();
653        assert_eq!(server.scans.len(), 1);
654    }
655
656    #[test]
657    fn builder_on_put_callback() {
658        let server = PvaServer::builder()
659            .ao("PUT:V", 0.0)
660            .on_put("PUT:V", |_name, _val| {})
661            .build();
662        // on_put is stored in the SimplePvStore, not directly inspectable,
663        // but the server built without panic.
664        let rt = tokio::runtime::Builder::new_current_thread()
665            .enable_all()
666            .build()
667            .unwrap();
668        assert!(rt.block_on(server.store.get_value("PUT:V")).is_some());
669    }
670
671    #[test]
672    fn store_runtime_get_set() {
673        let server = PvaServer::builder().ao("RT:V", 0.0).build();
674        let rt = tokio::runtime::Builder::new_current_thread()
675            .enable_all()
676            .build()
677            .unwrap();
678        let store = server.store().clone();
679        rt.block_on(async {
680            assert_eq!(store.get_value("RT:V").await, Some(ScalarValue::F64(0.0)));
681            store.set_value("RT:V", ScalarValue::F64(99.0)).await;
682            assert_eq!(store.get_value("RT:V").await, Some(ScalarValue::F64(99.0)));
683        });
684    }
685
686    #[test]
687    fn link_propagates_on_set_value() {
688        let server = PvaServer::builder()
689            .ao("INPUT:A", 1.0)
690            .ao("INPUT:B", 2.0)
691            .ai("CALC:SUM", 0.0)
692            .link("CALC:SUM", &["INPUT:A", "INPUT:B"], |values| {
693                let a = match &values[0] {
694                    ScalarValue::F64(v) => *v,
695                    _ => 0.0,
696                };
697                let b = match &values[1] {
698                    ScalarValue::F64(v) => *v,
699                    _ => 0.0,
700                };
701                ScalarValue::F64(a + b)
702            })
703            .build();
704
705        let rt = tokio::runtime::Builder::new_current_thread()
706            .enable_all()
707            .build()
708            .unwrap();
709        let store = server.store().clone();
710        rt.block_on(async {
711            // Writing INPUT:A should recompute CALC:SUM = 10 + 2.
712            store.set_value("INPUT:A", ScalarValue::F64(10.0)).await;
713            assert_eq!(
714                store.get_value("CALC:SUM").await,
715                Some(ScalarValue::F64(12.0))
716            );
717
718            // Writing INPUT:B should recompute CALC:SUM = 10 + 5.
719            store.set_value("INPUT:B", ScalarValue::F64(5.0)).await;
720            assert_eq!(
721                store.get_value("CALC:SUM").await,
722                Some(ScalarValue::F64(15.0))
723            );
724        });
725    }
726}