Skip to content

Commit

Permalink
Correctly handle precision matrix (#1110)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel authored Oct 2, 2023
1 parent e193abe commit df14b1b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 1 addition & 2 deletions R/distributions-multivariate_normal.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ MultivariateNormal <- R6::R6Class(
value_error("loc must be at least one-dimensional.")
}

if ((!is.null(covariance_matrix) + !is.null(precision_matrix) +
!is.null(scale_tril)) != 1) {
if ((is.null(covariance_matrix) + is.null(precision_matrix) + is.null(scale_tril)) != 2) {
value_error("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")
}

Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test-distributions-multivariate_normal.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,13 @@ test_that("properites", {
expect_tensor_shape(m$variance, c(2))
expect_tensor_shape(m$entropy(), 1)
})

test_that("works with precision matrix", {

dist <- distr_multivariate_normal(torch_ones(2), precision_matrix = torch_eye(2))

expect_no_error({
dist$sample(10)
})

})

0 comments on commit df14b1b

Please sign in to comment.