Skip to main content

spvirit_client/
client.rs

1use tokio::io::AsyncWriteExt;
2use tokio::net::TcpStream;
3use tokio::time::timeout;
4
5use crate::auth::{resolved_authnz_host, resolved_authnz_user};
6use crate::search::resolve_pv_server;
7use crate::transport::{read_packet, read_until};
8use crate::types::{PvGetError, PvGetOptions, PvGetResult};
9use spvirit_codec::spvd_encode::encode_pv_request;
10use spvirit_codec::epics_decode::{
11    PvaPacket, PvaPacketCommand, decode_op_response_status as codec_decode_op_response_status,
12};
13use spvirit_codec::spvirit_encode::encode_client_connection_validation;
14pub use spvirit_codec::spvirit_encode::{
15    encode_create_channel_request, encode_get_field_request, encode_get_request,
16    encode_monitor_request, encode_put_request,
17};
18
19pub fn build_client_validation(
20    opts: &crate::types::PvGetOptions,
21    version: u8,
22    is_be: bool,
23) -> Vec<u8> {
24    let user = resolved_authnz_user(opts);
25    let host = resolved_authnz_host(opts);
26    encode_client_connection_validation(87_040, 32_767, 0, "ca", &user, &host, version, is_be)
27}
28
29pub fn op_response_status(
30    raw: &[u8],
31    is_be: bool,
32) -> Result<Option<spvirit_codec::epics_decode::PvaStatus>, PvGetError> {
33    codec_decode_op_response_status(raw, is_be).map_err(PvGetError::Protocol)
34}
35
36pub fn ensure_status_ok(raw: &[u8], is_be: bool, step: &str) -> Result<(), PvGetError> {
37    match op_response_status(raw, is_be)? {
38        None => Ok(()),
39        Some(st) if st.code == 0 => Ok(()),
40        Some(st) => Err(PvGetError::Protocol(format!(
41            "{} failed: {}",
42            step,
43            st.message.unwrap_or_else(|| format!("code={}", st.code))
44        ))),
45    }
46}
47
48pub struct ChannelConn {
49    pub stream: TcpStream,
50    pub sid: u32,
51    pub version: u8,
52    pub is_be: bool,
53    pub server_addr: std::net::SocketAddr,
54}
55
56pub async fn establish_channel(
57    target: std::net::SocketAddr,
58    opts: &PvGetOptions,
59) -> Result<ChannelConn, PvGetError> {
60    let mut stream = timeout(opts.timeout, TcpStream::connect(target))
61        .await
62        .map_err(|_| PvGetError::Timeout("connect"))??;
63
64    let mut version = 2u8;
65    let mut is_be = false;
66
67    for _ in 0..2 {
68        if let Ok(bytes) = read_packet(&mut stream, opts.timeout).await {
69            let mut pkt = PvaPacket::new(&bytes);
70            if let Some(cmd) = pkt.decode_payload() {
71                match cmd {
72                    PvaPacketCommand::Control(payload) => {
73                        if payload.command == 2 {
74                            is_be = pkt.header.flags.is_msb;
75                        }
76                    }
77                    PvaPacketCommand::ConnectionValidation(_) => {
78                        version = pkt.header.version;
79                        is_be = pkt.header.flags.is_msb;
80                    }
81                    _ => {}
82                }
83            }
84        }
85    }
86
87    let validation = build_client_validation(opts, version, is_be);
88    stream.write_all(&validation).await?;
89
90    let _ = read_until(&mut stream, opts.timeout, |cmd| {
91        matches!(cmd, PvaPacketCommand::ConnectionValidated(_))
92    })
93    .await?;
94
95    let cid = 1u32;
96    let create = encode_create_channel_request(cid, &opts.pv_name, version, is_be);
97    stream.write_all(&create).await?;
98
99    let create_resp = read_until(&mut stream, opts.timeout, |cmd| {
100        matches!(cmd, PvaPacketCommand::CreateChannel(_))
101    })
102    .await?;
103    let mut pkt = PvaPacket::new(&create_resp);
104    let cmd = pkt.decode_payload().ok_or(PvGetError::Protocol(
105        "create_channel decode failed".to_string(),
106    ))?;
107    let sid = match cmd {
108        PvaPacketCommand::CreateChannel(payload) => {
109            if payload.status.as_ref().is_some_and(|s| s.is_error()) {
110                let detail = payload
111                    .status
112                    .as_ref()
113                    .map(ToString::to_string)
114                    .unwrap_or_default();
115                return Err(PvGetError::Protocol(format!(
116                    "create_channel error: {}",
117                    detail
118                )));
119            }
120            payload.sid
121        }
122        _ => {
123            return Err(PvGetError::Protocol(
124                "unexpected create_channel response".to_string(),
125            ));
126        }
127    };
128
129    Ok(ChannelConn {
130        stream,
131        sid,
132        version,
133        is_be,
134        server_addr: target,
135    })
136}
137
138/// Convenience wrapper: GET with no field filtering.
139pub async fn pvget(opts: &PvGetOptions) -> Result<PvGetResult, PvGetError> {
140    pvget_fields(opts, &[]).await
141}
142
143/// GET with optional field filtering.
144///
145/// If `fields` is empty, requests all fields (equivalent to `-r ""`).
146/// Otherwise, encodes a pvRequest like `field(value,alarm,timeStamp)`.
147pub async fn pvget_fields(
148    opts: &PvGetOptions,
149    fields: &[&str],
150) -> Result<PvGetResult, PvGetError> {
151    let target = resolve_pv_server(opts).await?;
152
153    let conn = establish_channel(target, opts).await?;
154    let ChannelConn {
155        mut stream,
156        sid,
157        version,
158        is_be,
159        ..
160    } = conn;
161
162    let ioid = 1u32;
163    let pv_request = if fields.is_empty() {
164        // Empty pvRequest — request all fields
165        vec![0xfd, 0x02, 0x00, 0x80, 0x00, 0x00]
166    } else {
167        encode_pv_request(fields, is_be)
168    };
169    let get_init_req = encode_get_request(
170        sid,
171        ioid,
172        0x08,
173        &pv_request,
174        version,
175        is_be,
176    );
177    stream.write_all(&get_init_req).await?;
178
179    let init_resp = read_until(
180        &mut stream,
181        opts.timeout,
182        |cmd| matches!(cmd, PvaPacketCommand::Op(op) if (op.subcmd & 0x08) != 0),
183    )
184    .await?;
185    let mut pkt = PvaPacket::new(&init_resp);
186    let cmd = pkt.decode_payload().ok_or(PvGetError::Protocol(
187        "get init response decode failed".to_string(),
188    ))?;
189
190    let desc = match cmd {
191        PvaPacketCommand::Op(op) => op
192            .introspection
193            .ok_or_else(|| PvGetError::Decode("missing introspection".to_string()))?,
194        _ => {
195            return Err(PvGetError::Protocol(
196                "unexpected get init response".to_string(),
197            ));
198        }
199    };
200
201    let get_data_req = encode_get_request(sid, ioid, 0x00, &[], version, is_be);
202    stream.write_all(&get_data_req).await?;
203
204    let data_resp = read_until(
205        &mut stream,
206        opts.timeout,
207        |cmd| matches!(cmd, PvaPacketCommand::Op(op) if op.subcmd == 0x00),
208    )
209    .await?;
210    let mut pkt = PvaPacket::new(&data_resp);
211    let cmd = pkt.decode_payload().ok_or(PvGetError::Protocol(
212        "get data response decode failed".to_string(),
213    ))?;
214
215    match cmd {
216        PvaPacketCommand::Op(mut op) => {
217            op.decode_with_field_desc(&desc, is_be);
218            if let Some(value) = op.decoded_value {
219                return Ok(PvGetResult {
220                    pv_name: opts.pv_name.clone(),
221                    value,
222                    raw_pva: data_resp,
223                    raw_pvd: op.body,
224                    introspection: desc,
225                });
226            }
227            Err(PvGetError::Decode("no decoded value".to_string()))
228        }
229        _ => Err(PvGetError::Protocol(
230            "unexpected get data response".to_string(),
231        )),
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand, PvaStatus};
239
240    #[test]
241    fn encode_decode_monitor_request_roundtrip() {
242        let msg =
243            encode_monitor_request(1, 2, 0x08, &[0xfd, 0x02, 0x00, 0x80, 0x00, 0x00], 2, false);
244        let mut pkt = PvaPacket::new(&msg);
245        let cmd = pkt.decode_payload().expect("decoded");
246        match cmd {
247            PvaPacketCommand::Op(op) => {
248                assert_eq!(op.command, 13);
249                assert_eq!(op.subcmd, 0x08);
250                assert_eq!(op.sid_or_cid, 1);
251                assert_eq!(op.ioid, 2);
252            }
253            other => panic!("unexpected decode: {:?}", other),
254        }
255    }
256
257    #[test]
258    fn encode_decode_put_init_roundtrip() {
259        let msg = encode_put_request(5, 6, 0x08, &[0xfd, 0x02, 0x00, 0x80, 0x00, 0x00], 2, false);
260        let mut pkt = PvaPacket::new(&msg);
261        let cmd = pkt.decode_payload().expect("decoded");
262        match cmd {
263            PvaPacketCommand::Op(op) => {
264                assert_eq!(op.command, 11);
265                assert_eq!(op.subcmd, 0x08);
266                assert_eq!(op.sid_or_cid, 5);
267                assert_eq!(op.ioid, 6);
268            }
269            other => panic!("unexpected decode: {:?}", other),
270        }
271    }
272
273    #[test]
274    fn encode_decode_get_field_request_roundtrip() {
275        let msg = encode_get_field_request(7, 1, Some("*"), 2, false);
276        let mut pkt = PvaPacket::new(&msg);
277        let cmd = pkt.decode_payload().expect("decoded");
278        match cmd {
279            PvaPacketCommand::GetField(payload) => {
280                assert!(!payload.is_server);
281                assert_eq!(payload.sid, Some(7));
282                assert_eq!(payload.ioid, Some(1));
283                assert_eq!(payload.field_name.as_deref(), Some("*"));
284            }
285            other => panic!("unexpected decode: {:?}", other),
286        }
287    }
288
289    #[test]
290    fn encode_decode_get_field_request_empty_field_roundtrip() {
291        let msg = encode_get_field_request(7, 1, None, 2, false);
292        let mut pkt = PvaPacket::new(&msg);
293        let cmd = pkt.decode_payload().expect("decoded");
294        match cmd {
295            PvaPacketCommand::GetField(payload) => {
296                assert!(!payload.is_server);
297                assert_eq!(payload.sid, Some(7));
298                assert_eq!(payload.ioid, Some(1));
299                assert_eq!(payload.field_name.as_deref(), Some(""));
300            }
301            other => panic!("unexpected decode: {:?}", other),
302        }
303    }
304
305    #[test]
306    fn pva_status_code_zero_is_not_an_error() {
307        let ok = PvaStatus {
308            code: 0,
309            message: None,
310            stack: None,
311        };
312        let err = PvaStatus {
313            code: 1,
314            message: Some("bad".to_string()),
315            stack: None,
316        };
317        assert!(!None::<&PvaStatus>.is_some_and(|s| s.is_error()));
318        assert!(!Some(&ok).is_some_and(|s| s.is_error()));
319        assert!(Some(&err).is_some_and(|s| s.is_error()));
320    }
321}