1
use crate::MAX_SEQ_COUNT;
2
use core::cell::Cell;
3
use paste::paste;
4
#[cfg(feature = "std")]
5
pub use stdmod::*;
6

            
7
/// Core trait for objects which can provide a sequence count.
8
///
9
/// The core functions are not mutable on purpose to allow easier usage with
10
/// static structs when using the interior mutability pattern. This can be achieved by using
11
/// [Cell], [core::cell::RefCell] or atomic types.
12
pub trait SequenceCountProvider {
13
    type Raw: Into<u64>;
14
    const MAX_BIT_WIDTH: usize;
15

            
16
    fn get(&self) -> Self::Raw;
17

            
18
    fn increment(&self);
19

            
20
    fn get_and_increment(&self) -> Self::Raw {
21
        let val = self.get();
22
        self.increment();
23
        val
24
    }
25
}
26

            
27
#[derive(Clone)]
28
pub struct SeqCountProviderSimple<T: Copy> {
29
    seq_count: Cell<T>,
30
    max_val: T,
31
}
32

            
33
macro_rules! impl_for_primitives {
34
    ($($ty: ident,)+) => {
35
        $(
36
            paste! {
37
                impl SeqCountProviderSimple<$ty> {
38
4
                    pub fn [<new_custom_max_val_ $ty>](max_val: $ty) -> Self {
39
4
                        Self {
40
4
                            seq_count: Cell::new(0),
41
4
                            max_val,
42
4
                        }
43
4
                    }
44
4
                    pub fn [<new_ $ty>]() -> Self {
45
4
                        Self {
46
4
                            seq_count: Cell::new(0),
47
4
                            max_val: $ty::MAX
48
4
                        }
49
4
                    }
50
                }
51

            
52
                impl Default for SeqCountProviderSimple<$ty> {
53
2
                    fn default() -> Self {
54
2
                        Self::[<new_ $ty>]()
55
2
                    }
56
                }
57

            
58
                impl SequenceCountProvider for SeqCountProviderSimple<$ty> {
59
                    type Raw = $ty;
60
                    const MAX_BIT_WIDTH: usize = core::mem::size_of::<Self::Raw>() * 8;
61

            
62
12
                    fn get(&self) -> Self::Raw {
63
12
                        self.seq_count.get()
64
12
                    }
65

            
66
33280
                    fn increment(&self) {
67
33280
                        self.get_and_increment();
68
33280
                    }
69

            
70
33288
                    fn get_and_increment(&self) -> Self::Raw {
71
33288
                        let curr_count = self.seq_count.get();
72
33288

            
73
33288
                        if curr_count == self.max_val {
74
4
                            self.seq_count.set(0);
75
33284
                        } else {
76
33284
                            self.seq_count.set(curr_count + 1);
77
33284
                        }
78
33288
                        curr_count
79
33288
                    }
80
                }
81
            }
82
        )+
83
    }
84
}
85

            
86
impl_for_primitives!(u8, u16, u32, u64,);
87

            
88
/// This is a sequence count provider which wraps around at [MAX_SEQ_COUNT].
89
#[derive(Clone)]
90
pub struct CcsdsSimpleSeqCountProvider {
91
    provider: SeqCountProviderSimple<u16>,
92
}
93

            
94
impl Default for CcsdsSimpleSeqCountProvider {
95
4
    fn default() -> Self {
96
4
        Self {
97
4
            provider: SeqCountProviderSimple::new_custom_max_val_u16(MAX_SEQ_COUNT),
98
4
        }
99
4
    }
100
}
101

            
102
impl SequenceCountProvider for CcsdsSimpleSeqCountProvider {
103
    type Raw = u16;
104
    const MAX_BIT_WIDTH: usize = core::mem::size_of::<Self::Raw>() * 8;
105
    delegate::delegate! {
106
        to self.provider {
107
6
            fn get(&self) -> u16;
108
32768
            fn increment(&self);
109
4
            fn get_and_increment(&self) -> u16;
110
        }
111
    }
112
}
113

            
114
#[cfg(feature = "std")]
115
pub mod stdmod {
116
    use super::*;
117
    use std::sync::{Arc, Mutex};
118

            
119
    macro_rules! sync_clonable_seq_counter_impl {
120
         ($($ty: ident,)+) => {
121
             $(paste! {
122
                 /// These sequence counters can be shared between threads and can also be
123
                 /// configured to wrap around at specified maximum values. Please note that
124
                 /// that the API provided by this class will not panic und [Mutex] lock errors,
125
                 /// but it will yield 0 for the getter functions.
126
                 #[derive(Clone, Default)]
127
                 pub struct [<SeqCountProviderSync $ty:upper>] {
128
                     seq_count: Arc<Mutex<$ty>>,
129
                     max_val: $ty
130
                 }
131

            
132
                 impl [<SeqCountProviderSync $ty:upper>] {
133
4
                     pub fn new() -> Self {
134
4
                        Self::new_with_max_val($ty::MAX)
135
4
                     }
136

            
137
6
                     pub fn new_with_max_val(max_val: $ty) -> Self {
138
6
                         Self {
139
6
                             seq_count: Arc::default(),
140
6
                             max_val
141
6
                         }
142
6
                     }
143
                 }
144
                 impl SequenceCountProvider for [<SeqCountProviderSync $ty:upper>] {
145
                    type Raw = $ty;
146
                    const MAX_BIT_WIDTH: usize = core::mem::size_of::<Self::Raw>() * 8;
147

            
148
8
                    fn get(&self) -> $ty {
149
8
                        match self.seq_count.lock() {
150
8
                            Ok(counter) => *counter,
151
                            Err(_) => 0
152
                        }
153
8
                    }
154

            
155
770
                    fn increment(&self) {
156
770
                        self.get_and_increment();
157
770
                    }
158

            
159
774
                    fn get_and_increment(&self) -> $ty {
160
774
                        match self.seq_count.lock() {
161
774
                            Ok(mut counter) => {
162
774
                                let val = *counter;
163
774
                                if val == self.max_val {
164
4
                                    *counter = 0;
165
770
                                } else {
166
770
                                    *counter += 1;
167
770
                                }
168
774
                                val
169
                            }
170
                            Err(_) => 0,
171
                        }
172
774
                    }
173
                 }
174
             })+
175
         }
176
    }
177
    sync_clonable_seq_counter_impl!(u8, u16, u32, u64,);
178
}
179

            
180
#[cfg(test)]
181
mod tests {
182
    use crate::seq_count::{
183
        CcsdsSimpleSeqCountProvider, SeqCountProviderSimple, SeqCountProviderSyncU8,
184
        SequenceCountProvider,
185
    };
186
    use crate::MAX_SEQ_COUNT;
187

            
188
    #[test]
189
2
    fn test_u8_counter() {
190
2
        let u8_counter = SeqCountProviderSimple::<u8>::default();
191
2
        assert_eq!(u8_counter.get(), 0);
192
2
        assert_eq!(u8_counter.get_and_increment(), 0);
193
2
        assert_eq!(u8_counter.get_and_increment(), 1);
194
2
        assert_eq!(u8_counter.get(), 2);
195
2
    }
196

            
197
    #[test]
198
2
    fn test_u8_counter_overflow() {
199
2
        let u8_counter = SeqCountProviderSimple::new_u8();
200
514
        for _ in 0..256 {
201
512
            u8_counter.increment();
202
512
        }
203
2
        assert_eq!(u8_counter.get(), 0);
204
2
    }
205

            
206
    #[test]
207
2
    fn test_ccsds_counter() {
208
2
        let ccsds_counter = CcsdsSimpleSeqCountProvider::default();
209
2
        assert_eq!(ccsds_counter.get(), 0);
210
2
        assert_eq!(ccsds_counter.get_and_increment(), 0);
211
2
        assert_eq!(ccsds_counter.get_and_increment(), 1);
212
2
        assert_eq!(ccsds_counter.get(), 2);
213
2
    }
214

            
215
    #[test]
216
2
    fn test_ccsds_counter_overflow() {
217
2
        let ccsds_counter = CcsdsSimpleSeqCountProvider::default();
218
32768
        for _ in 0..MAX_SEQ_COUNT + 1 {
219
32768
            ccsds_counter.increment();
220
32768
        }
221
2
        assert_eq!(ccsds_counter.get(), 0);
222
2
    }
223

            
224
    #[test]
225
2
    fn test_atomic_ref_counters() {
226
2
        let sync_u8_counter = SeqCountProviderSyncU8::new();
227
2
        assert_eq!(sync_u8_counter.get(), 0);
228
2
        assert_eq!(sync_u8_counter.get_and_increment(), 0);
229
2
        assert_eq!(sync_u8_counter.get_and_increment(), 1);
230
2
        assert_eq!(sync_u8_counter.get(), 2);
231
2
    }
232

            
233
    #[test]
234
2
    fn test_atomic_ref_counters_overflow() {
235
2
        let sync_u8_counter = SeqCountProviderSyncU8::new();
236
512
        for _ in 0..u8::MAX as u16 + 1 {
237
512
            sync_u8_counter.increment();
238
512
        }
239
2
        assert_eq!(sync_u8_counter.get(), 0);
240
2
    }
241

            
242
    #[test]
243
2
    fn test_atomic_ref_counters_overflow_custom_max_val() {
244
2
        let sync_u8_counter = SeqCountProviderSyncU8::new_with_max_val(128);
245
260
        for _ in 0..129 {
246
258
            sync_u8_counter.increment();
247
258
        }
248
2
        assert_eq!(sync_u8_counter.get(), 0);
249
2
    }
250
}