1use std::ops::Range;
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11#[derive(Clone, Debug)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct FieldVector {
18 values: Vec<i8>,
19}
20
21impl FieldVector {
22 pub fn new(dims: usize) -> Self {
24 Self {
25 values: vec![0; dims],
26 }
27 }
28
29 pub fn from_raw(values: Vec<i8>) -> Self {
31 Self { values }
32 }
33
34 pub fn from_f32_slice(values: &[f32]) -> Self {
36 let mut result = Self::new(values.len());
37 for (i, &v) in values.iter().enumerate() {
38 result.set(i, v);
39 }
40 result
41 }
42
43 #[inline]
45 pub fn dims(&self) -> usize {
46 self.values.len()
47 }
48
49 #[inline]
51 pub fn get(&self, idx: usize) -> f32 {
52 self.values[idx] as f32 / 100.0
53 }
54
55 #[inline]
57 pub fn set(&mut self, idx: usize, value: f32) {
58 let clamped = value.clamp(-1.0, 1.0);
59 self.values[idx] = (clamped * 100.0).round() as i8;
60 }
61
62 #[inline]
64 pub fn get_raw(&self, idx: usize) -> i8 {
65 self.values[idx]
66 }
67
68 pub fn decay(&mut self, retention: f32) {
70 for v in &mut self.values {
71 let current = *v as f32 / 100.0;
72 let decayed = current * retention;
73 *v = (decayed * 100.0).round() as i8;
74 }
75 }
76
77 pub fn add(&mut self, other: &FieldVector) {
79 debug_assert_eq!(self.dims(), other.dims());
80 for i in 0..self.values.len() {
81 let sum = (self.values[i] as i16) + (other.values[i] as i16);
82 self.values[i] = sum.clamp(-100, 100) as i8;
83 }
84 }
85
86 pub fn add_to_range(&mut self, values: &[f32], range: Range<usize>) {
88 let range_len = range.len();
89 for (i, &v) in values.iter().take(range_len).enumerate() {
90 let idx = range.start + i;
91 if idx < self.values.len() {
92 let current = self.values[idx] as i16;
93 let delta = (v.clamp(-1.0, 1.0) * 100.0).round() as i16;
94 self.values[idx] = (current + delta).clamp(-100, 100) as i8;
95 }
96 }
97 }
98
99 pub fn set_range(&mut self, values: &[f32], range: Range<usize>) {
101 let range_len = range.len();
102 for (i, &v) in values.iter().take(range_len).enumerate() {
103 let idx = range.start + i;
104 if idx < self.values.len() {
105 self.set(idx, v);
106 }
107 }
108 }
109
110 pub fn get_range(&self, range: Range<usize>) -> Vec<f32> {
112 (range.start..range.end.min(self.dims()))
113 .map(|i| self.get(i))
114 .collect()
115 }
116
117 pub fn range_energy(&self, range: Range<usize>) -> f32 {
119 (range.start..range.end.min(self.dims()))
120 .map(|i| {
121 let v = self.get(i);
122 v * v
123 })
124 .sum()
125 }
126
127 pub fn range_active(&self, range: Range<usize>, threshold: f32) -> bool {
129 self.range_energy(range) > threshold
130 }
131
132 pub fn is_zero(&self) -> bool {
134 self.values.iter().all(|&v| v == 0)
135 }
136
137 pub fn non_zero_count(&self) -> usize {
139 self.values.iter().filter(|&&v| v != 0).count()
140 }
141
142 pub fn max_abs(&self) -> f32 {
144 self.values.iter().map(|&v| v.abs()).max().unwrap_or(0) as f32 / 100.0
145 }
146
147 pub fn scale(&mut self, factor: f32) {
149 for v in &mut self.values {
150 let scaled = (*v as f32 / 100.0) * factor;
151 *v = (scaled.clamp(-1.0, 1.0) * 100.0).round() as i8;
152 }
153 }
154
155 pub fn norm(&self) -> f32 {
157 self.values
158 .iter()
159 .map(|&v| {
160 let f = v as f32 / 100.0;
161 f * f
162 })
163 .sum::<f32>()
164 .sqrt()
165 }
166
167 pub fn dot(&self, other: &FieldVector) -> f32 {
169 debug_assert_eq!(self.dims(), other.dims());
170 self.values
171 .iter()
172 .zip(other.values.iter())
173 .map(|(&a, &b)| (a as f32 / 100.0) * (b as f32 / 100.0))
174 .sum()
175 }
176}
177
178impl Default for FieldVector {
179 fn default() -> Self {
180 Self::new(64)
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_new_is_zero() {
190 let v = FieldVector::new(128);
191 assert!(v.is_zero());
192 assert_eq!(v.non_zero_count(), 0);
193 }
194
195 #[test]
196 fn test_set_get() {
197 let mut v = FieldVector::new(64);
198 v.set(0, 0.75);
199 v.set(10, -0.5);
200
201 assert!((v.get(0) - 0.75).abs() < 0.01);
202 assert!((v.get(10) - -0.5).abs() < 0.01);
203 }
204
205 #[test]
206 fn test_clamp() {
207 let mut v = FieldVector::new(64);
208 v.set(0, 1.5); v.set(1, -2.0); assert!((v.get(0) - 1.0).abs() < 0.01);
212 assert!((v.get(1) - -1.0).abs() < 0.01);
213 }
214
215 #[test]
216 fn test_decay() {
217 let mut v = FieldVector::new(64);
218 v.set(0, 1.0);
219 v.set(1, -1.0);
220
221 v.decay(0.5);
222
223 assert!((v.get(0) - 0.5).abs() < 0.01);
224 assert!((v.get(1) - -0.5).abs() < 0.01);
225 }
226
227 #[test]
228 fn test_range_energy() {
229 let mut v = FieldVector::new(128);
230 v.set_range(&vec![0.5; 32], 0..32);
231
232 let energy = v.range_energy(0..32);
233 assert!((energy - 8.0).abs() < 0.5);
235
236 assert!(v.range_energy(64..96).abs() < 0.01);
238 }
239}