diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..44f59d2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.py[cod] +*$py.class \ No newline at end of file diff --git a/input/01.png b/input/01.png new file mode 100644 index 0000000..6843f91 Binary files /dev/null and b/input/01.png differ diff --git a/input/02.png b/input/02.png new file mode 100644 index 0000000..1cf1479 Binary files /dev/null and b/input/02.png differ diff --git a/input/03.png b/input/03.png new file mode 100644 index 0000000..21a7347 Binary files /dev/null and b/input/03.png differ diff --git a/input/04.png b/input/04.png new file mode 100644 index 0000000..48454d5 Binary files /dev/null and b/input/04.png differ diff --git a/input/05.png b/input/05.png new file mode 100644 index 0000000..dd19ca1 Binary files /dev/null and b/input/05.png differ diff --git a/input/06.png b/input/06.png new file mode 100644 index 0000000..ef92fca Binary files /dev/null and b/input/06.png differ diff --git a/input/07.png b/input/07.png new file mode 100644 index 0000000..85fa6d1 Binary files /dev/null and b/input/07.png differ diff --git a/models/README.md b/models/README.md new file mode 100644 index 0000000..285d016 --- /dev/null +++ b/models/README.md @@ -0,0 +1,3 @@ +# Models + +Place `.pth` model files here diff --git a/output/01.png b/output/01.png new file mode 100644 index 0000000..931a9de Binary files /dev/null and b/output/01.png differ diff --git a/output/02.png b/output/02.png new file mode 100644 index 0000000..dea80ec Binary files /dev/null and b/output/02.png differ diff --git a/output/03.png b/output/03.png new file mode 100644 index 0000000..a0997d9 Binary files /dev/null and b/output/03.png differ diff --git a/output/04.png b/output/04.png new file mode 100644 index 0000000..8b82016 Binary files /dev/null and b/output/04.png differ diff --git a/output/05.png b/output/05.png new file mode 100644 index 0000000..9bedc78 Binary files /dev/null and b/output/05.png differ diff --git a/output/06.png b/output/06.png new file mode 100644 index 0000000..20dff9e Binary files /dev/null and b/output/06.png differ diff --git a/output/07.png b/output/07.png new file mode 100644 index 0000000..ecbad32 Binary files /dev/null and b/output/07.png differ diff --git a/run.py b/run.py new file mode 100644 index 0000000..3108adf --- /dev/null +++ b/run.py @@ -0,0 +1,216 @@ +import argparse +import torch +import os +import sys +import cv2 +import numpy as np + +import utils.architectures.SOFVSR_arch as SOFVSR +from torch.autograd import Variable +import utils.common as util +from utils.colors import * + +parser = argparse.ArgumentParser() +parser.add_argument('model') +parser.add_argument('--input', default='input', help='Input folder') +parser.add_argument('--output', default='output', help='Output folder') +parser.add_argument('--cpu', action='store_true', + help='Use CPU instead of CUDA') +parser.add_argument('--denoise', action='store_true', + help='Denoise the chroma layers') +parser.add_argument('--chop_forward', action='store_true',) +args = parser.parse_args() + +if not os.path.exists(args.input): + print('Error: Folder [{:s}] does not exist.'.format(args.input)) + sys.exit(1) +elif os.path.isfile(args.input): + print('Error: Folder [{:s}] is a file.'.format(args.input)) + sys.exit(1) +elif os.path.isfile(args.output): + print('Error: Folder [{:s}] is a file.'.format(args.output)) + sys.exit(1) +elif not os.path.exists(args.output): + os.mkdir(args.output) + +device = torch.device('cpu' if args.cpu else 'cuda') + +input_folder = os.path.normpath(args.input) +output_folder = os.path.normpath(args.output) + +def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1, need_HR=False): + # divide into 4 patches + b, n, c, h, w = x.size() + h_half, w_half = h // 2, w // 2 + h_size, w_size = h_half + shave, w_half + shave + inputlist = [ + x[:, :, :, 0:h_size, 0:w_size], + x[:, :, :, 0:h_size, (w - w_size):w], + x[:, :, :, (h - h_size):h, 0:w_size], + x[:, :, :, (h - h_size):h, (w - w_size):w]] + + + if w_size * h_size < min_size: + outputlist = [] + for i in range(0, 4, nGPUs): + input_batch = torch.cat(inputlist[i:(i + nGPUs)], dim=0) + with torch.no_grad(): + model = model.to(device) + _, _, _, output_batch = model(input_batch.to(device)) + outputlist.append(output_batch.data) + else: + outputlist = [ + chop_forward(patch, model, scale, shave, min_size, nGPUs) \ + for patch in inputlist] + + h, w = scale * h, scale * w + h_half, w_half = scale * h_half, scale * w_half + h_size, w_size = scale * h_size, scale * w_size + shave *= scale + + # output = Variable(x.data.new(1, 1, h, w), volatile=True) #UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead. + with torch.no_grad(): + output = Variable(x.data.new(1, 1, h, w)) + for idx, out in enumerate(outputlist): + if len(out.shape) < 4: + outputlist[idx] = out.unsqueeze(0) + output[:, :, 0:h_half, 0:w_half] = outputlist[0][:, :, 0:h_half, 0:w_half] + output[:, :, 0:h_half, w_half:w] = outputlist[1][:, :, 0:h_half, (w_size - w + w_half):w_size] + output[:, :, h_half:h, 0:w_half] = outputlist[2][:, :, (h_size - h + h_half):h_size, 0:w_half] + output[:, :, h_half:h, w_half:w] = outputlist[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] + + return output.float().cpu() + + + +def main(): + state_dict = torch.load(args.model) + + # Automatic scale detection + keys = state_dict.keys() + if 'OFR.SR.3.weight' in keys: + scale = 1 + elif 'SR.body.6.bias' in keys: + # 2 and 3 share the same architecture keys so here we check the shape + if state_dict['SR.body.3.weight'].shape[0] == 256: + scale = 2 + elif state_dict['SR.body.3.weight'].shape[0] == 576: + scale = 3 + elif 'SR.body.9.bias' in keys: + scale = 4 + else: + raise ValueError('Scale could not be determined from model') + + # Extract num_frames from model + frame_size = state_dict['SR.body.0.weight'].shape[1] + num_frames = ((frame_size - 1) // scale ** 2) + 1 + + # Extract num_channels + num_channels = state_dict['OFR.RNN1.0.weight'].shape[0] + + # Create model + model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels) + model.load_state_dict(state_dict) + + + images=[] + for root, _, files in os.walk(input_folder): + for file in sorted(files): + if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']: + images.append(os.path.join(root, file)) + + # pad beginning and end frames so they get included in output + images.insert(0, images[0]) + images.append(images[-1]) + + # Inference loop + for idx, path in enumerate(images[1:-1], 0): + img_name = os.path.splitext(os.path.basename(path))[0] + + idx_center = (num_frames - 1) // 2 + idx_frame = idx + LR_name = images[idx_frame + 1] # center frame + print(idx_frame, img_name) + + # read LR frames + LR_list = [] + LR_bicubic = None + for i_frame in range(num_frames): + # Last and second to last frames + if idx == len(images)-2 and num_frames == 3: + # print("second to last frame:", i_frame) + if i_frame == 0: + LR_img = cv2.imread(images[idx_frame], cv2.IMREAD_COLOR) + else: + LR_img = cv2.imread(images[idx_frame+1], cv2.IMREAD_COLOR) + elif idx == len(images)-1 and num_frames == 3: + # print("last frame:", i_frame) + LR_img = cv2.imread(images[idx_frame], cv2.IMREAD_COLOR) + # Every other internal frame + else: + # print("normal frame:", idx_frame) + LR_img = cv2.imread(images[idx_frame+i_frame], cv2.IMREAD_COLOR) + + # get the bicubic upscale of the center frame to concatenate for SR + if i_frame == idx_center: + if args.denoise: + LR_bicubic = cv2.blur(LR_img, (3,3)) + else: + LR_bicubic = LR_img + LR_bicubic = util.imresize_np(img=LR_bicubic, scale=scale) # bicubic upscale + + # extract Y channel from frames + # normal path, only Y for both + LR_img = util.bgr2ycbcr(LR_img, only_y=True) + + # expand Y images to add the channel dimension + # normal path, only Y for both + LR_img = util.fix_img_channels(LR_img, 1) + + LR_list.append(LR_img) # h, w, c + + LR = np.concatenate((LR_list), axis=2) # h, w, t + + LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W] + + # generate Cr, Cb channels using bicubic interpolation + LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False) + LR_bicubic = util.np2tensor(LR_bicubic, bgr2rgb=False, add_batch=True) + + if len(LR.size()) == 4: + b, n_frames, h_lr, w_lr = LR.size() + LR = LR.view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w + + if args.chop_forward: + + # crop borders to ensure each patch can be divisible by 2 + _, _, _, h, w = LR.size() + h = int(h//16) * 16 + w = int(w//16) * 16 + LR = LR[:, :, :, :h, :w] + if isinstance(LR_bicubic, torch.Tensor): + SR_cb = LR_bicubic[:, 1, :h * scale, :w * scale] + SR_cr = LR_bicubic[:, 2, :h * scale, :w * scale] + + SR_y = chop_forward(LR, model, scale).squeeze(0) + sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3)) + else: + + with torch.no_grad(): + model.to(device) + _, _, _, fake_H = model(LR.to(device)) + + SR = fake_H.detach()[0].float().cpu() + SR_cb = LR_bicubic[:, 1, :, :] + SR_cr = LR_bicubic[:, 2, :, :] + + sr_img = ycbcr_to_rgb(torch.stack((SR, SR_cb, SR_cr), -3)) + + sr_img = util.tensor2np(sr_img) # uint8 + + # save images + cv2.imwrite(os.path.join(output_folder, os.path.basename(path)), sr_img) + + +if __name__ == '__main__': + main() diff --git a/utils/architectures/SOFVSR_arch.py b/utils/architectures/SOFVSR_arch.py new file mode 100644 index 0000000..b2cbc77 --- /dev/null +++ b/utils/architectures/SOFVSR_arch.py @@ -0,0 +1,214 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils.architectures.video import optical_flow_warp + + +#TODO: +# - change pixelshuffle upscales with available options in block (can also add pa_unconv with pixel attention) +# - make the upscaling layers automatic +# - add the network configuration parameters to the init to pass from options file + +class SOFVSR(nn.Module): + def __init__(self, scale=4, n_frames=3, channels=320): + super(SOFVSR, self).__init__() + self.scale = scale + self.OFR = OFRnet(scale=scale, channels=channels) + self.SR = SRnet(scale=scale, channels=channels, n_frames=n_frames) + + def forward(self, x): + # x: b*n*c*h*w + b, n_frames, c, h, w = x.size() + idx_center = (n_frames - 1) // 2 + + # motion estimation + flow_L1 = [] + flow_L2 = [] + flow_L3 = [] + input = [] + + for idx_frame in range(n_frames): + if idx_frame != idx_center: + input.append(torch.cat((x[:,idx_frame,:,:,:], x[:,idx_center,:,:,:]), 1)) + optical_flow_L1, optical_flow_L2, optical_flow_L3 = self.OFR(torch.cat(input, 0)) + + optical_flow_L1 = optical_flow_L1.view(-1, b, 2, h//2, w//2) + optical_flow_L2 = optical_flow_L2.view(-1, b, 2, h, w) + optical_flow_L3 = optical_flow_L3.view(-1, b, 2, h*self.scale, w*self.scale) + + # motion compensation + draft_cube = [] + draft_cube.append(x[:, idx_center, :, :, :]) + + for idx_frame in range(n_frames): + if idx_frame == idx_center: + flow_L1.append([]) + flow_L2.append([]) + flow_L3.append([]) + else: # if idx_frame != idx_center: + if idx_frame < idx_center: + idx = idx_frame + if idx_frame > idx_center: + idx = idx_frame - 1 + + flow_L1.append(optical_flow_L1[idx, :, :, :, :]) + flow_L2.append(optical_flow_L2[idx, :, :, :, :]) + flow_L3.append(optical_flow_L3[idx, :, :, :, :]) + + # Generate the draft_cube by subsampling the SR flow optical_flow_L3 + # according to the scale + for i in range(self.scale): + for j in range(self.scale): + draft = optical_flow_warp(x[:, idx_frame, :, :, :], + optical_flow_L3[idx, :, :, i::self.scale, j::self.scale] / self.scale) + draft_cube.append(draft) + draft_cube = torch.cat(draft_cube, 1) + + # super-resolution + SR = self.SR(draft_cube) + + return flow_L1, flow_L2, flow_L3, SR + + +class OFRnet(nn.Module): + def __init__(self, scale, channels): + super(OFRnet, self).__init__() + self.pool = nn.AvgPool2d(2) + self.scale = scale + + ## RNN part + self.RNN1 = nn.Sequential( + nn.Conv2d(4, channels, 3, 1, 1, bias=False), # TODO: change 4 to 8 for 3 channel images + nn.LeakyReLU(0.1, inplace=True), + CasResB(3, channels) + ) + self.RNN2 = nn.Sequential( + nn.Conv2d(channels, 2, 3, 1, 1, bias=False), # TODO: change 2 to 6 for 3 channel images + ) + + # SR part + SR = [] + SR.append(CasResB(3, channels)) + if self.scale == 4: + SR.append(nn.Conv2d(channels, 64 * 4, 1, 1, 0, bias=False)) + SR.append(nn.PixelShuffle(2)) #TODO + SR.append(nn.LeakyReLU(0.1, inplace=True)) + SR.append(nn.Conv2d(64, 64 * 4, 1, 1, 0, bias=False)) + SR.append(nn.PixelShuffle(2)) #TODO + SR.append(nn.LeakyReLU(0.1, inplace=True)) + elif self.scale == 3: + SR.append(nn.Conv2d(channels, 64 * 9, 1, 1, 0, bias=False)) + SR.append(nn.PixelShuffle(3)) #TODO + SR.append(nn.LeakyReLU(0.1, inplace=True)) + elif self.scale == 2: + SR.append(nn.Conv2d(channels, 64 * 4, 1, 1, 0, bias=False)) + SR.append(nn.PixelShuffle(2)) #TODO + SR.append(nn.LeakyReLU(0.1, inplace=True)) + #TODO: test scale 1x + elif self.scale == 1: + SR.append(nn.Conv2d(channels, 64 * 1, 1, 1, 0, bias=False)) + SR.append(nn.LeakyReLU(0.1, inplace=True)) + SR.append(nn.Conv2d(64, 2, 3, 1, 1, bias=False)) + + self.SR = nn.Sequential(*SR) + + def __call__(self, x): + # x: b*2*h*w + #Part 1 + x_L1 = self.pool(x) + b, c, h, w = x_L1.size() + input_L1 = torch.cat((x_L1, torch.zeros(b, 2, h, w).cuda()), 1) + optical_flow_L1 = self.RNN2(self.RNN1(input_L1)) + # optical_flow_L1_upscaled = F.interpolate(optical_flow_L1, scale_factor=2, mode='bilinear', align_corners=False) * 2 + + # TODO: check, temporary fix, since the original interpolation was not producing the correct shape required in Part 2 + # in optical_flow_warp, instead of shape torch.Size([2, 1, 66, 75]) like the image, it was producing torch.Size([2, 1, 66, 74]) + # here I'm forcing it to be interpolated to exactly the size of the image + image_shape = torch.unsqueeze(x[:, 0, :, :], 1).shape + optical_flow_L1_upscaled = F.interpolate(optical_flow_L1, size=(image_shape[2],image_shape[3]), mode='bilinear', align_corners=False) * 2 + # print(optical_flow_L1_upscaled.shape) + # print(torch.unsqueeze(x[:, 0, :, :], 1).shape) + + #Part 2 + x_L2 = optical_flow_warp(torch.unsqueeze(x[:, 0, :, :], 1), optical_flow_L1_upscaled) + input_L2 = torch.cat((x_L2, torch.unsqueeze(x[:, 1, :, :], 1), optical_flow_L1_upscaled), 1) + optical_flow_L2 = self.RNN2(self.RNN1(input_L2)) + optical_flow_L1_upscaled + + #Part 3 + x_L3 = optical_flow_warp(torch.unsqueeze(x[:, 0, :, :], 1), optical_flow_L2) + input_L3 = torch.cat((x_L3, torch.unsqueeze(x[:, 1, :, :], 1), optical_flow_L2), 1) + #TODO: 3 channel images breaks here, because the first part has only 2 channels (2 * 1) and the second part now has 6 channels (2 * 3) + optical_flow_L3 = self.SR(self.RNN1(input_L3)) + \ + F.interpolate(optical_flow_L2, scale_factor=self.scale, mode='bilinear', align_corners=False) * self.scale + return optical_flow_L1, optical_flow_L2, optical_flow_L3 + + +class SRnet(nn.Module): + def __init__(self, scale, channels, n_frames): + super(SRnet, self).__init__() + body = [] + # scale ** 2 -> due to the subsampling of the SR flow + body.append(nn.Conv2d(1 * scale ** 2 * (n_frames-1) + 1, channels, 3, 1, 1, bias=False)) + body.append(nn.LeakyReLU(0.1, inplace=True)) + body.append(CasResB(8, channels)) + if scale == 4: + body.append(nn.Conv2d(channels, 64 * 4, 1, 1, 0, bias=False)) + body.append(nn.PixelShuffle(2)) #TODO + body.append(nn.LeakyReLU(0.1, inplace=True)) + body.append(nn.Conv2d(64, 64 * 4, 1, 1, 0, bias=False)) + body.append(nn.PixelShuffle(2)) #TODO + body.append(nn.LeakyReLU(0.1, inplace=True)) + elif scale == 3: + body.append(nn.Conv2d(channels, 64 * 9, 1, 1, 0, bias=False)) + body.append(nn.PixelShuffle(3)) #TODO + body.append(nn.LeakyReLU(0.1, inplace=True)) + elif scale == 2: + body.append(nn.Conv2d(channels, 64 * 4, 1, 1, 0, bias=False)) + body.append(nn.PixelShuffle(2)) #TODO + body.append(nn.LeakyReLU(0.1, inplace=True)) + #TODO: test scale 1x + elif scale == 1: + body.append(nn.Conv2d(channels, 64 * 1, 1, 1, 0, bias=False)) + body.append(nn.LeakyReLU(0.1, inplace=True)) + body.append(nn.Conv2d(64, 1, 3, 1, 1, bias=True)) + + self.body = nn.Sequential(*body) + + def __call__(self, x): + out = self.body(x) + return out + + +class ResB(nn.Module): + def __init__(self, channels): + super(ResB, self).__init__() + self.body = nn.Sequential( + nn.Conv2d(channels//2, channels//2, 1, 1, 0, bias=False), + nn.LeakyReLU(0.1, inplace=True), + nn.Conv2d(channels//2, channels//2, 3, 1, 1, bias=False, groups=channels//2), + nn.Conv2d(channels // 2, channels // 2, 1, 1, 0, bias=False), + nn.LeakyReLU(0.1, inplace=True), + ) + def forward(self, x): + input = x[:, x.shape[1]//2:, :, :] + out = torch.cat((x[:, :x.shape[1]//2, :, :], self.body(input)), 1) + return channel_shuffle(out, 2) + + +class CasResB(nn.Module): + def __init__(self, n_ResB, channels): + super(CasResB, self).__init__() + body = [] + for i in range(n_ResB): + body.append(ResB(channels)) + self.body = nn.Sequential(*body) + def forward(self, x): + return self.body(x) + + +def channel_shuffle(x, groups): + b, c, h, w = x.size() + x = x.view(b, groups, c//groups, h, w) + x = x.permute(0, 2, 1, 3, 4).contiguous() + x = x.view(b, -1, h, w) + return x diff --git a/utils/architectures/__init__.py b/utils/architectures/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/architectures/video.py b/utils/architectures/video.py new file mode 100644 index 0000000..5bfd71b --- /dev/null +++ b/utils/architectures/video.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +from torch.autograd import Variable + + +def optical_flow_warp(image, flow, + mode='vsr', + interp_mode='bilinear', + padding_mode='border', + align_corners=True, + mask=None): + """ + Warp an image or feature map with optical flow. + Arguments: + image (Tensor): reference images tensor (b, c, h, w) + flow (Tensor): optical flow to image_ref + (b, 2, h, w) for vsr mode, (n, h, w, 2) for edvr mode. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros', 'border' or 'reflection'. + Default: 'zeros' (EDVR), 'border'(SOF-VSR). + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + Returns: + Tensor: Warped image or feature map. + """ + if mode == 'vsr': + assert image.size()[-2:] == flow.size()[-2:] + elif mode == 'edvr': + assert image.size()[-2:] == flow.size()[1:3] + + b, _, h, w = image.size() + + # create mesh grid (torch) EDVR + #TODO: it produces the same mesh as numpy version, but results in + # images with displacements during inference. Leaving numpy version + # for training and inference until more tests can be done + ''' + grid_y, grid_x = torch.meshgrid( + torch.arange(0, h).type_as(image), + torch.arange(0, w).type_as(image)) + #TODO: check if float64 needed like SOF-VSR: + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + if mode == 'vsr': # to make equivalent with SOF-VSR's np grid + # scales and reshapes grid before adding + # scale grid to [-1,1] + grid[:, :, 0] = 2.0 * grid[:, :, 0] / max(w - 1, 1) - 1.0 + grid[:, :, 1] = 2.0 * grid[:, :, 1] / max(w - 1, 1) - 1.0 + grid = grid.transpose(2, 1) + grid = grid.transpose(1, 0) + grid = grid.expand(b, -1, -1, -1) + #TODO: check if needed: + grid.requires_grad = False + ''' + + # create mesh grid (np) SOF-VSR + # ''' + grid = np.meshgrid(range(w), range(h)) + grid = np.stack(grid, axis=-1).astype(np.float64) # W(x), H(y), 2 + + grid[:, :, 0] = 2.0 * grid[:, :, 0] / (w - 1) - 1.0 + grid[:, :, 1] = 2.0 * grid[:, :, 1] / (h - 1) - 1.0 + grid = grid.transpose(2, 0, 1) + grid = np.tile(grid, (b, 1, 1, 1)) + grid = Variable(torch.Tensor(grid)) + if flow.is_cuda == True: + grid = grid.cuda() + # ''' + + if mode == 'vsr': + # SOF-VSR scaled the grid before summing the flow + flow_0 = torch.unsqueeze(flow[:, 0, :, :] * 31 / (w - 1), dim=1) + flow_1 = torch.unsqueeze(flow[:, 1, :, :] * 31 / (h - 1), dim=1) + grid = grid + torch.cat((flow_0, flow_1), 1) + grid = grid.transpose(1, 2) + grid = grid.transpose(3, 2) + elif mode == 'edvr': + # EDVR scales the grid after summing the flow + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + grid = torch.stack((vgrid_x, vgrid_y), dim=3) #vgrid_scaled + + #TODO: added "align_corners=True" to maintain original behavior, needs testing: + # UserWarning: Default grid_sample and affine_grid behavior will be changed to align_corners=False from 1.4.0. See the documentation of grid_sample for details. + output = F.grid_sample( + image, grid, padding_mode=padding_mode, mode=interp_mode, align_corners=True) + + # TODO, what if align_corners=False + + if not isinstance(mask, np.ndarray): + return output + else: + # using 'mask' parameter prevents using the masked regions + mask = (1 - mask).astype(np.bool) + + mask = torch.autograd.Variable(torch.ones(x.size())).cuda() + mask = nn.functional.grid_sample(grid, output) + + mask = mask.masked_fill_(mask < 0.999, 0) + mask = mask.masked_fill_(mask > 0, 1) + + return output * mask + + + + + +#create tensor with random data +# image = torch.rand((4, 3, 16, 16)) +# flow = torch.rand((4, 2, 16, 16)) + +# optical_flow_warp(image, flow) \ No newline at end of file diff --git a/utils/colors.py b/utils/colors.py new file mode 100644 index 0000000..85fa4fd --- /dev/null +++ b/utils/colors.py @@ -0,0 +1,174 @@ +''' +Functions for color operations on tensors. +If needed, there are more conversions that can be used: +https://github.com/kornia/kornia/tree/master/kornia/color +https://github.com/R08UST/Color_Conversion_pytorch/blob/master/differentiable_color_conversion/basic_op.py +''' + + +import torch +import math +import cv2 + +def bgr_to_rgb(image: torch.Tensor) -> torch.Tensor: + # flip image channels + out: torch.Tensor = image.flip(-3) #https://github.com/pytorch/pytorch/issues/229 + #out: torch.Tensor = image[[2, 1, 0], :, :] #RGB to BGR #may be faster + return out + +def rgb_to_bgr(image: torch.Tensor) -> torch.Tensor: + #same operation as bgr_to_rgb(), flip image channels + return bgr_to_rgb(image) + +def bgra_to_rgba(image: torch.Tensor) -> torch.Tensor: + out: torch.Tensor = image[[2, 1, 0, 3], :, :] + return out + +def rgba_to_bgra(image: torch.Tensor) -> torch.Tensor: + #same operation as bgra_to_rgba(), flip image channels + return bgra_to_rgba(image) + +def rgb_to_grayscale(input: torch.Tensor) -> torch.Tensor: + r, g, b = torch.chunk(input, chunks=3, dim=-3) + gray: torch.Tensor = 0.299 * r + 0.587 * g + 0.114 * b + #gray = rgb_to_yuv(input,consts='y') + return gray + +def bgr_to_grayscale(input: torch.Tensor) -> torch.Tensor: + input_rgb = bgr_to_rgb(input) + gray: torch.Tensor = rgb_to_grayscale(input_rgb) + #gray = rgb_to_yuv(input_rgb,consts='y') + return gray + +def grayscale_to_rgb(input: torch.Tensor) -> torch.Tensor: + #repeat the gray image to the three channels + rgb: torch.Tensor = input.repeat(3, *[1] * (input.dim() - 1)) + return rgb + +def grayscale_to_bgr(input: torch.Tensor) -> torch.Tensor: + return grayscale_to_rgb(input) + +def rgb_to_ycbcr(input: torch.Tensor, consts='yuv'): + return rgb_to_yuv(input, consts == 'ycbcr') + +def rgb_to_yuv(input: torch.Tensor, consts='yuv'): + """Converts one or more images from RGB to YUV. + Outputs a tensor of the same shape as the `input` image tensor, containing the YUV + value of the pixels. + The output is only well defined if the value in images are in [0,1]. + Y′CbCr is often confused with the YUV color space, and typically the terms YCbCr + and YUV are used interchangeably, leading to some confusion. The main difference + is that YUV is analog and YCbCr is digital: https://en.wikipedia.org/wiki/YCbCr + Args: + input: 2-D or higher rank. Image data to convert. Last dimension must be + size 3. (Could add additional channels, ie, AlphaRGB = AlphaYUV) + consts: YUV constant parameters to use. BT.601 or BT.709. Could add YCbCr + https://en.wikipedia.org/wiki/YUV + Returns: + images: images tensor with the same shape as `input`. + """ + + #channels = input.shape[0] + + if consts == 'BT.709': # HDTV YUV + Wr = 0.2126 + Wb = 0.0722 + Wg = 1 - Wr - Wb #0.7152 + Uc = 0.539 + Vc = 0.635 + delta: float = 0.5 #128 if image range in [0,255] + elif consts == 'ycbcr': # Alt. BT.601 from Kornia YCbCr values, from JPEG conversion + Wr = 0.299 + Wb = 0.114 + Wg = 1 - Wr - Wb #0.587 + Uc = 0.564 #(b-y) #cb + Vc = 0.713 #(r-y) #cr + delta: float = .5 #128 if image range in [0,255] + elif consts == 'yuvK': # Alt. yuv from Kornia YUV values: https://github.com/kornia/kornia/blob/master/kornia/color/yuv.py + Wr = 0.299 + Wb = 0.114 + Wg = 1 - Wr - Wb #0.587 + Ur = -0.147 + Ug = -0.289 + Ub = 0.436 + Vr = 0.615 + Vg = -0.515 + Vb = -0.100 + #delta: float = 0.0 + elif consts == 'y': #returns only Y channel, same as rgb_to_grayscale() + #Note: torchvision uses ITU-R 601-2: Wr = 0.2989, Wg = 0.5870, Wb = 0.1140 + Wr = 0.299 + Wb = 0.114 + Wg = 1 - Wr - Wb #0.587 + else: # Default to 'BT.601', SDTV YUV + Wr = 0.299 + Wb = 0.114 + Wg = 1 - Wr - Wb #0.587 + Uc = 0.493 #0.492 + Vc = 0.877 + delta: float = 0.5 #128 if image range in [0,255] + + r: torch.Tensor = input[..., 0, :, :] + g: torch.Tensor = input[..., 1, :, :] + b: torch.Tensor = input[..., 2, :, :] + #TODO + #r, g, b = torch.chunk(input, chunks=3, dim=-3) #Alt. Which one is faster? Appear to be the same. Differentiable? Kornia uses both in different places + + if consts == 'y': + y: torch.Tensor = Wr * r + Wg * g + Wb * b + #(0.2989 * input[0] + 0.5870 * input[1] + 0.1140 * input[2]).to(img.dtype) + return y + elif consts == 'yuvK': + y: torch.Tensor = Wr * r + Wg * g + Wb * b + u: torch.Tensor = Ur * r + Ug * g + Ub * b + v: torch.Tensor = Vr * r + Vg * g + Vb * b + else: #if consts == 'ycbcr' or consts == 'yuv' or consts == 'BT.709': + y: torch.Tensor = Wr * r + Wg * g + Wb * b + u: torch.Tensor = (b - y) * Uc + delta #cb + v: torch.Tensor = (r - y) * Vc + delta #cr + + if consts == 'uv': #returns only UV channels + return torch.stack((u, v), -3) + else: + return torch.stack((y, u, v), -3) + +def ycbcr_to_rgb(input: torch.Tensor): + return yuv_to_rgb(input, consts = 'ycbcr') + +def yuv_to_rgb(input: torch.Tensor, consts='yuv') -> torch.Tensor: + if consts == 'yuvK': # Alt. yuv from Kornia YUV values: https://github.com/kornia/kornia/blob/master/kornia/color/yuv.py + Wr = 1.14 #1.402 + Wb = 2.029 #1.772 + Wgu = 0.396 #.344136 + Wgv = 0.581 #.714136 + delta: float = 0.0 + elif consts == 'yuv' or consts == 'ycbcr': # BT.601 from Kornia YCbCr values, from JPEG conversion + Wr = 1.403 #1.402 + Wb = 1.773 #1.772 + Wgu = .344 #.344136 + Wgv = .714 #.714136 + delta: float = .5 #128 if image range in [0,255] + + #Note: https://github.com/R08UST/Color_Conversion_pytorch/blob/75150c5fbfb283ae3adb85c565aab729105bbb66/differentiable_color_conversion/basic_op.py#L65 has u and v flipped + y: torch.Tensor = input[..., 0, :, :] + u: torch.Tensor = input[..., 1, :, :] #cb + v: torch.Tensor = input[..., 2, :, :] #cr + #TODO + #y, u, v = torch.chunk(input, chunks=3, dim=-3) #Alt. Which one is faster? Appear to be the same. Differentiable? Kornia uses both in different places + + u_shifted: torch.Tensor = u - delta #cb + v_shifted: torch.Tensor = v - delta #cr + + r: torch.Tensor = y + Wr * v_shifted + g: torch.Tensor = y - Wgv * v_shifted - Wgu * u_shifted + b: torch.Tensor = y + Wb * u_shifted + return torch.stack((r, g, b), -3) + +#Not tested: +def rgb2srgb(imgs): + return torch.where(imgs<=0.04045,imgs/12.92,torch.pow((imgs+0.055)/1.055,2.4)) + +#Not tested: +def srgb2rgb(imgs): + return torch.where(imgs<=0.0031308,imgs*12.92,1.055*torch.pow((imgs),1/2.4)-0.055) + diff --git a/utils/common.py b/utils/common.py new file mode 100644 index 0000000..c557ab5 --- /dev/null +++ b/utils/common.py @@ -0,0 +1,886 @@ +import os +import math +import pickle +import random +import numpy as np +import torch +import cv2 +import logging + +import copy +from torchvision.utils import make_grid + +from utils.colors import * + +#################### +# Files & IO +#################### + +###################### get image path list ###################### +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.dng', '.DNG', '.webp','.npy', '.NPY'] + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def _get_paths_from_images(path): + '''get image path list from image folder''' + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +def _get_paths_from_lmdb(dataroot): + '''get image path list from lmdb''' + import lmdb + env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False) + keys_cache_file = os.path.join(dataroot, '_keys_cache.p') + logger = logging.getLogger('base') + if os.path.isfile(keys_cache_file): + logger.info('Read lmdb keys from cache: {}'.format(keys_cache_file)) + keys = pickle.load(open(keys_cache_file, "rb")) + else: + with env.begin(write=False) as txn: + logger.info('Creating lmdb keys cache: {}'.format(keys_cache_file)) + keys = [key.decode('ascii') for key, _ in txn.cursor()] + pickle.dump(keys, open(keys_cache_file, 'wb')) + paths = sorted([key for key in keys if not key.endswith('.meta')]) + return env, paths + + +def get_image_paths(data_type, dataroot): + '''get image path list + support lmdb or image files''' + env, paths = None, None + if dataroot is not None: + if data_type == 'lmdb': + env, paths = _get_paths_from_lmdb(dataroot) + elif data_type == 'img': + paths = sorted(_get_paths_from_images(dataroot)) + else: + raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) + return env, paths + + +###################### read images ###################### +def _read_lmdb_img(env, path): + with env.begin(write=False) as txn: + buf = txn.get(path.encode('ascii')) + buf_meta = txn.get((path + '.meta').encode('ascii')).decode('ascii') + img_flat = np.frombuffer(buf, dtype=np.uint8) + H, W, C = [int(s) for s in buf_meta.split(',')] + img = img_flat.reshape(H, W, C) + return img + + +def read_img(env, path, out_nc=3, fix_channels=True): + ''' + Reads image using cv2 (rawpy if dng) or from lmdb by default + (can also use using PIL instead of cv2) + Arguments: + out_nc: Desired number of channels + fix_channels: changes the images to the desired number of channels + Output: + Numpy uint8, HWC, BGR, [0,255] by default + ''' + + img = None + if env is None: # img + if(path[-3:].lower() == 'dng'): # if image is a DNG + import rawpy + with rawpy.imread(path) as raw: + img = raw.postprocess() + if(path[-3:].lower() == 'npy'): # if image is a NPY numpy array + with open(path, 'rb') as f: + img = np.load(f) + else: # else, if image can be read by cv2 + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + #TODO: add variable detecting if cv2 is not available and try PIL instead + # elif: # using PIL instead of OpenCV + # img = Image.open(path).convert('RGB') + # else: # For other images unrecognized by cv2 + # import matplotlib.pyplot as plt + # img = (255*plt.imread(path)[:,:,:3]).astype('uint8') + else: + img = _read_lmdb_img(env, path) + + # if not img: + # raise ValueError(f"Failed to read image: {path}") + + if fix_channels: + img = fix_img_channels(img, out_nc) + + return img + +def fix_img_channels(img, out_nc): + ''' + fix image channels to the expected number + ''' + + # if image has only 2 dimensions, add "channel" dimension (1) + if img.ndim == 2: + #img = img[..., np.newaxis] #alt + #img = np.expand_dims(img, axis=2) + img = np.tile(np.expand_dims(img, axis=2), (1, 1, 3)) + # special case: properly remove alpha channel + if out_nc == 3 and img.shape[2] == 4: + img = bgra2rgb(img) + # remove all extra channels + elif img.shape[2] > out_nc: + img = img[:, :, :out_nc] + # if alpha is expected, add solid alpha channel + elif img.shape[2] == 3 and out_nc == 4: + img = np.dstack((img, np.full(img.shape[:-1], 255, dtype=np.uint8))) + return img + + +#################### +# image processing +# process on numpy image +#################### + +def bgra2rgb(img): + ''' + cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) has an issue removing the alpha channel, + this gets rid of wrong transparent colors that can harm training + ''' + if img.shape[2] == 4: + #b, g, r, a = cv2.split((img*255).astype(np.uint8)) + b, g, r, a = cv2.split((img.astype(np.uint8))) + b = cv2.bitwise_and(b, b, mask=a) + g = cv2.bitwise_and(g, g, mask=a) + r = cv2.bitwise_and(r, r, mask=a) + #return cv2.merge([b, g, r]).astype(np.float32)/255. + return cv2.merge([b, g, r]) + return img + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + # Note: OpenCV uses inverted channels BGR, instead of RGB. + # If images are loaded with something other than OpenCV, + # check that the channels are in the correct order and use + # the alternative conversion functions. + #if in_c == 4 and tar_type == 'RGB-A': # BGRA to BGR, remove alpha channel + #return [cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) for img in img_list] + #return [bgra2rgb(img) for img in img_list] + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'RGB-LAB': #RGB to LAB + return [cv2.cvtColor(img, cv2.COLOR_BGR2LAB) for img in img_list] + elif in_c == 3 and tar_type == 'LAB-RGB': #RGB to LAB + return [cv2.cvtColor(img, cv2.COLOR_LAB2BGR) for img in img_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img_ = img.astype(np.float32) + if in_img_type != np.uint8: + img_ *= 255. + # convert + if only_y: + rlt = np.dot(img_ , [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img_ , [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + +def bgr2ycbcr(img, only_y=True, separate=False): + '''bgr version of matlab rgb2ycbcr + Python opencv library (cv2) cv2.COLOR_BGR2YCrCb has + different parameters with MATLAB color convertion. + only_y: only return Y channel + separate: if true, will returng the channels as + separate images + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img_ = img.astype(np.float32) + if in_img_type != np.uint8: + img_ *= 255. + # convert + if only_y: + rlt = np.dot(img_ , [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img_ , [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + # to make ycrcb like cv2 + # rlt = rlt[:, :, (0, 2, 1)] + + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + + if separate: + rlt = rlt.astype(in_img_type) + # y, cb, cr + return rlt[:, :, 0], rlt[:, :, 1], rlt[:, :, 2] + else: + return rlt.astype(in_img_type) + +''' +def ycbcr2rgb_(img, only_y=True): + """same as matlab ycbcr2rgb + (Note: this implementation is the original from BasicSR, but + appears to be for ycrcb, like cv2) + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img_ = img.astype(np.float32) + if in_img_type != np.uint8: + img_ *= 255. + + # to make ycrcb like cv2 + # rlt = rlt[:, :, (0, 2, 1)] + + # convert + # original (for ycrcb): + rlt = np.matmul(img_ , [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + + #alternative conversion: + # xform = np.array([[1, 0, 1.402], [1, -0.34414, -.71414], [1, 1.772, 0]]) + # img_[:, :, [1, 2]] -= 128 + # rlt = img_.dot(xform.T) + np.putmask(rlt, rlt > 255, 255) + np.putmask(rlt, rlt < 0, 0) + + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) +''' + +def ycbcr2rgb(img, only_y=True): + ''' + bgr version of matlab ycbcr2rgb + Python opencv library (cv2) cv2.COLOR_YCrCb2BGR has + different parameters to MATLAB color convertion. + + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img_ = img.astype(np.float32) + if in_img_type != np.uint8: + img_ *= 255. + + # to make ycrcb like cv2 + # rlt = rlt[:, :, (0, 2, 1)] + + # convert + mat = np.array([[24.966, 128.553, 65.481],[112, -74.203, -37.797], [-18.214, -93.786, 112.0]]) + mat = np.linalg.inv(mat.T) * 255 + offset = np.array([[[16, 128, 128]]]) + + rlt = np.dot((img_ - offset), mat) + rlt = np.clip(rlt, 0, 255) + ## rlt = np.rint(rlt).astype('uint8') + + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + +''' +#TODO: TMP RGB version, to check (PIL) +def rgb2ycbcr(img_rgb): + ## the range of img_rgb should be (0, 1) + img_y = 0.257 * img_rgb[:, :, 0] + 0.504 * img_rgb[:, :, 1] + 0.098 * img_rgb[:, :, 2] + 16 / 255.0 + img_cb = -0.148 * img_rgb[:, :, 0] - 0.291 * img_rgb[:, :, 1] + 0.439 * img_rgb[:, :, 2] + 128 / 255.0 + img_cr = 0.439 * img_rgb[:, :, 0] - 0.368 * img_rgb[:, :, 1] - 0.071 * img_rgb[:, :, 2] + 128 / 255.0 + return img_y, img_cb, img_cr + +#TODO: TMP RGB version, to check (PIL) +def ycbcr2rgb(img_ycbcr): + ## the range of img_ycbcr should be (0, 1) + img_r = 1.164 * (img_ycbcr[:, :, 0] - 16 / 255.0) + 1.596 * (img_ycbcr[:, :, 2] - 128 / 255.0) + img_g = 1.164 * (img_ycbcr[:, :, 0] - 16 / 255.0) - 0.392 * (img_ycbcr[:, :, 1] - 128 / 255.0) - 0.813 * (img_ycbcr[:, :, 2] - 128 / 255.0) + img_b = 1.164 * (img_ycbcr[:, :, 0] - 16 / 255.0) + 2.017 * (img_ycbcr[:, :, 1] - 128 / 255.0) + img_r = img_r[:, :, np.newaxis] + img_g = img_g[:, :, np.newaxis] + img_b = img_b[:, :, np.newaxis] + img_rgb = np.concatenate((img_r, img_g, img_b), 2) + return img_rgb +''' + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + +#TODO: this should probably be elsewhere (augmentations.py) +def augment(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + #rot90n = rot and random.random() < 0.5 + + def _augment(img): + if hflip: img = np.flip(img, axis=1) #img[:, ::-1, :] + if vflip: img = np.flip(img, axis=0) #img[::-1, :, :] + #if rot90: img = img.transpose(1, 0, 2) + if rot90: img = np.rot90(img, 1) #90 degrees # In PIL: img.transpose(Image.ROTATE_90) + #if rot90n: img = np.rot90(img, -1) #-90 degrees + return img + + return [_augment(img) for img in img_list] + + + +#################### +# Normalization functions +#################### + + +#TODO: Could also automatically detect the possible range with min and max, like in def ssim() +def denorm(x, min_max=(-1.0, 1.0)): + ''' + Denormalize from [-1,1] range to [0,1] + formula: xi' = (xi - mu)/sigma + Example: "out = (x + 1.0) / 2.0" for denorm + range (-1,1) to (0,1) + for use with proper act in Generator output (ie. tanh) + ''' + out = (x - min_max[0]) / (min_max[1] - min_max[0]) + if isinstance(x, torch.Tensor): + return out.clamp(0, 1) + elif isinstance(x, np.ndarray): + return np.clip(out, 0, 1) + else: + raise TypeError("Got unexpected object type, expected torch.Tensor or \ + np.ndarray") + +def norm(x): + #Normalize (z-norm) from [0,1] range to [-1,1] + out = (x - 0.5) * 2.0 + if isinstance(x, torch.Tensor): + return out.clamp(-1, 1) + elif isinstance(x, np.ndarray): + return np.clip(out, -1, 1) + else: + raise TypeError("Got unexpected object type, expected torch.Tensor or \ + np.ndarray") + + +#################### +# np and tensor conversions +#################### + + +#2tensor +def np2tensor(img, bgr2rgb=True, data_range=1., normalize=False, change_range=True, add_batch=True): + """ + Converts a numpy image array into a Tensor array. + Parameters: + img (numpy array): the input image numpy array + add_batch (bool): choose if new tensor needs batch dimension added + """ + if not isinstance(img, np.ndarray): #images expected to be uint8 -> 255 + raise TypeError("Got unexpected object type, expected np.ndarray") + #check how many channels the image has, then condition, like in my BasicSR. ie. RGB, RGBA, Gray + #if bgr2rgb: + #img = img[:, :, [2, 1, 0]] #BGR to RGB -> in numpy, if using OpenCV, else not needed. Only if image has colors. + if change_range: + if np.issubdtype(img.dtype, np.integer): + info = np.iinfo + elif np.issubdtype(img.dtype, np.floating): + info = np.finfo + img = img*data_range/info(img.dtype).max #uint8 = /255 + img = torch.from_numpy(np.ascontiguousarray(np.transpose(img, (2, 0, 1)))).float() #"HWC to CHW" and "numpy to tensor" + if bgr2rgb: + if img.shape[0] == 3: #RGB + #BGR to RGB -> in tensor, if using OpenCV, else not needed. Only if image has colors. + img = bgr_to_rgb(img) + elif img.shape[0] == 4: #RGBA + #BGR to RGB -> in tensor, if using OpenCV, else not needed. Only if image has colors.) + img = bgra_to_rgba(img) + if add_batch: + img.unsqueeze_(0) # Add fake batch dimension = 1 . squeeze() will remove the dimensions of size 1 + if normalize: + img = norm(img) + return img + +#2np +def tensor2np(img, rgb2bgr=True, remove_batch=True, data_range=255, + denormalize=False, change_range=True, imtype=np.uint8): + """ + Converts a Tensor array into a numpy image array. + Parameters: + img (tensor): the input image tensor array + 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + remove_batch (bool): choose if tensor of shape BCHW needs to be squeezed + denormalize (bool): Used to denormalize from [-1,1] range back to [0,1] + imtype (type): the desired type of the converted numpy array (np.uint8 + default) + Output: + img (np array): 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + """ + if not isinstance(img, torch.Tensor): + raise TypeError("Got unexpected object type, expected torch.Tensor") + n_dim = img.dim() + + #TODO: Check: could denormalize here in tensor form instead, but end result is the same + + img = img.float().cpu() + + if n_dim == 4 or n_dim == 3: + #if n_dim == 4, has to convert to 3 dimensions, either removing batch or by creating a grid + if n_dim == 4 and remove_batch: + if img.shape[0] > 1: + # leave only the first image in the batch + img = img[0,...] + else: + # remove a fake batch dimension + img = img.squeeze() + # squeeze removes batch and channel of grayscale images (dimensions = 1) + if len(img.shape) < 3: + #add back the lost channel dimension + img = img.unsqueeze(dim=0) + # convert images in batch (BCHW) to a grid of all images (C B*H B*W) + else: + n_img = len(img) + img = make_grid(img, nrow=int(math.sqrt(n_img)), normalize=False) + + if img.shape[0] == 3 and rgb2bgr: #RGB + #RGB to BGR -> in tensor, if using OpenCV, else not needed. Only if image has colors. + img_np = rgb_to_bgr(img).numpy() + elif img.shape[0] == 4 and rgb2bgr: #RGBA + #RGBA to BGRA -> in tensor, if using OpenCV, else not needed. Only if image has colors. + img_np = rgba_to_bgra(img).numpy() + else: + img_np = img.numpy() + img_np = np.transpose(img_np, (1, 2, 0)) # "CHW to HWC" -> # HWC, BGR + elif n_dim == 2: + img_np = img.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + + #if rgb2bgr: + #img_np = img_np[[2, 1, 0], :, :] #RGB to BGR -> in numpy, if using OpenCV, else not needed. Only if image has colors. + #TODO: Check: could denormalize in the begining in tensor form instead + if denormalize: + img_np = denorm(img_np) #denormalize if needed + if change_range: + img_np = np.clip(data_range*img_np,0,data_range).round() #clip to the data_range + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + #has to be in range (0,255) before changing to np.uint8, else np.float32 + return img_np.astype(imtype) + + + + +#################### +# Prepare Images +#################### +# https://github.com/sunreef/BlindSR/blob/master/src/image_utils.py +def patchify_tensor(features, patch_size, overlap=10): + batch_size, channels, height, width = features.size() + + effective_patch_size = patch_size - overlap + n_patches_height = (height // effective_patch_size) + n_patches_width = (width // effective_patch_size) + + if n_patches_height * effective_patch_size < height: + n_patches_height += 1 + if n_patches_width * effective_patch_size < width: + n_patches_width += 1 + + patches = [] + for b in range(batch_size): + for h in range(n_patches_height): + for w in range(n_patches_width): + patch_start_height = min(h * effective_patch_size, height - patch_size) + patch_start_width = min(w * effective_patch_size, width - patch_size) + patches.append(features[b:b+1, :, + patch_start_height: patch_start_height + patch_size, + patch_start_width: patch_start_width + patch_size]) + return torch.cat(patches, 0) + +def recompose_tensor(patches, full_height, full_width, overlap=10): + + batch_size, channels, patch_size, _ = patches.size() + effective_patch_size = patch_size - overlap + n_patches_height = (full_height // effective_patch_size) + n_patches_width = (full_width // effective_patch_size) + + if n_patches_height * effective_patch_size < full_height: + n_patches_height += 1 + if n_patches_width * effective_patch_size < full_width: + n_patches_width += 1 + + n_patches = n_patches_height * n_patches_width + if batch_size % n_patches != 0: + print("Error: The number of patches provided to the recompose function does not match the number of patches in each image.") + final_batch_size = batch_size // n_patches + + blending_in = torch.linspace(0.1, 1.0, overlap) + blending_out = torch.linspace(1.0, 0.1, overlap) + middle_part = torch.ones(patch_size - 2 * overlap) + blending_profile = torch.cat([blending_in, middle_part, blending_out], 0) + + horizontal_blending = blending_profile[None].repeat(patch_size, 1) + vertical_blending = blending_profile[:, None].repeat(1, patch_size) + blending_patch = horizontal_blending * vertical_blending + + blending_image = torch.zeros(1, channels, full_height, full_width) + for h in range(n_patches_height): + for w in range(n_patches_width): + patch_start_height = min(h * effective_patch_size, full_height - patch_size) + patch_start_width = min(w * effective_patch_size, full_width - patch_size) + blending_image[0, :, patch_start_height: patch_start_height + patch_size, patch_start_width: patch_start_width + patch_size] += blending_patch[None] + + recomposed_tensor = torch.zeros(final_batch_size, channels, full_height, full_width) + if patches.is_cuda: + blending_patch = blending_patch.cuda() + blending_image = blending_image.cuda() + recomposed_tensor = recomposed_tensor.cuda() + patch_index = 0 + for b in range(final_batch_size): + for h in range(n_patches_height): + for w in range(n_patches_width): + patch_start_height = min(h * effective_patch_size, full_height - patch_size) + patch_start_width = min(w * effective_patch_size, full_width - patch_size) + recomposed_tensor[b, :, patch_start_height: patch_start_height + patch_size, patch_start_width: patch_start_width + patch_size] += patches[patch_index] * blending_patch + patch_index += 1 + recomposed_tensor /= blending_image + + return recomposed_tensor + + + + + + + + + + + + + + + + + + +#TODO: imresize could be an independent file (imresize.py) +#################### +# Matlab imresize +#################### + + +# These next functions are all interpolation methods. x is the distance from the left pixel center +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + +def box(x): + return ((-0.5 <= x) & (x < 0.5)) * 1.0 + +def linear(x): + return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) + +def lanczos2(x): + return (((torch.sin(math.pi*x) * torch.sin(math.pi*x/2) + torch.finfo(torch.float32).eps) / + ((math.pi**2 * x**2 / 2) + torch.finfo(torch.float32).eps)) + * (torch.abs(x) < 2)) + +def lanczos3(x): + return (((torch.sin(math.pi*x) * torch.sin(math.pi*x/3) + torch.finfo(torch.float32).eps) / + ((math.pi**2 * x**2 / 3) + torch.finfo(torch.float32).eps)) + * (torch.abs(x) < 3)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply kernel + if (scale < 1) and (antialiasing): + weights = scale * kernel(distance_to_center * scale) + else: + weights = kernel(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +def imresize(img, scale, antialiasing=True, interpolation=None): + # The scale should be the same for H and W + # input: img: CHW RGB [0,1] + # output: CHW RGB [0,1] w/o round + + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + + # Choose interpolation method, each method has the matching kernel size + kernel, kernel_width = { + "cubic": (cubic, 4.0), + "lanczos2": (lanczos2, 4.0), + "lanczos3": (lanczos3, 6.0), + "box": (box, 1.0), + "linear": (linear, 2.0), + None: (cubic, 4.0) # set default interpolation method as cubic + }.get(interpolation) + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i]) + out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i]) + + return out_2 + + +def imresize_np(img, scale, antialiasing=True, interpolation=None): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC BGR [0,1] + # output: HWC BGR [0,1] w/o round + + change_range = False + if img.max() > 1: + img_type = img.dtype + if np.issubdtype(img_type, np.integer): + info = np.iinfo + elif np.issubdtype(img_type, np.floating): + info = np.finfo + img = img/info(img_type).max + change_range = True + + img = torch.from_numpy(img) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + + # Choose interpolation method, each method has the matching kernel size + kernel, kernel_width = { + "cubic": (cubic, 4.0), + "lanczos2": (lanczos2, 4.0), + "lanczos3": (lanczos3, 6.0), + "box": (box, 1.0), + "linear": (linear, 2.0), + None: (cubic, 4.0) # set default interpolation method as cubic + }.get(interpolation) + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i]) + out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i]) + out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i]) + + out_2 = out_2.numpy().clip(0,1) + + if change_range: + out_2 = out_2*info(img_type).max #uint8 = 255 + out_2 = out_2.astype(img_type) + + return out_2 + + +if __name__ == '__main__': + # test imresize function + # read images + img = cv2.imread('test.png') + img = img * 1.0 / 255 + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + # imresize + scale = 1 / 4 + import time + total_time = 0 + for i in range(10): + start_time = time.time() + rlt = imresize(img, scale, antialiasing=True) + use_time = time.time() - start_time + total_time += use_time + print('average time: {}'.format(total_time / 10)) + + import torchvision.utils + torchvision.utils.save_image( + (rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, normalize=False)