Skip to content

Commit

Permalink
fixes cloning modules with empty states (#1108)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfalbel authored Oct 2, 2023
1 parent 4cd7ee3 commit 56974e6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
30 changes: 16 additions & 14 deletions R/nn.R
Original file line number Diff line number Diff line change
Expand Up @@ -522,21 +522,23 @@ create_nn_module_callable <- function(instance) {
instance$clone <- function(deep = FALSE, ..., replace_values = TRUE) {
if (deep && replace_values) {
state_dict <- append(instance$parameters, instance$buffers)
names(state_dict) <- sapply(state_dict, xptr_address)
if (length(state_dict) > 0) {
names(state_dict) <- sapply(state_dict, xptr_address)

state_dict <- state_dict[!duplicated(names(state_dict))]
state_dict <- lapply(state_dict, function(x) x$detach()$clone())

# also need to append a clone of the modules to this list.
# child modules can be duplicated - and have the same name
# child modules are also deep cloned, but we don't need to replace
# their values when cloning because we only have to do it once.
children <- instance$children
names(children) <- sapply(children, rlang::obj_address)
children <- children[!duplicated(names(children))]
children <- lapply(children, function(x) x$clone(deep = deep, replace_values = FALSE))

state_dict <- append(state_dict, children)
state_dict <- state_dict[!duplicated(names(state_dict))]
state_dict <- lapply(state_dict, function(x) x$detach()$clone())

# also need to append a clone of the modules to this list.
# child modules can be duplicated - and have the same name
# child modules are also deep cloned, but we don't need to replace
# their values when cloning because we only have to do it once.
children <- instance$children
names(children) <- sapply(children, rlang::obj_address)
children <- children[!duplicated(names(children))]
children <- lapply(children, function(x) x$clone(deep = deep, replace_values = FALSE))

state_dict <- append(state_dict, children)
}
}

cloned_instance <- clone(deep = deep)
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-nn.R
Original file line number Diff line number Diff line change
Expand Up @@ -801,4 +801,12 @@ test_that("can use a named module dict", {

expect_tensor_shape(z, c(100, 1))
expect_equal(length(dict$parameters), 4)
})

test_that("can clone a module with no state dict", {

expect_no_error({
nn_relu()$clone(TRUE)
})

})

0 comments on commit 56974e6

Please sign in to comment.