1 | //! Bivariate analysis |
---|---|

2 | |

3 | mod bootstrap; |

4 | pub mod regression; |

5 | mod resamples; |

6 | |

7 | use crate::stats::bivariate::resamples::Resamples; |

8 | use crate::stats::float::Float; |

9 | use crate::stats::tuple::{Tuple, TupledDistributionsBuilder}; |

10 | use crate::stats::univariate::Sample; |

11 | #[cfg(feature = "rayon")] |

12 | use rayon::iter::{IntoParallelIterator, ParallelIterator}; |

13 | |

14 | /// Bivariate `(X, Y)` data |

15 | /// |

16 | /// Invariants: |

17 | /// |

18 | /// - No `NaN`s in the data |

19 | /// - At least two data points in the set |

20 | pub struct Data<'a, X, Y>(&'a [X], &'a [Y]); |

21 | |

22 | impl<'a, X, Y> Copy for Data<'a, X, Y> {} |

23 | |

24 | #[cfg_attr(feature = "cargo-clippy", allow(clippy::expl_impl_clone_on_copy))] |

25 | impl<'a, X, Y> Clone for Data<'a, X, Y> { |

26 | fn clone(&self) -> Data<'a, X, Y> { |

27 | *self |

28 | } |

29 | } |

30 | |

31 | impl<'a, X, Y> Data<'a, X, Y> { |

32 | /// Returns the length of the data set |

33 | pub fn len(&self) -> usize { |

34 | self.0.len() |

35 | } |

36 | |

37 | /// Iterate over the data set |

38 | pub fn iter(&self) -> Pairs<'a, X, Y> { |

39 | Pairs { |

40 | data: *self, |

41 | state: 0, |

42 | } |

43 | } |

44 | } |

45 | |

46 | impl<'a, X, Y> Data<'a, X, Y> |

47 | where |

48 | X: Float, |

49 | Y: Float, |

50 | { |

51 | /// Creates a new data set from two existing slices |

52 | pub fn new(xs: &'a [X], ys: &'a [Y]) -> Data<'a, X, Y> { |

53 | assert!( |

54 | xs.len() == ys.len() |

55 | && xs.len() > 1 |

56 | && xs.iter().all(|x| !x.is_nan()) |

57 | && ys.iter().all(|y| !y.is_nan()) |

58 | ); |

59 | |

60 | Data(xs, ys) |

61 | } |

62 | |

63 | // TODO Remove the `T` parameter in favor of `S::Output` |

64 | /// Returns the bootstrap distributions of the parameters estimated by the `statistic` |

65 | /// |

66 | /// - Multi-threaded |

67 | /// - Time: `O(nresamples)` |

68 | /// - Memory: `O(nresamples)` |

69 | pub fn bootstrap<T, S>(&self, nresamples: usize, statistic: S) -> T::Distributions |

70 | where |

71 | S: Fn(Data<X, Y>) -> T + Sync, |

72 | T: Tuple + Send, |

73 | T::Distributions: Send, |

74 | T::Builder: Send, |

75 | { |

76 | #[cfg(feature = "rayon")] |

77 | { |

78 | (0..nresamples) |

79 | .into_par_iter() |

80 | .map_init( |

81 | || Resamples::new(*self), |

82 | |resamples, _| statistic(resamples.next()), |

83 | ) |

84 | .fold( |

85 | || T::Builder::new(0), |

86 | |mut sub_distributions, sample| { |

87 | sub_distributions.push(sample); |

88 | sub_distributions |

89 | }, |

90 | ) |

91 | .reduce( |

92 | || T::Builder::new(0), |

93 | |mut a, mut b| { |

94 | a.extend(&mut b); |

95 | a |

96 | }, |

97 | ) |

98 | .complete() |

99 | } |

100 | #[cfg(not(feature = "rayon"))] |

101 | { |

102 | let mut resamples = Resamples::new(*self); |

103 | (0..nresamples) |

104 | .map(|_| statistic(resamples.next())) |

105 | .fold(T::Builder::new(0), |mut sub_distributions, sample| { |

106 | sub_distributions.push(sample); |

107 | sub_distributions |

108 | }) |

109 | .complete() |

110 | } |

111 | } |

112 | |

113 | /// Returns a view into the `X` data |

114 | pub fn x(&self) -> &'a Sample<X> { |

115 | Sample::new(self.0) |

116 | } |

117 | |

118 | /// Returns a view into the `Y` data |

119 | pub fn y(&self) -> &'a Sample<Y> { |

120 | Sample::new(self.1) |

121 | } |

122 | } |

123 | |

124 | /// Iterator over `Data` |

125 | pub struct Pairs<'a, X: 'a, Y: 'a> { |

126 | data: Data<'a, X, Y>, |

127 | state: usize, |

128 | } |

129 | |

130 | impl<'a, X, Y> Iterator for Pairs<'a, X, Y> { |

131 | type Item = (&'a X, &'a Y); |

132 | |

133 | fn next(&mut self) -> Option<(&'a X, &'a Y)> { |

134 | if self.state < self.data.len() { |

135 | let i = self.state; |

136 | self.state += 1; |

137 | |

138 | // This is safe because i will always be < self.data.{0,1}.len() |

139 | debug_assert!(i < self.data.0.len()); |

140 | debug_assert!(i < self.data.1.len()); |

141 | unsafe { Some((self.data.0.get_unchecked(i), self.data.1.get_unchecked(i))) } |

142 | } else { |

143 | None |

144 | } |

145 | } |

146 | } |

147 |