Skip to content

Commit

Permalink
v1.0.3 - Convert back to original colorspace after execution, more co…
Browse files Browse the repository at this point in the history
…mments, pep8
  • Loading branch information
rlaphoenix committed Oct 22, 2019
1 parent 46998a2 commit 46779ae
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 49 deletions.
11 changes: 6 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,24 @@

setup(
name="vsgan",
version="1.0.2",
version="1.0.3",
author="PRAGMA",
author_email="[email protected]",
description="VapourSynth GAN Implementation using RRDBNet, based on ESRGAN's implementation",
license='MIT',
license="MIT",
long_description=readme,
long_description_content_type="text/markdown",
url="https://gitlab.com/imPRAGMA/VSGAN",
packages=find_packages(),
install_requires=[
'numpy',
'torch'
"numpy",
"torch",
"vapoursynth"
],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
python_requires=">=3.6",
)
109 changes: 65 additions & 44 deletions vsgan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@
# Created by PRAGMA #
# https://github.com/imPRAGMA/VSGAN #
#####################################################################
# Dependencies: #
# - PIP Module: numpy #
# - RRDBNet_arch.py from the xinntao's ESRGAN repo #
# (https://github.com/xinntao/ESRGAN/blob/master/RRDBNet_arch.py) #
# - PyTorch: https://pytorch.org/get-started/locally #
# - mvsfunc: https://github.com/HomeOfVapourSynthEvolution/mvsfunc #
#####################################################################
# For more details, consult the README.md #
#####################################################################

Expand All @@ -19,12 +12,13 @@
import mvsfunc
import torch

# - Torch&Cuda, RRDBNet Arch, and Model
# - Torch & CUDA, RRDBNet Arch, and Model
DEVICE = None
MODEL = None


# - Start VSGAN operations
def Start(clip, model, scale, device='cuda', old_arch=False):
def start(clip, model, scale, device="cuda", old_arch=False):
global DEVICE
global MODEL
# Setup a device, use CPU instead if cuda isn't available
Expand All @@ -33,63 +27,90 @@ def Start(clip, model, scale, device='cuda', old_arch=False):
DEVICE = torch.device(device)
# select the arch to be used based on old_arch parameter
if old_arch:
from . import RRDBNet_arch_old as arch
MODEL = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=scale, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
from . import RRDBNet_arch_old as Arch
MODEL = Arch.RRDB_Net(
3, 3, 64, 23,
gc=32,
upscale=scale,
norm_type=None,
act_type="leakyrelu",
mode="CNA",
res_scale=1,
upsample_mode="upconv"
)
else:
from . import RRDBNet_arch as arch
MODEL = arch.RRDBNet(3, 3, 64, 23, gc=32)
from . import RRDBNet_arch as Arch
MODEL = Arch.RRDBNet(3, 3, 64, 23, gc=32)
# load the model with selected arch
MODEL.load_state_dict(torch.load(model), strict=True)
MODEL.eval()
# tie model to pytorch device
# tie model to PyTorch device
MODEL = MODEL.to(DEVICE)
# remember the clip's original format
orig_format = clip.format
original_format = clip.format
# convert clip to RGB24 as it cannot read any other color space
buffer = mvsfunc.ToRGB(clip, depth=8)
#buffer = core.resize.Point(clip, format=vs.RGB24)
buffer = mvsfunc.ToRGB(clip, depth=8) # expecting RGB24 8bit
# take a frame when being used by VapourSynth and send it to the execute function
# returns the edited frame in a 1 frame clip based on the trained model
buffer = core.std.FrameEval(
core.std.BlankClip(
buffer,
width=clip.width*scale,
height=clip.height*scale
width=clip.width * scale,
height=clip.height * scale
),
functools.partial(
Execute,
execute,
clip=buffer
)
)
# Convert back to the original color space and return it to sender
# Convert back to the original color space
if original_format.color_family != buffer.format.color_family:
if original_format.color_family == vs.ColorFamily.RGB:
buffer = mvsfunc.ToRGB(buffer)
if original_format.color_family == vs.ColorFamily.YUV:
buffer = mvsfunc.ToYUV(buffer, css=original_format.name[3:6])
# return the new frame/(s)
return buffer
# let's not convert back to original since changes with colorspace conversion via mvsfunc
#core.resize.Point(buffer, format=orig_format, matrix_s="709") #should matrix be gotten from original clip?


# - Deals with the number crunching
def Execute(n, clip):
def execute(n, clip):
# get the frame being used
frame = clip.get_frame(n)
# convert it to a numpy readable array
def frameToNumpyArray(frame, planes):
return np.dstack([np.array(frame.get_read_array(i), copy=False) for i in reversed(range(planes))])
numpy_array = frameToNumpyArray(frame, clip.format.num_planes) #num_planes is expected to always be 3
# convert it to a numpy readable array for PyTorch
numpy_array = np.dstack(
[np.array(frame.get_read_array(i), copy=False) for i in reversed(range(clip.format.num_planes))]
)
# use the model's trained data against the images planes
with torch.no_grad():
s = MODEL(torch.from_numpy(np.transpose((numpy_array * 1.0 / 255)[:, :, [2, 1, 0]], (2, 0, 1))).float().unsqueeze(0).to(DEVICE)).data.squeeze().float().cpu().clamp_(0, 1).numpy()
s = MODEL(
torch.from_numpy(
np.transpose((numpy_array * 1.0 / 255)[:, :, [2, 1, 0]], (2, 0, 1))
).float().unsqueeze(0).to(DEVICE)
).data.squeeze().float().cpu().clamp_(0, 1).numpy()
s = (np.transpose(s[[2, 1, 0], :, :], (1, 2, 0)) * 255.0).round()
if len(s.shape) <= 3: s = s.reshape([1] + list(s.shape))
planes = s.shape[-1]
def conv(n, f):
fout = f.copy()
idx = -1
for p in reversed(range(planes)):
idx += 1
d = np.array(fout.get_write_array(p), copy=False)
np.copyto(d, s[n, :, :, idx], casting="unsafe")
del d
return fout
# create a blank clip with the plane types returned by the model
clip = core.std.BlankClip(None, s.shape[-2], s.shape[-3], vs.RGB24, s.shape[-4])
# take the blank clip and insert the new data into the planes (reversed) and return it back to sender
return core.std.ModifyFrame(clip, clip, conv)
if len(s.shape) <= 3:
s = s.reshape([1] + list(s.shape))
plane_count = s.shape[-1] # expecting 3 for RGB24 input

def replace_planes(n, f):
frame = f.copy()
for plane_num, p in enumerate(reversed(range(plane_count))):
# todo ; any better way to do this without storing the np.array in a variable?
# todo ; perhaps some way to directly copy it to s?
d = np.array(frame.get_write_array(p), copy=False)
# copy the value of d, into s[frame_num, :, :, plane_num]
np.copyto(d, s[n, :, :, plane_num], casting="unsafe")
# delete the d variable from memory
del d
return frame
# this is a clip (or array buffer for frames) that we will insert the GAN'd frames into
clip = core.std.BlankClip(
clip=None,
width=s.shape[-2],
height=s.shape[-3],
format=vs.PresetFormat[clip.format.name],
length=s.shape[-4]
)
# take the blank clip and insert the new data into the planes and return it back to sender
return core.std.ModifyFrame(clip=clip, clips=clip, selector=replace_planes)

0 comments on commit 46779ae

Please sign in to comment.