Skip to content

Commit

Permalink
v1.0.5 - Now uses a VSGAN class, and more readable code
Browse files Browse the repository at this point in the history
  • Loading branch information
rlaphoenix committed Oct 22, 2019
1 parent 3ff4c9b commit b44c0ea
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 118 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="vsgan",
version="1.0.4",
version="1.0.5",
author="PRAGMA",
author_email="[email protected]",
description="VapourSynth GAN Implementation using RRDBNet, based on ESRGAN's implementation",
Expand Down
253 changes: 136 additions & 117 deletions vsgan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,141 +5,160 @@
# For more details, consult the README.md #
#####################################################################

from vapoursynth import core
import vapoursynth as vs
import numpy as np
import functools

import mvsfunc
import numpy as np
import torch
import vapoursynth as vs
from vapoursynth import core

# - Torch & CUDA, RRDBNet Arch, and Model
DEVICE = None
MODEL = None

class VSGAN:

# - Start VSGAN operations
def start(clip, model, scale, device="cuda", chunk=False, old_arch=False):
global DEVICE
global MODEL
# Setup a device, use CPU instead if cuda isn't available
if not torch.cuda.is_available():
device = 'cpu'
DEVICE = torch.device(device)
# select the arch to be used based on old_arch parameter
if old_arch:
from . import RRDBNet_arch_old as Arch
MODEL = Arch.RRDB_Net(
3, 3, 64, 23,
gc=32,
upscale=scale,
norm_type=None,
act_type="leakyrelu",
mode="CNA",
res_scale=1,
upsample_mode="upconv"
)
else:
from . import RRDBNet_arch as Arch
MODEL = Arch.RRDBNet(3, 3, 64, 23, gc=32)
# load the model with selected arch
MODEL.load_state_dict(torch.load(model), strict=True)
MODEL.eval()
# tie model to PyTorch device
MODEL = MODEL.to(DEVICE)
# remember the clip's original format
original_format = clip.format
# convert clip to RGB24 as it cannot read any other color space
buffer = mvsfunc.ToRGB(clip, depth=8) # expecting RGB24 8bit
if chunk:
crops = {
"left": core.std.CropRel(buffer, left=0, top=0, right=buffer.width / 2, bottom=0),
"right": core.std.CropRel(buffer, left=buffer.width / 2, top=0, right=0, bottom=0)
}
# top left, bottom left, top right, bottom right
def __init__(self, device="cuda"):
self.torch_device = torch.device(device if torch.cuda.is_available() else "cpu")
# Stubs
self.model_file = None
self.model_scale = None
self.rrdb_net_model = None

def load_model(self, model, scale, old_arch=False):
self.model_file = model
self.model_scale = scale
self.rrdb_net_model = self.get_rrdb_net_arch(old_arch)
self.rrdb_net_model.load_state_dict(torch.load(self.model_file), strict=True)
self.rrdb_net_model.eval()
self.rrdb_net_model = self.rrdb_net_model.to(self.torch_device)

def run(self, clip, chunk=False):
# remember the clip's original format
original_format = clip.format
# convert clip to RGB24 as it cannot read any other color space
buffer = mvsfunc.ToRGB(clip, depth=8) # expecting RGB24 8bit
# send the clip array to execute()
results = []
for crop in [
core.std.CropRel(crops["left"], left=0, top=0, right=0, bottom=crops["left"].height / 2),
core.std.CropRel(crops["left"], left=0, top=crops["left"].height / 2, right=0, bottom=0),
core.std.CropRel(crops["right"], left=0, top=0, right=0, bottom=crops["right"].height / 2),
core.std.CropRel(crops["right"], left=0, top=crops["right"].height / 2, right=0, bottom=0)
]:
for c in self.chunk_clip(buffer) if chunk else [buffer]:
results.append(core.std.FrameEval(
core.std.BlankClip(
crop,
width=crop.width * scale,
height=crop.height * scale
clip=c,
width=c.width * self.model_scale,
height=c.height * self.model_scale
),
functools.partial(
execute,
clip=crop
self.execute,
clip=c
)
))
# if chunked, rejoin the chunked clips otherwise return the result
buffer = core.std.StackHorizontal([
core.std.StackVertical([results[0], results[1]]),
core.std.StackVertical([results[2], results[3]])
])
else:
# take a frame when being used by VapourSynth and send it to the execute function
# returns the edited frame in a 1 frame clip based on the trained model
buffer = core.std.FrameEval(
core.std.BlankClip(
buffer,
width=buffer.width * scale,
height=buffer.height * scale
),
functools.partial(
execute,
clip=buffer
]) if chunk else results[0]
# Convert back to the original color space
if original_format.color_family != buffer.format.color_family:
if original_format.color_family == vs.ColorFamily.RGB:
buffer = mvsfunc.ToRGB(buffer)
if original_format.color_family == vs.ColorFamily.YUV:
buffer = mvsfunc.ToYUV(buffer, css=original_format.name[3:6])
# return the new frame
return buffer

def get_rrdb_net_arch(self, old_arch):
"""
Import Old or Current Era RRDB Net Architecture
"""
if old_arch:
from . import RRDBNet_arch_old as Arch
return Arch.RRDB_Net(
3, 3, 64, 23,
gc=32,
upscale=self.model_scale,
norm_type=None,
act_type="leakyrelu",
mode="CNA",
res_scale=1,
upsample_mode="upconv"
)
else:
from . import RRDBNet_arch as Arch
return Arch.RRDBNet(3, 3, 64, 23, gc=32)

