1use crate::{
4 base58::*,
5 byte_converter::FromByteSlice,
6 converting_callback_receiver::ConvertingCallbackReceiver,
7 converting_receiver::{BrickletError, BrickletRecvTimeoutError, ConvertingReceiver},
8 ip_connection::{GetRequestSender, Request, SocketThreadRequest},
9 low_level_traits::*,
10};
11use std::sync::{
12 mpsc::{channel, Sender},
13 Arc, Mutex,
14};
15
16use std::error::Error;
17
18#[derive(Debug, Copy, Clone, PartialEq)]
19pub(crate) enum ResponseExpectedFlag {
20 InvalidFunctionId,
21 False,
22 True,
23 AlwaysTrue,
24}
25
26impl From<bool> for ResponseExpectedFlag {
27 fn from(b: bool) -> Self {
28 if b {
29 ResponseExpectedFlag::True
30 } else {
31 ResponseExpectedFlag::False
32 }
33 }
34}
35
36#[derive(Clone)]
37pub(crate) struct Device {
38 pub api_version: [u8; 3],
39 pub response_expected: [ResponseExpectedFlag; 256],
40 pub internal_uid: u32,
41 pub req_tx: Sender<SocketThreadRequest>,
42 pub high_level_locks: Vec<Arc<Mutex<()>>>,
43}
44
45#[derive(Debug, Copy, Clone)]
47pub struct GetResponseExpectedError(u8);
48
49impl std::error::Error for GetResponseExpectedError {}
50
51impl std::fmt::Display for GetResponseExpectedError {
52 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
53 write!(f, "Can not get response expected: Invalid function id {}", self.0)
54 }
55}
56
57#[derive(Debug, Copy, Clone)]
59pub enum SetResponseExpectedError {
60 InvalidFunctionId(u8),
62 IsAlwaysTrue(u8),
64}
65
66impl std::error::Error for SetResponseExpectedError {}
67
68impl std::fmt::Display for SetResponseExpectedError {
69 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
70 match self {
71 SetResponseExpectedError::InvalidFunctionId(fid) => write!(f, "Can not set response expected: Invalid function id {}", fid),
72 SetResponseExpectedError::IsAlwaysTrue(_fid) => write!(f, "Can not set response expected: function always responds."),
73 }
74 }
75}
76
77impl Device {
78 pub(crate) fn new<T: GetRequestSender>(api_version: [u8; 3], uid: &str, req_sender: T, high_level_function_count: u8) -> Device {
79 match uid.base58_to_u32() {
80 Ok(internal_uid) => Device {
81 api_version,
82 internal_uid,
83 req_tx: req_sender.get_rs().socket_thread_tx.clone(),
84 response_expected: [ResponseExpectedFlag::InvalidFunctionId; 256],
85 high_level_locks: vec![Arc::new(Mutex::new(())); high_level_function_count as usize],
86 },
87 Err(e) => panic!("UID {} could not be parsed: {}", uid, e.description()),
89 }
90 }
91
92 pub(crate) fn get_response_expected(&self, function_id: u8) -> Result<bool, GetResponseExpectedError> {
93 match self.response_expected[function_id as usize] {
94 ResponseExpectedFlag::False => Ok(false),
95 ResponseExpectedFlag::True => Ok(true),
96 ResponseExpectedFlag::AlwaysTrue => Ok(true),
97 ResponseExpectedFlag::InvalidFunctionId => Err(GetResponseExpectedError(function_id)),
98 }
99 }
100
101 pub(crate) fn set_response_expected(&mut self, function_id: u8, response_expected: bool) -> Result<(), SetResponseExpectedError> {
102 if self.response_expected[function_id as usize] == ResponseExpectedFlag::AlwaysTrue {
103 Err(SetResponseExpectedError::IsAlwaysTrue(function_id))
104 } else if self.response_expected[function_id as usize] == ResponseExpectedFlag::InvalidFunctionId {
105 Err(SetResponseExpectedError::InvalidFunctionId(function_id))
106 } else {
107 self.response_expected[function_id as usize] = ResponseExpectedFlag::from(response_expected);
108 Ok(())
109 }
110 }
111
112 pub(crate) fn set_response_expected_all(&mut self, response_expected: bool) {
113 for resp_exp in self.response_expected.iter_mut() {
114 if *resp_exp == ResponseExpectedFlag::True || *resp_exp == ResponseExpectedFlag::False {
115 *resp_exp = ResponseExpectedFlag::from(response_expected);
116 }
117 }
118 }
119
120 pub(crate) fn set<T: FromByteSlice>(&self, function_id: u8, payload: Vec<u8>) -> ConvertingReceiver<T> {
121 let (sent_tx, sent_rx) = channel();
122 if self.response_expected[function_id as usize] == ResponseExpectedFlag::False {
123 let (tx, rx) = channel();
124 self.req_tx
125 .send(SocketThreadRequest::Request(
126 Request::Set { uid: self.internal_uid, function_id, payload, response_sender: None },
127 sent_tx,
128 ))
129 .expect("The socket thread queue was disconnected from the ip connection. This is a bug in the rust bindings.");
130 let timeout = sent_rx.recv().expect("The sent queue was dropped. This is a bug in the rust bindings.");
131 let _ = tx.send(Err(BrickletError::SuccessButResponseExpectedIsDisabled));
132 ConvertingReceiver::new(rx, timeout)
133 } else {
134 let (tx, rx) = channel();
135 self.req_tx
136 .send(SocketThreadRequest::Request(
137 Request::Set { uid: self.internal_uid, function_id, payload, response_sender: Some(tx) },
138 sent_tx,
139 ))
140 .expect("The socket thread queue was disconnected from the ip connection. This is a bug in the rust bindings.");
141 let timeout = sent_rx.recv().expect("The sent queue was dropped. This is a bug in the rust bindings.");
142 ConvertingReceiver::new(rx, timeout)
143 }
144 }
145
146 pub(crate) fn get_callback_receiver<T: FromByteSlice>(&self, function_id: u8) -> ConvertingCallbackReceiver<T> {
147 let (tx, rx) = channel();
148 let (sent_tx, sent_rx) = channel();
149 self.req_tx
150 .send(SocketThreadRequest::Request(
151 Request::RegisterCallback { uid: self.internal_uid, function_id, response_sender: tx },
152 sent_tx,
153 ))
154 .expect("The socket thread queue was disconnected from the ip connection. This is a bug in the rust bindings.");
155 sent_rx.recv().expect("The sent queue was dropped. This is a bug in the rust bindings.");
156 ConvertingCallbackReceiver::new(rx)
157 }
158
159 pub(crate) fn get<T: FromByteSlice>(&self, function_id: u8, payload: Vec<u8>) -> ConvertingReceiver<T> {
160 let (tx, rx) = channel();
161 let (sent_tx, sent_rx) = channel();
162 self.req_tx
163 .send(SocketThreadRequest::Request(Request::Get { uid: self.internal_uid, function_id, payload, response_sender: tx }, sent_tx))
164 .expect("The socket thread queue was disconnected from the ip connection. This is a bug in the rust bindings.");
165 let timeout = sent_rx.recv().expect("The sent queue was dropped. This is a bug in the rust bindings.");
166 ConvertingReceiver::new(rx, timeout)
167 }
168
169 pub(crate) fn set_high_level<
170 PayloadT,
171 OutputT,
172 LlwT: LowLevelWrite<OutputT>,
173 ClosureT: FnMut(usize, usize, &[PayloadT]) -> Result<LlwT, BrickletRecvTimeoutError>,
174 >(
175 &self,
176 high_level_function_idx: u8,
177 payload: &[PayloadT],
178 max_payload_len: usize,
179 chunk_len: usize,
180 low_level_closure: &mut ClosureT,
181 ) -> Result<(usize, OutputT), BrickletRecvTimeoutError> {
182 if payload.len() > max_payload_len {
183 return Err(BrickletRecvTimeoutError::InvalidParameter);
184 }
185
186 let length = payload.len();
187
188 let mut chunk_offset = 0;
189 {
190 let _lock_guard = self.high_level_locks[high_level_function_idx as usize].lock().unwrap();
191 if length == 0 {
192 match low_level_closure(length, chunk_offset, &[]) {
193 Ok(low_level_result) => return Ok((low_level_result.ll_message_written(), low_level_result.get_result())),
194 Err(e) => return Err(e),
195 }
196 }
197 let mut written_sum = 0;
198 loop {
199 match low_level_closure(length, chunk_offset, &payload[chunk_offset..std::cmp::min(chunk_offset + chunk_len, length)]) {
200 Ok(low_level_result) => {
201 let written = low_level_result.ll_message_written();
202 let output = low_level_result.get_result();
203 written_sum += written;
204 if written < chunk_len {
205 return Ok((written_sum, output));
206 }
207 chunk_offset += chunk_len;
208 if chunk_offset >= length {
209 return Ok((written_sum, output));
210 }
211 }
212 Err(e) => return Err(e),
213 }
214 }
215 }
216 }
217
218 pub(crate) fn get_high_level<
219 PayloadT: Default + Clone + Copy,
220 OutputT,
221 LlrT: LowLevelRead<PayloadT, OutputT>,
222 ClosureT: FnMut() -> Result<LlrT, BrickletRecvTimeoutError>,
223 >(
224 &self,
225 high_level_function_idx: u8,
226 low_level_closure: &mut ClosureT,
227 ) -> Result<(Vec<PayloadT>, OutputT), BrickletRecvTimeoutError> {
228 let mut chunk_offset = 0;
229 {
230 let _lock_guard = self.high_level_locks[high_level_function_idx as usize].lock().unwrap();
231 let mut result = low_level_closure()?;
232 let mut out_of_sync = result.ll_message_chunk_offset() != 0;
233 let message_length = result.ll_message_length();
234
235 if !out_of_sync {
236 let mut buf = vec![PayloadT::default(); message_length];
237 let first_read_length = std::cmp::min(result.ll_message_chunk_data().len(), message_length - chunk_offset);
238 buf[chunk_offset..chunk_offset + first_read_length].copy_from_slice(&result.ll_message_chunk_data()[0..first_read_length]);
239 chunk_offset += first_read_length;
240 while chunk_offset < message_length {
241 result = low_level_closure()?;
242 out_of_sync = result.ll_message_chunk_offset() != chunk_offset || result.ll_message_length() != message_length;
243 if out_of_sync {
244 break;
245 }
246
247 let read_length = std::cmp::min(result.ll_message_chunk_data().len(), message_length - chunk_offset);
248 buf[chunk_offset..chunk_offset + read_length].copy_from_slice(&result.ll_message_chunk_data()[0..read_length]);
249 chunk_offset += read_length;
250 }
251 if !out_of_sync {
252 return Ok((buf, result.get_result()));
253 }
254 }
255
256 assert!(out_of_sync);
257 while chunk_offset + result.ll_message_chunk_data().len() < message_length {
258 chunk_offset += result.ll_message_chunk_data().len();
259 result = low_level_closure()?;
260 }
261 Err(BrickletRecvTimeoutError::MalformedPacket)
262 }
263 }
264}