1use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use tokio::sync::{RwLock, mpsc};
10use tracing::debug;
11
12use spvirit_codec::spvd_decode::{DecodedValue, FieldDesc, FieldType, StructureDesc, TypeCode};
13use spvirit_types::{NtPayload, ScalarArrayValue, ScalarValue};
14
15use crate::apply::{
16 apply_alarm_update, apply_control_update, apply_display_update, apply_scalar_array_put,
17 apply_value_update,
18};
19use crate::monitor::MonitorRegistry;
20use crate::pvstore::PvStore;
21use crate::types::{RecordData, RecordInstance};
22
23pub type OnPutCallback = Arc<dyn Fn(&str, &DecodedValue) + Send + Sync>;
25
26pub type ScanCallback = Arc<dyn Fn(&str) -> ScalarValue + Send + Sync>;
28
29pub type LinkCallback = Arc<dyn Fn(&[ScalarValue]) -> ScalarValue + Send + Sync>;
31
32pub(crate) struct LinkDef {
34 pub output: String,
35 pub inputs: Vec<String>,
36 pub compute: LinkCallback,
37}
38
39struct PvEntry {
40 record: RecordInstance,
41 subscribers: Vec<mpsc::Sender<NtPayload>>,
42}
43
44pub struct SimplePvStore {
46 pvs: RwLock<HashMap<String, PvEntry>>,
47 on_put: HashMap<String, OnPutCallback>,
48 links: Vec<LinkDef>,
49 compute_alarms: bool,
50 registry: RwLock<Option<Arc<MonitorRegistry>>>,
51}
52
53impl SimplePvStore {
54 pub(crate) fn new(
55 records: HashMap<String, RecordInstance>,
56 on_put: HashMap<String, OnPutCallback>,
57 links: Vec<LinkDef>,
58 compute_alarms: bool,
59 ) -> Self {
60 let pvs = records
61 .into_iter()
62 .map(|(name, record)| {
63 (
64 name,
65 PvEntry {
66 record,
67 subscribers: Vec::new(),
68 },
69 )
70 })
71 .collect();
72 Self {
73 pvs: RwLock::new(pvs),
74 on_put,
75 links,
76 compute_alarms,
77 registry: RwLock::new(None),
78 }
79 }
80
81 pub async fn set_registry(&self, registry: Arc<MonitorRegistry>) {
84 *self.registry.write().await = Some(registry);
85 }
86
87 pub async fn insert(&self, name: String, record: RecordInstance) {
89 let mut pvs = self.pvs.write().await;
90 pvs.insert(
91 name,
92 PvEntry {
93 record,
94 subscribers: Vec::new(),
95 },
96 );
97 }
98
99 pub async fn get_value(&self, name: &str) -> Option<ScalarValue> {
101 let pvs = self.pvs.read().await;
102 pvs.get(name).map(|e| e.record.current_value())
103 }
104
105 pub async fn get_nt(&self, name: &str) -> Option<NtPayload> {
107 let pvs = self.pvs.read().await;
108 pvs.get(name).map(|e| e.record.to_ntpayload())
109 }
110
111 pub async fn set_value(&self, name: &str, value: ScalarValue) -> bool {
113 if self.set_value_inner(name, value).await {
114 self.evaluate_links(name).await;
115 true
116 } else {
117 false
118 }
119 }
120
121 pub async fn set_array_value(&self, name: &str, value: ScalarArrayValue) -> bool {
123 if self.set_array_value_inner(name, value).await {
124 self.evaluate_links(name).await;
125 true
126 } else {
127 false
128 }
129 }
130
131 pub async fn put_nt(&self, name: &str, payload: NtPayload) -> bool {
133 if self.put_nt_inner(name, payload).await {
134 self.evaluate_links(name).await;
135 true
136 } else {
137 false
138 }
139 }
140
141 async fn set_value_inner(&self, name: &str, value: ScalarValue) -> bool {
144 let payload = {
145 let mut pvs = self.pvs.write().await;
146 if let Some(entry) = pvs.get_mut(name) {
147 let changed = entry.record.set_scalar_value(value, self.compute_alarms);
148 if changed {
149 let payload = entry.record.to_ntpayload();
150 entry
151 .subscribers
152 .retain(|tx| tx.try_send(payload.clone()).is_ok());
153 Some(payload)
154 } else {
155 None
156 }
157 } else {
158 return false;
159 }
160 };
161
162 if let Some(payload) = payload {
163 let reg = self.registry.read().await;
165 if let Some(registry) = reg.as_ref() {
166 registry.notify_monitors(name, &payload).await;
167 }
168 true
169 } else {
170 false
171 }
172 }
173
174 async fn set_array_value_inner(&self, name: &str, value: ScalarArrayValue) -> bool {
177 let payload = {
178 let mut pvs = self.pvs.write().await;
179 if let Some(entry) = pvs.get_mut(name) {
180 let changed = entry.record.set_array_value(value);
181 if changed {
182 let payload = entry.record.to_ntpayload();
183 entry
184 .subscribers
185 .retain(|tx| tx.try_send(payload.clone()).is_ok());
186 Some(payload)
187 } else {
188 None
189 }
190 } else {
191 return false;
192 }
193 };
194
195 if let Some(payload) = payload {
196 let reg = self.registry.read().await;
198 if let Some(registry) = reg.as_ref() {
199 registry.notify_monitors(name, &payload).await;
200 }
201 true
202 } else {
203 false
204 }
205 }
206
207 async fn put_nt_inner(&self, name: &str, payload: NtPayload) -> bool {
210 let payload = {
211 let mut pvs = self.pvs.write().await;
212 if let Some(entry) = pvs.get_mut(name) {
213 let changed = entry.record.set_nt_payload(payload);
214 if changed {
215 let payload = entry.record.to_ntpayload();
216 entry
217 .subscribers
218 .retain(|tx| tx.try_send(payload.clone()).is_ok());
219 Some(payload)
220 } else {
221 None
222 }
223 } else {
224 return false;
225 }
226 };
227
228 if let Some(payload) = payload {
229 let reg = self.registry.read().await;
231 if let Some(registry) = reg.as_ref() {
232 registry.notify_monitors(name, &payload).await;
233 }
234 true
235 } else {
236 false
237 }
238 }
239
240 async fn evaluate_links(&self, changed_pv: &str) {
243 if self.links.is_empty() {
244 return;
245 }
246 let mut queue = vec![changed_pv.to_string()];
247 let mut visited = HashSet::new();
248
249 while let Some(pv) = queue.pop() {
250 if !visited.insert(pv.clone()) {
251 debug!("Circular link detected for PV '{}', skipping", pv);
252 continue;
253 }
254 for link in &self.links {
255 if !link.inputs.iter().any(|i| i == &pv) {
256 continue;
257 }
258 let values = {
260 let pvs = self.pvs.read().await;
261 link.inputs
262 .iter()
263 .map(|n| {
264 pvs.get(n)
265 .map(|e| e.record.current_value())
266 .unwrap_or(ScalarValue::F64(0.0))
267 })
268 .collect::<Vec<_>>()
269 };
270 let new_val = (link.compute)(&values);
271 if self.set_value_inner(&link.output, new_val).await {
272 queue.push(link.output.clone());
273 }
274 }
275 }
276 }
277
278 pub async fn pv_names(&self) -> Vec<String> {
280 let pvs = self.pvs.read().await;
281 pvs.keys().cloned().collect()
282 }
283}
284
285impl PvStore for SimplePvStore {
286 fn has_pv(&self, name: &str) -> impl Future<Output = bool> + Send {
287 async move {
288 let pvs = self.pvs.read().await;
289 pvs.contains_key(name)
290 }
291 }
292
293 fn get_snapshot(&self, name: &str) -> impl Future<Output = Option<NtPayload>> + Send {
294 async move {
295 let pvs = self.pvs.read().await;
296 pvs.get(name).map(|e| e.record.to_ntpayload())
297 }
298 }
299
300 fn get_descriptor(&self, name: &str) -> impl Future<Output = Option<StructureDesc>> + Send {
301 async move {
302 let pvs = self.pvs.read().await;
303 pvs.get(name)
304 .map(|e| descriptor_for_payload(&e.record.to_ntpayload()))
305 }
306 }
307
308 fn put_value(
309 &self,
310 name: &str,
311 value: &DecodedValue,
312 ) -> impl Future<Output = Result<Vec<(String, NtPayload)>, String>> + Send {
313 let name = name.to_string();
314 let value = value.clone();
315 async move {
316 let result = {
317 let mut pvs = self.pvs.write().await;
318 let entry = pvs
319 .get_mut(&name)
320 .ok_or_else(|| format!("PV '{}' not found", name))?;
321
322 if !entry.record.writable() {
323 return Err(format!("PV '{}' is not writable", name));
324 }
325
326 let changed = apply_put_to_record(&mut entry.record, &value, self.compute_alarms);
327 if !changed {
328 return Ok(vec![]);
329 }
330
331 let payload = entry.record.to_ntpayload();
332 entry
333 .subscribers
334 .retain(|tx| tx.try_send(payload.clone()).is_ok());
335
336 (name.clone(), payload)
337 }; if let Some(cb) = self.on_put.get(&name) {
341 let cb = cb.clone();
342 let n = name.clone();
343 let v = value.clone();
344 tokio::spawn(async move { cb(&n, &v) });
345 }
346
347 self.evaluate_links(&name).await;
349
350 Ok(vec![result])
351 }
352 }
353
354 fn is_writable(&self, name: &str) -> impl Future<Output = bool> + Send {
355 async move {
356 let pvs = self.pvs.read().await;
357 pvs.get(name).is_some_and(|e| e.record.writable())
358 }
359 }
360
361 fn list_pvs(&self) -> impl Future<Output = Vec<String>> + Send {
362 async move {
363 let pvs = self.pvs.read().await;
364 pvs.keys().cloned().collect()
365 }
366 }
367
368 fn subscribe(
369 &self,
370 name: &str,
371 ) -> impl Future<Output = Option<mpsc::Receiver<NtPayload>>> + Send {
372 let name = name.to_string();
373 async move {
374 let mut pvs = self.pvs.write().await;
375 let entry = pvs.get_mut(&name)?;
376 let (tx, rx) = mpsc::channel(64);
377 entry.subscribers.push(tx);
378 Some(rx)
379 }
380 }
381}
382
383fn apply_put_to_record(
387 record: &mut RecordInstance,
388 value: &DecodedValue,
389 compute_alarms: bool,
390) -> bool {
391 let fields = match value {
392 DecodedValue::Structure(f) => f,
393 other => {
394 return apply_put_to_record(
396 record,
397 &DecodedValue::Structure(vec![("value".to_string(), other.clone())]),
398 compute_alarms,
399 );
400 }
401 };
402
403 let mut changed = false;
404
405 match &mut record.data {
406 RecordData::Ai { nt, .. }
407 | RecordData::Ao { nt, .. }
408 | RecordData::Bi { nt, .. }
409 | RecordData::Bo { nt, .. }
410 | RecordData::StringIn { nt, .. }
411 | RecordData::StringOut { nt, .. } => {
412 for (name, val) in fields {
413 match name.as_str() {
414 "value" => {
415 changed |= apply_value_update(nt, val, compute_alarms);
416 }
417 "alarm" => {
418 changed |= apply_alarm_update(nt, val);
419 }
420 "display" => {
421 changed |= apply_display_update(nt, val);
422 }
423 "control" => {
424 changed |= apply_control_update(nt, val);
425 }
426 _ => {}
427 }
428 }
429 }
430 RecordData::Waveform { nt, nord, .. }
431 | RecordData::Aai { nt, nord, .. }
432 | RecordData::Aao { nt, nord, .. }
433 | RecordData::SubArray { nt, nord, .. } => {
434 changed = apply_scalar_array_put(nt, nord, value);
435 }
436 RecordData::NtTable { .. } | RecordData::NtNdArray { .. } => {
437 debug!("PUT to NtTable/NtNdArray not yet supported in SimplePvStore");
439 }
440 RecordData::NtEnum { nt, .. } => {
441 for (name, val) in fields {
443 if name == "value" {
444 let idx = match val {
445 DecodedValue::Int32(v) => Some(*v),
446 DecodedValue::Int64(v) => Some(*v as i32),
447 DecodedValue::Int16(v) => Some(*v as i32),
448 DecodedValue::Int8(v) => Some(*v as i32),
449 DecodedValue::Float64(v) => Some(*v as i32),
450 _ => None,
451 };
452 if let Some(idx) = idx {
453 if nt.index != idx {
454 nt.index = idx;
455 changed = true;
456 }
457 }
458 }
459 }
460 }
461 RecordData::Generic { .. } => {
462 debug!("PUT to Generic not yet supported in SimplePvStore");
463 }
464 }
465
466 changed
467}
468
469pub(crate) fn descriptor_for_payload(payload: &NtPayload) -> StructureDesc {
472 match payload {
473 NtPayload::Scalar(nt) => nt_scalar_desc(&nt.value),
474 NtPayload::ScalarArray(arr) => nt_scalar_array_desc(&arr.value),
475 _ => StructureDesc::new(),
476 }
477}
478
479fn value_type_code(sv: &ScalarValue) -> TypeCode {
480 match sv {
481 ScalarValue::Bool(_) => TypeCode::Boolean,
482 ScalarValue::I8(_) => TypeCode::Int8,
483 ScalarValue::I16(_) => TypeCode::Int16,
484 ScalarValue::I32(_) => TypeCode::Int32,
485 ScalarValue::I64(_) => TypeCode::Int64,
486 ScalarValue::U8(_) => TypeCode::UInt8,
487 ScalarValue::U16(_) => TypeCode::UInt16,
488 ScalarValue::U32(_) => TypeCode::UInt32,
489 ScalarValue::U64(_) => TypeCode::UInt64,
490 ScalarValue::F32(_) => TypeCode::Float32,
491 ScalarValue::F64(_) => TypeCode::Float64,
492 ScalarValue::Str(_) => TypeCode::String,
493 }
494}
495
496fn array_type_code(sav: &ScalarArrayValue) -> TypeCode {
497 match sav {
498 ScalarArrayValue::Bool(_) => TypeCode::Boolean,
499 ScalarArrayValue::I8(_) => TypeCode::Int8,
500 ScalarArrayValue::I16(_) => TypeCode::Int16,
501 ScalarArrayValue::I32(_) => TypeCode::Int32,
502 ScalarArrayValue::I64(_) => TypeCode::Int64,
503 ScalarArrayValue::U8(_) => TypeCode::UInt8,
504 ScalarArrayValue::U16(_) => TypeCode::UInt16,
505 ScalarArrayValue::U32(_) => TypeCode::UInt32,
506 ScalarArrayValue::U64(_) => TypeCode::UInt64,
507 ScalarArrayValue::F32(_) => TypeCode::Float32,
508 ScalarArrayValue::F64(_) => TypeCode::Float64,
509 ScalarArrayValue::Str(_) => TypeCode::String,
510 }
511}
512
513fn nt_scalar_desc(sv: &ScalarValue) -> StructureDesc {
514 let tc = value_type_code(sv);
515 StructureDesc {
516 struct_id: Some("epics:nt/NTScalar:1.0".to_string()),
517 fields: vec![
518 FieldDesc {
519 name: "value".to_string(),
520 field_type: FieldType::Scalar(tc),
521 },
522 FieldDesc {
523 name: "alarm".to_string(),
524 field_type: FieldType::Structure(alarm_desc()),
525 },
526 FieldDesc {
527 name: "timeStamp".to_string(),
528 field_type: FieldType::Structure(timestamp_desc()),
529 },
530 FieldDesc {
531 name: "display".to_string(),
532 field_type: FieldType::Structure(display_desc()),
533 },
534 FieldDesc {
535 name: "control".to_string(),
536 field_type: FieldType::Structure(control_desc()),
537 },
538 FieldDesc {
539 name: "valueAlarm".to_string(),
540 field_type: FieldType::Structure(value_alarm_desc()),
541 },
542 ],
543 }
544}
545
546fn nt_scalar_array_desc(sav: &ScalarArrayValue) -> StructureDesc {
547 let tc = array_type_code(sav);
548 StructureDesc {
549 struct_id: Some("epics:nt/NTScalarArray:1.0".to_string()),
550 fields: vec![
551 FieldDesc {
552 name: "value".to_string(),
553 field_type: FieldType::ScalarArray(tc),
554 },
555 FieldDesc {
556 name: "alarm".to_string(),
557 field_type: FieldType::Structure(alarm_desc()),
558 },
559 FieldDesc {
560 name: "timeStamp".to_string(),
561 field_type: FieldType::Structure(timestamp_desc()),
562 },
563 FieldDesc {
564 name: "display".to_string(),
565 field_type: FieldType::Structure(display_desc()),
566 },
567 FieldDesc {
568 name: "control".to_string(),
569 field_type: FieldType::Structure(control_desc()),
570 },
571 ],
572 }
573}
574
575fn alarm_desc() -> StructureDesc {
576 StructureDesc {
577 struct_id: Some("alarm_t".to_string()),
578 fields: vec![
579 FieldDesc {
580 name: "severity".to_string(),
581 field_type: FieldType::Scalar(TypeCode::Int32),
582 },
583 FieldDesc {
584 name: "status".to_string(),
585 field_type: FieldType::Scalar(TypeCode::Int32),
586 },
587 FieldDesc {
588 name: "message".to_string(),
589 field_type: FieldType::String,
590 },
591 ],
592 }
593}
594
595fn timestamp_desc() -> StructureDesc {
596 StructureDesc {
597 struct_id: Some("time_t".to_string()),
598 fields: vec![
599 FieldDesc {
600 name: "secondsPastEpoch".to_string(),
601 field_type: FieldType::Scalar(TypeCode::Int64),
602 },
603 FieldDesc {
604 name: "nanoseconds".to_string(),
605 field_type: FieldType::Scalar(TypeCode::Int32),
606 },
607 FieldDesc {
608 name: "userTag".to_string(),
609 field_type: FieldType::Scalar(TypeCode::Int32),
610 },
611 ],
612 }
613}
614
615fn display_desc() -> StructureDesc {
616 StructureDesc {
617 struct_id: Some("display_t".to_string()),
618 fields: vec![
619 FieldDesc {
620 name: "limitLow".to_string(),
621 field_type: FieldType::Scalar(TypeCode::Float64),
622 },
623 FieldDesc {
624 name: "limitHigh".to_string(),
625 field_type: FieldType::Scalar(TypeCode::Float64),
626 },
627 FieldDesc {
628 name: "description".to_string(),
629 field_type: FieldType::String,
630 },
631 FieldDesc {
632 name: "units".to_string(),
633 field_type: FieldType::String,
634 },
635 FieldDesc {
636 name: "precision".to_string(),
637 field_type: FieldType::Scalar(TypeCode::Int32),
638 },
639 FieldDesc {
640 name: "form".to_string(),
641 field_type: FieldType::Structure(StructureDesc {
642 struct_id: Some("enum_t".to_string()),
643 fields: vec![
644 FieldDesc {
645 name: "index".to_string(),
646 field_type: FieldType::Scalar(TypeCode::Int32),
647 },
648 FieldDesc {
649 name: "choices".to_string(),
650 field_type: FieldType::StringArray,
651 },
652 ],
653 }),
654 },
655 ],
656 }
657}
658
659fn control_desc() -> StructureDesc {
660 StructureDesc {
661 struct_id: Some("control_t".to_string()),
662 fields: vec![
663 FieldDesc {
664 name: "limitLow".to_string(),
665 field_type: FieldType::Scalar(TypeCode::Float64),
666 },
667 FieldDesc {
668 name: "limitHigh".to_string(),
669 field_type: FieldType::Scalar(TypeCode::Float64),
670 },
671 FieldDesc {
672 name: "minStep".to_string(),
673 field_type: FieldType::Scalar(TypeCode::Float64),
674 },
675 ],
676 }
677}
678
679fn value_alarm_desc() -> StructureDesc {
680 StructureDesc {
681 struct_id: Some("valueAlarm_t".to_string()),
682 fields: vec![
683 FieldDesc {
684 name: "active".to_string(),
685 field_type: FieldType::Scalar(TypeCode::Boolean),
686 },
687 FieldDesc {
688 name: "lowAlarmLimit".to_string(),
689 field_type: FieldType::Scalar(TypeCode::Float64),
690 },
691 FieldDesc {
692 name: "lowWarningLimit".to_string(),
693 field_type: FieldType::Scalar(TypeCode::Float64),
694 },
695 FieldDesc {
696 name: "highWarningLimit".to_string(),
697 field_type: FieldType::Scalar(TypeCode::Float64),
698 },
699 FieldDesc {
700 name: "highAlarmLimit".to_string(),
701 field_type: FieldType::Scalar(TypeCode::Float64),
702 },
703 FieldDesc {
704 name: "lowAlarmSeverity".to_string(),
705 field_type: FieldType::Scalar(TypeCode::Int32),
706 },
707 FieldDesc {
708 name: "lowWarningSeverity".to_string(),
709 field_type: FieldType::Scalar(TypeCode::Int32),
710 },
711 FieldDesc {
712 name: "highWarningSeverity".to_string(),
713 field_type: FieldType::Scalar(TypeCode::Int32),
714 },
715 FieldDesc {
716 name: "highAlarmSeverity".to_string(),
717 field_type: FieldType::Scalar(TypeCode::Int32),
718 },
719 FieldDesc {
720 name: "hysteresis".to_string(),
721 field_type: FieldType::Scalar(TypeCode::UInt8),
722 },
723 ],
724 }
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730 use crate::types::{DbCommonState, RecordType};
731 use spvirit_types::{
732 NdCodec, NdDimension, NtNdArray, NtPayload, NtScalar, NtScalarArray, NtTable,
733 NtTableColumn, ScalarArrayValue, ScalarValue,
734 };
735
736 fn make_ai(name: &str, val: f64) -> RecordInstance {
737 RecordInstance {
738 name: name.to_string(),
739 record_type: RecordType::Ai,
740 common: DbCommonState::default(),
741 data: RecordData::Ai {
742 nt: NtScalar::from_value(ScalarValue::F64(val)),
743 inp: None,
744 siml: None,
745 siol: None,
746 simm: false,
747 },
748 raw_fields: HashMap::new(),
749 }
750 }
751
752 fn make_ao(name: &str, val: f64) -> RecordInstance {
753 RecordInstance {
754 name: name.to_string(),
755 record_type: RecordType::Ao,
756 common: DbCommonState::default(),
757 data: RecordData::Ao {
758 nt: NtScalar::from_value(ScalarValue::F64(val)),
759 out: None,
760 dol: None,
761 omsl: crate::types::OutputMode::Supervisory,
762 drvl: None,
763 drvh: None,
764 oroc: None,
765 siml: None,
766 siol: None,
767 simm: false,
768 },
769 raw_fields: HashMap::new(),
770 }
771 }
772
773 fn make_waveform(name: &str, value: ScalarArrayValue) -> RecordInstance {
774 let nelm = value.len();
775 RecordInstance {
776 name: name.to_string(),
777 record_type: RecordType::Waveform,
778 common: DbCommonState::default(),
779 data: RecordData::Waveform {
780 nt: NtScalarArray::from_value(value),
781 inp: None,
782 ftvl: "DOUBLE".to_string(),
783 nelm,
784 nord: nelm,
785 },
786 raw_fields: HashMap::new(),
787 }
788 }
789
790 fn make_nt_table(name: &str) -> RecordInstance {
791 RecordInstance {
792 name: name.to_string(),
793 record_type: RecordType::NtTable,
794 common: DbCommonState::default(),
795 data: RecordData::NtTable {
796 nt: NtTable {
797 labels: vec!["X".to_string(), "Y".to_string()],
798 columns: vec![
799 NtTableColumn {
800 name: "x".to_string(),
801 values: ScalarArrayValue::F64(vec![1.0, 2.0]),
802 },
803 NtTableColumn {
804 name: "y".to_string(),
805 values: ScalarArrayValue::F64(vec![10.0, 20.0]),
806 },
807 ],
808 descriptor: Some("table".to_string()),
809 alarm: None,
810 time_stamp: None,
811 },
812 inp: None,
813 out: None,
814 omsl: crate::types::OutputMode::Supervisory,
815 },
816 raw_fields: HashMap::new(),
817 }
818 }
819
820 fn make_nt_ndarray(name: &str) -> RecordInstance {
821 RecordInstance {
822 name: name.to_string(),
823 record_type: RecordType::NtNdArray,
824 common: DbCommonState::default(),
825 data: RecordData::NtNdArray {
826 nt: NtNdArray {
827 value: ScalarArrayValue::U8(vec![0; 4]),
828 codec: NdCodec {
829 name: "none".to_string(),
830 parameters: HashMap::new(),
831 },
832 compressed_size: 4,
833 uncompressed_size: 4,
834 dimension: vec![NdDimension {
835 size: 2,
836 offset: 0,
837 full_size: 2,
838 binning: 1,
839 reverse: false,
840 }],
841 unique_id: 1,
842 data_time_stamp: Default::default(),
843 attribute: vec![],
844 descriptor: Some("ndarray".to_string()),
845 alarm: None,
846 time_stamp: None,
847 display: None,
848 },
849 inp: None,
850 out: None,
851 omsl: crate::types::OutputMode::Supervisory,
852 },
853 raw_fields: HashMap::new(),
854 }
855 }
856
857 #[tokio::test]
858 async fn has_pv_returns_true_for_existing() {
859 let mut records = HashMap::new();
860 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
861 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
862 assert!(store.has_pv("TEST:AI").await);
863 assert!(!store.has_pv("MISSING").await);
864 }
865
866 #[tokio::test]
867 async fn get_snapshot_returns_payload() {
868 let mut records = HashMap::new();
869 records.insert("TEST:AI".into(), make_ai("TEST:AI", 42.0));
870 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
871 let snap = store.get_snapshot("TEST:AI").await.unwrap();
872 match snap {
873 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(42.0)),
874 _ => panic!("expected scalar"),
875 }
876 }
877
878 #[tokio::test]
879 async fn put_value_updates_writable_record() {
880 let mut records = HashMap::new();
881 records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
882 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
883
884 let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(99.5))]);
885 let result = store.put_value("TEST:AO", &val).await.unwrap();
886 assert_eq!(result.len(), 1);
887 assert_eq!(result[0].0, "TEST:AO");
888
889 let snap = store.get_snapshot("TEST:AO").await.unwrap();
890 match snap {
891 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(99.5)),
892 _ => panic!("expected scalar"),
893 }
894 }
895
896 #[tokio::test]
897 async fn put_value_rejects_readonly() {
898 let mut records = HashMap::new();
899 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
900 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
901
902 let val = DecodedValue::Float64(5.0);
903 let err = store.put_value("TEST:AI", &val).await.unwrap_err();
904 assert!(err.contains("not writable"));
905 }
906
907 #[tokio::test]
908 async fn set_value_bypasses_writable_check() {
909 let mut records = HashMap::new();
910 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
911 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
912 assert!(store.set_value("TEST:AI", ScalarValue::F64(10.0)).await);
913 let val = store.get_value("TEST:AI").await.unwrap();
914 assert_eq!(val, ScalarValue::F64(10.0));
915 }
916
917 #[tokio::test]
918 async fn set_array_value_updates_all_scalar_array_types() {
919 let cases: Vec<ScalarArrayValue> = vec![
920 ScalarArrayValue::Bool(vec![false, true]),
921 ScalarArrayValue::I8(vec![1, 2]),
922 ScalarArrayValue::I16(vec![1, 2]),
923 ScalarArrayValue::I32(vec![1, 2]),
924 ScalarArrayValue::I64(vec![1, 2]),
925 ScalarArrayValue::U8(vec![1, 2]),
926 ScalarArrayValue::U16(vec![1, 2]),
927 ScalarArrayValue::U32(vec![1, 2]),
928 ScalarArrayValue::U64(vec![1, 2]),
929 ScalarArrayValue::F32(vec![1.0, 2.0]),
930 ScalarArrayValue::F64(vec![1.0, 2.0]),
931 ScalarArrayValue::Str(vec!["a".to_string(), "b".to_string()]),
932 ];
933
934 for (idx, updated) in cases.into_iter().enumerate() {
935 let pv = format!("TEST:WF:{idx}");
936 let mut records = HashMap::new();
937 records.insert(pv.clone(), make_waveform(&pv, updated.clone()));
938 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
939
940 assert!(!store.set_array_value(&pv, updated.clone()).await);
941
942 let second = match updated {
943 ScalarArrayValue::Bool(_) => ScalarArrayValue::Bool(vec![true, false]),
944 ScalarArrayValue::I8(_) => ScalarArrayValue::I8(vec![3, 4]),
945 ScalarArrayValue::I16(_) => ScalarArrayValue::I16(vec![3, 4]),
946 ScalarArrayValue::I32(_) => ScalarArrayValue::I32(vec![3, 4]),
947 ScalarArrayValue::I64(_) => ScalarArrayValue::I64(vec![3, 4]),
948 ScalarArrayValue::U8(_) => ScalarArrayValue::U8(vec![3, 4]),
949 ScalarArrayValue::U16(_) => ScalarArrayValue::U16(vec![3, 4]),
950 ScalarArrayValue::U32(_) => ScalarArrayValue::U32(vec![3, 4]),
951 ScalarArrayValue::U64(_) => ScalarArrayValue::U64(vec![3, 4]),
952 ScalarArrayValue::F32(_) => ScalarArrayValue::F32(vec![3.0, 4.0]),
953 ScalarArrayValue::F64(_) => ScalarArrayValue::F64(vec![3.0, 4.0]),
954 ScalarArrayValue::Str(_) => {
955 ScalarArrayValue::Str(vec!["x".to_string(), "y".to_string()])
956 }
957 };
958
959 assert!(store.set_array_value(&pv, second.clone()).await);
960 let snap = store.get_snapshot(&pv).await.unwrap();
961 match snap {
962 NtPayload::ScalarArray(nt) => assert_eq!(nt.value, second),
963 _ => panic!("expected scalar array"),
964 }
965 }
966 }
967
968 #[tokio::test]
969 async fn get_nt_returns_full_payload() {
970 let mut records = HashMap::new();
971 records.insert("TEST:AI".into(), make_ai("TEST:AI", 12.5));
972 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
973
974 let nt = store.get_nt("TEST:AI").await.unwrap();
975 match nt {
976 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(12.5)),
977 _ => panic!("expected scalar payload"),
978 }
979 }
980
981 #[tokio::test]
982 async fn put_nt_updates_scalar_array_table_and_ndarray() {
983 let mut records = HashMap::new();
984 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
985 records.insert(
986 "TEST:WF".into(),
987 make_waveform("TEST:WF", ScalarArrayValue::F64(vec![0.0, 0.0])),
988 );
989 records.insert("TEST:TBL".into(), make_nt_table("TEST:TBL"));
990 records.insert("TEST:NDA".into(), make_nt_ndarray("TEST:NDA"));
991 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
992
993 assert!(
994 store
995 .put_nt(
996 "TEST:AI",
997 NtPayload::Scalar(NtScalar::from_value(ScalarValue::F64(5.0))),
998 )
999 .await
1000 );
1001 assert!(
1002 store
1003 .put_nt(
1004 "TEST:WF",
1005 NtPayload::ScalarArray(NtScalarArray::from_value(ScalarArrayValue::F64(vec![
1006 3.0, 4.0
1007 ],))),
1008 )
1009 .await
1010 );
1011
1012 let table = NtTable {
1013 labels: vec!["X".to_string(), "Y".to_string()],
1014 columns: vec![
1015 NtTableColumn {
1016 name: "x".to_string(),
1017 values: ScalarArrayValue::F64(vec![2.0, 3.0]),
1018 },
1019 NtTableColumn {
1020 name: "y".to_string(),
1021 values: ScalarArrayValue::F64(vec![20.0, 30.0]),
1022 },
1023 ],
1024 descriptor: Some("updated table".to_string()),
1025 alarm: None,
1026 time_stamp: None,
1027 };
1028 assert!(
1029 store
1030 .put_nt("TEST:TBL", NtPayload::Table(table.clone()))
1031 .await
1032 );
1033
1034 let ndarray = NtNdArray {
1035 value: ScalarArrayValue::U8(vec![1, 2, 3, 4]),
1036 codec: NdCodec {
1037 name: "none".to_string(),
1038 parameters: HashMap::new(),
1039 },
1040 compressed_size: 4,
1041 uncompressed_size: 4,
1042 dimension: vec![NdDimension {
1043 size: 4,
1044 offset: 0,
1045 full_size: 4,
1046 binning: 1,
1047 reverse: false,
1048 }],
1049 unique_id: 2,
1050 data_time_stamp: Default::default(),
1051 attribute: vec![],
1052 descriptor: Some("updated ndarray".to_string()),
1053 alarm: None,
1054 time_stamp: None,
1055 display: None,
1056 };
1057 assert!(
1058 store
1059 .put_nt("TEST:NDA", NtPayload::NdArray(ndarray.clone()))
1060 .await
1061 );
1062
1063 assert!(
1064 !store
1065 .put_nt(
1066 "TEST:AI",
1067 NtPayload::ScalarArray(NtScalarArray::from_value(ScalarArrayValue::F64(vec![
1068 1.0
1069 ]))),
1070 )
1071 .await
1072 );
1073
1074 match store.get_nt("TEST:TBL").await.unwrap() {
1075 NtPayload::Table(nt) => assert_eq!(nt, table),
1076 _ => panic!("expected table payload"),
1077 }
1078 match store.get_nt("TEST:NDA").await.unwrap() {
1079 NtPayload::NdArray(nt) => assert_eq!(nt, ndarray),
1080 _ => panic!("expected ndarray payload"),
1081 }
1082 }
1083
1084 #[tokio::test]
1085 async fn descriptor_matches_value_type() {
1086 let mut records = HashMap::new();
1087 records.insert("TEST:AI".into(), make_ai("TEST:AI", 0.0));
1088 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
1089 let desc = store.get_descriptor("TEST:AI").await.unwrap();
1090 assert_eq!(desc.struct_id.as_deref(), Some("epics:nt/NTScalar:1.0"));
1091 let value_field = desc.field("value").unwrap();
1092 assert!(matches!(
1093 value_field.field_type,
1094 FieldType::Scalar(TypeCode::Float64)
1095 ));
1096 }
1097
1098 #[tokio::test]
1099 async fn subscribe_receives_updates() {
1100 let mut records = HashMap::new();
1101 records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
1102 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
1103
1104 let mut rx = store.subscribe("TEST:AO").await.unwrap();
1105
1106 let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(7.7))]);
1107 store.put_value("TEST:AO", &val).await.unwrap();
1108
1109 let update = rx.recv().await.unwrap();
1110 match update {
1111 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(7.7)),
1112 _ => panic!("expected scalar"),
1113 }
1114 }
1115
1116 #[tokio::test]
1117 async fn on_put_callback_is_invoked() {
1118 use std::sync::atomic::{AtomicBool, Ordering};
1119
1120 let called = Arc::new(AtomicBool::new(false));
1121 let called2 = called.clone();
1122
1123 let mut records = HashMap::new();
1124 records.insert("CB:AO".into(), make_ao("CB:AO", 0.0));
1125
1126 let mut on_put = HashMap::new();
1127 let cb: OnPutCallback = Arc::new(move |_name, _val| {
1128 called2.store(true, Ordering::SeqCst);
1129 });
1130 on_put.insert("CB:AO".into(), cb);
1131
1132 let store = SimplePvStore::new(records, on_put, vec![], false);
1133 let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(1.0))]);
1134 store.put_value("CB:AO", &val).await.unwrap();
1135
1136 tokio::task::yield_now().await;
1138 tokio::task::yield_now().await;
1139
1140 assert!(called.load(Ordering::SeqCst));
1141 }
1142}