1
use std::fmt;
2
use std::iter::repeat_with;
3
use std::sync::atomic::AtomicUsize;
4
use std::sync::atomic::Ordering;
5

            
6
use crossbeam_utils::CachePadded;
7

            
8
use crate::thread_id;
9

            
10
/// A sharded atomic counter
11
///
12
/// `ConcurrentCounter` shards cacheline aligned `AtomicUsizes` across a vector for faster updates in
13
/// a high contention scenarios.
14
pub struct ConcurrentCounter {
15
    cells: Vec<CachePadded<AtomicUsize>>,
16
}
17

            
18
impl fmt::Debug for ConcurrentCounter {
19
1
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20
1
        f.debug_struct("ConcurrentCounter")
21
1
            .field("sum", &self.sum())
22
1
            .field("shards", &self.cells.len())
23
1
            .finish()
24
1
    }
25
}
26

            
27
impl ConcurrentCounter {
28
    /// Creates a new `ConcurrentCounter` with a minimum of the
29
    /// `number_of_threads` cells. Concurrent counter will align the
30
    /// `number_of_threads` to the next power of two for better speed when doing
31
    /// the modulus.
32
    #[inline]
33
3
    pub fn new(value: usize, number_of_threads: usize) -> Self {
34
3
        let number_of_threads = number_of_threads.next_power_of_two();
35
10
        let cells: Vec<CachePadded<AtomicUsize>> = repeat_with(|| CachePadded::new(AtomicUsize::new(0)))
36
3
            .take(number_of_threads)
37
3
            .collect();
38
3

            
39
3
        // Make sure the initial value is correct.
40
3
        cells[0].store(value, Ordering::Relaxed);
41
3

            
42
3
        Self { cells }
43
3
    }
44

            
45
    /// Adds `value` to the counter.
46
8000004
    pub fn add(&self, value: usize) {
47
8000004
        let c = self.cells.get(thread_id() & (self.cells.len() - 1)).unwrap();
48
8000004
        c.fetch_add(value, Ordering::Relaxed);
49
8000004
    }
50

            
51
    /// Computes the max of `value` and the counter.
52
    #[inline]
53
    pub fn max(&self, value: usize) {
54
        let c = self.cells.get(thread_id() & (self.cells.len() - 1)).unwrap();
55
        c.fetch_max(value, Ordering::Relaxed);
56
    }
57

            
58
    /// This will fetch the sum of the concurrent counter be iterating through
59
    /// each of the cells and loading the values with the ordering defined by
60
    /// `ordering`. This is only accurate when all writes have been finished by
61
    /// the time this function is called.
62
    #[inline]
63
4
    pub fn sum(&self) -> usize {
64
18
        self.cells.iter().map(|c| c.load(Ordering::Relaxed)).sum()
65
4
    }
66

            
67
    /// This will fetch the max of the concurrent counter be iterating through
68
    /// each of the cells and loading the values with the ordering defined by
69
    /// `ordering`. This is only accurate when all writes have been finished by
70
    /// the time this function is called.
71
    pub fn total_max(&self) -> usize {
72
        self.cells.iter().map(|c| c.load(Ordering::Relaxed)).max().unwrap()
73
    }
74
}
75

            
76
#[cfg(test)]
77
mod tests {
78
    use crate::ConcurrentCounter;
79

            
80
    #[test]
81
1
    fn basic_test() {
82
1
        let counter = ConcurrentCounter::new(0, 1);
83
1
        counter.add(1);
84
1
        assert_eq!(counter.sum(), 1);
85
1
    }
86

            
87
    #[test]
88
1
    fn increment_multiple_times() {
89
1
        let counter = ConcurrentCounter::new(0, 1);
90
1
        counter.add(1);
91
1
        counter.add(1);
92
1
        counter.add(1);
93
1
        assert_eq!(counter.sum(), 3);
94
1
    }
95

            
96
    #[test]
97
1
    fn multple_threads_incrementing_multiple_times_concurrently() {
98
        const WRITE_COUNT: usize = 1_000_000;
99
        const THREAD_COUNT: usize = 8;
100

            
101
        // Spin up threads that increment the counter concurrently
102
1
        let counter = ConcurrentCounter::new(0, THREAD_COUNT as usize);
103
1

            
104
1
        std::thread::scope(|s| {
105
9
            for _ in 0..THREAD_COUNT {
106
8
                s.spawn(|| {
107
8000008
                    for _ in 0..WRITE_COUNT {
108
8000000
                        counter.add(1);
109
8000000
                    }
110
8
                });
111
8
            }
112
1
        });
113
1

            
114
1
        assert_eq!(counter.sum(), THREAD_COUNT * WRITE_COUNT);
115

            
116
1
        assert_eq!(
117
1
            format!("Counter is: {counter:?}"),
118
1
            "Counter is: ConcurrentCounter { sum: 8000000, shards: 8 }"
119
1
        )
120
1
    }
121
}