1use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use tokio::sync::{RwLock, mpsc};
10use tracing::debug;
11
12use spvirit_codec::spvd_decode::{DecodedValue, FieldDesc, FieldType, StructureDesc, TypeCode};
13use spvirit_types::{NtPayload, ScalarArrayValue, ScalarValue};
14
15use crate::apply::{
16 apply_alarm_update, apply_control_update, apply_display_update, apply_scalar_array_put,
17 apply_value_update,
18};
19use crate::monitor::MonitorRegistry;
20use crate::pvstore::PvStore;
21use crate::types::{RecordData, RecordInstance};
22
23pub type OnPutCallback = Arc<dyn Fn(&str, &DecodedValue) + Send + Sync>;
25
26pub type ScanCallback = Arc<dyn Fn(&str) -> ScalarValue + Send + Sync>;
28
29pub type LinkCallback = Arc<dyn Fn(&[ScalarValue]) -> ScalarValue + Send + Sync>;
31
32pub(crate) struct LinkDef {
34 pub output: String,
35 pub inputs: Vec<String>,
36 pub compute: LinkCallback,
37}
38
39struct PvEntry {
40 record: RecordInstance,
41 subscribers: Vec<mpsc::Sender<NtPayload>>,
42}
43
44pub struct SimplePvStore {
46 pvs: RwLock<HashMap<String, PvEntry>>,
47 on_put: HashMap<String, OnPutCallback>,
48 links: Vec<LinkDef>,
49 compute_alarms: bool,
50 registry: RwLock<Option<Arc<MonitorRegistry>>>,
51}
52
53impl SimplePvStore {
54 pub(crate) fn new(
55 records: HashMap<String, RecordInstance>,
56 on_put: HashMap<String, OnPutCallback>,
57 links: Vec<LinkDef>,
58 compute_alarms: bool,
59 ) -> Self {
60 let pvs = records
61 .into_iter()
62 .map(|(name, record)| {
63 (
64 name,
65 PvEntry {
66 record,
67 subscribers: Vec::new(),
68 },
69 )
70 })
71 .collect();
72 Self {
73 pvs: RwLock::new(pvs),
74 on_put,
75 links,
76 compute_alarms,
77 registry: RwLock::new(None),
78 }
79 }
80
81 pub async fn set_registry(&self, registry: Arc<MonitorRegistry>) {
84 *self.registry.write().await = Some(registry);
85 }
86
87 pub async fn insert(&self, name: String, record: RecordInstance) {
89 let mut pvs = self.pvs.write().await;
90 pvs.insert(
91 name,
92 PvEntry {
93 record,
94 subscribers: Vec::new(),
95 },
96 );
97 }
98
99 pub async fn get_value(&self, name: &str) -> Option<ScalarValue> {
101 let pvs = self.pvs.read().await;
102 pvs.get(name).map(|e| e.record.current_value())
103 }
104
105 pub async fn get_nt(&self, name: &str) -> Option<NtPayload> {
107 let pvs = self.pvs.read().await;
108 pvs.get(name).map(|e| e.record.to_ntpayload())
109 }
110
111 pub async fn set_value(&self, name: &str, value: ScalarValue) -> bool {
113 if self.set_value_inner(name, value).await {
114 self.evaluate_links(name).await;
115 true
116 } else {
117 false
118 }
119 }
120
121 pub async fn set_array_value(&self, name: &str, value: ScalarArrayValue) -> bool {
123 if self.set_array_value_inner(name, value).await {
124 self.evaluate_links(name).await;
125 true
126 } else {
127 false
128 }
129 }
130
131 pub async fn put_nt(&self, name: &str, payload: NtPayload) -> bool {
133 if self.put_nt_inner(name, payload).await {
134 self.evaluate_links(name).await;
135 true
136 } else {
137 false
138 }
139 }
140
141 async fn set_value_inner(&self, name: &str, value: ScalarValue) -> bool {
144 let payload = {
145 let mut pvs = self.pvs.write().await;
146 if let Some(entry) = pvs.get_mut(name) {
147 let changed = entry.record.set_scalar_value(value, self.compute_alarms);
148 if changed {
149 let payload = entry.record.to_ntpayload();
150 entry
151 .subscribers
152 .retain(|tx| tx.try_send(payload.clone()).is_ok());
153 Some(payload)
154 } else {
155 None
156 }
157 } else {
158 return false;
159 }
160 };
161
162 if let Some(payload) = payload {
163 let reg = self.registry.read().await;
165 if let Some(registry) = reg.as_ref() {
166 registry.notify_monitors(name, &payload).await;
167 }
168 true
169 } else {
170 false
171 }
172 }
173
174 async fn set_array_value_inner(&self, name: &str, value: ScalarArrayValue) -> bool {
177 let payload = {
178 let mut pvs = self.pvs.write().await;
179 if let Some(entry) = pvs.get_mut(name) {
180 let changed = entry.record.set_array_value(value);
181 if changed {
182 let payload = entry.record.to_ntpayload();
183 entry
184 .subscribers
185 .retain(|tx| tx.try_send(payload.clone()).is_ok());
186 Some(payload)
187 } else {
188 None
189 }
190 } else {
191 return false;
192 }
193 };
194
195 if let Some(payload) = payload {
196 let reg = self.registry.read().await;
198 if let Some(registry) = reg.as_ref() {
199 registry.notify_monitors(name, &payload).await;
200 }
201 true
202 } else {
203 false
204 }
205 }
206
207 async fn put_nt_inner(&self, name: &str, payload: NtPayload) -> bool {
210 let payload = {
211 let mut pvs = self.pvs.write().await;
212 if let Some(entry) = pvs.get_mut(name) {
213 let changed = entry.record.set_nt_payload(payload);
214 if changed {
215 let payload = entry.record.to_ntpayload();
216 entry
217 .subscribers
218 .retain(|tx| tx.try_send(payload.clone()).is_ok());
219 Some(payload)
220 } else {
221 None
222 }
223 } else {
224 return false;
225 }
226 };
227
228 if let Some(payload) = payload {
229 let reg = self.registry.read().await;
231 if let Some(registry) = reg.as_ref() {
232 registry.notify_monitors(name, &payload).await;
233 }
234 true
235 } else {
236 false
237 }
238 }
239
240 async fn evaluate_links(&self, changed_pv: &str) {
243 if self.links.is_empty() {
244 return;
245 }
246 let mut queue = vec![changed_pv.to_string()];
247 let mut visited = HashSet::new();
248
249 while let Some(pv) = queue.pop() {
250 if !visited.insert(pv.clone()) {
251 debug!("Circular link detected for PV '{}', skipping", pv);
252 continue;
253 }
254 for link in &self.links {
255 if !link.inputs.iter().any(|i| i == &pv) {
256 continue;
257 }
258 let values = {
260 let pvs = self.pvs.read().await;
261 link.inputs
262 .iter()
263 .map(|n| {
264 pvs.get(n)
265 .map(|e| e.record.current_value())
266 .unwrap_or(ScalarValue::F64(0.0))
267 })
268 .collect::<Vec<_>>()
269 };
270 let new_val = (link.compute)(&values);
271 if self.set_value_inner(&link.output, new_val).await {
272 queue.push(link.output.clone());
273 }
274 }
275 }
276 }
277
278 pub async fn pv_names(&self) -> Vec<String> {
280 let pvs = self.pvs.read().await;
281 pvs.keys().cloned().collect()
282 }
283}
284
285impl PvStore for SimplePvStore {
286 fn has_pv(&self, name: &str) -> impl Future<Output = bool> + Send {
287 async move {
288 let pvs = self.pvs.read().await;
289 pvs.contains_key(name)
290 }
291 }
292
293 fn get_snapshot(&self, name: &str) -> impl Future<Output = Option<NtPayload>> + Send {
294 async move {
295 let pvs = self.pvs.read().await;
296 pvs.get(name).map(|e| e.record.to_ntpayload())
297 }
298 }
299
300 fn get_descriptor(&self, name: &str) -> impl Future<Output = Option<StructureDesc>> + Send {
301 async move {
302 let pvs = self.pvs.read().await;
303 pvs.get(name)
304 .map(|e| descriptor_for_payload(&e.record.to_ntpayload()))
305 }
306 }
307
308 fn put_value(
309 &self,
310 name: &str,
311 value: &DecodedValue,
312 ) -> impl Future<Output = Result<Vec<(String, NtPayload)>, String>> + Send {
313 let name = name.to_string();
314 let value = value.clone();
315 async move {
316 let result = {
317 let mut pvs = self.pvs.write().await;
318 let entry = pvs
319 .get_mut(&name)
320 .ok_or_else(|| format!("PV '{}' not found", name))?;
321
322 if !entry.record.writable() {
323 return Err(format!("PV '{}' is not writable", name));
324 }
325
326 let changed = apply_put_to_record(&mut entry.record, &value, self.compute_alarms);
327 if !changed {
328 return Ok(vec![]);
329 }
330
331 let payload = entry.record.to_ntpayload();
332 entry
333 .subscribers
334 .retain(|tx| tx.try_send(payload.clone()).is_ok());
335
336 (name.clone(), payload)
337 }; if let Some(cb) = self.on_put.get(&name) {
341 let cb = cb.clone();
342 let n = name.clone();
343 let v = value.clone();
344 tokio::spawn(async move { cb(&n, &v) });
345 }
346
347 self.evaluate_links(&name).await;
349
350 Ok(vec![result])
351 }
352 }
353
354 fn is_writable(&self, name: &str) -> impl Future<Output = bool> + Send {
355 async move {
356 let pvs = self.pvs.read().await;
357 pvs.get(name).is_some_and(|e| e.record.writable())
358 }
359 }
360
361 fn list_pvs(&self) -> impl Future<Output = Vec<String>> + Send {
362 async move {
363 let pvs = self.pvs.read().await;
364 pvs.keys().cloned().collect()
365 }
366 }
367
368 fn subscribe(
369 &self,
370 name: &str,
371 ) -> impl Future<Output = Option<mpsc::Receiver<NtPayload>>> + Send {
372 let name = name.to_string();
373 async move {
374 let mut pvs = self.pvs.write().await;
375 let entry = pvs.get_mut(&name)?;
376 let (tx, rx) = mpsc::channel(64);
377 entry.subscribers.push(tx);
378 Some(rx)
379 }
380 }
381}
382
383fn apply_put_to_record(
387 record: &mut RecordInstance,
388 value: &DecodedValue,
389 compute_alarms: bool,
390) -> bool {
391 let fields = match value {
392 DecodedValue::Structure(f) => f,
393 other => {
394 return apply_put_to_record(
396 record,
397 &DecodedValue::Structure(vec![("value".to_string(), other.clone())]),
398 compute_alarms,
399 );
400 }
401 };
402
403 let mut changed = false;
404
405 match &mut record.data {
406 RecordData::Ai { nt, .. }
407 | RecordData::Ao { nt, .. }
408 | RecordData::Bi { nt, .. }
409 | RecordData::Bo { nt, .. }
410 | RecordData::StringIn { nt, .. }
411 | RecordData::StringOut { nt, .. } => {
412 for (name, val) in fields {
413 match name.as_str() {
414 "value" => {
415 changed |= apply_value_update(nt, val, compute_alarms);
416 }
417 "alarm" => {
418 changed |= apply_alarm_update(nt, val);
419 }
420 "display" => {
421 changed |= apply_display_update(nt, val);
422 }
423 "control" => {
424 changed |= apply_control_update(nt, val);
425 }
426 _ => {}
427 }
428 }
429 }
430 RecordData::Waveform { nt, nord, .. }
431 | RecordData::Aai { nt, nord, .. }
432 | RecordData::Aao { nt, nord, .. }
433 | RecordData::SubArray { nt, nord, .. } => {
434 changed = apply_scalar_array_put(nt, nord, value);
435 }
436 RecordData::NtTable { .. } | RecordData::NtNdArray { .. } => {
437 debug!("PUT to NtTable/NtNdArray not yet supported in SimplePvStore");
439 }
440 }
441
442 changed
443}
444
445pub(crate) fn descriptor_for_payload(payload: &NtPayload) -> StructureDesc {
448 match payload {
449 NtPayload::Scalar(nt) => nt_scalar_desc(&nt.value),
450 NtPayload::ScalarArray(arr) => nt_scalar_array_desc(&arr.value),
451 _ => StructureDesc::new(),
452 }
453}
454
455fn value_type_code(sv: &ScalarValue) -> TypeCode {
456 match sv {
457 ScalarValue::Bool(_) => TypeCode::Boolean,
458 ScalarValue::I8(_) => TypeCode::Int8,
459 ScalarValue::I16(_) => TypeCode::Int16,
460 ScalarValue::I32(_) => TypeCode::Int32,
461 ScalarValue::I64(_) => TypeCode::Int64,
462 ScalarValue::U8(_) => TypeCode::UInt8,
463 ScalarValue::U16(_) => TypeCode::UInt16,
464 ScalarValue::U32(_) => TypeCode::UInt32,
465 ScalarValue::U64(_) => TypeCode::UInt64,
466 ScalarValue::F32(_) => TypeCode::Float32,
467 ScalarValue::F64(_) => TypeCode::Float64,
468 ScalarValue::Str(_) => TypeCode::String,
469 }
470}
471
472fn array_type_code(sav: &ScalarArrayValue) -> TypeCode {
473 match sav {
474 ScalarArrayValue::Bool(_) => TypeCode::Boolean,
475 ScalarArrayValue::I8(_) => TypeCode::Int8,
476 ScalarArrayValue::I16(_) => TypeCode::Int16,
477 ScalarArrayValue::I32(_) => TypeCode::Int32,
478 ScalarArrayValue::I64(_) => TypeCode::Int64,
479 ScalarArrayValue::U8(_) => TypeCode::UInt8,
480 ScalarArrayValue::U16(_) => TypeCode::UInt16,
481 ScalarArrayValue::U32(_) => TypeCode::UInt32,
482 ScalarArrayValue::U64(_) => TypeCode::UInt64,
483 ScalarArrayValue::F32(_) => TypeCode::Float32,
484 ScalarArrayValue::F64(_) => TypeCode::Float64,
485 ScalarArrayValue::Str(_) => TypeCode::String,
486 }
487}
488
489fn nt_scalar_desc(sv: &ScalarValue) -> StructureDesc {
490 let tc = value_type_code(sv);
491 StructureDesc {
492 struct_id: Some("epics:nt/NTScalar:1.0".to_string()),
493 fields: vec![
494 FieldDesc {
495 name: "value".to_string(),
496 field_type: FieldType::Scalar(tc),
497 },
498 FieldDesc {
499 name: "alarm".to_string(),
500 field_type: FieldType::Structure(alarm_desc()),
501 },
502 FieldDesc {
503 name: "timeStamp".to_string(),
504 field_type: FieldType::Structure(timestamp_desc()),
505 },
506 FieldDesc {
507 name: "display".to_string(),
508 field_type: FieldType::Structure(display_desc()),
509 },
510 FieldDesc {
511 name: "control".to_string(),
512 field_type: FieldType::Structure(control_desc()),
513 },
514 FieldDesc {
515 name: "valueAlarm".to_string(),
516 field_type: FieldType::Structure(value_alarm_desc()),
517 },
518 ],
519 }
520}
521
522fn nt_scalar_array_desc(sav: &ScalarArrayValue) -> StructureDesc {
523 let tc = array_type_code(sav);
524 StructureDesc {
525 struct_id: Some("epics:nt/NTScalarArray:1.0".to_string()),
526 fields: vec![
527 FieldDesc {
528 name: "value".to_string(),
529 field_type: FieldType::ScalarArray(tc),
530 },
531 FieldDesc {
532 name: "alarm".to_string(),
533 field_type: FieldType::Structure(alarm_desc()),
534 },
535 FieldDesc {
536 name: "timeStamp".to_string(),
537 field_type: FieldType::Structure(timestamp_desc()),
538 },
539 FieldDesc {
540 name: "display".to_string(),
541 field_type: FieldType::Structure(display_desc()),
542 },
543 FieldDesc {
544 name: "control".to_string(),
545 field_type: FieldType::Structure(control_desc()),
546 },
547 ],
548 }
549}
550
551fn alarm_desc() -> StructureDesc {
552 StructureDesc {
553 struct_id: Some("alarm_t".to_string()),
554 fields: vec![
555 FieldDesc {
556 name: "severity".to_string(),
557 field_type: FieldType::Scalar(TypeCode::Int32),
558 },
559 FieldDesc {
560 name: "status".to_string(),
561 field_type: FieldType::Scalar(TypeCode::Int32),
562 },
563 FieldDesc {
564 name: "message".to_string(),
565 field_type: FieldType::String,
566 },
567 ],
568 }
569}
570
571fn timestamp_desc() -> StructureDesc {
572 StructureDesc {
573 struct_id: Some("time_t".to_string()),
574 fields: vec![
575 FieldDesc {
576 name: "secondsPastEpoch".to_string(),
577 field_type: FieldType::Scalar(TypeCode::Int64),
578 },
579 FieldDesc {
580 name: "nanoseconds".to_string(),
581 field_type: FieldType::Scalar(TypeCode::Int32),
582 },
583 FieldDesc {
584 name: "userTag".to_string(),
585 field_type: FieldType::Scalar(TypeCode::Int32),
586 },
587 ],
588 }
589}
590
591fn display_desc() -> StructureDesc {
592 StructureDesc {
593 struct_id: Some("display_t".to_string()),
594 fields: vec![
595 FieldDesc {
596 name: "limitLow".to_string(),
597 field_type: FieldType::Scalar(TypeCode::Float64),
598 },
599 FieldDesc {
600 name: "limitHigh".to_string(),
601 field_type: FieldType::Scalar(TypeCode::Float64),
602 },
603 FieldDesc {
604 name: "description".to_string(),
605 field_type: FieldType::String,
606 },
607 FieldDesc {
608 name: "units".to_string(),
609 field_type: FieldType::String,
610 },
611 FieldDesc {
612 name: "precision".to_string(),
613 field_type: FieldType::Scalar(TypeCode::Int32),
614 },
615 FieldDesc {
616 name: "form".to_string(),
617 field_type: FieldType::Structure(StructureDesc {
618 struct_id: Some("enum_t".to_string()),
619 fields: vec![
620 FieldDesc {
621 name: "index".to_string(),
622 field_type: FieldType::Scalar(TypeCode::Int32),
623 },
624 FieldDesc {
625 name: "choices".to_string(),
626 field_type: FieldType::StringArray,
627 },
628 ],
629 }),
630 },
631 ],
632 }
633}
634
635fn control_desc() -> StructureDesc {
636 StructureDesc {
637 struct_id: Some("control_t".to_string()),
638 fields: vec![
639 FieldDesc {
640 name: "limitLow".to_string(),
641 field_type: FieldType::Scalar(TypeCode::Float64),
642 },
643 FieldDesc {
644 name: "limitHigh".to_string(),
645 field_type: FieldType::Scalar(TypeCode::Float64),
646 },
647 FieldDesc {
648 name: "minStep".to_string(),
649 field_type: FieldType::Scalar(TypeCode::Float64),
650 },
651 ],
652 }
653}
654
655fn value_alarm_desc() -> StructureDesc {
656 StructureDesc {
657 struct_id: Some("valueAlarm_t".to_string()),
658 fields: vec![
659 FieldDesc {
660 name: "active".to_string(),
661 field_type: FieldType::Scalar(TypeCode::Boolean),
662 },
663 FieldDesc {
664 name: "lowAlarmLimit".to_string(),
665 field_type: FieldType::Scalar(TypeCode::Float64),
666 },
667 FieldDesc {
668 name: "lowWarningLimit".to_string(),
669 field_type: FieldType::Scalar(TypeCode::Float64),
670 },
671 FieldDesc {
672 name: "highWarningLimit".to_string(),
673 field_type: FieldType::Scalar(TypeCode::Float64),
674 },
675 FieldDesc {
676 name: "highAlarmLimit".to_string(),
677 field_type: FieldType::Scalar(TypeCode::Float64),
678 },
679 FieldDesc {
680 name: "lowAlarmSeverity".to_string(),
681 field_type: FieldType::Scalar(TypeCode::Int32),
682 },
683 FieldDesc {
684 name: "lowWarningSeverity".to_string(),
685 field_type: FieldType::Scalar(TypeCode::Int32),
686 },
687 FieldDesc {
688 name: "highWarningSeverity".to_string(),
689 field_type: FieldType::Scalar(TypeCode::Int32),
690 },
691 FieldDesc {
692 name: "highAlarmSeverity".to_string(),
693 field_type: FieldType::Scalar(TypeCode::Int32),
694 },
695 FieldDesc {
696 name: "hysteresis".to_string(),
697 field_type: FieldType::Scalar(TypeCode::UInt8),
698 },
699 ],
700 }
701}
702
703#[cfg(test)]
704mod tests {
705 use super::*;
706 use crate::types::{DbCommonState, RecordType};
707 use spvirit_types::{
708 NdCodec, NdDimension, NtNdArray, NtPayload, NtScalar, NtScalarArray, NtTable,
709 NtTableColumn, ScalarArrayValue, ScalarValue,
710 };
711
712 fn make_ai(name: &str, val: f64) -> RecordInstance {
713 RecordInstance {
714 name: name.to_string(),
715 record_type: RecordType::Ai,
716 common: DbCommonState::default(),
717 data: RecordData::Ai {
718 nt: NtScalar::from_value(ScalarValue::F64(val)),
719 inp: None,
720 siml: None,
721 siol: None,
722 simm: false,
723 },
724 raw_fields: HashMap::new(),
725 }
726 }
727
728 fn make_ao(name: &str, val: f64) -> RecordInstance {
729 RecordInstance {
730 name: name.to_string(),
731 record_type: RecordType::Ao,
732 common: DbCommonState::default(),
733 data: RecordData::Ao {
734 nt: NtScalar::from_value(ScalarValue::F64(val)),
735 out: None,
736 dol: None,
737 omsl: crate::types::OutputMode::Supervisory,
738 drvl: None,
739 drvh: None,
740 oroc: None,
741 siml: None,
742 siol: None,
743 simm: false,
744 },
745 raw_fields: HashMap::new(),
746 }
747 }
748
749 fn make_waveform(name: &str, value: ScalarArrayValue) -> RecordInstance {
750 let nelm = value.len();
751 RecordInstance {
752 name: name.to_string(),
753 record_type: RecordType::Waveform,
754 common: DbCommonState::default(),
755 data: RecordData::Waveform {
756 nt: NtScalarArray::from_value(value),
757 inp: None,
758 ftvl: "DOUBLE".to_string(),
759 nelm,
760 nord: nelm,
761 },
762 raw_fields: HashMap::new(),
763 }
764 }
765
766 fn make_nt_table(name: &str) -> RecordInstance {
767 RecordInstance {
768 name: name.to_string(),
769 record_type: RecordType::NtTable,
770 common: DbCommonState::default(),
771 data: RecordData::NtTable {
772 nt: NtTable {
773 labels: vec!["X".to_string(), "Y".to_string()],
774 columns: vec![
775 NtTableColumn {
776 name: "x".to_string(),
777 values: ScalarArrayValue::F64(vec![1.0, 2.0]),
778 },
779 NtTableColumn {
780 name: "y".to_string(),
781 values: ScalarArrayValue::F64(vec![10.0, 20.0]),
782 },
783 ],
784 descriptor: Some("table".to_string()),
785 alarm: None,
786 time_stamp: None,
787 },
788 inp: None,
789 out: None,
790 omsl: crate::types::OutputMode::Supervisory,
791 },
792 raw_fields: HashMap::new(),
793 }
794 }
795
796 fn make_nt_ndarray(name: &str) -> RecordInstance {
797 RecordInstance {
798 name: name.to_string(),
799 record_type: RecordType::NtNdArray,
800 common: DbCommonState::default(),
801 data: RecordData::NtNdArray {
802 nt: NtNdArray {
803 value: ScalarArrayValue::U8(vec![0; 4]),
804 codec: NdCodec {
805 name: "none".to_string(),
806 parameters: HashMap::new(),
807 },
808 compressed_size: 4,
809 uncompressed_size: 4,
810 dimension: vec![NdDimension {
811 size: 2,
812 offset: 0,
813 full_size: 2,
814 binning: 1,
815 reverse: false,
816 }],
817 unique_id: 1,
818 data_time_stamp: Default::default(),
819 attribute: vec![],
820 descriptor: Some("ndarray".to_string()),
821 alarm: None,
822 time_stamp: None,
823 display: None,
824 },
825 inp: None,
826 out: None,
827 omsl: crate::types::OutputMode::Supervisory,
828 },
829 raw_fields: HashMap::new(),
830 }
831 }
832
833 #[tokio::test]
834 async fn has_pv_returns_true_for_existing() {
835 let mut records = HashMap::new();
836 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
837 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
838 assert!(store.has_pv("TEST:AI").await);
839 assert!(!store.has_pv("MISSING").await);
840 }
841
842 #[tokio::test]
843 async fn get_snapshot_returns_payload() {
844 let mut records = HashMap::new();
845 records.insert("TEST:AI".into(), make_ai("TEST:AI", 42.0));
846 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
847 let snap = store.get_snapshot("TEST:AI").await.unwrap();
848 match snap {
849 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(42.0)),
850 _ => panic!("expected scalar"),
851 }
852 }
853
854 #[tokio::test]
855 async fn put_value_updates_writable_record() {
856 let mut records = HashMap::new();
857 records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
858 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
859
860 let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(99.5))]);
861 let result = store.put_value("TEST:AO", &val).await.unwrap();
862 assert_eq!(result.len(), 1);
863 assert_eq!(result[0].0, "TEST:AO");
864
865 let snap = store.get_snapshot("TEST:AO").await.unwrap();
866 match snap {
867 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(99.5)),
868 _ => panic!("expected scalar"),
869 }
870 }
871
872 #[tokio::test]
873 async fn put_value_rejects_readonly() {
874 let mut records = HashMap::new();
875 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
876 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
877
878 let val = DecodedValue::Float64(5.0);
879 let err = store.put_value("TEST:AI", &val).await.unwrap_err();
880 assert!(err.contains("not writable"));
881 }
882
883 #[tokio::test]
884 async fn set_value_bypasses_writable_check() {
885 let mut records = HashMap::new();
886 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
887 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
888 assert!(store.set_value("TEST:AI", ScalarValue::F64(10.0)).await);
889 let val = store.get_value("TEST:AI").await.unwrap();
890 assert_eq!(val, ScalarValue::F64(10.0));
891 }
892
893 #[tokio::test]
894 async fn set_array_value_updates_all_scalar_array_types() {
895 let cases: Vec<ScalarArrayValue> = vec![
896 ScalarArrayValue::Bool(vec![false, true]),
897 ScalarArrayValue::I8(vec![1, 2]),
898 ScalarArrayValue::I16(vec![1, 2]),
899 ScalarArrayValue::I32(vec![1, 2]),
900 ScalarArrayValue::I64(vec![1, 2]),
901 ScalarArrayValue::U8(vec![1, 2]),
902 ScalarArrayValue::U16(vec![1, 2]),
903 ScalarArrayValue::U32(vec![1, 2]),
904 ScalarArrayValue::U64(vec![1, 2]),
905 ScalarArrayValue::F32(vec![1.0, 2.0]),
906 ScalarArrayValue::F64(vec![1.0, 2.0]),
907 ScalarArrayValue::Str(vec!["a".to_string(), "b".to_string()]),
908 ];
909
910 for (idx, updated) in cases.into_iter().enumerate() {
911 let pv = format!("TEST:WF:{idx}");
912 let mut records = HashMap::new();
913 records.insert(pv.clone(), make_waveform(&pv, updated.clone()));
914 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
915
916 assert!(!store.set_array_value(&pv, updated.clone()).await);
917
918 let second = match updated {
919 ScalarArrayValue::Bool(_) => ScalarArrayValue::Bool(vec![true, false]),
920 ScalarArrayValue::I8(_) => ScalarArrayValue::I8(vec![3, 4]),
921 ScalarArrayValue::I16(_) => ScalarArrayValue::I16(vec![3, 4]),
922 ScalarArrayValue::I32(_) => ScalarArrayValue::I32(vec![3, 4]),
923 ScalarArrayValue::I64(_) => ScalarArrayValue::I64(vec![3, 4]),
924 ScalarArrayValue::U8(_) => ScalarArrayValue::U8(vec![3, 4]),
925 ScalarArrayValue::U16(_) => ScalarArrayValue::U16(vec![3, 4]),
926 ScalarArrayValue::U32(_) => ScalarArrayValue::U32(vec![3, 4]),
927 ScalarArrayValue::U64(_) => ScalarArrayValue::U64(vec![3, 4]),
928 ScalarArrayValue::F32(_) => ScalarArrayValue::F32(vec![3.0, 4.0]),
929 ScalarArrayValue::F64(_) => ScalarArrayValue::F64(vec![3.0, 4.0]),
930 ScalarArrayValue::Str(_) => {
931 ScalarArrayValue::Str(vec!["x".to_string(), "y".to_string()])
932 }
933 };
934
935 assert!(store.set_array_value(&pv, second.clone()).await);
936 let snap = store.get_snapshot(&pv).await.unwrap();
937 match snap {
938 NtPayload::ScalarArray(nt) => assert_eq!(nt.value, second),
939 _ => panic!("expected scalar array"),
940 }
941 }
942 }
943
944 #[tokio::test]
945 async fn get_nt_returns_full_payload() {
946 let mut records = HashMap::new();
947 records.insert("TEST:AI".into(), make_ai("TEST:AI", 12.5));
948 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
949
950 let nt = store.get_nt("TEST:AI").await.unwrap();
951 match nt {
952 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(12.5)),
953 _ => panic!("expected scalar payload"),
954 }
955 }
956
957 #[tokio::test]
958 async fn put_nt_updates_scalar_array_table_and_ndarray() {
959 let mut records = HashMap::new();
960 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
961 records.insert(
962 "TEST:WF".into(),
963 make_waveform("TEST:WF", ScalarArrayValue::F64(vec![0.0, 0.0])),
964 );
965 records.insert("TEST:TBL".into(), make_nt_table("TEST:TBL"));
966 records.insert("TEST:NDA".into(), make_nt_ndarray("TEST:NDA"));
967 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
968
969 assert!(
970 store
971 .put_nt(
972 "TEST:AI",
973 NtPayload::Scalar(NtScalar::from_value(ScalarValue::F64(5.0))),
974 )
975 .await
976 );
977 assert!(
978 store
979 .put_nt(
980 "TEST:WF",
981 NtPayload::ScalarArray(NtScalarArray::from_value(ScalarArrayValue::F64(vec![
982 3.0, 4.0
983 ],))),
984 )
985 .await
986 );
987
988 let table = NtTable {
989 labels: vec!["X".to_string(), "Y".to_string()],
990 columns: vec![
991 NtTableColumn {
992 name: "x".to_string(),
993 values: ScalarArrayValue::F64(vec![2.0, 3.0]),
994 },
995 NtTableColumn {
996 name: "y".to_string(),
997 values: ScalarArrayValue::F64(vec![20.0, 30.0]),
998 },
999 ],
1000 descriptor: Some("updated table".to_string()),
1001 alarm: None,
1002 time_stamp: None,
1003 };
1004 assert!(
1005 store
1006 .put_nt("TEST:TBL", NtPayload::Table(table.clone()))
1007 .await
1008 );
1009
1010 let ndarray = NtNdArray {
1011 value: ScalarArrayValue::U8(vec![1, 2, 3, 4]),
1012 codec: NdCodec {
1013 name: "none".to_string(),
1014 parameters: HashMap::new(),
1015 },
1016 compressed_size: 4,
1017 uncompressed_size: 4,
1018 dimension: vec![NdDimension {
1019 size: 4,
1020 offset: 0,
1021 full_size: 4,
1022 binning: 1,
1023 reverse: false,
1024 }],
1025 unique_id: 2,
1026 data_time_stamp: Default::default(),
1027 attribute: vec![],
1028 descriptor: Some("updated ndarray".to_string()),
1029 alarm: None,
1030 time_stamp: None,
1031 display: None,
1032 };
1033 assert!(
1034 store
1035 .put_nt("TEST:NDA", NtPayload::NdArray(ndarray.clone()))
1036 .await
1037 );
1038
1039 assert!(
1040 !store
1041 .put_nt(
1042 "TEST:AI",
1043 NtPayload::ScalarArray(NtScalarArray::from_value(ScalarArrayValue::F64(vec![
1044 1.0
1045 ]))),
1046 )
1047 .await
1048 );
1049
1050 match store.get_nt("TEST:TBL").await.unwrap() {
1051 NtPayload::Table(nt) => assert_eq!(nt, table),
1052 _ => panic!("expected table payload"),
1053 }
1054 match store.get_nt("TEST:NDA").await.unwrap() {
1055 NtPayload::NdArray(nt) => assert_eq!(nt, ndarray),
1056 _ => panic!("expected ndarray payload"),
1057 }
1058 }
1059
1060 #[tokio::test]
1061 async fn descriptor_matches_value_type() {
1062 let mut records = HashMap::new();
1063 records.insert("TEST:AI".into(), make_ai("TEST:AI", 0.0));
1064 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
1065 let desc = store.get_descriptor("TEST:AI").await.unwrap();
1066 assert_eq!(desc.struct_id.as_deref(), Some("epics:nt/NTScalar:1.0"));
1067 let value_field = desc.field("value").unwrap();
1068 assert!(matches!(
1069 value_field.field_type,
1070 FieldType::Scalar(TypeCode::Float64)
1071 ));
1072 }
1073
1074 #[tokio::test]
1075 async fn subscribe_receives_updates() {
1076 let mut records = HashMap::new();
1077 records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
1078 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
1079
1080 let mut rx = store.subscribe("TEST:AO").await.unwrap();
1081
1082 let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(7.7))]);
1083 store.put_value("TEST:AO", &val).await.unwrap();
1084
1085 let update = rx.recv().await.unwrap();
1086 match update {
1087 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(7.7)),
1088 _ => panic!("expected scalar"),
1089 }
1090 }
1091
1092 #[tokio::test]
1093 async fn on_put_callback_is_invoked() {
1094 use std::sync::atomic::{AtomicBool, Ordering};
1095
1096 let called = Arc::new(AtomicBool::new(false));
1097 let called2 = called.clone();
1098
1099 let mut records = HashMap::new();
1100 records.insert("CB:AO".into(), make_ao("CB:AO", 0.0));
1101
1102 let mut on_put = HashMap::new();
1103 let cb: OnPutCallback = Arc::new(move |_name, _val| {
1104 called2.store(true, Ordering::SeqCst);
1105 });
1106 on_put.insert("CB:AO".into(), cb);
1107
1108 let store = SimplePvStore::new(records, on_put, vec![], false);
1109 let val = DecodedValue::Structure(vec![("value".to_string(), DecodedValue::Float64(1.0))]);
1110 store.put_value("CB:AO", &val).await.unwrap();
1111
1112 tokio::task::yield_now().await;
1114 tokio::task::yield_now().await;
1115
1116 assert!(called.load(Ordering::SeqCst));
1117 }
1118}