diff --git a/lib/gcforest/layers/fg_pool_layer.py b/lib/gcforest/layers/fg_pool_layer.py index e5463ad..8a6238e 100644 --- a/lib/gcforest/layers/fg_pool_layer.py +++ b/lib/gcforest/layers/fg_pool_layer.py @@ -24,7 +24,7 @@ def __init__(self, layer_config, data_cache): super(FGPoolLayer, self).__init__(layer_config, data_cache) self.win_x = self.get_value("win_x", None, int, required=True) self.win_y = self.get_value("win_y", None, int, required=True) - self.pool_method = self.get_value("pool_method", "avg", basestring) + self.pool_method = self.get_value("pool_method", "avg", str) def fit_transform(self, train_config): LOGGER.info("[data][{}] bottoms={}, tops={}".format(self.name, self.bottom_names, self.top_names)) @@ -49,8 +49,8 @@ def _transform(self, phases, check_top_cache): #assert w % win_x == 0 #nh = int(h / win_y) #nw = int(w / win_x) - nh = (h - 1) / win_y + 1 - nw = (w - 1) / win_x + 1 + nh = int((h - 1) / win_y + 1) + nw = int((w - 1) / win_x + 1) X_pool = np.empty(( n, c, nh, nw), dtype=np.float32) #for k in trange(c, desc='loop channel'): # for di in trange(nh, desc='loop win_y'): diff --git a/lib/gcforest/utils/win_utils.py b/lib/gcforest/utils/win_utils.py index 6930491..dcb4059 100644 --- a/lib/gcforest/utils/win_utils.py +++ b/lib/gcforest/utils/win_utils.py @@ -13,9 +13,9 @@ def get_windows_channel(X, X_win, des_id, nw, nh, win_x, win_y, stride_x, stride (k, di, dj) in range(X.channle, win_y, win_x) """ #des_id = (k * win_y + di) * win_x + dj - dj = des_id % win_x - di = des_id / win_x % win_y - k = des_id / win_x / win_y + dj = int(des_id % win_x) + di = int(des_id / win_x % win_y) + k = int(des_id / win_x / win_y) src = X[:, k, di:di+nh*stride_y:stride_y, dj:dj+nw*stride_x:stride_x].ravel() des = X_win[des_id, :] np.copyto(des, src) @@ -39,8 +39,8 @@ def get_windows(X, win_x, win_y, stride_x=1, stride_y=1, pad_x=0, pad_y=0): X = np.concatenate(( np.zeros((n, c, h, pad_x),dtype=X.dtype), X ), axis=3) n, c, h, w = X.shape nc = win_y * win_x * c - nh = (h - win_y) / stride_y + 1 - nw = (w - win_x) / stride_x + 1 + nh = int((h - win_y) / stride_y + 1) + nw = int((w - win_x) / stride_x + 1) X_win = np.empty(( nc, n * nh * nw), dtype=np.float32) LOGGER.info("get_windows_start: X.shape={}, X_win.shape={}, nw={}, nh={}, c={}, win_x={}, win_y={}, stride_x={}, stride_y={}".format( X.shape, X_win.shape, nw, nh, c, win_x, win_y, stride_x, stride_y))