1use tokio::io::AsyncWriteExt;
2use tokio::net::TcpStream;
3use tokio::time::timeout;
4
5use crate::auth::{resolved_authnz_host, resolved_authnz_user};
6use crate::search::resolve_pv_server;
7use crate::transport::{read_packet, read_until};
8use crate::types::{PvGetError, PvGetOptions, PvGetResult};
9use spvirit_codec::spvd_encode::encode_pv_request;
10use spvirit_codec::epics_decode::{
11 PvaPacket, PvaPacketCommand, decode_op_response_status as codec_decode_op_response_status,
12};
13use spvirit_codec::spvirit_encode::encode_client_connection_validation;
14pub use spvirit_codec::spvirit_encode::{
15 encode_create_channel_request, encode_get_field_request, encode_get_request,
16 encode_monitor_request, encode_put_request,
17};
18
19pub fn build_client_validation(
20 opts: &crate::types::PvGetOptions,
21 version: u8,
22 is_be: bool,
23) -> Vec<u8> {
24 let user = resolved_authnz_user(opts);
25 let host = resolved_authnz_host(opts);
26 encode_client_connection_validation(87_040, 32_767, 0, "ca", &user, &host, version, is_be)
27}
28
29pub fn op_response_status(
30 raw: &[u8],
31 is_be: bool,
32) -> Result<Option<spvirit_codec::epics_decode::PvaStatus>, PvGetError> {
33 codec_decode_op_response_status(raw, is_be).map_err(PvGetError::Protocol)
34}
35
36pub fn ensure_status_ok(raw: &[u8], is_be: bool, step: &str) -> Result<(), PvGetError> {
37 match op_response_status(raw, is_be)? {
38 None => Ok(()),
39 Some(st) if st.code == 0 => Ok(()),
40 Some(st) => Err(PvGetError::Protocol(format!(
41 "{} failed: {}",
42 step,
43 st.message.unwrap_or_else(|| format!("code={}", st.code))
44 ))),
45 }
46}
47
48pub struct ChannelConn {
49 pub stream: TcpStream,
50 pub sid: u32,
51 pub version: u8,
52 pub is_be: bool,
53 pub server_addr: std::net::SocketAddr,
54}
55
56pub async fn establish_channel(
57 target: std::net::SocketAddr,
58 opts: &PvGetOptions,
59) -> Result<ChannelConn, PvGetError> {
60 let mut stream = timeout(opts.timeout, TcpStream::connect(target))
61 .await
62 .map_err(|_| PvGetError::Timeout("connect"))??;
63
64 let mut version = 2u8;
65 let mut is_be = false;
66
67 for _ in 0..2 {
68 if let Ok(bytes) = read_packet(&mut stream, opts.timeout).await {
69 let mut pkt = PvaPacket::new(&bytes);
70 if let Some(cmd) = pkt.decode_payload() {
71 match cmd {
72 PvaPacketCommand::Control(payload) => {
73 if payload.command == 2 {
74 is_be = pkt.header.flags.is_msb;
75 }
76 }
77 PvaPacketCommand::ConnectionValidation(_) => {
78 version = pkt.header.version;
79 is_be = pkt.header.flags.is_msb;
80 }
81 _ => {}
82 }
83 }
84 }
85 }
86
87 let validation = build_client_validation(opts, version, is_be);
88 stream.write_all(&validation).await?;
89
90 let _ = read_until(&mut stream, opts.timeout, |cmd| {
91 matches!(cmd, PvaPacketCommand::ConnectionValidated(_))
92 })
93 .await?;
94
95 let cid = 1u32;
96 let create = encode_create_channel_request(cid, &opts.pv_name, version, is_be);
97 stream.write_all(&create).await?;
98
99 let create_resp = read_until(&mut stream, opts.timeout, |cmd| {
100 matches!(cmd, PvaPacketCommand::CreateChannel(_))
101 })
102 .await?;
103 let mut pkt = PvaPacket::new(&create_resp);
104 let cmd = pkt.decode_payload().ok_or(PvGetError::Protocol(
105 "create_channel decode failed".to_string(),
106 ))?;
107 let sid = match cmd {
108 PvaPacketCommand::CreateChannel(payload) => {
109 if payload.status.as_ref().is_some_and(|s| s.is_error()) {
110 let detail = payload
111 .status
112 .as_ref()
113 .map(ToString::to_string)
114 .unwrap_or_default();
115 return Err(PvGetError::Protocol(format!(
116 "create_channel error: {}",
117 detail
118 )));
119 }
120 payload.sid
121 }
122 _ => {
123 return Err(PvGetError::Protocol(
124 "unexpected create_channel response".to_string(),
125 ));
126 }
127 };
128
129 Ok(ChannelConn {
130 stream,
131 sid,
132 version,
133 is_be,
134 server_addr: target,
135 })
136}
137
138pub async fn pvget(opts: &PvGetOptions) -> Result<PvGetResult, PvGetError> {
140 pvget_fields(opts, &[]).await
141}
142
143pub async fn pvget_fields(
148 opts: &PvGetOptions,
149 fields: &[&str],
150) -> Result<PvGetResult, PvGetError> {
151 let target = resolve_pv_server(opts).await?;
152
153 let conn = establish_channel(target, opts).await?;
154 let ChannelConn {
155 mut stream,
156 sid,
157 version,
158 is_be,
159 ..
160 } = conn;
161
162 let ioid = 1u32;
163 let pv_request = if fields.is_empty() {
164 vec![0xfd, 0x02, 0x00, 0x80, 0x00, 0x00]
166 } else {
167 encode_pv_request(fields, is_be)
168 };
169 let get_init_req = encode_get_request(
170 sid,
171 ioid,
172 0x08,
173 &pv_request,
174 version,
175 is_be,
176 );
177 stream.write_all(&get_init_req).await?;
178
179 let init_resp = read_until(
180 &mut stream,
181 opts.timeout,
182 |cmd| matches!(cmd, PvaPacketCommand::Op(op) if (op.subcmd & 0x08) != 0),
183 )
184 .await?;
185 let mut pkt = PvaPacket::new(&init_resp);
186 let cmd = pkt.decode_payload().ok_or(PvGetError::Protocol(
187 "get init response decode failed".to_string(),
188 ))?;
189
190 let desc = match cmd {
191 PvaPacketCommand::Op(op) => op
192 .introspection
193 .ok_or_else(|| PvGetError::Decode("missing introspection".to_string()))?,
194 _ => {
195 return Err(PvGetError::Protocol(
196 "unexpected get init response".to_string(),
197 ));
198 }
199 };
200
201 let get_data_req = encode_get_request(sid, ioid, 0x00, &[], version, is_be);
202 stream.write_all(&get_data_req).await?;
203
204 let data_resp = read_until(
205 &mut stream,
206 opts.timeout,
207 |cmd| matches!(cmd, PvaPacketCommand::Op(op) if op.subcmd == 0x00),
208 )
209 .await?;
210 let mut pkt = PvaPacket::new(&data_resp);
211 let cmd = pkt.decode_payload().ok_or(PvGetError::Protocol(
212 "get data response decode failed".to_string(),
213 ))?;
214
215 match cmd {
216 PvaPacketCommand::Op(mut op) => {
217 op.decode_with_field_desc(&desc, is_be);
218 if let Some(value) = op.decoded_value {
219 return Ok(PvGetResult {
220 pv_name: opts.pv_name.clone(),
221 value,
222 raw_pva: data_resp,
223 raw_pvd: op.body,
224 introspection: desc,
225 });
226 }
227 Err(PvGetError::Decode("no decoded value".to_string()))
228 }
229 _ => Err(PvGetError::Protocol(
230 "unexpected get data response".to_string(),
231 )),
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use spvirit_codec::epics_decode::{PvaPacket, PvaPacketCommand, PvaStatus};
239
240 #[test]
241 fn encode_decode_monitor_request_roundtrip() {
242 let msg =
243 encode_monitor_request(1, 2, 0x08, &[0xfd, 0x02, 0x00, 0x80, 0x00, 0x00], 2, false);
244 let mut pkt = PvaPacket::new(&msg);
245 let cmd = pkt.decode_payload().expect("decoded");
246 match cmd {
247 PvaPacketCommand::Op(op) => {
248 assert_eq!(op.command, 13);
249 assert_eq!(op.subcmd, 0x08);
250 assert_eq!(op.sid_or_cid, 1);
251 assert_eq!(op.ioid, 2);
252 }
253 other => panic!("unexpected decode: {:?}", other),
254 }
255 }
256
257 #[test]
258 fn encode_decode_put_init_roundtrip() {
259 let msg = encode_put_request(5, 6, 0x08, &[0xfd, 0x02, 0x00, 0x80, 0x00, 0x00], 2, false);
260 let mut pkt = PvaPacket::new(&msg);
261 let cmd = pkt.decode_payload().expect("decoded");
262 match cmd {
263 PvaPacketCommand::Op(op) => {
264 assert_eq!(op.command, 11);
265 assert_eq!(op.subcmd, 0x08);
266 assert_eq!(op.sid_or_cid, 5);
267 assert_eq!(op.ioid, 6);
268 }
269 other => panic!("unexpected decode: {:?}", other),
270 }
271 }
272
273 #[test]
274 fn encode_decode_get_field_request_roundtrip() {
275 let msg = encode_get_field_request(7, 1, Some("*"), 2, false);
276 let mut pkt = PvaPacket::new(&msg);
277 let cmd = pkt.decode_payload().expect("decoded");
278 match cmd {
279 PvaPacketCommand::GetField(payload) => {
280 assert!(!payload.is_server);
281 assert_eq!(payload.sid, Some(7));
282 assert_eq!(payload.ioid, Some(1));
283 assert_eq!(payload.field_name.as_deref(), Some("*"));
284 }
285 other => panic!("unexpected decode: {:?}", other),
286 }
287 }
288
289 #[test]
290 fn encode_decode_get_field_request_empty_field_roundtrip() {
291 let msg = encode_get_field_request(7, 1, None, 2, false);
292 let mut pkt = PvaPacket::new(&msg);
293 let cmd = pkt.decode_payload().expect("decoded");
294 match cmd {
295 PvaPacketCommand::GetField(payload) => {
296 assert!(!payload.is_server);
297 assert_eq!(payload.sid, Some(7));
298 assert_eq!(payload.ioid, Some(1));
299 assert_eq!(payload.field_name.as_deref(), Some(""));
300 }
301 other => panic!("unexpected decode: {:?}", other),
302 }
303 }
304
305 #[test]
306 fn pva_status_code_zero_is_not_an_error() {
307 let ok = PvaStatus {
308 code: 0,
309 message: None,
310 stack: None,
311 };
312 let err = PvaStatus {
313 code: 1,
314 message: Some("bad".to_string()),
315 stack: None,
316 };
317 assert!(!None::<&PvaStatus>.is_some_and(|s| s.is_error()));
318 assert!(!Some(&ok).is_some_and(|s| s.is_error()));
319 assert!(Some(&err).is_some_and(|s| s.is_error()));
320 }
321}