Skip to main content

spvirit_client/
client.rs

1use tokio::io::AsyncWriteExt;
2use tokio::net::TcpStream;
3use tokio::time::timeout;
4
5use crate::auth::{resolved_authnz_host, resolved_authnz_user};
6use crate::search::resolve_pv_server;
7use crate::transport::{read_packet, read_until};
8use crate::types::{PvGetError, PvGetOptions, PvGetResult};
9use spvirit_codec::epics_decode::{
10    decode_op_response_status as codec_decode_op_response_status, PvaPacket, PvaPacketCommand,
11};
12pub use spvirit_codec::spvirit_encode::{
13    encode_create_channel_request, encode_get_field_request, encode_get_request,
14    encode_monitor_request, encode_put_request,
15};
16use spvirit_codec::spvirit_encode::encode_client_connection_validation;
17
18pub fn build_client_validation(
19    opts: &crate::types::PvGetOptions,
20    version: u8,
21    is_be: bool,
22) -> Vec<u8> {
23    let user = resolved_authnz_user(opts);
24    let host = resolved_authnz_host(opts);
25    encode_client_connection_validation(87_040, 32_767, 0, "ca", &user, &host, version, is_be)
26}
27
28pub fn op_response_status(
29    raw: &[u8],
30    is_be: bool,
31) -> Result<Option<spvirit_codec::epics_decode::PvaStatus>, PvGetError> {
32    codec_decode_op_response_status(raw, is_be).map_err(PvGetError::Protocol)
33}
34
35pub fn ensure_status_ok(raw: &[u8], is_be: bool, step: &str) -> Result<(), PvGetError> {
36    match op_response_status(raw, is_be)? {
37        None => Ok(()),
38        Some(st) if st.code == 0 => Ok(()),
39        Some(st) => Err(PvGetError::Protocol(format!(
40            "{} failed: {}",
41            step,
42            st.message.unwrap_or_else(|| format!("code={}", st.code))
43        ))),
44    }
45}
46
47pub struct ChannelConn {
48    pub stream: TcpStream,
49    pub sid: u32,
50    pub version: u8,
51    pub is_be: bool,
52}
53
54pub async fn establish_channel(
55    target: std::net::SocketAddr,
56    opts: &PvGetOptions,
57) -> Result<ChannelConn, PvGetError> {
58    let mut stream = timeout(opts.timeout, TcpStream::connect(target))
59        .await
60        .map_err(|_| PvGetError::Timeout("connect"))??;
61
62    let mut version = 2u8;
63    let mut is_be = false;
64
65    for _ in 0..2 {
66        if let Ok(bytes) = read_packet(&mut stream, opts.timeout).await {
67            let mut pkt = PvaPacket::new(&bytes);
68            if let Some(cmd) = pkt.decode_payload() {
69                match cmd {
70                    PvaPacketCommand::Control(payload) => {
71                        if payload.command == 2 {
72                            is_be = pkt.header.flags.is_msb;
73                        }
74                    }
75                    PvaPacketCommand::ConnectionValidation(_) => {
76                        version = pkt.header.version;
77                        is_be = pkt.header.flags.is_msb;
78                    }
79                    _ => {}
80                }
81            }
82        }
83    }
84
85    let validation = build_client_validation(opts, version, is_be);
86    stream.write_all(&validation).await?;
87
88    let _ = read_until(&mut stream, opts.timeout, |cmd| {
89        matches!(cmd, PvaPacketCommand::ConnectionValidated(_))
90    })
91    .await?;
92
93    let cid = 1u32;
94    let create = encode_create_channel_request(cid, &opts.pv_name, version, is_be);
95    stream.write_all(&create).await?;
96
97    let create_resp = read_until(&mut stream, opts.timeout, |cmd| {
98        matches!(cmd, PvaPacketCommand::CreateChannel(_))
99    })
100    .await?;
101    let mut pkt = PvaPacket::new(&create_resp);
102    let cmd = pkt.decode_payload().ok_or(PvGetError::Protocol(
103        "create_channel decode failed".to_string(),
104    ))?;
105    let sid = match cmd {
106        PvaPacketCommand::CreateChannel(payload) => {
107            if payload.status.as_ref().is_some_and(|s| s.is_error()) {
108                let detail = payload
109                    .status
110                    .as_ref()
111                    .map(ToString::to_string)
112                    .unwrap_or_default();
113                return Err(PvGetError::Protocol(format!(
114                    "create_channel error: {}",
115                    detail
116                )));
117            }
118            payload.sid
119        }
120        _ => {
121            return Err(PvGetError::Protocol(
122                "unexpected create_channel response".to_string(),
123            ))
124        }
125    };
126
127    Ok(ChannelConn {
128        stream,
129        sid,
130        version,
131        is_be,
132    })
133}
134
135pub async fn pvget(opts: &PvGetOptions) -> Result<PvGetResult, PvGetError> {
136    let target = resolve_pv_server(opts).await?;
137
138    let conn = establish_channel(target, opts).await?;
139    let ChannelConn {
140        mut stream,
141        sid,
142        version,
143        is_be,
144    } = conn;
145
146    let ioid = 1u32;
147    // Match EPICS pvget GET init payload (extra pvRequest bytes observed in capture).
148    let get_init_req = encode_get_request(
149        sid,
150        ioid,
151        0x08,
152        &[0xfd, 0x02, 0x00, 0x80, 0x00, 0x00],
153        version,
154        is_be,
155    );
156    stream.write_all(&get_init_req).await?;
157
158    let init_resp = read_until(
159        &mut stream,
160        opts.timeout,
161        |cmd| matches!(cmd, PvaPacketCommand::Op(op) if (op.subcmd & 0x08) != 0),
162    )
163    .await?;
164    let mut pkt = PvaPacket::new(&init_resp);
165    let cmd = pkt.decode_payload().ok_or(PvGetError::Protocol(
166        "get init response decode failed".to_string(),
167    ))?;
168
169    let desc = match cmd {
170        PvaPacketCommand::Op(op) => op
171            .introspection
172            .ok_or_else(|| PvGetError::Decode("missing introspection".to_string()))?,
173        _ => {
174            return Err(PvGetError::Protocol(
175                "unexpected get init response".to_string(),
176            ))
177        }
178    };
179
180    let get_data_req = encode_get_request(sid, ioid, 0x00, &[], version, is_be);
181    stream.write_all(&get_data_req).await?;
182
183    let data_resp = read_until(
184        &mut stream,
185        opts.timeout,
186        |cmd| matches!(cmd, PvaPacketCommand::Op(op) if op.subcmd == 0x00),
187    )
188    .await?;
189    let mut pkt = PvaPacket::new(&data_resp);
190    let cmd = pkt.decode_payload().ok_or(PvGetError::Protocol(
191        "get data response decode failed".to_string(),
192    ))?;
193
194    match cmd {
195        PvaPacketCommand::Op(mut op) => {
196            op.decode_with_field_desc(&desc, is_be);
197            if let Some(value) = op.decoded_value {
198                return Ok(PvGetResult {
199                    pv_name: opts.pv_name.clone(),
200                    value,
201                    raw_pva: data_resp,
202                    raw_pvd: op.body,
203                    introspection: desc,
204                });
205            }
206            Err(PvGetError::Decode("no decoded value".to_string()))
207        }
208        _ => Err(PvGetError::Protocol(
209            "unexpected get data response".to_string(),
210        )),
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand, PvaStatus};
218
219    #[test]
220    fn encode_decode_monitor_request_roundtrip() {
221        let msg =
222            encode_monitor_request(1, 2, 0x08, &[0xfd, 0x02, 0x00, 0x80, 0x00, 0x00], 2, false);
223        let mut pkt = PvaPacket::new(&msg);
224        let cmd = pkt.decode_payload().expect("decoded");
225        match cmd {
226            PvaPacketCommand::Op(op) => {
227                assert_eq!(op.command, 13);
228                assert_eq!(op.subcmd, 0x08);
229                assert_eq!(op.sid_or_cid, 1);
230                assert_eq!(op.ioid, 2);
231            }
232            other => panic!("unexpected decode: {:?}", other),
233        }
234    }
235
236    #[test]
237    fn encode_decode_put_init_roundtrip() {
238        let msg = encode_put_request(5, 6, 0x08, &[0xfd, 0x02, 0x00, 0x80, 0x00, 0x00], 2, false);
239        let mut pkt = PvaPacket::new(&msg);
240        let cmd = pkt.decode_payload().expect("decoded");
241        match cmd {
242            PvaPacketCommand::Op(op) => {
243                assert_eq!(op.command, 11);
244                assert_eq!(op.subcmd, 0x08);
245                assert_eq!(op.sid_or_cid, 5);
246                assert_eq!(op.ioid, 6);
247            }
248            other => panic!("unexpected decode: {:?}", other),
249        }
250    }
251
252    #[test]
253    fn encode_decode_get_field_request_roundtrip() {
254        let msg = encode_get_field_request(7, 1, Some("*"), 2, false);
255        let mut pkt = PvaPacket::new(&msg);
256        let cmd = pkt.decode_payload().expect("decoded");
257        match cmd {
258            PvaPacketCommand::GetField(payload) => {
259                assert!(!payload.is_server);
260                assert_eq!(payload.sid, Some(7));
261                assert_eq!(payload.ioid, Some(1));
262                assert_eq!(payload.field_name.as_deref(), Some("*"));
263            }
264            other => panic!("unexpected decode: {:?}", other),
265        }
266    }
267
268    #[test]
269    fn encode_decode_get_field_request_empty_field_roundtrip() {
270        let msg = encode_get_field_request(7, 1, None, 2, false);
271        let mut pkt = PvaPacket::new(&msg);
272        let cmd = pkt.decode_payload().expect("decoded");
273        match cmd {
274            PvaPacketCommand::GetField(payload) => {
275                assert!(!payload.is_server);
276                assert_eq!(payload.sid, Some(7));
277                assert_eq!(payload.ioid, Some(1));
278                assert_eq!(payload.field_name.as_deref(), Some(""));
279            }
280            other => panic!("unexpected decode: {:?}", other),
281        }
282    }
283
284    #[test]
285    fn pva_status_code_zero_is_not_an_error() {
286        let ok = PvaStatus {
287            code: 0,
288            message: None,
289            stack: None,
290        };
291        let err = PvaStatus {
292            code: 1,
293            message: Some("bad".to_string()),
294            stack: None,
295        };
296        assert!(!None::<&PvaStatus>.is_some_and(|s| s.is_error()));
297        assert!(!Some(&ok).is_some_and(|s| s.is_error()));
298        assert!(Some(&err).is_some_and(|s| s.is_error()));
299    }
300}