1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
//! The hypergeometric distribution.

use crate::Distribution;
use rand::Rng;
use rand::distributions::uniform::Uniform;
use core::fmt;
#[allow(unused_imports)]
use num_traits::Float;

#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
enum SamplingMethod {
    InverseTransform{ initial_p: f64, initial_x: i64 },
    RejectionAcceptance{
        m: f64,
        a: f64,
        lambda_l: f64,
        lambda_r: f64,
        x_l: f64,
        x_r: f64,
        p1: f64,
        p2: f64,
        p3: f64
    },
}

/// The hypergeometric distribution `Hypergeometric(N, K, n)`.
/// 
/// This is the distribution of successes in samples of size `n` drawn without
/// replacement from a population of size `N` containing `K` success states.
/// It has the density function:
/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`,
/// where `binomial(a, b) = a! / (b! * (a - b)!)`.
/// 
/// The [binomial distribution](crate::Binomial) is the analogous distribution
/// for sampling with replacement. It is a good approximation when the population
/// size is much larger than the sample size.
/// 
/// # Example
/// 
/// ```
/// use rand_distr::{Distribution, Hypergeometric};
///
/// let hypergeo = Hypergeometric::new(60, 24, 7).unwrap();
/// let v = hypergeo.sample(&mut rand::thread_rng());
/// println!("{} is from a hypergeometric distribution", v);
/// ```
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Hypergeometric {
    n1: u64,
    n2: u64,
    k: u64,
    offset_x: i64,
    sign_x: i64,
    sampling_method: SamplingMethod,
}

/// Error type returned from `Hypergeometric::new`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
    /// `total_population_size` is too large, causing floating point underflow.
    PopulationTooLarge,
    /// `population_with_feature > total_population_size`.
    ProbabilityTooLarge,
    /// `sample_size > total_population_size`.
    SampleSizeTooLarge,
}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(match self {
            Error::PopulationTooLarge => "total_population_size is too large causing underflow in geometric distribution",
            Error::ProbabilityTooLarge => "population_with_feature > total_population_size in geometric distribution",
            Error::SampleSizeTooLarge => "sample_size > total_population_size in geometric distribution",
        })
    }
}

#[cfg(feature = "std")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for Error {}

// evaluate fact(numerator.0)*fact(numerator.1) / fact(denominator.0)*fact(denominator.1)
fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, u64)) -> f64 {
    let min_top = u64::min(numerator.0, numerator.1);
    let min_bottom = u64::min(denominator.0, denominator.1);
    // the factorial of this will cancel out:
    let min_all = u64::min(min_top, min_bottom);

    let max_top = u64::max(numerator.0, numerator.1);
    let max_bottom = u64::max(denominator.0, denominator.1);
    let max_all = u64::max(max_top, max_bottom);

    let mut result = 1.0;
    for i in (min_all + 1)..=max_all {
        if i <= min_top {
            result *= i as f64;
        }
        
        if i <= min_bottom {
            result /= i as f64;
        }
        
        if i <= max_top {
            result *= i as f64;
        }
        
        if i <= max_bottom {
            result /= i as f64;
        }
    }
    
    result
}

fn ln_of_factorial(v: f64) -> f64 {
    // the paper calls for ln(v!), but also wants to pass in fractions,
    // so we need to use Stirling's approximation to fill in the gaps:
    v * v.ln() - v
}

