Skip to main content

spvirit_server/
simple_store.rs

1//! A simple in-memory [`PvStore`] implementation backed by `RecordInstance`.
2//!
3//! Used by [`PvaServer`](crate::pva_server::PvaServer) to serve PVs without
4//! requiring an external database.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use tokio::sync::{RwLock, mpsc};
10use tracing::debug;
11
12use spvirit_codec::spvd_decode::{DecodedValue, FieldDesc, FieldType, StructureDesc, TypeCode};
13use spvirit_types::{NtPayload, ScalarArrayValue, ScalarValue};
14
15use crate::apply::{
16    apply_alarm_update, apply_control_update, apply_display_update, apply_scalar_array_put,
17    apply_value_update,
18};
19use crate::monitor::MonitorRegistry;
20use crate::pvstore::PvStore;
21use crate::types::{RecordData, RecordInstance};
22
23/// Callback invoked after a PUT value is applied to a record.
24pub type OnPutCallback = Arc<dyn Fn(&str, &DecodedValue) + Send + Sync>;
25
26/// Callback invoked by the scan scheduler; returns the new value for the PV.
27pub type ScanCallback = Arc<dyn Fn(&str) -> ScalarValue + Send + Sync>;
28
29/// Callback that computes a derived PV value from its input values.
30pub type LinkCallback = Arc<dyn Fn(&[ScalarValue]) -> ScalarValue + Send + Sync>;
31
32/// A link from one or more input PVs to a computed output PV.
33pub(crate) struct LinkDef {
34    pub output: String,
35    pub inputs: Vec<String>,
36    pub compute: LinkCallback,
37}
38
39struct PvEntry {
40    record: RecordInstance,
41    subscribers: Vec<mpsc::Sender<NtPayload>>,
42}
43
44/// A simple in-memory PV store.
45pub struct SimplePvStore {
46    pvs: RwLock<HashMap<String, PvEntry>>,
47    on_put: HashMap<String, OnPutCallback>,
48    links: Vec<LinkDef>,
49    compute_alarms: bool,
50    registry: RwLock<Option<Arc<MonitorRegistry>>>,
51}
52
53impl SimplePvStore {
54    pub(crate) fn new(
55        records: HashMap<String, RecordInstance>,
56        on_put: HashMap<String, OnPutCallback>,
57        links: Vec<LinkDef>,
58        compute_alarms: bool,
59    ) -> Self {
60        let pvs = records
61            .into_iter()
62            .map(|(name, record)| {
63                (
64                    name,
65                    PvEntry {
66                        record,
67                        subscribers: Vec::new(),
68                    },
69                )
70            })
71            .collect();
72        Self {
73            pvs: RwLock::new(pvs),
74            on_put,
75            links,
76            compute_alarms,
77            registry: RwLock::new(None),
78        }
79    }
80
81    /// Attach the [`MonitorRegistry`] so that `set_value` can push updates
82    /// to PVAccess monitor clients.  Called automatically by [`PvaServer::run`].
83    pub async fn set_registry(&self, registry: Arc<MonitorRegistry>) {
84        *self.registry.write().await = Some(registry);
85    }
86
87    /// Insert or replace a PV record at runtime.
88    pub async fn insert(&self, name: String, record: RecordInstance) {
89        let mut pvs = self.pvs.write().await;
90        pvs.insert(
91            name,
92            PvEntry {
93                record,
94                subscribers: Vec::new(),
95            },
96        );
97    }
98
99    /// Read the current [`ScalarValue`] of a PV.
100    pub async fn get_value(&self, name: &str) -> Option<ScalarValue> {
101        let pvs = self.pvs.read().await;
102        pvs.get(name).map(|e| e.record.current_value())
103    }
104
105    /// Read the full [`NtPayload`] of a PV.
106    pub async fn get_nt(&self, name: &str) -> Option<NtPayload> {
107        let pvs = self.pvs.read().await;
108        pvs.get(name).map(|e| e.record.to_ntpayload())
109    }
110
111    /// Write a [`ScalarValue`] to a PV (bypasses on_put).
112    pub async fn set_value(&self, name: &str, value: ScalarValue) -> bool {
113        if self.set_value_inner(name, value).await {
114            self.evaluate_links(name).await;
115            true
116        } else {
117            false
118        }
119    }
120
121    /// Write a [`ScalarArrayValue`] to an array PV (bypasses on_put).
122    pub async fn set_array_value(&self, name: &str, value: ScalarArrayValue) -> bool {
123        if self.set_array_value_inner(name, value).await {
124            self.evaluate_links(name).await;
125            true
126        } else {
127            false
128        }
129    }
130
131    /// Write a full [`NtPayload`] to a PV (bypasses on_put).
132    pub async fn put_nt(&self, name: &str, payload: NtPayload) -> bool {
133        if self.put_nt_inner(name, payload).await {
134            self.evaluate_links(name).await;
135            true
136        } else {
137            false
138        }
139    }
140
141    /// Core write logic — updates the value, notifies subscribers and monitors,
142    /// but does **not** trigger link evaluation (to avoid recursion).
143    async fn set_value_inner(&self, name: &str, value: ScalarValue) -> bool {
144        let payload = {
145            let mut pvs = self.pvs.write().await;
146            if let Some(entry) = pvs.get_mut(name) {
147                let changed = entry.record.set_scalar_value(value, self.compute_alarms);
148                if changed {
149                    let payload = entry.record.to_ntpayload();
150                    entry
151                        .subscribers
152                        .retain(|tx| tx.try_send(payload.clone()).is_ok());
153                    Some(payload)
154                } else {
155                    None
156                }
157            } else {
158                return false;
159            }
160        };
161
162        if let Some(payload) = payload {
163            // Notify PVAccess monitor clients (if the registry is attached).
164            let reg = self.registry.read().await;
165            if let Some(registry) = reg.as_ref() {
166                registry.notify_monitors(name, &payload).await;
167            }
168            true
169        } else {
170            false
171        }
172    }
173
174    /// Core array write logic — updates the value, notifies subscribers and monitors,
175    /// but does **not** trigger link evaluation (to avoid recursion).
176    async fn set_array_value_inner(&self, name: &str, value: ScalarArrayValue) -> bool {
177        let payload = {
178            let mut pvs = self.pvs.write().await;
179            if let Some(entry) = pvs.get_mut(name) {
180                let changed = entry.record.set_array_value(value);
181                if changed {
182                    let payload = entry.record.to_ntpayload();
183                    entry
184                        .subscribers
185                        .retain(|tx| tx.try_send(payload.clone()).is_ok());
186                    Some(payload)
187                } else {
188                    None
189                }
190            } else {
191                return false;
192            }
193        };
194
195        if let Some(payload) = payload {
196            // Notify PVAccess monitor clients (if the registry is attached).
197            let reg = self.registry.read().await;
198            if let Some(registry) = reg.as_ref() {
199                registry.notify_monitors(name, &payload).await;
200            }
201            true
202        } else {
203            false
204        }
205    }
206
207    /// Core NtPayload write logic — updates the payload, notifies subscribers
208    /// and monitors, but does **not** trigger link evaluation.
209    async fn put_nt_inner(&self, name: &str, payload: NtPayload) -> bool {
210        let payload = {
211            let mut pvs = self.pvs.write().await;
212            if let Some(entry) = pvs.get_mut(name) {
213                let changed = entry.record.set_nt_payload(payload);
214                if changed {
215                    let payload = entry.record.to_ntpayload();
216                    entry
217                        .subscribers
218                        .retain(|tx| tx.try_send(payload.clone()).is_ok());
219                    Some(payload)
220                } else {
221                    None
222                }
223            } else {
224                return false;
225            }
226        };
227
228        if let Some(payload) = payload {
229            // Notify PVAccess monitor clients (if the registry is attached).
230            let reg = self.registry.read().await;
231            if let Some(registry) = reg.as_ref() {
232                registry.notify_monitors(name, &payload).await;
233            }
234            true
235        } else {
236            false
237        }
238    }
239
240    /// Walk every link whose inputs include `changed_pv`, compute the output,
241    /// and propagate (BFS with cycle detection).
242    async fn evaluate_links(&self, changed_pv: &str) {
243        if self.links.is_empty() {
244            return;
245        }
246        let mut queue = vec![changed_pv.to_string()];
247        let mut visited = HashSet::new();
248
249        while let Some(pv) = queue.pop() {
250            if !visited.insert(pv.clone()) {
251                debug!("Circular link detected for PV '{}', skipping", pv);
252                continue;
253            }
254            for link in &self.links {
255                if !link.inputs.iter().any(|i| i == &pv) {
256                    continue;
257                }
258                // Gather current values of all inputs.
259                let values = {
260                    let pvs = self.pvs.read().await;
261                    link.inputs
262                        .iter()
263                        .map(|n| {
264                            pvs.get(n)
265                                .map(|e| e.record.current_value())
266                                .unwrap_or(ScalarValue::F64(0.0))
267                        })
268                        .collect::<Vec<_>>()
269                };
270                let new_val = (link.compute)(&values);
271                if self.set_value_inner(&link.output, new_val).await {
272                    queue.push(link.output.clone());
273                }
274            }
275        }
276    }
277
278    /// List all PV names.
279    pub async fn pv_names(&self) -> Vec<String> {
280        let pvs = self.pvs.read().await;
281        pvs.keys().cloned().collect()
282    }
283}
284
285impl PvStore for SimplePvStore {
286    fn has_pv(&self, name: &str) -> impl Future<Output = bool> + Send {
287        async move {
288            let pvs = self.pvs.read().await;
289            pvs.contains_key(name)
290        }
291    }
292
293    fn get_snapshot(&self, name: &str) -> impl Future<Output = Option<NtPayload>> + Send {
294        async move {
295            let pvs = self.pvs.read().await;
296            pvs.get(name).map(|e| e.record.to_ntpayload())
297        }
298    }
299
300    fn get_descriptor(&self, name: &str) -> impl Future<Output = Option<StructureDesc>> + Send {
301        async move {
302            let pvs = self.pvs.read().await;
303            pvs.get(name)
304                .map(|e| descriptor_for_payload(&e.record.to_ntpayload()))
305        }
306    }
307
308    fn put_value(
309        &self,
310        name: &str,
311        value: &DecodedValue,
312    ) -> impl Future<Output = Result<Vec<(String, NtPayload)>, String>> + Send {
313        let name = name.to_string();
314        let value = value.clone();
315        async move {
316            let result = {
317                let mut pvs = self.pvs.write().await;
318                let entry = pvs
319                    .get_mut(&name)
320                    .ok_or_else(|| format!("PV '{}' not found", name))?;
321
322                if !entry.record.writable() {
323                    return Err(format!("PV '{}' is not writable", name));
324                }
325
326                let changed = apply_put_to_record(&mut entry.record, &value, self.compute_alarms);
327                if !changed {
328                    return Ok(vec![]);
329                }
330
331                let payload = entry.record.to_ntpayload();
332                entry
333                    .subscribers
334                    .retain(|tx| tx.try_send(payload.clone()).is_ok());
335
336                (name.clone(), payload)
337            }; // pvs lock dropped
338
339            // Fire on_put callback (non-blocking).
340            if let Some(cb) = self.on_put.get(&name) {
341                let cb = cb.clone();
342                let n = name.clone();
343                let v = value.clone();
344                tokio::spawn(async move { cb(&n, &v) });
345            }
346
347            // Propagate linked PV updates.
348            self.evaluate_links(&name).await;
349
350            Ok(vec![result])
351        }
352    }
353
354    fn is_writable(&self, name: &str) -> impl Future<Output = bool> + Send {
355        async move {
356            let pvs = self.pvs.read().await;
357            pvs.get(name).is_some_and(|e| e.record.writable())
358        }
359    }
360
361    fn list_pvs(&self) -> impl Future<Output = Vec<String>> + Send {
362        async move {
363            let pvs = self.pvs.read().await;
364            pvs.keys().cloned().collect()
365        }
366    }
367
368    fn subscribe(
369        &self,
370        name: &str,
371    ) -> impl Future<Output = Option<mpsc::Receiver<NtPayload>>> + Send {
372        let name = name.to_string();
373        async move {
374            let mut pvs = self.pvs.write().await;
375            let entry = pvs.get_mut(&name)?;
376            let (tx, rx) = mpsc::channel(64);
377            entry.subscribers.push(tx);
378            Some(rx)
379        }
380    }
381}
382
383// ── Helpers ──────────────────────────────────────────────────────────────
384
385/// Apply a decoded PUT value to a RecordInstance, returning whether it changed.
386fn apply_put_to_record(
387    record: &mut RecordInstance,
388    value: &DecodedValue,
389    compute_alarms: bool,
390) -> bool {
391    let fields = match value {
392        DecodedValue::Structure(f) => f,
393        other => {
394            // Bare scalar — wrap as value field.
395            return apply_put_to_record(
396                record,
397                &DecodedValue::Structure(vec![("value".to_string(), other.clone())]),
398                compute_alarms,
399            );
400        }
401    };
402
403    let mut changed = false;
404
405    match &mut record.data {
406        RecordData::Ai { nt, .. }
407        | RecordData::Ao { nt, .. }
408        | RecordData::Bi { nt, .. }
409        | RecordData::Bo { nt, .. }
410        | RecordData::StringIn { nt, .. }
411        | RecordData::StringOut { nt, .. } => {
412            for (name, val) in fields {
413                match name.as_str() {
414                    "value" => {
415                        changed |= apply_value_update(nt, val, compute_alarms);
416                    }
417                    "alarm" => {
418                        changed |= apply_alarm_update(nt, val);
419                    }
420                    "display" => {
421                        changed |= apply_display_update(nt, val);
422                    }
423                    "control" => {
424                        changed |= apply_control_update(nt, val);
425                    }
426                    _ => {}
427                }
428            }
429        }
430        RecordData::Waveform { nt, nord, .. }
431        | RecordData::Aai { nt, nord, .. }
432        | RecordData::Aao { nt, nord, .. }
433        | RecordData::SubArray { nt, nord, .. } => {
434            changed = apply_scalar_array_put(nt, nord, value);
435        }
436        RecordData::NtTable { .. } | RecordData::NtNdArray { .. } => {
437            // Table/NdArray PUT not supported via high-level API yet.
438            debug!("PUT to NtTable/NtNdArray not yet supported in SimplePvStore");
439        }
440        RecordData::NtEnum { nt, .. } => {
441            // Accept index updates for NtEnum PVs.
442            for (name, val) in fields {
443                if name == "value" {
444                    let idx = match val {
445                        DecodedValue::Int32(v) => Some(*v),
446                        DecodedValue::Int64(v) => Some(*v as i32),
447                        DecodedValue::Int16(v) => Some(*v as i32),
448                        DecodedValue::Int8(v) => Some(*v as i32),
449                        DecodedValue::Float64(v) => Some(*v as i32),
450                        _ => None,
451                    };
452                    if let Some(idx) = idx {
453                        if nt.index != idx {
454                            nt.index = idx;
455                            changed = true;
456                        }
457                    }
458                }
459            }
460        }
461        RecordData::Generic { .. } => {
462            debug!("PUT to Generic not yet supported in SimplePvStore");
463        }
464    }
465
466    changed
467}
468
469// ── NtPayload → StructureDesc ────────────────────────────────────────────
470
471pub(crate) fn descriptor_for_payload(payload: &NtPayload) -> StructureDesc {
472    match payload {
473        NtPayload::Scalar(nt) => nt_scalar_desc(&nt.value),
474        NtPayload::ScalarArray(arr) => nt_scalar_array_desc(&arr.value),
475        _ => StructureDesc::new(),
476    }
477}
478
479fn value_type_code(sv: &ScalarValue) -> TypeCode {
480    match sv {
481        ScalarValue::Bool(_) => TypeCode::Boolean,
482        ScalarValue::I8(_) => TypeCode::Int8,
483        ScalarValue::I16(_) => TypeCode::Int16,
484        ScalarValue::I32(_) => TypeCode::Int32,
485        ScalarValue::I64(_) => TypeCode::Int64,
486        ScalarValue::U8(_) => TypeCode::UInt8,
487        ScalarValue::U16(_) => TypeCode::UInt16,
488        ScalarValue::U32(_) => TypeCode::UInt32,
489        ScalarValue::U64(_) => TypeCode::UInt64,
490        ScalarValue::F32(_) => TypeCode::Float32,
491        ScalarValue::F64(_) => TypeCode::Float64,
492        ScalarValue::Str(_) => TypeCode::String,
493    }
494}
495
496fn array_type_code(sav: &ScalarArrayValue) -> TypeCode {
497    match sav {
498        ScalarArrayValue::Bool(_) => TypeCode::Boolean,
499        ScalarArrayValue::I8(_) => TypeCode::Int8,
500        ScalarArrayValue::I16(_) => TypeCode::Int16,
501        ScalarArrayValue::I32(_) => TypeCode::Int32,
502        ScalarArrayValue::I64(_) => TypeCode::Int64,
503        ScalarArrayValue::U8(_) => TypeCode::UInt8,
504        ScalarArrayValue::U16(_) => TypeCode::UInt16,
505        ScalarArrayValue::U32(_) => TypeCode::UInt32,
506        ScalarArrayValue::U64(_) => TypeCode::UInt64,
507        ScalarArrayValue::F32(_) => TypeCode::Float32,
508        ScalarArrayValue::F64(_) => TypeCode::Float64,
509        ScalarArrayValue::Str(_) => TypeCode::String,
510    }
511}
512
513fn nt_scalar_desc(sv: &ScalarValue) -> StructureDesc {
514    let tc = value_type_code(sv);
515    StructureDesc {
516        struct_id: Some("epics:nt/NTScalar:1.0".to_string()),
517        fields: vec![
518            FieldDesc {
519                name: "value".to_string(),
520                field_type: FieldType::Scalar(tc),
521            },
522            FieldDesc {
523                name: "alarm".to_string(),
524                field_type: FieldType::Structure(alarm_desc()),
525            },
526            FieldDesc {
527                name: "timeStamp".to_string(),
528                field_type: FieldType::Structure(timestamp_desc()),
529            },
530            FieldDesc {
531                name: "display".to_string(),
532                field_type: FieldType::Structure(display_desc()),
533            },
534            FieldDesc {
535                name: "control".to_string(),
536                field_type: FieldType::Structure(control_desc()),
537            },
538            FieldDesc {
539                name: "valueAlarm".to_string(),
540                field_type: FieldType::Structure(value_alarm_desc()),
541            },
542        ],
543    }
544}
545
546fn nt_scalar_array_desc(sav: &ScalarArrayValue) -> StructureDesc {
547    let tc = array_type_code(sav);
548    StructureDesc {
549        struct_id: Some("epics:nt/NTScalarArray:1.0".to_string()),
550        fields: vec![
551            FieldDesc {
552                name: "value".to_string(),
553                field_type: FieldType::ScalarArray(tc),
554            },
555            FieldDesc {
556                name: "alarm".to_string(),
557                field_type: FieldType::Structure(alarm_desc()),
558            },
559            FieldDesc {
560                name: "timeStamp".to_string(),
561                field_type: FieldType::Structure(timestamp_desc()),
562            },
563            FieldDesc {
564                name: "display".to_string(),
565                field_type: FieldType::Structure(display_desc()),
566            },
567            FieldDesc {
568                name: "control".to_string(),
569                field_type: FieldType::Structure(control_desc()),
570            },
571        ],
572    }
573}
574
575fn alarm_desc() -> StructureDesc {
576    StructureDesc {
577        struct_id: Some("alarm_t".to_string()),
578        fields: vec![
579            FieldDesc {
580                name: "severity".to_string(),
581                field_type: FieldType::Scalar(TypeCode::Int32),
582            },
583            FieldDesc {
584                name: "status".to_string(),
585                field_type: FieldType::Scalar(TypeCode::Int32),
586            },
587            FieldDesc {
588                name: "message".to_string(),
589                field_type: FieldType::String,
590            },
591        ],
592    }
593}
594
595fn timestamp_desc() -> StructureDesc {
596    StructureDesc {
597        struct_id: Some("time_t".to_string()),
598        fields: vec![
599            FieldDesc {
600                name: "secondsPastEpoch".to_string(),
601                field_type: FieldType::Scalar(TypeCode::Int64),
602            },
603            FieldDesc {
604                name: "nanoseconds".to_string(),
605                field_type: FieldType::Scalar(TypeCode::Int32),
606            },
607            FieldDesc {
608                name: "userTag".to_string(),
609                field_type: FieldType::Scalar(TypeCode::Int32),
610            },
611        ],
612    }
613}
614
615fn display_desc() -> StructureDesc {
616    StructureDesc {
617        struct_id: Some("display_t".to_string()),
618        fields: vec![
619            FieldDesc {
620                name: "limitLow".to_string(),
621                field_type: FieldType::Scalar(TypeCode::Float64),
622            },
623            FieldDesc {
624                name: "limitHigh".to_string(),
625                field_type: FieldType::Scalar(TypeCode::Float64),
626            },
627            FieldDesc {
628                name: "description".to_string(),
629                field_type: FieldType::String,
630            },
631            FieldDesc {
632                name: "units".to_string(),
633                field_type: FieldType::String,
634            },
635            FieldDesc {
636                name: "precision".to_string(),
637                field_type: FieldType::Scalar(TypeCode::Int32),
638            },
639            FieldDesc {
640                name: "form".to_string(),
641                field_type: FieldType::Structure(StructureDesc {
642                    struct_id: Some("enum_t".to_string()),
643                    fields: vec![
644                        FieldDesc {
645                            name: "index".to_string(),
646                            field_type: FieldType::Scalar(TypeCode::Int32),
647                        },
648                        FieldDesc {
649                            name: "choices".to_string(),
650                            field_type: FieldType::StringArray,
651                        },
652                    ],
653                }),
654            },
655        ],
656    }
657}
658
659fn control_desc() -> StructureDesc {
660    StructureDesc {
661        struct_id: Some("control_t".to_string()),
662        fields: vec![
663            FieldDesc {
664                name: "limitLow".to_string(),
665                field_type: FieldType::Scalar(TypeCode::Float64),
666            },
667            FieldDesc {
668                name: "limitHigh".to_string(),
669                field_type: FieldType::Scalar(TypeCode::Float64),
670            },
671            FieldDesc {
672                name: "minStep".to_string(),
673                field_type: FieldType::Scalar(TypeCode::Float64),
674            },
675        ],
676    }
677}
678
679fn value_alarm_desc() -> StructureDesc {
680    StructureDesc {
681        struct_id: Some("valueAlarm_t".to_string()),
682        fields: vec![
683            FieldDesc {
684                name: "active".to_string(),
685                field_type: FieldType::Scalar(TypeCode::Boolean),
686            },
687            FieldDesc {
688                name: "lowAlarmLimit".to_string(),
689                field_type: FieldType::Scalar(TypeCode::Float64),
690            },
691            FieldDesc {
692                name: "lowWarningLimit".to_string(),
693                field_type: FieldType::Scalar(TypeCode::Float64),
694            },
695            FieldDesc {
696                name: "highWarningLimit".to_string(),
697                field_type: FieldType::Scalar(TypeCode::Float64),
698            },
699            FieldDesc {
700                name: "highAlarmLimit".to_string(),
701                field_type: FieldType::Scalar(TypeCode::Float64),
702            },
703            FieldDesc {
704                name: "lowAlarmSeverity".to_string(),
705                field_type: FieldType::Scalar(TypeCode::Int32),
706            },
707            FieldDesc {
708                name: "lowWarningSeverity".to_string(),
709                field_type: FieldType::Scalar(TypeCode::Int32),
710            },
711            FieldDesc {
712                name: "highWarningSeverity".to_string(),
713                field_type: FieldType::Scalar(TypeCode::Int32),
714            },
715            FieldDesc {
716                name: "highAlarmSeverity".to_string(),
717                field_type: FieldType::Scalar(TypeCode::Int32),
718            },
719            FieldDesc {
720                name: "hysteresis".to_string(),
721                field_type: FieldType::Scalar(TypeCode::UInt8),
722            },
723        ],
724    }
725}
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730    use crate::types::{DbCommonState, RecordType};
731    use spvirit_types::{
732        NdCodec, NdDimension, NtNdArray, NtPayload, NtScalar, NtScalarArray, NtTable,
733        NtTableColumn, ScalarArrayValue, ScalarValue,
734    };
735
736    fn make_ai(name: &str, val: f64) -> RecordInstance {
737        RecordInstance {
738            name: name.to_string(),
739            record_type: RecordType::Ai,
740            common: DbCommonState::default(),
741            data: RecordData::Ai {
742                nt: NtScalar::from_value(ScalarValue::F64(val)),
743                inp: None,
744                siml: None,
745                siol: None,
746                simm: false,
747            },
748            raw_fields: HashMap::new(),
749        }
750    }
751
752    fn make_ao(name: &str, val: f64) -> RecordInstance {
753        RecordInstance {
754            name: name.to_string(),
755            record_type: RecordType::Ao,
756            common: DbCommonState::default(),
757            data: RecordData::Ao {
758                nt: NtScalar::from_value(ScalarValue::F64(val)),
759                out: None,
760                dol: None,
761                omsl: crate::types::OutputMode::Supervisory,
762                drvl: None,
763                drvh: None,
764                oroc: None,
765                siml: None,
766                siol: None,
767                simm: false,
768            },
769            raw_fields: HashMap::new(),
770        }
771    }
772
773    fn make_waveform(name: &str, value: ScalarArrayValue) -> RecordInstance {
774        let nelm = value.len();
775        RecordInstance {
776            name: name.to_string(),
777            record_type: RecordType::Waveform,
778            common: DbCommonState::default(),
779            data: RecordData::Waveform {
780                nt: NtScalarArray::from_value(value),
781                inp: None,
782                ftvl: "DOUBLE".to_string(),
783                nelm,
784                nord: nelm,
785            },
786            raw_fields: HashMap::new(),
787        }
788    }
789
790    fn make_nt_table(name: &str) -> RecordInstance {
791        RecordInstance {
792            name: name.to_string(),
793            record_type: RecordType::NtTable,
794            common: DbCommonState::default(),
795            data: RecordData::NtTable {
796                nt: NtTable {
797                    labels: vec!["X".to_string(), "Y".to_string()],
798                    columns: vec![
799                        NtTableColumn {
800                            name: "x".to_string(),
801                            values: ScalarArrayValue::F64(vec![1.0, 2.0]),
802                        },
803                        NtTableColumn {
804                            name: "y".to_string(),
805                            values: ScalarArrayValue::F64(vec![10.0, 20.0]),
806                        },
807                    ],
808                    descriptor: Some("table".to_string()),
809                    alarm: None,
810                    time_stamp: None,
811                },
812                inp: None,
813                out: None,
814                omsl: crate::types::OutputMode::Supervisory,
815            },
816            raw_fields: HashMap::new(),
817        }
818    }
819
820    fn make_nt_ndarray(name: &str) -> RecordInstance {
821        RecordInstance {
822            name: name.to_string(),
823            record_type: RecordType::NtNdArray,
824            common: DbCommonState::default(),
825            data: RecordData::NtNdArray {
826                nt: NtNdArray {
827                    value: ScalarArrayValue::U8(vec![0; 4]),
828                    codec: NdCodec {
829                        name: "none".to_string(),
830                        parameters: HashMap::new(),
831                    },
832                    compressed_size: 4,
833                    uncompressed_size: 4,
834                    dimension: vec![NdDimension {
835                        size: 2,
836                        offset: 0,
837                        full_size: 2,
838                        binning: 1,
839                        reverse: false,
840                    }],
841                    unique_id: 1,
842                    data_time_stamp: Default::default(),
843                    attribute: vec![],
844                    descriptor: Some("ndarray".to_string()),
845                    alarm: None,
846                    time_stamp: None,
847                    display: None,
848                },
849                inp: None,
850                out: None,
851                omsl: crate::types::OutputMode::Supervisory,
852            },
853            raw_fields: HashMap::new(),
854        }
855    }
856
857    #[tokio::test]
858    async fn has_pv_returns_true_for_existing() {
859        let mut records = HashMap::new();
860        records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
861        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
862        assert!(store.has_pv("TEST:AI").await);
863        assert!(!store.has_pv("MISSING").await);
864    }
865
866    #[tokio::test]
867    async fn get_snapshot_returns_payload() {
868        let mut records = HashMap::new();
869        records.insert("TEST:AI".into(), make_ai("TEST:AI", 42.0));
870        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
871        let snap = store.get_snapshot("TEST:AI").await.unwrap();
872        match snap {
873            NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(42.0)),
874            _ => panic!("expected scalar"),
875        }
876    }
877
878    #[tokio::test]
879    async fn put_value_updates_writable_record() {
880        let mut records = HashMap::new();
881        records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
882        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
883
884        let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(99.5))]);
885        let result = store.put_value("TEST:AO", &val).await.unwrap();
886        assert_eq!(result.len(), 1);
887        assert_eq!(result[0].0, "TEST:AO");
888
889        let snap = store.get_snapshot("TEST:AO").await.unwrap();
890        match snap {
891            NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(99.5)),
892            _ => panic!("expected scalar"),
893        }
894    }
895
896    #[tokio::test]
897    async fn put_value_rejects_readonly() {
898        let mut records = HashMap::new();
899        records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
900        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
901
902        let val = DecodedValue::Float64(5.0);
903        let err = store.put_value("TEST:AI", &val).await.unwrap_err();
904        assert!(err.contains("not writable"));
905    }
906
907    #[tokio::test]
908    async fn set_value_bypasses_writable_check() {
909        let mut records = HashMap::new();
910        records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
911        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
912        assert!(store.set_value("TEST:AI", ScalarValue::F64(10.0)).await);
913        let val = store.get_value("TEST:AI").await.unwrap();
914        assert_eq!(val, ScalarValue::F64(10.0));
915    }
916
917    #[tokio::test]
918    async fn set_array_value_updates_all_scalar_array_types() {
919        let cases: Vec<ScalarArrayValue> = vec![
920            ScalarArrayValue::Bool(vec![false, true]),
921            ScalarArrayValue::I8(vec![1, 2]),
922            ScalarArrayValue::I16(vec![1, 2]),
923            ScalarArrayValue::I32(vec![1, 2]),
924            ScalarArrayValue::I64(vec![1, 2]),
925            ScalarArrayValue::U8(vec![1, 2]),
926            ScalarArrayValue::U16(vec![1, 2]),
927            ScalarArrayValue::U32(vec![1, 2]),
928            ScalarArrayValue::U64(vec![1, 2]),
929            ScalarArrayValue::F32(vec![1.0, 2.0]),
930            ScalarArrayValue::F64(vec![1.0, 2.0]),
931            ScalarArrayValue::Str(vec!["a".to_string(), "b".to_string()]),
932        ];
933
934        for (idx, updated) in cases.into_iter().enumerate() {
935            let pv = format!("TEST:WF:{idx}");
936            let mut records = HashMap::new();
937            records.insert(pv.clone(), make_waveform(&pv, updated.clone()));
938            let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
939
940            assert!(!store.set_array_value(&pv, updated.clone()).await);
941
942            let second = match updated {
943                ScalarArrayValue::Bool(_) => ScalarArrayValue::Bool(vec![true, false]),
944                ScalarArrayValue::I8(_) => ScalarArrayValue::I8(vec![3, 4]),
945                ScalarArrayValue::I16(_) => ScalarArrayValue::I16(vec![3, 4]),
946                ScalarArrayValue::I32(_) => ScalarArrayValue::I32(vec![3, 4]),
947                ScalarArrayValue::I64(_) => ScalarArrayValue::I64(vec![3, 4]),
948                ScalarArrayValue::U8(_) => ScalarArrayValue::U8(vec![3, 4]),
949                ScalarArrayValue::U16(_) => ScalarArrayValue::U16(vec![3, 4]),
950                ScalarArrayValue::U32(_) => ScalarArrayValue::U32(vec![3, 4]),
951                ScalarArrayValue::U64(_) => ScalarArrayValue::U64(vec![3, 4]),
952                ScalarArrayValue::F32(_) => ScalarArrayValue::F32(vec![3.0, 4.0]),
953                ScalarArrayValue::F64(_) => ScalarArrayValue::F64(vec![3.0, 4.0]),
954                ScalarArrayValue::Str(_) => {
955                    ScalarArrayValue::Str(vec!["x".to_string(), "y".to_string()])
956                }
957            };
958
959            assert!(store.set_array_value(&pv, second.clone()).await);
960            let snap = store.get_snapshot(&pv).await.unwrap();
961            match snap {
962                NtPayload::ScalarArray(nt) => assert_eq!(nt.value, second),
963                _ => panic!("expected scalar array"),
964            }
965        }
966    }
967
968    #[tokio::test]
969    async fn get_nt_returns_full_payload() {
970        let mut records = HashMap::new();
971        records.insert("TEST:AI".into(), make_ai("TEST:AI", 12.5));
972        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
973
974        let nt = store.get_nt("TEST:AI").await.unwrap();
975        match nt {
976            NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(12.5)),
977            _ => panic!("expected scalar payload"),
978        }
979    }
980
981    #[tokio::test]
982    async fn put_nt_updates_scalar_array_table_and_ndarray() {
983        let mut records = HashMap::new();
984        records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
985        records.insert(
986            "TEST:WF".into(),
987            make_waveform("TEST:WF", ScalarArrayValue::F64(vec![0.0, 0.0])),
988        );
989        records.insert("TEST:TBL".into(), make_nt_table("TEST:TBL"));
990        records.insert("TEST:NDA".into(), make_nt_ndarray("TEST:NDA"));
991        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
992
993        assert!(
994            store
995                .put_nt(
996                    "TEST:AI",
997                    NtPayload::Scalar(NtScalar::from_value(ScalarValue::F64(5.0))),
998                )
999                .await
1000        );
1001        assert!(
1002            store
1003                .put_nt(
1004                    "TEST:WF",
1005                    NtPayload::ScalarArray(NtScalarArray::from_value(ScalarArrayValue::F64(vec![
1006                        3.0, 4.0
1007                    ],))),
1008                )
1009                .await
1010        );
1011
1012        let table = NtTable {
1013            labels: vec!["X".to_string(), "Y".to_string()],
1014            columns: vec![
1015                NtTableColumn {
1016                    name: "x".to_string(),
1017                    values: ScalarArrayValue::F64(vec![2.0, 3.0]),
1018                },
1019                NtTableColumn {
1020                    name: "y".to_string(),
1021                    values: ScalarArrayValue::F64(vec![20.0, 30.0]),
1022                },
1023            ],
1024            descriptor: Some("updated table".to_string()),
1025            alarm: None,
1026            time_stamp: None,
1027        };
1028        assert!(
1029            store
1030                .put_nt("TEST:TBL", NtPayload::Table(table.clone()))
1031                .await
1032        );
1033
1034        let ndarray = NtNdArray {
1035            value: ScalarArrayValue::U8(vec![1, 2, 3, 4]),
1036            codec: NdCodec {
1037                name: "none".to_string(),
1038                parameters: HashMap::new(),
1039            },
1040            compressed_size: 4,
1041            uncompressed_size: 4,
1042            dimension: vec![NdDimension {
1043                size: 4,
1044                offset: 0,
1045                full_size: 4,
1046                binning: 1,
1047                reverse: false,
1048            }],
1049            unique_id: 2,
1050            data_time_stamp: Default::default(),
1051            attribute: vec![],
1052            descriptor: Some("updated ndarray".to_string()),
1053            alarm: None,
1054            time_stamp: None,
1055            display: None,
1056        };
1057        assert!(
1058            store
1059                .put_nt("TEST:NDA", NtPayload::NdArray(ndarray.clone()))
1060                .await
1061        );
1062
1063        assert!(
1064            !store
1065                .put_nt(
1066                    "TEST:AI",
1067                    NtPayload::ScalarArray(NtScalarArray::from_value(ScalarArrayValue::F64(vec![
1068                        1.0
1069                    ]))),
1070                )
1071                .await
1072        );
1073
1074        match store.get_nt("TEST:TBL").await.unwrap() {
1075            NtPayload::Table(nt) => assert_eq!(nt, table),
1076            _ => panic!("expected table payload"),
1077        }
1078        match store.get_nt("TEST:NDA").await.unwrap() {
1079            NtPayload::NdArray(nt) => assert_eq!(nt, ndarray),
1080            _ => panic!("expected ndarray payload"),
1081        }
1082    }
1083
1084    #[tokio::test]
1085    async fn descriptor_matches_value_type() {
1086        let mut records = HashMap::new();
1087        records.insert("TEST:AI".into(), make_ai("TEST:AI", 0.0));
1088        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
1089        let desc = store.get_descriptor("TEST:AI").await.unwrap();
1090        assert_eq!(desc.struct_id.as_deref(), Some("epics:nt/NTScalar:1.0"));
1091        let value_field = desc.field("value").unwrap();
1092        assert!(matches!(
1093            value_field.field_type,
1094            FieldType::Scalar(TypeCode::Float64)
1095        ));
1096    }
1097
1098    #[tokio::test]
1099    async fn subscribe_receives_updates() {
1100        let mut records = HashMap::new();
1101        records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
1102        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
1103
1104        let mut rx = store.subscribe("TEST:AO").await.unwrap();
1105
1106        let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(7.7))]);
1107        store.put_value("TEST:AO", &val).await.unwrap();
1108
1109        let update = rx.recv().await.unwrap();
1110        match update {
1111            NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(7.7)),
1112            _ => panic!("expected scalar"),
1113        }
1114    }
1115
1116    #[tokio::test]
1117    async fn on_put_callback_is_invoked() {
1118        use std::sync::atomic::{AtomicBool, Ordering};
1119
1120        let called = Arc::new(AtomicBool::new(false));
1121        let called2 = called.clone();
1122
1123        let mut records = HashMap::new();
1124        records.insert("CB:AO".into(), make_ao("CB:AO", 0.0));
1125
1126        let mut on_put = HashMap::new();
1127        let cb: OnPutCallback = Arc::new(move |_name, _val| {
1128            called2.store(true, Ordering::SeqCst);
1129        });
1130        on_put.insert("CB:AO".into(), cb);
1131
1132        let store = SimplePvStore::new(records, on_put, vec![], false);
1133        let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(1.0))]);
1134        store.put_value("CB:AO", &val).await.unwrap();
1135
1136        // Give the spawned task time to run.
1137        tokio::task::yield_now().await;
1138        tokio::task::yield_now().await;
1139
1140        assert!(called.load(Ordering::SeqCst));
1141    }
1142}