1use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use tokio::sync::{mpsc, RwLock};
10use tracing::debug;
11
12use spvirit_codec::spvd_decode::{
13 DecodedValue, FieldDesc, FieldType, StructureDesc, TypeCode,
14};
15use spvirit_types::{NtPayload, ScalarArrayValue, ScalarValue};
16
17use crate::apply::{
18 apply_alarm_update, apply_control_update, apply_display_update, apply_scalar_array_put,
19 apply_value_update,
20};
21use crate::monitor::MonitorRegistry;
22use crate::pvstore::PvStore;
23use crate::types::{RecordData, RecordInstance};
24
25pub type OnPutCallback =
27 Arc<dyn Fn(&str, &DecodedValue) + Send + Sync>;
28
29pub type ScanCallback =
31 Arc<dyn Fn(&str) -> ScalarValue + Send + Sync>;
32
33pub type LinkCallback =
35 Arc<dyn Fn(&[ScalarValue]) -> ScalarValue + Send + Sync>;
36
37pub(crate) struct LinkDef {
39 pub output: String,
40 pub inputs: Vec<String>,
41 pub compute: LinkCallback,
42}
43
44struct PvEntry {
45 record: RecordInstance,
46 subscribers: Vec<mpsc::Sender<NtPayload>>,
47}
48
49pub struct SimplePvStore {
51 pvs: RwLock<HashMap<String, PvEntry>>,
52 on_put: HashMap<String, OnPutCallback>,
53 links: Vec<LinkDef>,
54 compute_alarms: bool,
55 registry: RwLock<Option<Arc<MonitorRegistry>>>,
56}
57
58impl SimplePvStore {
59 pub(crate) fn new(
60 records: HashMap<String, RecordInstance>,
61 on_put: HashMap<String, OnPutCallback>,
62 links: Vec<LinkDef>,
63 compute_alarms: bool,
64 ) -> Self {
65 let pvs = records
66 .into_iter()
67 .map(|(name, record)| {
68 (
69 name,
70 PvEntry {
71 record,
72 subscribers: Vec::new(),
73 },
74 )
75 })
76 .collect();
77 Self {
78 pvs: RwLock::new(pvs),
79 on_put,
80 links,
81 compute_alarms,
82 registry: RwLock::new(None),
83 }
84 }
85
86 pub async fn set_registry(&self, registry: Arc<MonitorRegistry>) {
89 *self.registry.write().await = Some(registry);
90 }
91
92 pub async fn insert(&self, name: String, record: RecordInstance) {
94 let mut pvs = self.pvs.write().await;
95 pvs.insert(
96 name,
97 PvEntry {
98 record,
99 subscribers: Vec::new(),
100 },
101 );
102 }
103
104 pub async fn get_value(&self, name: &str) -> Option<ScalarValue> {
106 let pvs = self.pvs.read().await;
107 pvs.get(name).map(|e| e.record.current_value())
108 }
109
110 pub async fn set_value(&self, name: &str, value: ScalarValue) -> bool {
112 if self.set_value_inner(name, value).await {
113 self.evaluate_links(name).await;
114 true
115 } else {
116 false
117 }
118 }
119
120 async fn set_value_inner(&self, name: &str, value: ScalarValue) -> bool {
123 let payload = {
124 let mut pvs = self.pvs.write().await;
125 if let Some(entry) = pvs.get_mut(name) {
126 let changed = entry.record.set_scalar_value(value, self.compute_alarms);
127 if changed {
128 let payload = entry.record.to_ntpayload();
129 entry.subscribers.retain(|tx| tx.try_send(payload.clone()).is_ok());
130 Some(payload)
131 } else {
132 None
133 }
134 } else {
135 return false;
136 }
137 };
138
139 if let Some(payload) = payload {
140 let reg = self.registry.read().await;
142 if let Some(registry) = reg.as_ref() {
143 registry.notify_monitors(name, &payload).await;
144 }
145 true
146 } else {
147 false
148 }
149 }
150
151 async fn evaluate_links(&self, changed_pv: &str) {
154 if self.links.is_empty() {
155 return;
156 }
157 let mut queue = vec![changed_pv.to_string()];
158 let mut visited = HashSet::new();
159
160 while let Some(pv) = queue.pop() {
161 if !visited.insert(pv.clone()) {
162 debug!("Circular link detected for PV '{}', skipping", pv);
163 continue;
164 }
165 for link in &self.links {
166 if !link.inputs.iter().any(|i| i == &pv) {
167 continue;
168 }
169 let values = {
171 let pvs = self.pvs.read().await;
172 link.inputs
173 .iter()
174 .map(|n| {
175 pvs.get(n)
176 .map(|e| e.record.current_value())
177 .unwrap_or(ScalarValue::F64(0.0))
178 })
179 .collect::<Vec<_>>()
180 };
181 let new_val = (link.compute)(&values);
182 if self.set_value_inner(&link.output, new_val).await {
183 queue.push(link.output.clone());
184 }
185 }
186 }
187 }
188
189 pub async fn pv_names(&self) -> Vec<String> {
191 let pvs = self.pvs.read().await;
192 pvs.keys().cloned().collect()
193 }
194}
195
196impl PvStore for SimplePvStore {
197 fn has_pv(&self, name: &str) -> impl Future<Output = bool> + Send {
198 async move {
199 let pvs = self.pvs.read().await;
200 pvs.contains_key(name)
201 }
202 }
203
204 fn get_snapshot(&self, name: &str) -> impl Future<Output = Option<NtPayload>> + Send {
205 async move {
206 let pvs = self.pvs.read().await;
207 pvs.get(name).map(|e| e.record.to_ntpayload())
208 }
209 }
210
211 fn get_descriptor(&self, name: &str) -> impl Future<Output = Option<StructureDesc>> + Send {
212 async move {
213 let pvs = self.pvs.read().await;
214 pvs.get(name)
215 .map(|e| descriptor_for_payload(&e.record.to_ntpayload()))
216 }
217 }
218
219 fn put_value(
220 &self,
221 name: &str,
222 value: &DecodedValue,
223 ) -> impl Future<Output = Result<Vec<(String, NtPayload)>, String>> + Send {
224 let name = name.to_string();
225 let value = value.clone();
226 async move {
227 let result = {
228 let mut pvs = self.pvs.write().await;
229 let entry = pvs
230 .get_mut(&name)
231 .ok_or_else(|| format!("PV '{}' not found", name))?;
232
233 if !entry.record.writable() {
234 return Err(format!("PV '{}' is not writable", name));
235 }
236
237 let changed = apply_put_to_record(&mut entry.record, &value, self.compute_alarms);
238 if !changed {
239 return Ok(vec![]);
240 }
241
242 let payload = entry.record.to_ntpayload();
243 entry
244 .subscribers
245 .retain(|tx| tx.try_send(payload.clone()).is_ok());
246
247 (name.clone(), payload)
248 }; if let Some(cb) = self.on_put.get(&name) {
252 let cb = cb.clone();
253 let n = name.clone();
254 let v = value.clone();
255 tokio::spawn(async move { cb(&n, &v) });
256 }
257
258 self.evaluate_links(&name).await;
260
261 Ok(vec![result])
262 }
263 }
264
265 fn is_writable(&self, name: &str) -> impl Future<Output = bool> + Send {
266 async move {
267 let pvs = self.pvs.read().await;
268 pvs.get(name).is_some_and(|e| e.record.writable())
269 }
270 }
271
272 fn list_pvs(&self) -> impl Future<Output = Vec<String>> + Send {
273 async move {
274 let pvs = self.pvs.read().await;
275 pvs.keys().cloned().collect()
276 }
277 }
278
279 fn subscribe(
280 &self,
281 name: &str,
282 ) -> impl Future<Output = Option<mpsc::Receiver<NtPayload>>> + Send {
283 let name = name.to_string();
284 async move {
285 let mut pvs = self.pvs.write().await;
286 let entry = pvs.get_mut(&name)?;
287 let (tx, rx) = mpsc::channel(64);
288 entry.subscribers.push(tx);
289 Some(rx)
290 }
291 }
292}
293
294fn apply_put_to_record(
298 record: &mut RecordInstance,
299 value: &DecodedValue,
300 compute_alarms: bool,
301) -> bool {
302 let fields = match value {
303 DecodedValue::Structure(f) => f,
304 other => {
305 return apply_put_to_record(
307 record,
308 &DecodedValue::Structure(vec![("value".to_string(), other.clone())]),
309 compute_alarms,
310 );
311 }
312 };
313
314 let mut changed = false;
315
316 match &mut record.data {
317 RecordData::Ai { nt, .. }
318 | RecordData::Ao { nt, .. }
319 | RecordData::Bi { nt, .. }
320 | RecordData::Bo { nt, .. }
321 | RecordData::StringIn { nt, .. }
322 | RecordData::StringOut { nt, .. } => {
323 for (name, val) in fields {
324 match name.as_str() {
325 "value" => {
326 changed |= apply_value_update(nt, val, compute_alarms);
327 }
328 "alarm" => {
329 changed |= apply_alarm_update(nt, val);
330 }
331 "display" => {
332 changed |= apply_display_update(nt, val);
333 }
334 "control" => {
335 changed |= apply_control_update(nt, val);
336 }
337 _ => {}
338 }
339 }
340 }
341 RecordData::Waveform { nt, nord, .. }
342 | RecordData::Aai { nt, nord, .. }
343 | RecordData::Aao { nt, nord, .. }
344 | RecordData::SubArray { nt, nord, .. } => {
345 changed = apply_scalar_array_put(nt, nord, value);
346 }
347 RecordData::NtTable { .. } | RecordData::NtNdArray { .. } => {
348 debug!("PUT to NtTable/NtNdArray not yet supported in SimplePvStore");
350 }
351 }
352
353 changed
354}
355
356pub(crate) fn descriptor_for_payload(payload: &NtPayload) -> StructureDesc {
359 match payload {
360 NtPayload::Scalar(nt) => nt_scalar_desc(&nt.value),
361 NtPayload::ScalarArray(arr) => nt_scalar_array_desc(&arr.value),
362 _ => StructureDesc::new(),
363 }
364}
365
366fn value_type_code(sv: &ScalarValue) -> TypeCode {
367 match sv {
368 ScalarValue::Bool(_) => TypeCode::Boolean,
369 ScalarValue::I8(_) => TypeCode::Int8,
370 ScalarValue::I16(_) => TypeCode::Int16,
371 ScalarValue::I32(_) => TypeCode::Int32,
372 ScalarValue::I64(_) => TypeCode::Int64,
373 ScalarValue::U8(_) => TypeCode::UInt8,
374 ScalarValue::U16(_) => TypeCode::UInt16,
375 ScalarValue::U32(_) => TypeCode::UInt32,
376 ScalarValue::U64(_) => TypeCode::UInt64,
377 ScalarValue::F32(_) => TypeCode::Float32,
378 ScalarValue::F64(_) => TypeCode::Float64,
379 ScalarValue::Str(_) => TypeCode::String,
380 }
381}
382
383fn array_type_code(sav: &ScalarArrayValue) -> TypeCode {
384 match sav {
385 ScalarArrayValue::Bool(_) => TypeCode::Boolean,
386 ScalarArrayValue::I8(_) => TypeCode::Int8,
387 ScalarArrayValue::I16(_) => TypeCode::Int16,
388 ScalarArrayValue::I32(_) => TypeCode::Int32,
389 ScalarArrayValue::I64(_) => TypeCode::Int64,
390 ScalarArrayValue::U8(_) => TypeCode::UInt8,
391 ScalarArrayValue::U16(_) => TypeCode::UInt16,
392 ScalarArrayValue::U32(_) => TypeCode::UInt32,
393 ScalarArrayValue::U64(_) => TypeCode::UInt64,
394 ScalarArrayValue::F32(_) => TypeCode::Float32,
395 ScalarArrayValue::F64(_) => TypeCode::Float64,
396 ScalarArrayValue::Str(_) => TypeCode::String,
397 }
398}
399
400fn nt_scalar_desc(sv: &ScalarValue) -> StructureDesc {
401 let tc = value_type_code(sv);
402 StructureDesc {
403 struct_id: Some("epics:nt/NTScalar:1.0".to_string()),
404 fields: vec![
405 FieldDesc { name: "value".to_string(), field_type: FieldType::Scalar(tc) },
406 FieldDesc { name: "alarm".to_string(), field_type: FieldType::Structure(alarm_desc()) },
407 FieldDesc { name: "timeStamp".to_string(), field_type: FieldType::Structure(timestamp_desc()) },
408 FieldDesc { name: "display".to_string(), field_type: FieldType::Structure(display_desc()) },
409 FieldDesc { name: "control".to_string(), field_type: FieldType::Structure(control_desc()) },
410 FieldDesc { name: "valueAlarm".to_string(), field_type: FieldType::Structure(value_alarm_desc()) },
411 ],
412 }
413}
414
415fn nt_scalar_array_desc(sav: &ScalarArrayValue) -> StructureDesc {
416 let tc = array_type_code(sav);
417 StructureDesc {
418 struct_id: Some("epics:nt/NTScalarArray:1.0".to_string()),
419 fields: vec![
420 FieldDesc { name: "value".to_string(), field_type: FieldType::ScalarArray(tc) },
421 FieldDesc { name: "alarm".to_string(), field_type: FieldType::Structure(alarm_desc()) },
422 FieldDesc { name: "timeStamp".to_string(), field_type: FieldType::Structure(timestamp_desc()) },
423 FieldDesc { name: "display".to_string(), field_type: FieldType::Structure(display_desc()) },
424 FieldDesc { name: "control".to_string(), field_type: FieldType::Structure(control_desc()) },
425 ],
426 }
427}
428
429fn alarm_desc() -> StructureDesc {
430 StructureDesc {
431 struct_id: Some("alarm_t".to_string()),
432 fields: vec![
433 FieldDesc { name: "severity".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
434 FieldDesc { name: "status".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
435 FieldDesc { name: "message".to_string(), field_type: FieldType::String },
436 ],
437 }
438}
439
440fn timestamp_desc() -> StructureDesc {
441 StructureDesc {
442 struct_id: Some("time_t".to_string()),
443 fields: vec![
444 FieldDesc { name: "secondsPastEpoch".to_string(), field_type: FieldType::Scalar(TypeCode::Int64) },
445 FieldDesc { name: "nanoseconds".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
446 FieldDesc { name: "userTag".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
447 ],
448 }
449}
450
451fn display_desc() -> StructureDesc {
452 StructureDesc {
453 struct_id: Some("display_t".to_string()),
454 fields: vec![
455 FieldDesc { name: "limitLow".to_string(), field_type: FieldType::Scalar(TypeCode::Float64) },
456 FieldDesc { name: "limitHigh".to_string(), field_type: FieldType::Scalar(TypeCode::Float64) },
457 FieldDesc { name: "description".to_string(), field_type: FieldType::String },
458 FieldDesc { name: "units".to_string(), field_type: FieldType::String },
459 FieldDesc { name: "precision".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
460 FieldDesc {
461 name: "form".to_string(),
462 field_type: FieldType::Structure(StructureDesc {
463 struct_id: Some("enum_t".to_string()),
464 fields: vec![
465 FieldDesc { name: "index".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
466 FieldDesc { name: "choices".to_string(), field_type: FieldType::StringArray },
467 ],
468 }),
469 },
470 ],
471 }
472}
473
474fn control_desc() -> StructureDesc {
475 StructureDesc {
476 struct_id: Some("control_t".to_string()),
477 fields: vec![
478 FieldDesc { name: "limitLow".to_string(), field_type: FieldType::Scalar(TypeCode::Float64) },
479 FieldDesc { name: "limitHigh".to_string(), field_type: FieldType::Scalar(TypeCode::Float64) },
480 FieldDesc { name: "minStep".to_string(), field_type: FieldType::Scalar(TypeCode::Float64) },
481 ],
482 }
483}
484
485fn value_alarm_desc() -> StructureDesc {
486 StructureDesc {
487 struct_id: Some("valueAlarm_t".to_string()),
488 fields: vec![
489 FieldDesc { name: "active".to_string(), field_type: FieldType::Scalar(TypeCode::Boolean) },
490 FieldDesc { name: "lowAlarmLimit".to_string(), field_type: FieldType::Scalar(TypeCode::Float64) },
491 FieldDesc { name: "lowWarningLimit".to_string(), field_type: FieldType::Scalar(TypeCode::Float64) },
492 FieldDesc { name: "highWarningLimit".to_string(), field_type: FieldType::Scalar(TypeCode::Float64) },
493 FieldDesc { name: "highAlarmLimit".to_string(), field_type: FieldType::Scalar(TypeCode::Float64) },
494 FieldDesc { name: "lowAlarmSeverity".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
495 FieldDesc { name: "lowWarningSeverity".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
496 FieldDesc { name: "highWarningSeverity".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
497 FieldDesc { name: "highAlarmSeverity".to_string(), field_type: FieldType::Scalar(TypeCode::Int32) },
498 FieldDesc { name: "hysteresis".to_string(), field_type: FieldType::Scalar(TypeCode::UInt8) },
499 ],
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::types::{DbCommonState, RecordType};
507 use spvirit_types::NtScalar;
508
509 fn make_ai(name: &str, val: f64) -> RecordInstance {
510 RecordInstance {
511 name: name.to_string(),
512 record_type: RecordType::Ai,
513 common: DbCommonState::default(),
514 data: RecordData::Ai {
515 nt: NtScalar::from_value(ScalarValue::F64(val)),
516 inp: None,
517 siml: None,
518 siol: None,
519 simm: false,
520 },
521 raw_fields: HashMap::new(),
522 }
523 }
524
525 fn make_ao(name: &str, val: f64) -> RecordInstance {
526 RecordInstance {
527 name: name.to_string(),
528 record_type: RecordType::Ao,
529 common: DbCommonState::default(),
530 data: RecordData::Ao {
531 nt: NtScalar::from_value(ScalarValue::F64(val)),
532 out: None,
533 dol: None,
534 omsl: crate::types::OutputMode::Supervisory,
535 drvl: None,
536 drvh: None,
537 oroc: None,
538 siml: None,
539 siol: None,
540 simm: false,
541 },
542 raw_fields: HashMap::new(),
543 }
544 }
545
546 #[tokio::test]
547 async fn has_pv_returns_true_for_existing() {
548 let mut records = HashMap::new();
549 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
550 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
551 assert!(store.has_pv("TEST:AI").await);
552 assert!(!store.has_pv("MISSING").await);
553 }
554
555 #[tokio::test]
556 async fn get_snapshot_returns_payload() {
557 let mut records = HashMap::new();
558 records.insert("TEST:AI".into(), make_ai("TEST:AI", 42.0));
559 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
560 let snap = store.get_snapshot("TEST:AI").await.unwrap();
561 match snap {
562 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(42.0)),
563 _ => panic!("expected scalar"),
564 }
565 }
566
567 #[tokio::test]
568 async fn put_value_updates_writable_record() {
569 let mut records = HashMap::new();
570 records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
571 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
572
573 let val = DecodedValue::Structure(vec![
574 ("value".to_string(), DecodedValue::Float64(99.5)),
575 ]);
576 let result = store.put_value("TEST:AO", &val).await.unwrap();
577 assert_eq!(result.len(), 1);
578 assert_eq!(result[0].0, "TEST:AO");
579
580 let snap = store.get_snapshot("TEST:AO").await.unwrap();
581 match snap {
582 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(99.5)),
583 _ => panic!("expected scalar"),
584 }
585 }
586
587 #[tokio::test]
588 async fn put_value_rejects_readonly() {
589 let mut records = HashMap::new();
590 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
591 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
592
593 let val = DecodedValue::Float64(5.0);
594 let err = store.put_value("TEST:AI", &val).await.unwrap_err();
595 assert!(err.contains("not writable"));
596 }
597
598 #[tokio::test]
599 async fn set_value_bypasses_writable_check() {
600 let mut records = HashMap::new();
601 records.insert("TEST:AI".into(), make_ai("TEST:AI", 1.0));
602 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
603 assert!(store.set_value("TEST:AI", ScalarValue::F64(10.0)).await);
604 let val = store.get_value("TEST:AI").await.unwrap();
605 assert_eq!(val, ScalarValue::F64(10.0));
606 }
607
608 #[tokio::test]
609 async fn descriptor_matches_value_type() {
610 let mut records = HashMap::new();
611 records.insert("TEST:AI".into(), make_ai("TEST:AI", 0.0));
612 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
613 let desc = store.get_descriptor("TEST:AI").await.unwrap();
614 assert_eq!(desc.struct_id.as_deref(), Some("epics:nt/NTScalar:1.0"));
615 let value_field = desc.field("value").unwrap();
616 assert!(matches!(value_field.field_type, FieldType::Scalar(TypeCode::Float64)));
617 }
618
619 #[tokio::test]
620 async fn subscribe_receives_updates() {
621 let mut records = HashMap::new();
622 records.insert("TEST:AO".into(), make_ao("TEST:AO", 0.0));
623 let store = SimplePvStore::new(records, HashMap::new(), vec![], false);
624
625 let mut rx = store.subscribe("TEST:AO").await.unwrap();
626
627 let val = DecodedValue::Structure(vec![
628 ("value".to_string(), DecodedValue::Float64(7.7)),
629 ]);
630 store.put_value("TEST:AO", &val).await.unwrap();
631
632 let update = rx.recv().await.unwrap();
633 match update {
634 NtPayload::Scalar(nt) => assert_eq!(nt.value, ScalarValue::F64(7.7)),
635 _ => panic!("expected scalar"),
636 }
637 }
638
639 #[tokio::test]
640 async fn on_put_callback_is_invoked() {
641 use std::sync::atomic::{AtomicBool, Ordering};
642
643 let called = Arc::new(AtomicBool::new(false));
644 let called2 = called.clone();
645
646 let mut records = HashMap::new();
647 records.insert("CB:AO".into(), make_ao("CB:AO", 0.0));
648
649 let mut on_put = HashMap::new();
650 let cb: OnPutCallback = Arc::new(move |_name, _val| {
651 called2.store(true, Ordering::SeqCst);
652 });
653 on_put.insert("CB:AO".into(), cb);
654
655 let store = SimplePvStore::new(records, on_put, vec![], false);
656 let val = DecodedValue::Structure(vec![
657 ("value".to_string(), DecodedValue::Float64(1.0)),
658 ]);
659 store.put_value("CB:AO", &val).await.unwrap();
660
661 tokio::task::yield_now().await;
663 tokio::task::yield_now().await;
664
665 assert!(called.load(Ordering::SeqCst));
666 }
667}