impl Hypergeometric {
    /// Constructs a new `Hypergeometric` with the shape parameters
    /// `N = total_population_size`,
    /// `K = population_with_feature`,
    /// `n = sample_size`.
    #[allow(clippy::many_single_char_names)] // Same names as in the reference.
    pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result<Self, Error> {
        if population_with_feature > total_population_size {
            return Err(Error::ProbabilityTooLarge);
        }

        if sample_size > total_population_size {
            return Err(Error::SampleSizeTooLarge);
        }

        // set-up constants as function of original parameters
        let n = total_population_size;
        let (mut sign_x, mut offset_x) = (1, 0);
        let (n1, n2) = {
            // switch around success and failure states if necessary to ensure n1 <= n2
            let population_without_feature = n - population_with_feature;
            if population_with_feature > population_without_feature {
                sign_x = -1;
                offset_x = sample_size as i64;
                (population_without_feature, population_with_feature)
            } else {
                (population_with_feature, population_without_feature)
            }
        };
        // when sampling more than half the total population, take the smaller
        // group as sampled instead (we can then return n1-x instead).
        // 
        // Note: the boundary condition given in the paper is `sample_size < n / 2`;
        // we're deviating here, because when n is even, it doesn't matter whether
        // we switch here or not, but when n is odd `n/2 < n - n/2`, so switching
        // when `k == n/2`, we'd actually be taking the _larger_ group as sampled.
        let k = if sample_size <= n / 2 {
            sample_size
        } else {
            offset_x += n1 as i64 * sign_x;
            sign_x *= -1;
            n - sample_size
        };

        // Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`,
        // where `M` is the mode of the distribution.
        // Use algorithm HIN for the remaining parameter space.
        // 
        // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer
        // generation of hypergeometric random variates.
        // J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145
        // https://www.researchgate.net/publication/233212638
        const HIN_THRESHOLD: f64 = 10.0;
        let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor();
        let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD {
            let (initial_p, initial_x) = if k < n2 {
                (fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), 0)
            } else {
                (fraction_of_products_of_factorials((n1, k), (n, k - n2)), (k - n2) as i64)
            };

            if initial_p <= 0.0 || !initial_p.is_finite() {
                return Err(Error::PopulationTooLarge);
            }

            SamplingMethod::InverseTransform { initial_p, initial_x }
        } else {
            let a = ln_of_factorial(m) +
                ln_of_factorial(n1 as f64 - m) +
                ln_of_factorial(k as f64 - m) +
                ln_of_factorial((n2 - k) as f64 + m);

            let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64;
            let denominator = (n - 1) as f64 * n as f64 * n as f64;
            let d = 1.5 * (numerator / denominator).sqrt() + 0.5;

            let x_l = m - d + 0.5;
            let x_r = m + d + 0.5;

            let k_l = f64::exp(a -
                ln_of_factorial(x_l) -
                ln_of_factorial(n1 as f64 - x_l) -
                ln_of_factorial(k as f64 - x_l) -
                ln_of_factorial((n2 - k) as f64 + x_l));
            let k_r = f64::exp(a -
                ln_of_factorial(x_r - 1.0) -
                ln_of_factorial(n1 as f64 - x_r + 1.0) -
                ln_of_factorial(k as f64 - x_r + 1.0) -
                ln_of_factorial((n2 - k) as f64 + x_r - 1.0));
            
            let numerator = x_l * ((n2 - k) as f64 + x_l);
            let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0);
            let lambda_l = -((numerator / denominator).ln());

            let numerator = (n1 as f64 - x_r + 1.0) * (k as f64 - x_r + 1.0);
            let denominator = x_r * ((n2 - k) as f64 + x_r);
            let lambda_r = -((numerator / denominator).ln());

            // the paper literally gives `p2 + kL/lambdaL` where it (probably)
            // should have been `p2 <- p1 + kL/lambdaL`; another print error?!
            let p1 = 2.0 * d;
            let p2 = p1 + k_l / lambda_l;
            let p3 = p2 + k_r / lambda_r;

            SamplingMethod::RejectionAcceptance {
                m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3
            }
        };

        Ok(Hypergeometric { n1, n2, k, offset_x, sign_x, sampling_method })
    }
}

impl Distribution<u64> for Hypergeometric {
    #[allow(clippy::many_single_char_names)] // Same names as in the reference.
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
        use SamplingMethod::*;