@staticmethod
def chunk_clip(clip):
# split the clip horizontally into 2 images
crops = {
"left": core.std.CropRel(clip, left=0, top=0, right=clip.width / 2, bottom=0),
"right": core.std.CropRel(clip, left=clip.width / 2, top=0, right=0, bottom=0)
}
# split each of the 2 images from above, vertically, into a further 2 images (totalling 4 images per frame)
# top left, bottom left, top right, bottom right
return [
core.std.CropRel(crops["left"], left=0, top=0, right=0, bottom=crops["left"].height / 2),
core.std.CropRel(crops["left"], left=0, top=crops["left"].height / 2, right=0, bottom=0),
core.std.CropRel(crops["right"], left=0, top=0, right=0, bottom=crops["right"].height / 2),
core.std.CropRel(crops["right"], left=0, top=crops["right"].height / 2, right=0, bottom=0)
]

@staticmethod
def cv2_imread(frame, plane_count):
"""
Alternative to cv2.imread() that will directly read images to a numpy array
"""
return np.dstack(
[np.array(frame.get_read_array(i), copy=False) for i in reversed(range(plane_count))]
)

@staticmethod
def cv2_imwrite(image, out_color_space="RGB24"):
"""
Alternative to cv2.imwrite() that will convert the data into an image readable by VapourSynth
"""
if len(image.shape) <= 3:
image = image.reshape([1] + list(image.shape))
# Define the shapes items
plane_count = image.shape[-1]
image_width = image.shape[-2]
image_height = image.shape[-3]
image_length = image.shape[-4]
# this is a clip (or array buffer for frames) that we will insert the GAN'd frames into
buffer = core.std.BlankClip(
clip=None,
width=image_width,
height=image_height,
format=vs.PresetFormat[out_color_space],
length=image_length
)
# Convert back to the original color space
if original_format.color_family != buffer.format.color_family:
if original_format.color_family == vs.ColorFamily.RGB:
buffer = mvsfunc.ToRGB(buffer)
if original_format.color_family == vs.ColorFamily.YUV:
buffer = mvsfunc.ToYUV(buffer, css=original_format.name[3:6])
# return the new frame/(s)
return buffer

def replace_planes(n, f):
frame = f.copy()
for i, plane_num in enumerate(reversed(range(plane_count))):
# todo ; any better way to do this without storing the np.array in a variable?
# todo ; perhaps some way to directly copy it to s?
d = np.array(frame.get_write_array(plane_num), copy=False)
# copy the value of d, into s[frame_num, :, :, plane_num]
np.copyto(d, image[n, :, :, i], casting="unsafe")
# delete the d variable from memory
del d
return frame

# - Deals with the number crunching
def execute(n, clip):
# get the frame being used
frame = clip.get_frame(n)
# convert it to a numpy readable array for PyTorch
numpy_array = np.dstack(
[np.array(frame.get_read_array(i), copy=False) for i in reversed(range(clip.format.num_planes))]
)
# use the model's trained data against the images planes
with torch.no_grad():
s = MODEL(
torch.from_numpy(
np.transpose((numpy_array * 1.0 / 255)[:, :, [2, 1, 0]], (2, 0, 1))
).float().unsqueeze(0).to(DEVICE)
).data.squeeze().float().cpu().clamp_(0, 1).numpy()
s = (np.transpose(s[[2, 1, 0], :, :], (1, 2, 0)) * 255.0).round()
if len(s.shape) <= 3:
s = s.reshape([1] + list(s.shape))
plane_count = s.shape[-1] # expecting 3 for RGB24 input
# take the blank clip and insert the new data into the planes and return it back to sender
return core.std.ModifyFrame(clip=buffer, clips=buffer, selector=replace_planes)

def replace_planes(n, f):
frame = f.copy()
for plane_num, p in enumerate(reversed(range(plane_count))):
# todo ; any better way to do this without storing the np.array in a variable?
# todo ; perhaps some way to directly copy it to s?
d = np.array(frame.get_write_array(p), copy=False)
# copy the value of d, into s[frame_num, :, :, plane_num]
np.copyto(d, s[n, :, :, plane_num], casting="unsafe")
# delete the d variable from memory
del d
return frame
# this is a clip (or array buffer for frames) that we will insert the GAN'd frames into
clip = core.std.BlankClip(
clip=None,
width=s.shape[-2],
height=s.shape[-3],
format=vs.PresetFormat[clip.format.name],
length=s.shape[-4]
)
# take the blank clip and insert the new data into the planes and return it back to sender
return core.std.ModifyFrame(clip=clip, clips=clip, selector=replace_planes)
def execute(self, n, clip):
"""
Essentially the same as ESRGAN, except it replaces the cv2 functions with ones geared towards VapourSynth
https://github.com/xinntao/ESRGAN/blob/master/test.py#L26
"""
# get the frame being used
frame = clip.get_frame(n)
img = self.cv2_imread(frame=frame, plane_count=clip.format.num_planes)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_lr = img.unsqueeze(0)
img_lr = img_lr.to(self.torch_device)
with torch.no_grad():
output = self.rrdb_net_model(img_lr).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
return self.cv2_imwrite(image=output, out_color_space=clip.format.name)

0 comments on commit b44c0ea

Please sign in to comment.