Skip to content

Commit

Permalink
get_format_from_npy: Update to match changes
Browse files Browse the repository at this point in the history
  • Loading branch information
LightArrowsEXE committed Oct 8, 2024
1 parent 6c3fbbf commit b8d0a2d
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions lvsfunc/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_format_from_npy(frame_data: np.ndarray, func_except: FuncExceptT | None
If every array has the same shape, it's assumed to be YUV 4:4:4.
If you output RGB data, you may have to convert it back.
If either U or V arrays are None, it's assumed to be GRAY.
If the array has only one plane, it's assumed to be GRAY.
:param frame_data: The numpy array data to guess the format from.
:param func_except: Function returned for custom error handling.
Expand All @@ -30,22 +30,19 @@ def get_format_from_npy(frame_data: np.ndarray, func_except: FuncExceptT | None

func = fallback(func_except, get_format_from_npy)

if isinstance(frame_data, dict):
y_data = frame_data['plane_0']
num_planes = len(frame_data)
elif isinstance(frame_data, np.ndarray):
y_data = frame_data
num_planes = y_data.ndim if y_data.ndim <= 2 else y_data.shape[2]
else:
if not isinstance(frame_data, np.ndarray):
raise NumpyArrayLoadError(f"Unsupported data type: {type(frame_data)}", func)

num_planes = frame_data.shape[0] if frame_data.ndim > 2 else 1
y_data = frame_data[0] if num_planes > 1 else frame_data

bit_depth = 32 if y_data.dtype == np.float32 else y_data.itemsize * 8

if num_planes == 1:
return get_video_format(depth(core.std.BlankClip(format=vs.GRAY8, keep=True), bit_depth))

y_shape = y_data.shape[:2]
u_shape = frame_data['plane_1'].shape[:2] if isinstance(frame_data, dict) else y_data[:, :, 1].shape
y_shape = y_data.shape
u_shape = frame_data[1].shape

subsampling_map = {
(1, 1): vs.YUV444P8,
Expand Down

0 comments on commit b8d0a2d

Please sign in to comment.