        let Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method } = *self;
        let x = match sampling_method {
            InverseTransform { initial_p: mut p, initial_x: mut x } => {
                let mut u = rng.gen::<f64>();
                while u > p && x < k as i64 { // the paper erroneously uses `until n < p`, which doesn't make any sense
                    u -= p;
                    p *= ((n1 as i64 - x as i64) * (k as i64 - x as i64)) as f64;
                    p /= ((x as i64 + 1) * (n2 as i64 - k as i64 + 1 + x as i64)) as f64;
                    x += 1;
                }
                x
            },
            RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => {
                let distr_region_select = Uniform::new(0.0, p3);
                loop {
                    let (y, v) = loop {
                        let u = distr_region_select.sample(rng);
                        let v = rng.gen::<f64>(); // for the accept/reject decision
            
                        if u <= p1 {
                            // Region 1, central bell
                            let y = (x_l + u).floor();
                            break (y, v);
                        } else if u <= p2 {
                            // Region 2, left exponential tail
                            let y = (x_l + v.ln() / lambda_l).floor();
                            if y as i64 >= i64::max(0, k as i64 - n2 as i64) {
                                let v = v * (u - p1) * lambda_l;
                                break (y, v);
                            }
                        } else {
                            // Region 3, right exponential tail
                            let y = (x_r - v.ln() / lambda_r).floor();
                            if y as u64 <= u64::min(n1, k) {
                                let v = v * (u - p2) * lambda_r;
                                break (y, v);
                            }
                        }
                    };
        
                    // Step 4: Acceptance/Rejection Comparison
                    if m < 100.0 || y <= 50.0 {
                        // Step 4.1: evaluate f(y) via recursive relationship
                        let mut f = 1.0;
                        if m < y {
                            for i in (m as u64 + 1)..=(y as u64) {
                                f *= (n1 - i + 1) as f64 * (k - i + 1) as f64;
                                f /= i as f64 * (n2 - k + i) as f64;
                            }
                        } else {
                            for i in (y as u64 + 1)..=(m as u64) {
                                f *= i as f64 * (n2 - k + i) as f64;
                                f /= (n1 - i) as f64 * (k - i) as f64;
                            }
                        }
        
                        if v <= f { break y as i64; }
                    } else {
                        // Step 4.2: Squeezing
                        let y1 = y + 1.0;
                        let ym = y - m;
                        let yn = n1 as f64 - y + 1.0;
                        let yk = k as f64 - y + 1.0;
                        let nk = n2 as f64 - k as f64 + y1;
                        let r = -ym / y1;
                        let s = ym / yn;
                        let t = ym / yk;
                        let e = -ym / nk;
                        let g = yn * yk / (y1 * nk) - 1.0;
                        let dg = if g < 0.0 {
                            1.0 + g
                        } else {
                            1.0
                        };
                        let gu = g * (1.0 + g * (-0.5 + g / 3.0));
                        let gl = gu - g.powi(4) / (4.0 * dg);
                        let xm = m + 0.5;
                        let xn = n1 as f64 - m + 0.5;
                        let xk = k as f64 - m + 0.5;
                        let nm = n2 as f64 - k as f64 + xm;
                        let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) +
                            xn * s * (1.0 + s * (-0.5 + s / 3.0)) +
                            xk * t * (1.0 + t * (-0.5 + t / 3.0)) +
                            nm * e * (1.0 + e * (-0.5 + e / 3.0)) +
                            y * gu - m * gl + 0.0034;
                        let av = v.ln();
                        if av > ub { continue; }
                        let dr = if r < 0.0 {
                            xm * r.powi(4) / (1.0 + r)
                        } else {
                            xm * r.powi(4)
                        };
                        let ds = if s < 0.0 {
                            xn * s.powi(4) / (1.0 + s)
                        } else {
                            xn * s.powi(4)
                        };
                        let dt = if t < 0.0 {
                            xk * t.powi(4) / (1.0 + t)
                        } else {
                            xk * t.powi(4)
                        };
                        let de = if e < 0.0 {
                            nm * e.powi(4) / (1.0 + e)
                        } else {
                            nm * e.powi(4)
                        };
        
                        if av < ub - 0.25*(dr + ds + dt + de) + (y + m)*(gl - gu) - 0.0078 {
                            break y as i64;
                        }
        
                        // Step 4.3: Final Acceptance/Rejection Test
                        let av_critical = a -
                            ln_of_factorial(y) -
                            ln_of_factorial(n1 as f64 - y) - 
                            ln_of_factorial(k as f64 - y) - 
                            ln_of_factorial((n2 - k) as f64 + y);
                        if v.ln() <= av_critical {
                            break y as i64;
                        }
                    }
                }
            }
        };

        (offset_x + sign_x * x) as u64
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_hypergeometric_invalid_params() {
        assert!(Hypergeometric::new(100, 101, 5).is_err());
        assert!(Hypergeometric::new(100, 10, 101).is_err());
        assert!(Hypergeometric::new(100, 101, 101).is_err());
        assert!(Hypergeometric::new(100, 10, 5).is_ok());
    }

    fn test_hypergeometric_mean_and_variance<R: Rng>(n: u64, k: u64, s: u64, rng: &mut R)
    {
        let distr = Hypergeometric::new(n, k, s).unwrap();

        let expected_mean = s as f64 * k as f64 / n as f64;
        let expected_variance = {
            let numerator = (s * k * (n - k) * (n - s)) as f64;
            let denominator = (n * n * (n - 1)) as f64;
            numerator / denominator
        };

        let mut results = [0.0; 1000];
        for i in results.iter_mut() {
            *i = distr.sample(rng) as f64;
        }

        let mean = results.iter().sum::<f64>() / results.len() as f64;
        assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0);

        let variance =
            results.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>() / results.len() as f64;
        assert!((variance - expected_variance).abs() < expected_variance / 10.0);
    }

    #[test]
    fn test_hypergeometric() {
        let mut rng = crate::test::rng(737);

        // exercise algorithm HIN:
        test_hypergeometric_mean_and_variance(500, 400, 30, &mut rng);
        test_hypergeometric_mean_and_variance(250, 200, 230, &mut rng);
        test_hypergeometric_mean_and_variance(100, 20, 6, &mut rng);
        test_hypergeometric_mean_and_variance(50, 10, 47, &mut rng);

        // exercise algorithm H2PE
        test_hypergeometric_mean_and_variance(5000, 2500, 500, &mut rng);
        test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng);
        test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng);
    }

    #[test]
    fn hypergeometric_distributions_can_be_compared() {
        assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3));
    }
}