#![cfg(feature = "alloc")]
use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
use core::fmt;
use num_traits::{Float, NumCast};
use rand::Rng;
#[cfg(feature = "serde_with")] use serde_with::serde_as;
use alloc::{boxed::Box, vec, vec::Vec};
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde_with", serde_as)]
struct DirichletFromGamma<F, const N: usize>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
samplers: [Gamma<F>; N],
}
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum DirichletFromGammaError {
GammmaNewFailed,
GammaArrayCreationFailed,
}
impl<F, const N: usize> DirichletFromGamma<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
fn new(alpha: [F; N]) -> Result<DirichletFromGamma<F, N>, DirichletFromGammaError> {
let mut gamma_dists = Vec::new();
for a in alpha {
let dist =
Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?;
gamma_dists.push(dist);
}
Ok(DirichletFromGamma {
samplers: gamma_dists
.try_into()
.map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?,
})
}
}
impl<F, const N: usize> Distribution<[F; N]> for DirichletFromGamma<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
let mut samples = [F::zero(); N];
let mut sum = F::zero();
for (s, g) in samples.iter_mut().zip(self.samplers.iter()) {
*s = g.sample(rng);
sum = sum + *s;
}
let invacc = F::one() / sum;
for s in samples.iter_mut() {
*s = *s * invacc;
}
samples
}
}
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
struct DirichletFromBeta<F, const N: usize>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
samplers: Box<[Beta<F>]>,
}
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum DirichletFromBetaError {
BetaNewFailed,
}
impl<F, const N: usize> DirichletFromBeta<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
fn new(alpha: [F; N]) -> Result<DirichletFromBeta<F, N>, DirichletFromBetaError> {
let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1];
for k in 0..(N - 2) {
alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k];
}
let mut beta_dists = Vec::new();
for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) {
let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?;
beta_dists.push(dist);
}
Ok(DirichletFromBeta {
samplers: beta_dists.into_boxed_slice(),
})
}
}
impl<F, const N: usize> Distribution<[F; N]> for DirichletFromBeta<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
let mut samples = [F::zero(); N];
let mut acc = F::one();
for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) {
let beta_sample = beta.sample(rng);
*s = acc * beta_sample;
acc = acc * (F::one() - beta_sample);
}
samples[N - 1] = acc;
samples
}
}
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde_with", serde_as)]
enum DirichletRepr<F, const N: usize>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
FromGamma(DirichletFromGamma<F, N>),
FromBeta(DirichletFromBeta<F, N>),
}
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[cfg_attr(feature = "serde_with", serde_as)]
#[derive(Clone, Debug, PartialEq)]
pub struct Dirichlet<F, const N: usize>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
repr: DirichletRepr<F, N>,
}
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
AlphaTooShort,
AlphaTooSmall,
AlphaSubnormal,
AlphaInfinite,
FailedToCreateGamma,
FailedToCreateBeta,
SizeTooSmall,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::AlphaTooShort | Error::SizeTooSmall => {
"less than 2 dimensions in Dirichlet distribution"
}
Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution",
Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution",
Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution",
Error::FailedToCreateGamma => {
"failed to create required Gamma distribution for Dirichlet distribution"
}
Error::FailedToCreateBeta => {
"failed to create required Beta distribition for Dirichlet distribution"
}
})
}
}
#[cfg(feature = "std")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
impl std::error::Error for Error {}
impl<F, const N: usize> Dirichlet<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
#[inline]
pub fn new(alpha: [F; N]) -> Result<Dirichlet<F, N>, Error> {
if N < 2 {
return Err(Error::AlphaTooShort);
}
for &ai in alpha.iter() {
if !(ai > F::zero()) {
return Err(Error::AlphaTooSmall);
}
if ai.is_infinite() {
return Err(Error::AlphaInfinite);
}
if !ai.is_normal() {
return Err(Error::AlphaSubnormal);
}
}
if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) {
let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?;
Ok(Dirichlet {
repr: DirichletRepr::FromBeta(dist),
})
} else {
let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?;
Ok(Dirichlet {
repr: DirichletRepr::FromGamma(dist),
})
}
}
}
impl<F, const N: usize> Distribution<[F; N]> for Dirichlet<F, N>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> [F; N] {
match &self.repr {
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng),
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use alloc::vec::Vec;
#[test]
fn test_dirichlet() {
let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
let mut rng = crate::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
.into_iter()
.map(|x| {
assert!(x > 0.0);
x
})
.collect();
}
#[test]
#[should_panic]
fn test_dirichlet_invalid_length() {
Dirichlet::new([0.5]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_zero() {
Dirichlet::new([0.1, 0.0, 0.3]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_negative() {
Dirichlet::new([0.1, -1.5, 0.3]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_nan() {
Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_subnormal() {
Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap();
}
#[test]
#[should_panic]
fn test_dirichlet_alpha_inf() {
Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap();
}
#[test]
fn dirichlet_distributions_can_be_compared() {
assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0]));
}
fn check_dirichlet_means<const N: usize>(alpha: [f64; N], n: i32, rtol: f64, seed: u64) {
let d = Dirichlet::new(alpha).unwrap();
let mut rng = crate::test::rng(seed);
let mut sums = [0.0; N];
for _ in 0..n {
let samples = d.sample(&mut rng);
for i in 0..N {
sums[i] += samples[i];
}
}
let sample_mean = sums.map(|x| x / n as f64);
let alpha_sum: f64 = alpha.iter().sum();
let expected_mean = alpha.map(|x| x / alpha_sum);
for i in 0..N {
assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
}
}
#[test]
fn test_dirichlet_means() {
let n = 20000;
let rtol = 2e-2;
let seed = 1317624576693539401;
check_dirichlet_means([0.5, 0.25], n, rtol, seed);
check_dirichlet_means([123.0, 75.0], n, rtol, seed);
check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed);
check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed);
}
#[test]
fn test_dirichlet_means_very_small_alpha() {
let alpha = [0.001; 3];
let n = 10000;
let rtol = 1e-2;
let seed = 1317624576693539401;
check_dirichlet_means(alpha, n, rtol, seed);
}
#[test]
fn test_dirichlet_means_small_alpha() {
let alpha = [0.05, 0.025, 0.075, 0.05];
let n = 150000;
let rtol = 1e-3;
let seed = 1317624576693539401;
check_dirichlet_means(alpha, n, rtol, seed);
}
}