Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement weighted sampling #72

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b2c96e0
Implement weighted sampling
rstub Oct 7, 2023
9a436a7
additional test
rstub Oct 7, 2023
e7407e2
Reduce false positive misses in test coverage
rstub Oct 7, 2023
dbe3072
Implement weighted samlping w/o replacement in C++
rstub Oct 8, 2023
9c9402b
Remove unnecessary struct
rstub Oct 8, 2023
29ce7b4
Replace int with INT
rstub Oct 10, 2023
8414aec
Add some checks
rstub Oct 10, 2023
42814ff
Add fair and biased coin for n == 2
rstub Oct 10, 2023
60cdf8f
Documentation
rstub Oct 10, 2023
ce16f22
Allow for large output size to trigger dqsample_num
rstub Oct 11, 2023
8ed7d33
Factor out creation of alias table
rstub Oct 12, 2023
41498c4
Remove a compiler warning
rstub Oct 12, 2023
8c7683c
Add set-based no-replacement methods for weighted sampling
rstub Oct 12, 2023
54aa1fd
Add messages to static_assert to not force usage of C++17
rstub Oct 12, 2023
36b01c9
Initial rules when to use which algorithm
rstub Oct 12, 2023
66c7bb7
Draft documentation
rstub Oct 12, 2023
60a5303
Test both weighted coin options
rstub Oct 12, 2023
06e2ced
Add references
rstub Oct 13, 2023
ed142a3
Use n/size instead of m/n as arguments
rstub Oct 13, 2023
22c5570
Add more formal references
rstub Oct 13, 2023
17f6b94
Document n=2 case
rstub Oct 13, 2023
cd6285f
Add news and bump version
rstub Oct 13, 2023
f736a66
Fix off-by-one error
rstub Oct 18, 2023
d7dcac3
C++11 does not allow `auto` with arguments
rstub Jan 20, 2024
386558a
Changes from version 0.3.2
rstub Jan 20, 2024
92be4ec
Merge changes from master
rstub Jan 27, 2024
499ab57
Compare the two weights directly for n=2
rstub Apr 13, 2024
81fc04d
Merge branch 'main' into feature/weighted-sampling-2
rstub Apr 22, 2024
d489a1b
Update sampling code to not use deprected dqrng::uniform01()
rstub Apr 22, 2024
557bf06
Get closer to original touchstone config
rstub Apr 22, 2024
74ef89a
Use GHA files from styler package
rstub Apr 22, 2024
482687a
Update performance testing
rstub Apr 23, 2024
1162bf0
Merge branch 'main' into feature/weighted-sampling-2
rstub Apr 23, 2024
60cb414
Add tests with uneven weight distribution
rstub Apr 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* New methods `variate<dist>(param)`, `generate<dist>(container, param)` etc. using and inspired by [`randutils`](https://www.pcg-random.org/posts/ease-of-use-without-loss-of-power.html).
* The scalar functions `dqrng::runif`, `dqrng::rnorm` and `dqrng::rexp` available from `dqrng.h` have been deprecated and will be removed in a future release. Please use the more flexible and faster `dqrng::random_64bit_accessor` together with `variate<Dist>()` instead. The same applies to `dqrng::uniform01` from `dqrng_distribution.h`, which can be replaced by the member function `dqrng::random_64bit_generator::uniform01`.
* New template function `dqrng::extra::parallel_generate` in `dqrng_extra/parallel_generate.h` as an example for using the global RNG in a parallel context (fixing [#77](https://github.com/daqana/dqrng/issues/77) in [#82](https://github.com/daqana/dqrng/issues/82) together with Philippe Grosjean)

* Implement weighted sampling with and without replacement. ([#72](https://github.com/daqana/dqrng/pull/72) fixing [#18](https://github.com/daqana/dqrng/issues/18), [#45](https://github.com/daqana/dqrng/issues/45) and [#52](https://github.com/daqana/dqrng/issues/52))

# dgrng 0.3.2

Expand Down
8 changes: 4 additions & 4 deletions R/dqsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ dqsample <- function(x, size, replace = FALSE, prob = NULL) {
##' @rdname dqsample
##' @export
dqsample.int <- function(n, size = n, replace = FALSE, prob = NULL) {
if (!is.null(prob)) {
warning("Using 'prob' is not supported yet. Using default 'sample.int'.")
sample.int(n, size, replace, prob)
} else if (n <= .Machine$integer.max)
if (!is.null(prob))
stopifnot(n == length(prob))

if (n <= .Machine$integer.max && size <= .Machine$integer.max)
dqsample_int(n, size, replace, prob, 1L)
else
dqsample_num(n, size, replace, prob, 1L)
Expand Down
23 changes: 23 additions & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,29 @@ bm[, 1:4]

Note that sampling from `10^10` elements triggers "long-vector support" in R.

It is also possible to use weighted sampling both with replacement:

```{r sampling3}
m <- 1e6
n <- 1e4
prob <- dqrunif(m)
bm <- bench::mark(sample.int(m, n, replace = TRUE, prob = prob),
dqsample.int(m, n, replace = TRUE, prob = prob),
check = FALSE)
bm[, 1:4]
```

And without replacement:

```{r sampling4}
bm <- bench::mark(sample.int(m, n, prob = prob),
dqsample.int(m, n, prob = prob),
check = FALSE)
bm[, 1:4]
```

Especially for weighted sampling without replacement the performance advantage compared with R's default methods is particularly large.

In addition the RNGs provide support for multiple independent streams for parallel usage:

```{r parallel}
Expand Down
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,40 @@ bm[, 1:4]
Note that sampling from `10^10` elements triggers “long-vector support”
in R.

It is also possible to use weighted sampling both with replacement:

``` r
m <- 1e6
n <- 1e4
prob <- dqrunif(m)
bm <- bench::mark(sample.int(m, n, replace = TRUE, prob = prob),
dqsample.int(m, n, replace = TRUE, prob = prob),
check = FALSE)
bm[, 1:4]
#> # A tibble: 2 × 4
#> expression min median `itr/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl>
#> 1 sample.int(m, n, replace = TRUE, prob = prob) 22.94ms 24.33ms 39.8
#> 2 dqsample.int(m, n, replace = TRUE, prob = prob) 5.76ms 5.96ms 166.
```

And without replacement:

``` r
bm <- bench::mark(sample.int(m, n, prob = prob),
dqsample.int(m, n, prob = prob),
check = FALSE)
bm[, 1:4]
#> # A tibble: 2 × 4
#> expression min median `itr/sec`
#> <bch:expr> <bch:tm> <bch:tm> <dbl>
#> 1 sample.int(m, n, prob = prob) 14.34s 14.34s 0.0697
#> 2 dqsample.int(m, n, prob = prob) 5.09ms 5.34ms 184.
```

Especially for weighted sampling without replacement the performance
advantage compared with R’s default methods is particularly large.

In addition the RNGs provide support for multiple independent streams
for parallel usage:

Expand Down
220 changes: 217 additions & 3 deletions inst/include/dqrng_sample.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright 2018-2019 Ralf Stubner (daqana GmbH)
// Copyright 2022-2023 Ralf Stubner
// Copyright 2022-2024 Ralf Stubner
//
// This file is part of dqrng.
//
Expand All @@ -20,8 +20,9 @@
#define DQRNG_SAMPLE_H 1

#include <mystdint.h>
#include <queue>
#include <Rcpp.h>
#include <dqrng_types.h>
#include <dqrng_distribution.h>
#include <minimal_int_set.h>

namespace dqrng {
Expand All @@ -30,10 +31,109 @@
inline VEC replacement(dqrng::random_64bit_generator &rng, INT n, INT size, int offset) {
VEC result(size);
std::generate(result.begin(), result.end(),
[n, offset, &rng] () {return (offset + rng(n));});
[&n, &offset, &rng] () {return (offset + rng(n));});
return result;
}

template<typename VEC, typename INT>
inline VEC fair_coin(dqrng::random_64bit_generator &rng, INT n, INT size, int head, int tail) {
VEC result(size);
INT k = 0;
while (k < size) {
uint64_t bits = rng();

for (INT j = 0; j < 64 && k < size; ++k, ++j)
result[k] = (bits >> j) & 1 ? head : tail;
}
return result;
}


template<typename VEC, typename INT, typename FVEC>
inline VEC biased_coin(dqrng::random_64bit_generator &rng, INT n, INT size, FVEC prob, int head, int tail) {
VEC result(size);

// smaller probability scaled by 2^64 in order to compare directly with RNG output
uint64_t p;
if (prob[0] <= prob[1])
p = 0x1p64 * prob[0] / (prob[0] + prob[1]);
else {
p = 0x1p64 * prob[1] / (prob[0] + prob[1]);
std::swap(head, tail);
}
std::generate(result.begin(), result.end(),
[&p, &rng, &head, &tail] () {
return rng() < p ? head : tail;
});
return result;
}

// stochastic acceptance
template<typename VEC, typename INT, typename FVEC>
inline VEC replacement_prob(dqrng::random_64bit_generator &rng, INT n, INT size, FVEC prob, double max_prob, int offset) {
VEC result(size);
std::generate(result.begin(), result.end(),
[&n, &prob, &max_prob, &rng, &offset] () {
while (true) {
INT index = rng(n);
if (rng.uniform01() < prob[index] / max_prob)
return index + offset;
}
});
return result;
}

// create table for alias method (Walker/Voss)
template<typename INT, typename FVEC>
inline std::vector<std::pair<double,INT>> create_alias_table(INT n, FVEC prob, double prob_sum) {
std::vector<std::pair<double,INT>> prob_alias(n);
std::queue<INT> high;
std::queue<INT> low;
for(INT i = 0; i < n; ++i) {
prob_alias[i].first = prob[i] * n / prob_sum;
if (prob_alias[i].first < 1.0)
low.push(i);
else
high.push(i);
}
while(!low.empty() && !high.empty()) {
INT l = low.front();
low.pop();
INT h = high.front();
prob_alias[l].second = h;
prob_alias[h].first = (prob_alias[h].first + prob_alias[l].first) - 1.0;
if (prob_alias[h].first < 1.0) {
low.push(h);
high.pop();
}
}
while (!low.empty()) {
prob_alias[low.front()].first = 1.0;
low.pop();
}
while (!high.empty()) {
prob_alias[high.front()].first = 1.0;
high.pop();

Check warning on line 116 in inst/include/dqrng_sample.h

View check run for this annotation

Codecov / codecov/patch

inst/include/dqrng_sample.h#L115-L116

Added lines #L115 - L116 were not covered by tests
}
return prob_alias;
}

// alias method (Walker/Voss)
template<typename VEC, typename INT, typename FVEC>
inline VEC replacement_alias(dqrng::random_64bit_generator &rng, INT n, INT size, FVEC prob, double prob_sum, int offset) {
VEC result(size);
std::vector<std::pair<double,INT>> prob_alias = create_alias_table(n, prob, prob_sum);
std::generate(result.begin(), result.end(),
[&n, &prob_alias, &rng, &offset] () {
INT index = rng(n);
return (rng.uniform01() < prob_alias[index].first) ?
index + offset : prob_alias[index].second + offset;
});

return result;
}

// Fisher-Yates shuffle
template<typename VEC, typename INT>
inline VEC no_replacement_shuffle(dqrng::random_64bit_generator &rng, INT n, INT size, int offset) {
VEC tmp(n);
Expand All @@ -47,6 +147,7 @@
return VEC(tmp.begin(), tmp.begin() + size);
}

// set-based rejection sampling
template<typename VEC, typename INT, typename SET>
inline VEC no_replacement_set(dqrng::random_64bit_generator &rng, INT n, INT size, int offset) {
VEC result(size);
Expand All @@ -61,9 +162,72 @@
return result;
}

// exponential rank (Efraimidis/Spirakis)
template<typename VEC, typename INT, typename FVEC>
inline VEC no_replacement_exp(dqrng::random_64bit_generator &rng, INT n, INT size, FVEC prob, int offset) {
VEC index(n);
std::iota(index.begin(), index.end(), 0);
FVEC weight(n);
dqrng::exponential_distribution exponential;
std::transform(prob.begin(), prob.end(), weight.begin(),
[&rng, &exponential] (double x) {return exponential(rng) / x;});
std::partial_sort(index.begin(), index.begin() + size, index.end(),
[&weight](size_t i1, size_t i2) {return weight[i1] < weight[i2];});

VEC result(size);
std::transform(index.begin(), index.begin() + size, result.begin(),
[&offset] (INT x) {return x + offset;});
return result;
}

// set-based rejection sampling with stochastic acceptance
template<typename VEC, typename INT, typename SET, typename FVEC>
inline VEC no_replacement_prob_set(dqrng::random_64bit_generator &rng, INT n, INT size, FVEC prob, double max_prob, int offset) {
VEC result(size);
SET elems(n, size);

for (INT i = 0; i < size; ++i) {
INT v;
do {
do {
v = rng(n);
} while (rng.uniform01() >= prob[v] / max_prob);
} while (!elems.insert(v));
result[i] = (offset + v);
}
return result;
}

// set-based rejection sampling with alias selection
template<typename VEC, typename INT, typename SET, typename FVEC>
inline VEC no_replacement_alias_set(dqrng::random_64bit_generator &rng, INT n, INT size, FVEC prob, double prob_sum, int offset) {
VEC result(size);
std::vector<std::pair<double,INT>> prob_alias = create_alias_table(n, prob, prob_sum);
SET elems(n, size);
for (INT i = 0; i < size; ++i) {
INT v;
do {
INT index = rng(n);
v = (rng.uniform01() < prob_alias[index].first) ?
index : prob_alias[index].second;
} while (!elems.insert(v));
result[i] = (offset + v);
}
return result;
}

template<typename VEC, typename INT>
inline VEC sample(dqrng::random_64bit_generator &rng, INT n, INT size, bool replace, int offset = 0) {
static_assert(std::is_integral<INT>::value && std::is_unsigned<INT>::value,
"Provided INT has the wrong type.");
static_assert(std::is_floating_point<typename VEC::value_type>::value ||
std::is_integral<typename VEC::value_type>::value ||
std::is_reference<typename VEC::value_type>::value,
"Provided VEC has the wrong type.");
if (replace || size <= 1) {
if (n == 2)
return dqrng::sample::fair_coin<VEC, INT>(rng, n, size, offset, 1 + offset);

return dqrng::sample::replacement<VEC, INT>(rng, n, size, offset);
} else {
if (!(n >= size))
Expand All @@ -77,6 +241,56 @@
}
}
}

template<typename VEC, typename INT, typename FVEC>
inline VEC sample(dqrng::random_64bit_generator &rng, INT n, INT size, bool replace, FVEC prob, int offset = 0) {
static_assert(std::is_integral<INT>::value && std::is_unsigned<INT>::value,
"Provided INT has the wrong type.");
static_assert(std::is_floating_point<typename VEC::value_type>::value ||
std::is_integral<typename VEC::value_type>::value ||
std::is_reference<typename VEC::value_type>::value,
"Provided VEC has the wrong type.");
static_assert(std::is_floating_point<typename FVEC::value_type>::value ||
std::is_reference<typename FVEC::value_type>::value,
"Provided FVEC has the wrong type.");
if (n != INT(prob.size()))
Rcpp::stop("Argument requirements not fulfilled: n == prob.size()");

Check warning on line 257 in inst/include/dqrng_sample.h

View check run for this annotation

Codecov / codecov/patch

inst/include/dqrng_sample.h#L257

Added line #L257 was not covered by tests
if (replace || size <= 1) {
if (n == 2)
return dqrng::sample::biased_coin<VEC, INT>(rng, n, size, prob, offset, 1 + offset);

double prob_sum = std::accumulate(prob.begin(), prob.end(), 0.0);
if (size >= n)
return dqrng::sample::replacement_alias<VEC, INT>(rng, n, size, prob, prob_sum, offset);

double *max_prob = std::max_element(prob.begin(), prob.end());
if (*max_prob * n / prob_sum < 3.)
return dqrng::sample::replacement_prob<VEC, INT>(rng, n, size, prob, *max_prob, offset);
else
return dqrng::sample::replacement_alias<VEC, INT>(rng, n, size, prob, prob_sum, offset);
} else {
if (!(n >= size))
Rcpp::stop("Argument requirements not fulfilled: n >= size");
if (n < 2 * size)
return dqrng::sample::no_replacement_exp<VEC, INT>(rng, n, size, prob, offset);

double prob_sum = std::accumulate(prob.begin(), prob.end(), 0.0);
double *max_prob = std::max_element(prob.begin(), prob.end());
if (n < 1000 * size) {// check this factor
using set_t = dqrng::minimal_bit_set;
if (*max_prob * n / prob_sum < 3.)
return dqrng::sample::no_replacement_prob_set<VEC, INT, set_t>(rng, n, size, prob, *max_prob, offset);
else
return dqrng::sample::no_replacement_alias_set<VEC, INT, set_t>(rng, n, size, prob, prob_sum, offset);
} else {
using set_t = dqrng::minimal_hash_set<INT>;
if (*max_prob * n / prob_sum < 3.)
return dqrng::sample::no_replacement_prob_set<VEC, INT, set_t>(rng, n, size, prob, *max_prob, offset);
else
return dqrng::sample::no_replacement_alias_set<VEC, INT, set_t>(rng, n, size, prob, prob_sum, offset);
}
}
}
} // sample
} // dqrng

Expand Down
10 changes: 8 additions & 2 deletions src/dqrng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ Rcpp::IntegerVector dqsample_int(int n,
int offset = 0) {
if (!(n > 0 && size >= 0))
Rcpp::stop("Argument requirements not fulfilled: n > 0 && size >= 0");
return dqrng::sample::sample<Rcpp::IntegerVector, uint32_t>(*rng, uint32_t(n), uint32_t(size), replace, offset);
if (probs.isNull())
return dqrng::sample::sample<Rcpp::IntegerVector, uint32_t>(*rng, uint32_t(n), uint32_t(size), replace, offset);
else
return dqrng::sample::sample<Rcpp::IntegerVector, uint32_t>(*rng, uint32_t(n), uint32_t(size), replace, probs.as(), offset);
}

// [[Rcpp::export(rng = false)]]
Expand All @@ -204,7 +207,10 @@ Rcpp::NumericVector dqsample_num(double n,
#ifndef LONG_VECTOR_SUPPORT
Rcpp::stop("Long vectors are not supported");
#else
return dqrng::sample::sample<Rcpp::NumericVector, uint64_t>(*rng, uint64_t(n), uint64_t(size), replace, offset);
if (probs.isNull())
return dqrng::sample::sample<Rcpp::NumericVector, uint64_t>(*rng, uint64_t(n), uint64_t(size), replace, offset);
else
return dqrng::sample::sample<Rcpp::NumericVector, uint64_t>(*rng, uint64_t(n), uint64_t(size), replace, probs.as(), offset);
#endif
}

Expand Down
Loading
Loading