rust_ethernet_ip/
python.rs

1#![allow(non_local_definitions)]
2
3use pyo3::prelude::*;
4// use pyo3::wrap_pyfunction;
5use pyo3::types::{PyDict, PyTuple};
6// use pyo3::types::IntoPyDict;
7use crate::{EipClient, PlcValue, SubscriptionOptions};
8use std::collections::HashMap;
9use tokio::runtime::Runtime;
10
11/// Python module for rust_ethernet_ip
12#[pymodule]
13fn rust_ethernet_ip(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
14    m.add_class::<PyEipClient>()?;
15    m.add_class::<PyPlcValue>()?;
16    m.add_class::<PySubscriptionOptions>()?;
17    Ok(())
18}
19
20/// Python wrapper for EipClient
21#[pyclass]
22struct PyEipClient {
23    client: EipClient,
24    runtime: Runtime,
25}
26
27// Newtype for (String, PyPlcValue)
28struct TagValueArg {
29    name: String,
30    value: PyPlcValue,
31}
32
33impl<'a> FromPyObject<'a> for TagValueArg {
34    fn extract(ob: &'a pyo3::PyAny) -> PyResult<Self> {
35        let tuple = ob.downcast::<PyTuple>()?;
36        if tuple.len() != 2 {
37            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
38                "Expected tuple of length 2",
39            ));
40        }
41        let name = tuple.get_item(0)?.extract::<String>()?;
42        let value = tuple.get_item(1)?.extract::<PyPlcValue>()?;
43        Ok(TagValueArg { name, value })
44    }
45}
46
47// Newtype for (String, PySubscriptionOptions)
48struct TagSubOptArg {
49    name: String,
50    options: PySubscriptionOptions,
51}
52
53impl<'a> FromPyObject<'a> for TagSubOptArg {
54    fn extract(ob: &'a pyo3::PyAny) -> PyResult<Self> {
55        let tuple = ob.downcast::<PyTuple>()?;
56        if tuple.len() != 2 {
57            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
58                "Expected tuple of length 2",
59            ));
60        }
61        let name = tuple.get_item(0)?.extract::<String>()?;
62        let options = tuple.get_item(1)?.extract::<PySubscriptionOptions>()?;
63        Ok(TagSubOptArg { name, options })
64    }
65}
66
67#[pymethods]
68impl PyEipClient {
69    /// Create a new EipClient instance
70    #[new]
71    fn new(addr: &str) -> PyResult<Self> {
72        let runtime = Runtime::new().unwrap();
73        let client = runtime
74            .block_on(async { EipClient::connect(addr).await })
75            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
76
77        Ok(PyEipClient { client, runtime })
78    }
79
80    /// Read a tag value
81    fn read_tag(&mut self, tag_name: &str) -> PyResult<PyPlcValue> {
82        let value = self
83            .runtime
84            .block_on(async { self.client.read_tag(tag_name).await })
85            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
86
87        Ok(PyPlcValue { value })
88    }
89
90    /// Write a value to a tag
91    fn write_tag(&mut self, tag_name: &str, value: &PyPlcValue) -> PyResult<bool> {
92        let result = self
93            .runtime
94            .block_on(async { self.client.write_tag(tag_name, value.value.clone()).await });
95        match result {
96            Ok(_) => Ok(true),
97            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
98                e.to_string(),
99            )),
100        }
101    }
102
103    /// Read multiple tags in batch
104    fn read_tags_batch(&mut self, tag_names: Vec<String>) -> PyResult<Vec<(String, PyObject)>> {
105        Python::with_gil(|py| {
106            let runtime = tokio::runtime::Runtime::new().unwrap();
107            let results = runtime
108                .block_on(async {
109                    self.client
110                        .read_tags_batch(&tag_names.iter().map(|s| s.as_str()).collect::<Vec<_>>())
111                        .await
112                })
113                .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
114            Ok(results
115                .into_iter()
116                .map(|(name, result)| {
117                    let obj = match result {
118                        Ok(v) => PyPlcValue { value: v }.into_py(py),
119                        Err(e) => PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string())
120                            .into_py(py),
121                    };
122                    (name, obj)
123                })
124                .collect())
125        })
126    }
127
128    /// Write multiple tags in batch
129    fn write_tags_batch(
130        &mut self,
131        tag_values: Vec<TagValueArg>,
132    ) -> PyResult<Vec<(String, PyObject)>> {
133        Python::with_gil(|py| {
134            let runtime = tokio::runtime::Runtime::new().unwrap();
135            let results = runtime
136                .block_on(async {
137                    self.client
138                        .write_tags_batch(
139                            &tag_values
140                                .iter()
141                                .map(|arg| (arg.name.as_str(), arg.value.value.clone()))
142                                .collect::<Vec<_>>(),
143                        )
144                        .await
145                })
146                .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
147            Ok(results
148                .into_iter()
149                .map(|(name, result)| {
150                    let obj = match result {
151                        Ok(()) => py.None(),
152                        Err(e) => PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string())
153                            .into_py(py),
154                    };
155                    (name, obj)
156                })
157                .collect())
158        })
159    }
160
161    /// Subscribe to a tag
162    fn subscribe_to_tag(&self, tag_path: &str, options: &PySubscriptionOptions) -> PyResult<()> {
163        self.runtime
164            .block_on(async {
165                self.client
166                    .subscribe_to_tag(tag_path, options.options.clone())
167                    .await
168            })
169            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
170
171        Ok(())
172    }
173
174    /// Subscribe to multiple tags
175    fn subscribe_to_tags(&self, tags: Vec<TagSubOptArg>) -> PyResult<()> {
176        self.runtime
177            .block_on(async {
178                self.client
179                    .subscribe_to_tags(
180                        &tags
181                            .iter()
182                            .map(|arg| (arg.name.as_str(), arg.options.options.clone()))
183                            .collect::<Vec<_>>(),
184                    )
185                    .await
186            })
187            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
188        Ok(())
189    }
190
191    /// Unregister the session
192    fn unregister_session(&mut self) -> PyResult<()> {
193        self.runtime
194            .block_on(async { self.client.unregister_session().await })
195            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
196
197        Ok(())
198    }
199}
200
201/// Python wrapper for PlcValue
202#[pyclass]
203struct PyPlcValue {
204    value: PlcValue,
205}
206
207impl FromPyObject<'_> for PyPlcValue {
208    fn extract(ob: &PyAny) -> PyResult<Self> {
209        if let Ok(bool_val) = ob.extract::<bool>() {
210            Ok(PyPlcValue {
211                value: PlcValue::Bool(bool_val),
212            })
213        } else if let Ok(int_val) = ob.extract::<i32>() {
214            Ok(PyPlcValue {
215                value: PlcValue::Dint(int_val),
216            })
217        } else if let Ok(float_val) = ob.extract::<f64>() {
218            Ok(PyPlcValue {
219                value: PlcValue::Lreal(float_val),
220            })
221        } else if let Ok(string_val) = ob.extract::<String>() {
222            Ok(PyPlcValue {
223                value: PlcValue::String(string_val),
224            })
225        } else if let Ok(dict) = ob.downcast::<PyDict>() {
226            let mut map = HashMap::new();
227            for (key, value) in dict.iter() {
228                let key = key.extract::<String>()?;
229                let value = value.extract::<PyPlcValue>()?.value;
230                map.insert(key, value);
231            }
232            Ok(PyPlcValue {
233                value: PlcValue::Udt(map),
234            })
235        } else {
236            Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
237                "Unsupported value type",
238            ))
239        }
240    }
241}
242
243#[pymethods]
244impl PyPlcValue {
245    #[new]
246    fn new(value: PyObject) -> PyResult<Self> {
247        Python::with_gil(|py| {
248            if let Ok(val) = value.extract::<bool>(py) {
249                Ok(PyPlcValue {
250                    value: PlcValue::Bool(val),
251                })
252            } else if let Ok(val) = value.extract::<i32>(py) {
253                Ok(PyPlcValue {
254                    value: PlcValue::Dint(val),
255                })
256            } else if let Ok(val) = value.extract::<f32>(py) {
257                Ok(PyPlcValue {
258                    value: PlcValue::Real(val),
259                })
260            } else if let Ok(val) = value.extract::<f64>(py) {
261                Ok(PyPlcValue {
262                    value: PlcValue::Real(val as f32),
263                })
264            } else if let Ok(val) = value.extract::<String>(py) {
265                Ok(PyPlcValue {
266                    value: PlcValue::String(val),
267                })
268            } else {
269                Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
270                    "Unsupported value type",
271                ))
272            }
273        })
274    }
275
276    #[staticmethod]
277    fn real(val: f32) -> Self {
278        PyPlcValue {
279            value: PlcValue::Real(val),
280        }
281    }
282    #[staticmethod]
283    fn lreal(val: f64) -> Self {
284        PyPlcValue {
285            value: PlcValue::Lreal(val),
286        }
287    }
288    #[staticmethod]
289    fn dint(val: i32) -> Self {
290        PyPlcValue {
291            value: PlcValue::Dint(val),
292        }
293    }
294    #[staticmethod]
295    fn lint(val: i64) -> Self {
296        PyPlcValue {
297            value: PlcValue::Lint(val),
298        }
299    }
300    #[staticmethod]
301    fn string(val: String) -> Self {
302        PyPlcValue {
303            value: PlcValue::String(val),
304        }
305    }
306
307    #[getter]
308    fn value(&self, py: Python) -> PyResult<PyObject> {
309        match &self.value {
310            PlcValue::Bool(b) => Ok(b.into_py(py)),
311            PlcValue::Sint(i) => Ok(i.into_py(py)),
312            PlcValue::Int(i) => Ok(i.into_py(py)),
313            PlcValue::Dint(i) => Ok(i.into_py(py)),
314            PlcValue::Lint(i) => Ok(i.into_py(py)),
315            PlcValue::Usint(u) => Ok(u.into_py(py)),
316            PlcValue::Uint(u) => Ok(u.into_py(py)),
317            PlcValue::Udint(u) => Ok(u.into_py(py)),
318            PlcValue::Ulint(u) => Ok(u.into_py(py)),
319            PlcValue::Real(f) => Ok(f.into_py(py)),
320            PlcValue::Lreal(f) => Ok(f.into_py(py)),
321            PlcValue::String(s) => Ok(s.into_py(py)),
322            PlcValue::Udt(map) => {
323                let dict = PyDict::new(py);
324                for (k, v) in map.iter() {
325                    let v_py = PyPlcValue { value: v.clone() }.value(py)?;
326                    dict.set_item(k, v_py)?;
327                }
328                Ok(dict.into_py(py))
329            }
330        }
331    }
332
333    fn __str__(&self) -> String {
334        format!("{:?}", self.value)
335    }
336
337    fn __repr__(&self) -> String {
338        format!("PyPlcValue({:?})", self.value)
339    }
340}
341
342/// Python wrapper for SubscriptionOptions
343#[pyclass]
344struct PySubscriptionOptions {
345    options: SubscriptionOptions,
346}
347
348impl FromPyObject<'_> for PySubscriptionOptions {
349    fn extract(ob: &PyAny) -> PyResult<Self> {
350        let update_rate = ob.getattr("update_rate")?.extract::<u32>()?;
351        let change_threshold = ob.getattr("change_threshold")?.extract::<f32>()?;
352        let timeout = ob.getattr("timeout")?.extract::<u32>()?;
353
354        Ok(PySubscriptionOptions {
355            options: SubscriptionOptions {
356                update_rate,
357                change_threshold,
358                timeout,
359            },
360        })
361    }
362}
363
364#[pymethods]
365impl PySubscriptionOptions {
366    #[new]
367    fn new(update_rate: u32, change_threshold: f32, timeout: u32) -> PyResult<Self> {
368        let options = SubscriptionOptions {
369            update_rate,
370            change_threshold,
371            timeout,
372        };
373
374        Ok(PySubscriptionOptions { options })
375    }
376
377    /// Get the update rate in milliseconds
378    #[getter]
379    fn update_rate(&self) -> u32 {
380        self.options.update_rate
381    }
382
383    /// Get the change threshold for numeric values
384    #[getter]
385    fn change_threshold(&self) -> f32 {
386        self.options.change_threshold
387    }
388
389    /// Get the timeout in milliseconds
390    #[getter]
391    fn timeout(&self) -> u32 {
392        self.options.timeout
393    }
394}