Skip to main content

spvirit_server/
simple_store.rs

1//! A simple in-memory [`PvStore`] implementation backed by `RecordInstance`.
2//!
3//! Used by [`PvaServer`](crate::pva_server::PvaServer) to serve PVs without
4//! requiring an external database.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use tokio::sync::{RwLock, mpsc};
10use tracing::debug;
11
12use spvirit_codec::spvd_decode::{DecodedValue, FieldDesc, FieldType, StructureDesc, TypeCode};
13use spvirit_types::{NtPayload, ScalarArrayValue, ScalarValue};
14
15use crate::apply::{
16    apply_alarm_update, apply_control_update, apply_display_update, apply_scalar_array_put,
17    apply_value_update,
18};
19use crate::monitor::MonitorRegistry;
20use crate::pvstore::PvStore;
21use crate::types::{RecordData, RecordInstance};
22
23/// Callback invoked after a PUT value is applied to a record.
24pub type OnPutCallback = Arc<dyn Fn(&str, &DecodedValue) + Send + Sync>;
25
26/// Callback invoked by the scan scheduler; returns the new value for the PV.
27pub type ScanCallback = Arc<dyn Fn(&str) -> ScalarValue + Send + Sync>;
28
29/// Callback that computes a derived PV value from its input values.
30pub type LinkCallback = Arc<dyn Fn(&[ScalarValue]) -> ScalarValue + Send + Sync>;
31
32/// A link from one or more input PVs to a computed output PV.
33pub(crate) struct LinkDef {
34    pub output: String,
35    pub inputs: Vec<String>,
36    pub compute: LinkCallback,
37}
38
39struct PvEntry {
40    record: RecordInstance,
41    subscribers: Vec<mpsc::Sender<NtPayload>>,
42}
43
44/// A simple in-memory PV store.
45pub struct SimplePvStore {
46    pvs: RwLock<HashMap<String, PvEntry>>,
47    on_put: HashMap<String, OnPutCallback>,
48    links: Vec<LinkDef>,
49    compute_alarms: bool,
50    registry: RwLock<Option<Arc<MonitorRegistry>>>,
51}
52
53impl SimplePvStore {
54    pub(crate) fn new(
55        records: HashMap<String, RecordInstance>,
56        on_put: HashMap<String, OnPutCallback>,
57        links: Vec<LinkDef>,
58        compute_alarms: bool,
59    ) -> Self {
60        let pvs = records
61            .into_iter()
62            .map(|(name, record)| {
63                (
64                    name,
65                    PvEntry {
66                        record,
67                        subscribers: Vec::new(),
68                    },
69                )
70            })
71            .collect();
72        Self {
73            pvs: RwLock::new(pvs),
74            on_put,
75            links,
76            compute_alarms,
77            registry: RwLock::new(None),
78        }
79    }
80
81    /// Attach the [`MonitorRegistry`] so that `set_value` can push updates
82    /// to PVAccess monitor clients.  Called automatically by [`PvaServer::run`].
83    pub async fn set_registry(&self, registry: Arc<MonitorRegistry>) {
84        *self.registry.write().await = Some(registry);
85    }
86
87    /// Insert or replace a PV record at runtime.
88    pub async fn insert(&self, name: String, record: RecordInstance) {
89        let mut pvs = self.pvs.write().await;
90        pvs.insert(
91            name,
92            PvEntry {
93                record,
94                subscribers: Vec::new(),
95            },
96        );
97    }
98
99    /// Read the current [`ScalarValue`] of a PV.
100    pub async fn get_value(&self, name: &str) -> Option<ScalarValue> {
101        let pvs = self.pvs.read().await;
102        pvs.get(name).map(|e| e.record.current_value())
103    }
104
105    /// Read the full [`NtPayload`] of a PV.
106    pub async fn get_nt(&self, name: &str) -> Option<NtPayload> {
107        let pvs = self.pvs.read().await;
108        pvs.get(name).map(|e| e.record.to_ntpayload())
109    }
110
111    /// Write a [`ScalarValue`] to a PV (bypasses on_put).
112    pub async fn set_value(&self, name: &str, value: ScalarValue) -> bool {
113        if self.set_value_inner(name, value).await {
114            self.evaluate_links(name).await;
115            true
116        } else {
117            false
118        }
119    }
120
121    /// Write a [`ScalarArrayValue`] to an array PV (bypasses on_put).
122    pub async fn set_array_value(&self, name: &str, value: ScalarArrayValue) -> bool {
123        if self.set_array_value_inner(name, value).await {
124            self.evaluate_links(name).await;
125            true
126        } else {
127            false
128        }
129    }
130
131    /// Write a full [`NtPayload`] to a PV (bypasses on_put).
132    pub async fn put_nt(&self, name: &str, payload: NtPayload) -> bool {
133        if self.put_nt_inner(name, payload).await {
134            self.evaluate_links(name).await;
135            true
136        } else {
137            false
138        }
139    }
140
141    /// Core write logic — updates the value, notifies subscribers and monitors,
142    /// but does **not** trigger link evaluation (to avoid recursion).
143    async fn set_value_inner(&self, name: &str, value: ScalarValue) -> bool {
144        let payload = {
145            let mut pvs = self.pvs.write().await;
146            if let Some(entry) = pvs.get_mut(name) {
147                let changed = entry.record.set_scalar_value(value, self.compute_alarms);
148                if changed {
149                    let payload = entry.record.to_ntpayload();
150                    entry
151                        .subscribers
152                        .retain(|tx| tx.try_send(payload.clone()).is_ok());
153                    Some(payload)
154                } else {
155                    None
156                }
157            } else {
158                return false;
159            }
160        };
161
162        if let Some(payload) = payload {
163            // Notify PVAccess monitor clients (if the registry is attached).
164            let reg = self.registry.read().await;
165            if let Some(registry) = reg.as_ref() {
166                registry.notify_monitors(name, &payload).await;
167            }
168            true
169        } else {
170            false
171        }
172    }
173
174    /// Core array write logic — updates the value, notifies subscribers and monitors,
175    /// but does **not** trigger link evaluation (to avoid recursion).
176    async fn set_array_value_inner(&self, name: &str, value: ScalarArrayValue) -> bool {
177        let payload = {
178            let mut pvs = self.pvs.write().await;
179            if let Some(entry) = pvs.get_mut(name) {
180                let changed = entry.record.set_array_value(value);
181                if changed {
182                    let payload = entry.record.to_ntpayload();
183                    entry
184                        .subscribers
185                        .retain(|tx| tx.try_send(payload.clone()).is_ok());
186                    Some(payload)
187                } else {
188                    None
189                }
190            } else {
191                return false;
192            }
193        };
194
195        if let Some(payload) = payload {
196            // Notify PVAccess monitor clients (if the registry is attached).
197            let reg = self.registry.read().await;
198            if let Some(registry) = reg.as_ref() {
199                registry.notify_monitors(name, &payload).await;
200            }
201            true
202        } else {
203            false
204        }
205    }
206
207    /// Core NtPayload write logic — updates the payload, notifies subscribers
208    /// and monitors, but does **not** trigger link evaluation.
209    async fn put_nt_inner(&self, name: &str, payload: NtPayload) -> bool {
210        let payload = {
211            let mut pvs = self.pvs.write().await;
212            if let Some(entry) = pvs.get_mut(name) {
213                let changed = entry.record.set_nt_payload(payload);
214                if changed {
215                    let payload = entry.record.to_ntpayload();
216                    entry
217                        .subscribers
218                        .retain(|tx| tx.try_send(payload.clone()).is_ok());
219                    Some(payload)
220                } else {
221                    None
222                }
223            } else {
224                return false;
225            }
226        };
227
228        if let Some(payload) = payload {
229            // Notify PVAccess monitor clients (if the registry is attached).
230            let reg = self.registry.read().await;
231            if let Some(registry) = reg.as_ref() {
232                registry.notify_monitors(name, &payload).await;
233            }
234            true
235        } else {
236            false
237        }
238    }
239
240    /// Walk every link whose inputs include `changed_pv`, compute the output,
241    /// and propagate (BFS with cycle detection).
242    async fn evaluate_links(&self, changed_pv: &str) {
243        if self.links.is_empty() {
244            return;
245        }
246        let mut queue = vec![changed_pv.to_string()];
247        let mut visited = HashSet::new();
248
249        while let Some(pv) = queue.pop() {
250            if !visited.insert(pv.clone()) {
251                debug!("Circular link detected for PV '{}', skipping", pv);
252                continue;
253            }
254            for link in &self.links {
255                if !link.inputs.iter().any(|i| i == &pv) {
256                    continue;
257                }
258                // Gather current values of all inputs.
259                let values = {
260                    let pvs = self.pvs.read().await;
261                    link.inputs
262                        .iter()
263                        .map(|n| {
264                            pvs.get(n)
265                                .map(|e| e.record.current_value())
266                                .unwrap_or(ScalarValue::F64(0.0))
267                        })
268                        .collect::<Vec<_>>()
269                };
270                let new_val = (link.compute)(&values);
271                if self.set_value_inner(&link.output, new_val).await {
272                    queue.push(link.output.clone());
273                }
274            }
275        }
276    }
277
278    /// List all PV names.
279    pub async fn pv_names(&self) -> Vec<String> {
280        let pvs = self.pvs.read().await;
281        pvs.keys().cloned().collect()
282    }
283}
284
285impl PvStore for SimplePvStore {
286    fn has_pv(&self, name: &str) -> impl Future<Output = bool> + Send {
287        async move {
288            let pvs = self.pvs.read().await;
289            pvs.contains_key(name)
290        }
291    }
292
293    fn get_snapshot(&self, name: &str) -> impl Future<Output = Option<NtPayload>> + Send {
294        async move {
295            let pvs = self.pvs.read().await;
296            pvs.get(name).map(|e| e.record.to_ntpayload())
297        }
298    }
299
300    fn get_descriptor(&self, name: &str) -> impl Future<Output = Option<StructureDesc>> + Send {
301        async move {
302            let pvs = self.pvs.read().await;
303            pvs.get(name)
304                .map(|e| descriptor_for_payload(&e.record.to_ntpayload()))
305        }
306    }
307
308    fn put_value(
309        &self,
310        name: &str,
311        value: &DecodedValue,
312    ) -> impl Future<Output = Result<Vec<(String, NtPayload)>, String>> + Send {
313        let name = name.to_string();
314        let value = value.clone();
315        async move {
316            let result = {
317                let mut pvs = self.pvs.write().await;
318                let entry = pvs
319                    .get_mut(&name)
320                    .ok_or_else(|| format!("PV '{}' not found", name))?;
321
322                if !entry.record.writable() {
323                    return Err(format!("PV '{}' is not writable", name));
324                }
325
326                let changed = apply_put_to_record(&mut entry.record, &value, self.compute_alarms);
327                if !changed {
328                    return Ok(vec![]);
329                }
330
331                let payload = entry.record.to_ntpayload();
332                entry
333                    .subscribers
334                    .retain(|tx| tx.try_send(payload.clone()).is_ok());
335
336                (name.clone(), payload)
337            }; // pvs lock dropped
338
339            // Fire on_put callback (non-blocking).
340            if let Some(cb) = self.on_put.get(&name) {
341                let cb = cb.clone();
342                let n = name.clone();
343                let v = value.clone();
344                tokio::spawn(async move { cb(&n, &v) });
345            }
346
347            // Propagate linked PV updates.
348            self.evaluate_links(&name).await;
349
350            Ok(vec![result])
351        }
352    }
353
354    fn is_writable(&self, name: &str) -> impl Future<Output = bool> + Send {
355        async move {
356            let pvs = self.pvs.read().await;
357            pvs.get(name).is_some_and(|e| e.record.writable())
358        }
359    }
360
361    fn list_pvs(&self) -> impl Future<Output = Vec<String>> + Send {
362        async move {
363            let pvs = self.pvs.read().await;
364            pvs.keys().cloned().collect()
365        }
366    }
367
368    fn subscribe(
369        &self,
370        name: &str,
371    ) -> impl Future<Output = Option<mpsc::Receiver<NtPayload>>> + Send {
372        let name = name.to_string();
373        async move {
374            let mut pvs = self.pvs.write().await;
375            let entry = pvs.get_mut(&name)?;
376            let (tx, rx) = mpsc::channel(64);
377            entry.subscribers.push(tx);
378            Some(rx)
379        }
380    }
381}
382
383// ── Helpers ──────────────────────────────────────────────────────────────
384
385/// Apply a decoded PUT value to a RecordInstance, returning whether it changed.
386fn apply_put_to_record(
387    record: &mut RecordInstance,
388    value: &DecodedValue,
389    compute_alarms: bool,
390) -> bool {
391    let fields = match value {
392        DecodedValue::Structure(f) => f,
393        other => {
394            // Bare scalar — wrap as value field.
395            return apply_put_to_record(
396                record,
397                &DecodedValue::Structure(vec![("value".to_string(), other.clone())]),
398                compute_alarms,
399            );
400        }
401    };
402
403    let mut changed = false;
404
405    match &mut record.data {
406        RecordData::Ai { nt, .. }
407        | RecordData::Ao { nt, .. }
408        | RecordData::Bi { nt, .. }
409        | RecordData::Bo { nt, .. }
410        | RecordData::StringIn { nt, .. }
411        | RecordData::StringOut { nt, .. } => {
412            for (name, val) in fields {
413                match name.as_str() {
414                    "value" => {
415                        changed |= apply_value_update(nt, val, compute_alarms);
416                    }
417                    "alarm" => {
418                        changed |= apply_alarm_update(nt, val);
419                    }
420                    "display" => {
421                        changed |= apply_display_update(nt, val);
422                    }
423                    "control" => {
424                        changed |= apply_control_update(nt, val);
425                    }
426                    _ => {}
427                }
428            }
429        }
430        RecordData::Waveform { nt, nord, .. }
431        | RecordData::Aai { nt, nord, .. }
432        | RecordData::Aao { nt, nord, .. }
433        | RecordData::SubArray { nt, nord, .. } => {
434            changed = apply_scalar_array_put(nt, nord, value);
435        }
436        RecordData::NtTable { .. } | RecordData::NtNdArray { .. } => {
437            // Table/NdArray PUT not supported via high-level API yet.
438            debug!("PUT to NtTable/NtNdArray not yet supported in SimplePvStore");
439        }
440    }
441
442    changed
443}
444
445// ── NtPayload → StructureDesc ────────────────────────────────────────────
446
447pub(crate) fn descriptor_for_payload(payload: &NtPayload) -> StructureDesc {
448    match payload {
449        NtPayload::Scalar(nt) => nt_scalar_desc(&nt.value),
450        NtPayload::ScalarArray(arr) => nt_scalar_array_desc(&arr.value),
451        _ => StructureDesc::new(),
452    }
453}
454
455fn value_type_code(sv: &ScalarValue) -> TypeCode {
456    match sv {
457        ScalarValue::Bool(_) => TypeCode::Boolean,
458        ScalarValue::I8(_) => TypeCode::Int8,
459        ScalarValue::I16(_) => TypeCode::Int16,
460        ScalarValue::I32(_) => TypeCode::Int32,
461        ScalarValue::I64(_) => TypeCode::Int64,
462        ScalarValue::U8(_) => TypeCode::UInt8,
463        ScalarValue::U16(_) => TypeCode::UInt16,
464        ScalarValue::U32(_) => TypeCode::UInt32,
465        ScalarValue::U64(_) => TypeCode::UInt64,
466        ScalarValue::F32(_) => TypeCode::Float32,
467        ScalarValue::F64(_) => TypeCode::Float64,
468        ScalarValue::Str(_) => TypeCode::String,
469    }
470}
471
472fn array_type_code(sav: &ScalarArrayValue) -> TypeCode {
473    match sav {
474        ScalarArrayValue::Bool(_) => TypeCode::Boolean,
475        ScalarArrayValue::I8(_) => TypeCode::Int8,
476        ScalarArrayValue::I16(_) => TypeCode::Int16,
477        ScalarArrayValue::I32(_) => TypeCode::Int32,
478        ScalarArrayValue::I64(_) => TypeCode::Int64,
479        ScalarArrayValue::U8(_) => TypeCode::UInt8,
480        ScalarArrayValue::U16(_) => TypeCode::UInt16,
481        ScalarArrayValue::U32(_) => TypeCode::UInt32,
482        ScalarArrayValue::U64(_) => TypeCode::UInt64,
483        ScalarArrayValue::F32(_) => TypeCode::Float32,
484        ScalarArrayValue::F64(_) => TypeCode::Float64,
485        ScalarArrayValue::Str(_) => TypeCode::String,
486    }
487}
488
489fn nt_scalar_desc(sv: &ScalarValue) -> StructureDesc {
490    let tc = value_type_code(sv);
491    StructureDesc {
492        struct_id: Some("epics:nt/NTScalar:1.0".to_string()),
493        fields: vec![
494            FieldDesc {
495                name: "value".to_string(),
496                field_type: FieldType::Scalar(tc),
497            },
498            FieldDesc {
499                name: "alarm".to_string(),
500                field_type: FieldType::Structure(alarm_desc()),
501            },
502            FieldDesc {
503                name: "timeStamp".to_string(),
504                field_type: FieldType::Structure(timestamp_desc()),
505            },
506            FieldDesc {
507                name: "display".to_string(),
508                field_type: FieldType::Structure(display_desc()),
509            },
510            FieldDesc {
511                name: "control".to_string(),
512                field_type: FieldType::Structure(control_desc()),
513            },
514            FieldDesc {
515                name: "valueAlarm".to_string(),
516                field_type: FieldType::Structure(value_alarm_desc()),
517            },
518        ],
519    }
520}
521
522fn nt_scalar_array_desc(sav: &ScalarArrayValue) -> StructureDesc {
523    let tc = array_type_code(sav);
524    StructureDesc {
525        struct_id: Some("epics:nt/NTScalarArray:1.0".to_string()),
526        fields: vec![
527            FieldDesc {
528                name: "value".to_string(),
529                field_type: FieldType::ScalarArray(tc),
530            },
531            FieldDesc {
532                name: "alarm".to_string(),
533                field_type: FieldType::Structure(alarm_desc()),
534            },
535            FieldDesc {
536                name: "timeStamp".to_string(),
537                field_type: FieldType::Structure(timestamp_desc()),
538            },
539            FieldDesc {
540                name: "display".to_string(),
541                field_type: FieldType::Structure(display_desc()),
542            },
543            FieldDesc {
544                name: "control".to_string(),
545                field_type: FieldType::Structure(control_desc()),
546            },
547        ],
548    }
549}
550
551fn alarm_desc() -> StructureDesc {
552    StructureDesc {
553        struct_id: Some("alarm_t".to_string()),
554        fields: vec![
555            FieldDesc {
556                name: "severity".to_string(),
557                field_type: FieldType::Scalar(TypeCode::Int32),
558            },
559            FieldDesc {
560                name: "status".to_string(),
561                field_type: FieldType::Scalar(TypeCode::Int32),
562            },
563            FieldDesc {
564                name: "message".to_string(),
565                field_type: FieldType::String,
566            },
567        ],
568    }
569}
570
571fn timestamp_desc() -> StructureDesc {
572    StructureDesc {
573        struct_id: Some("time_t".to_string()),
574        fields: vec![
575            FieldDesc {
576                name: "secondsPastEpoch".to_string(),
577                field_type: FieldType::Scalar(TypeCode::Int64),
578            },
579            FieldDesc {
580                name: "nanoseconds".to_string(),
581                field_type: FieldType::Scalar(TypeCode::Int32),
582            },
583            FieldDesc {
584                name: "userTag".to_string(),
585                field_type: FieldType::Scalar(TypeCode::Int32),
586            },
587        ],
588    }
589}
590
591fn display_desc() -> StructureDesc {
592    StructureDesc {
593        struct_id: Some("display_t".to_string()),
594        fields: vec![
595            FieldDesc {
596                name: "limitLow".to_string(),
597                field_type: FieldType::Scalar(TypeCode::Float64),
598            },
599            FieldDesc {
600                name: "limitHigh".to_string(),
601                field_type: FieldType::Scalar(TypeCode::Float64),
602            },
603            FieldDesc {
604                name: "description".to_string(),
605                field_type: FieldType::String,
606            },
607            FieldDesc {
608                name: "units".to_string(),
609                field_type: FieldType::String,
610            },
611            FieldDesc {
612                name: "precision".to_string(),
613                field_type: FieldType::Scalar(TypeCode::Int32),
614            },
615            FieldDesc {
616                name: "form".to_string(),
617                field_type: FieldType::Structure(StructureDesc {
618                    struct_id: Some("enum_t".to_string()),
619                    fields: vec![
620                        FieldDesc {
621                            name: "index".to_string(),
622                            field_type: FieldType::Scalar(TypeCode::Int32),
623                        },
624                        FieldDesc {
625                            name: "choices".to_string(),
626                            field_type: FieldType::StringArray,
627                        },
628                    ],
629                }),
630            },
631        ],
632    }
633}
634
635fn control_desc() -> StructureDesc {
636    StructureDesc {
637        struct_id: Some("control_t".to_string()),
638        fields: vec![
639            FieldDesc {
640                name: "limitLow".to_string(),
641                field_type: FieldType::Scalar(TypeCode::Float64),
642            },
643            FieldDesc {
644                name: "limitHigh".to_string(),
645                field_type: FieldType::Scalar(TypeCode::Float64),
646            },
647            FieldDesc {
648                name: "minStep".to_string(),
649                field_type: FieldType::Scalar(TypeCode::Float64),
650            },
651        ],
652    }
653}
654
655fn value_alarm_desc() -> StructureDesc {
656    StructureDesc {
657        struct_id: Some("valueAlarm_t".to_string()),
658        fields: vec![
659            FieldDesc {
660                name: "active".to_string(),
661                field_type: FieldType::Scalar(TypeCode::Boolean),
662            },
663            FieldDesc {
664                name: "lowAlarmLimit".to_string(),
665                field_type: FieldType::Scalar(TypeCode::Float64),
666            },
667            FieldDesc {
668                name: "lowWarningLimit".to_string(),
669                field_type: FieldType::Scalar(TypeCode::Float64),
670            },
671            FieldDesc {
672                name: "highWarningLimit".to_string(),
673                field_type: FieldType::Scalar(TypeCode::Float64),
674            },
675            FieldDesc {
676                name: "highAlarmLimit".to_string(),
677                field_type: FieldType::Scalar(TypeCode::Float64),
678            },
679            FieldDesc {
680                name: "lowAlarmSeverity".to_string(),
681                field_type: FieldType::Scalar(TypeCode::Int32),
682            },
683            FieldDesc {
684                name: "lowWarningSeverity".to_string(),
685                field_type: FieldType::Scalar(TypeCode::Int32),
686            },
687            FieldDesc {
688                name: "highWarningSeverity".to_string(),
689                field_type: FieldType::Scalar(TypeCode::Int32),
690            },
691            FieldDesc {
692                name: "highAlarmSeverity".to_string(),
693                field_type: FieldType::Scalar(TypeCode::Int32),
694            },
695            FieldDesc {
696                name: "hysteresis".to_string(),
697                field_type: FieldType::Scalar(TypeCode::UInt8),
698            },
699        ],
700    }
701}
702
703#[cfg(test)]
704mod tests {
705    use super::*;
706    use crate::types::{DbCommonState, RecordType};
707    use spvirit_types::{
708        NdCodec, NdDimension, NtNdArray, NtPayload, NtScalar, NtScalarArray, NtTable,
709        NtTableColumn, ScalarArrayValue, ScalarValue,
710    };
711
712    fn make_ai(name: &str, val: f64) -> RecordInstance {
713        RecordInstance {
714            name: name.to_string(),
715            record_type: RecordType::Ai,
716            common: DbCommonState::default(),
717            data: RecordData::Ai {
718                nt: NtScalar::from_value(ScalarValue::F64(val)),
719                inp: None,
720                siml: None,
721                siol: None,
722                simm: false,
723            },
724            raw_fields: HashMap::new(),
725        }
726    }
727
728    fn make_ao(name: &str, val: f64) -> RecordInstance {
729        RecordInstance {
730            name: name.to_string(),
731            record_type: RecordType::Ao,
732            common: DbCommonState::default(),
733            data: RecordData::Ao {
734                nt: NtScalar::from_value(ScalarValue::F64(val)),
735                out: None,
736                dol: None,
737                omsl: crate::types::OutputMode::Supervisory,
738                drvl: None,
739                drvh: None,
740                oroc: None,
741                siml: None,
742                siol: None,
743                simm: false,
744            },
745            raw_fields: HashMap::new(),
746        }
747    }
748
749    fn make_waveform(name: &str, value: ScalarArrayValue) -> RecordInstance {
750        let nelm = value.len();
751        RecordInstance {
752            name: name.to_string(),
753            record_type: RecordType::Waveform,
754            common: DbCommonState::default(),
755            data: RecordData::Waveform {
756                nt: NtScalarArray::from_value(value),
757                inp: None,
758                ftvl: "DOUBLE".to_string(),
759                nelm,
760                nord: nelm,
761            },
762            raw_fields: HashMap::new(),
763        }
764    }
765
766    fn make_nt_table(name: &str) -> RecordInstance {
767        RecordInstance {
768            name: name.to_string(),
769            record_type: RecordType::NtTable,
770            common: DbCommonState::default(),
771            data: RecordData::NtTable {
772                nt: NtTable {
773                    labels: vec!["X".to_string(), "Y".to_string()],
774                    columns: vec![
775                        NtTableColumn {
776                            name: "x".to_string(),
777                            values: ScalarArrayValue::F64(vec![1.0, 2.0]),
778                        },
779                        NtTableColumn {
780                            name: "y".to_string(),
781                            values: ScalarArrayValue::F64(vec![10.0, 20.0]),
782                        },
783                    ],
784                    descriptor: Some("table".to_string()),
785                    alarm: None,
786                    time_stamp: None,
787                },
788                inp: None,
789                out: None,
790                omsl: crate::types::OutputMode::Supervisory,
791            },
792            raw_fields: HashMap::new(),
793        }
794    }
795
796    fn make_nt_ndarray(name: &str) -> RecordInstance {
797        RecordInstance {
798            name: name.to_string(),
799            record_type: RecordType::NtNdArray,
800            common: DbCommonState::default(),
801            data: RecordData::NtNdArray {
802                nt: NtNdArray {
803                    value: ScalarArrayValue::U8(vec![0; 4]),
804                    codec: NdCodec {
805                        name: "none".to_string(),
806                        parameters: HashMap::new(),
807                    },
808                    compressed_size: 4,
809                    uncompressed_size: 4,
810                    dimension: vec![NdDimension {
811                        size: 2,
812                        offset: 0,
813                        full_size: 2,
814                        binning: 1,
815                        reverse: false,
816                    }],
817                    unique_id: 1,
818                    data_time_stamp: Default::default(),
819                    attribute: vec![],
820                    descriptor: Some("ndarray".to_string()),
821                    alarm: None,
822                    time_stamp: None,
823                    display: None,
824                },
825                inp: None,
826                out: None,
827                omsl: crate::types::OutputMode::Supervisory,
828            },
829            raw_fields: HashMap::new(),
830        }
831    }
832
833    #[tokio::test]
834    async fn has_pv_returns_true_for_existing() {
835        let mut records = HashMap::new();
836        records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
837        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
838        assert!(store.has_pv("TEST:AI").await);
839        assert!(!store.has_pv("MISSING").await);
840    }
841
842    #[tokio::test]
843    async fn get_snapshot_returns_payload() {
844        let mut records = HashMap::new();
845        records.insert("TEST:AI".into(), make_ai("TEST:AI", 42.0));
846        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
847        let snap = store.get_snapshot("TEST:AI").await.unwrap();
848        match snap {
849            NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(42.0)),
850            _ => panic!("expected scalar"),
851        }
852    }
853
854    #[tokio::test]
855    async fn put_value_updates_writable_record() {
856        let mut records = HashMap::new();
857        records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
858        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
859
860        let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(99.5))]);
861        let result = store.put_value("TEST:AO", &val).await.unwrap();
862        assert_eq!(result.len(), 1);
863        assert_eq!(result[0].0, "TEST:AO");
864
865        let snap = store.get_snapshot("TEST:AO").await.unwrap();
866        match snap {
867            NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(99.5)),
868            _ => panic!("expected scalar"),
869        }
870    }
871
872    #[tokio::test]
873    async fn put_value_rejects_readonly() {
874        let mut records = HashMap::new();
875        records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
876        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
877
878        let val = DecodedValue::Float64(5.0);
879        let err = store.put_value("TEST:AI", &val).await.unwrap_err();
880        assert!(err.contains("not writable"));
881    }
882
883    #[tokio::test]
884    async fn set_value_bypasses_writable_check() {
885        let mut records = HashMap::new();
886        records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
887        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
888        assert!(store.set_value("TEST:AI", ScalarValue::F64(10.0)).await);
889        let val = store.get_value("TEST:AI").await.unwrap();
890        assert_eq!(val, ScalarValue::F64(10.0));
891    }
892
893    #[tokio::test]
894    async fn set_array_value_updates_all_scalar_array_types() {
895        let cases: Vec<ScalarArrayValue> = vec![
896            ScalarArrayValue::Bool(vec![false, true]),
897            ScalarArrayValue::I8(vec![1, 2]),
898            ScalarArrayValue::I16(vec![1, 2]),
899            ScalarArrayValue::I32(vec![1, 2]),
900            ScalarArrayValue::I64(vec![1, 2]),
901            ScalarArrayValue::U8(vec![1, 2]),
902            ScalarArrayValue::U16(vec![1, 2]),
903            ScalarArrayValue::U32(vec![1, 2]),
904            ScalarArrayValue::U64(vec![1, 2]),
905            ScalarArrayValue::F32(vec![1.0, 2.0]),
906            ScalarArrayValue::F64(vec![1.0, 2.0]),
907            ScalarArrayValue::Str(vec!["a".to_string(), "b".to_string()]),
908        ];
909
910        for (idx, updated) in cases.into_iter().enumerate() {
911            let pv = format!("TEST:WF:{idx}");
912            let mut records = HashMap::new();
913            records.insert(pv.clone(), make_waveform(&pv, updated.clone()));
914            let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
915
916            assert!(!store.set_array_value(&pv, updated.clone()).await);
917
918            let second = match updated {
919                ScalarArrayValue::Bool(_) => ScalarArrayValue::Bool(vec![true, false]),
920                ScalarArrayValue::I8(_) => ScalarArrayValue::I8(vec![3, 4]),
921                ScalarArrayValue::I16(_) => ScalarArrayValue::I16(vec![3, 4]),
922                ScalarArrayValue::I32(_) => ScalarArrayValue::I32(vec![3, 4]),
923                ScalarArrayValue::I64(_) => ScalarArrayValue::I64(vec![3, 4]),
924                ScalarArrayValue::U8(_) => ScalarArrayValue::U8(vec![3, 4]),
925                ScalarArrayValue::U16(_) => ScalarArrayValue::U16(vec![3, 4]),
926                ScalarArrayValue::U32(_) => ScalarArrayValue::U32(vec![3, 4]),
927                ScalarArrayValue::U64(_) => ScalarArrayValue::U64(vec![3, 4]),
928                ScalarArrayValue::F32(_) => ScalarArrayValue::F32(vec![3.0, 4.0]),
929                ScalarArrayValue::F64(_) => ScalarArrayValue::F64(vec![3.0, 4.0]),
930                ScalarArrayValue::Str(_) => {
931                    ScalarArrayValue::Str(vec!["x".to_string(), "y".to_string()])
932                }
933            };
934
935            assert!(store.set_array_value(&pv, second.clone()).await);
936            let snap = store.get_snapshot(&pv).await.unwrap();
937            match snap {
938                NtPayload::ScalarArray(nt) => assert_eq!(nt.value, second),
939                _ => panic!("expected scalar array"),
940            }
941        }
942    }
943
944    #[tokio::test]
945    async fn get_nt_returns_full_payload() {
946        let mut records = HashMap::new();
947        records.insert("TEST:AI".into(), make_ai("TEST:AI", 12.5));
948        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
949
950        let nt = store.get_nt("TEST:AI").await.unwrap();
951        match nt {
952            NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(12.5)),
953            _ => panic!("expected scalar payload"),
954        }
955    }
956
957    #[tokio::test]
958    async fn put_nt_updates_scalar_array_table_and_ndarray() {
959        let mut records = HashMap::new();
960        records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
961        records.insert(
962            "TEST:WF".into(),
963            make_waveform("TEST:WF", ScalarArrayValue::F64(vec![0.0, 0.0])),
964        );
965        records.insert("TEST:TBL".into(), make_nt_table("TEST:TBL"));
966        records.insert("TEST:NDA".into(), make_nt_ndarray("TEST:NDA"));
967        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
968
969        assert!(
970            store
971                .put_nt(
972                    "TEST:AI",
973                    NtPayload::Scalar(NtScalar::from_value(ScalarValue::F64(5.0))),
974                )
975                .await
976        );
977        assert!(
978            store
979                .put_nt(
980                    "TEST:WF",
981                    NtPayload::ScalarArray(NtScalarArray::from_value(ScalarArrayValue::F64(vec![
982                        3.0, 4.0
983                    ],))),
984                )
985                .await
986        );
987
988        let table = NtTable {
989            labels: vec!["X".to_string(), "Y".to_string()],
990            columns: vec![
991                NtTableColumn {
992                    name: "x".to_string(),
993                    values: ScalarArrayValue::F64(vec![2.0, 3.0]),
994                },
995                NtTableColumn {
996                    name: "y".to_string(),
997                    values: ScalarArrayValue::F64(vec![20.0, 30.0]),
998                },
999            ],
1000            descriptor: Some("updated table".to_string()),
1001            alarm: None,
1002            time_stamp: None,
1003        };
1004        assert!(
1005            store
1006                .put_nt("TEST:TBL", NtPayload::Table(table.clone()))
1007                .await
1008        );
1009
1010        let ndarray = NtNdArray {
1011            value: ScalarArrayValue::U8(vec![1, 2, 3, 4]),
1012            codec: NdCodec {
1013                name: "none".to_string(),
1014                parameters: HashMap::new(),
1015            },
1016            compressed_size: 4,
1017            uncompressed_size: 4,
1018            dimension: vec![NdDimension {
1019                size: 4,
1020                offset: 0,
1021                full_size: 4,
1022                binning: 1,
1023                reverse: false,
1024            }],
1025            unique_id: 2,
1026            data_time_stamp: Default::default(),
1027            attribute: vec![],
1028            descriptor: Some("updated ndarray".to_string()),
1029            alarm: None,
1030            time_stamp: None,
1031            display: None,
1032        };
1033        assert!(
1034            store
1035                .put_nt("TEST:NDA", NtPayload::NdArray(ndarray.clone()))
1036                .await
1037        );
1038
1039        assert!(
1040            !store
1041                .put_nt(
1042                    "TEST:AI",
1043                    NtPayload::ScalarArray(NtScalarArray::from_value(ScalarArrayValue::F64(vec![
1044                        1.0
1045                    ]))),
1046                )
1047                .await
1048        );
1049
1050        match store.get_nt("TEST:TBL").await.unwrap() {
1051            NtPayload::Table(nt) => assert_eq!(nt, table),
1052            _ => panic!("expected table payload"),
1053        }
1054        match store.get_nt("TEST:NDA").await.unwrap() {
1055            NtPayload::NdArray(nt) => assert_eq!(nt, ndarray),
1056            _ => panic!("expected ndarray payload"),
1057        }
1058    }
1059
1060    #[tokio::test]
1061    async fn descriptor_matches_value_type() {
1062        let mut records = HashMap::new();
1063        records.insert("TEST:AI".into(), make_ai("TEST:AI", 0.0));
1064        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
1065        let desc = store.get_descriptor("TEST:AI").await.unwrap();
1066        assert_eq!(desc.struct_id.as_deref(), Some("epics:nt/NTScalar:1.0"));
1067        let value_field = desc.field("value").unwrap();
1068        assert!(matches!(
1069            value_field.field_type,
1070            FieldType::Scalar(TypeCode::Float64)
1071        ));
1072    }
1073
1074    #[tokio::test]
1075    async fn subscribe_receives_updates() {
1076        let mut records = HashMap::new();
1077        records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
1078        let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
1079
1080        let mut rx = store.subscribe("TEST:AO").await.unwrap();
1081
1082        let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(7.7))]);
1083        store.put_value("TEST:AO", &val).await.unwrap();
1084
1085        let update = rx.recv().await.unwrap();
1086        match update {
1087            NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(7.7)),
1088            _ => panic!("expected scalar"),
1089        }
1090    }
1091
1092    #[tokio::test]
1093    async fn on_put_callback_is_invoked() {
1094        use std::sync::atomic::{AtomicBool, Ordering};
1095
1096        let called = Arc::new(AtomicBool::new(false));
1097        let called2 = called.clone();
1098
1099        let mut records = HashMap::new();
1100        records.insert("CB:AO".into(), make_ao("CB:AO", 0.0));
1101
1102        let mut on_put = HashMap::new();
1103        let cb: OnPutCallback = Arc::new(move |_name, _val| {
1104            called2.store(true, Ordering::SeqCst);
1105        });
1106        on_put.insert("CB:AO".into(), cb);
1107
1108        let store = SimplePvStore::new(records, on_put, vec![], false);
1109        let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(1.0))]);
1110        store.put_value("CB:AO", &val).await.unwrap();
1111
1112        // Give the spawned task time to run.
1113        tokio::task::yield_now().await;
1114        tokio::task::yield_now().await;
1115
1116        assert!(called.load(Ordering::SeqCst));
1117    }
1118}