Skip to content

Commit

Permalink
[FEATURE] Allow YAML maps and lists in overrides, convert to script f…
Browse files Browse the repository at this point in the history
…ormat (#956)

Adds the ability to create map and list-based override variables. 

For example, you can now create lists like this:
```
overrides:
  urls:
    - "https://...1"
    - "https://...2"
```
which is equivalent to:
```
overrides:
  urls: >-
    {
      [
        "https://...1",
        "https://...2",
      ]
    }
```

Likewise, maps can now look like:
```
overrides:
  music_video_category:
    concerts:
      - "https://...1"
      - "https://...2"
    interviews:
      - "https://...3"
```
which is equivalent to:
```
overrides:
  music_video_category: >-
    {
      "concerts": [
        "https://...1",
        "https://...2"
      ],
      "interviews": [
        "https://...3"
      ]
    }
```
  • Loading branch information
jmbannon authored Jun 3, 2024
1 parent 6e31ad3 commit 5e335b1
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/ytdl_sub/config/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from ytdl_sub.utils.exceptions import ValidationException
from ytdl_sub.utils.script import ScriptUtils
from ytdl_sub.utils.scriptable import Scriptable
from ytdl_sub.validators.string_formatter_validators import DictFormatterValidator
from ytdl_sub.validators.string_formatter_validators import StringFormatterValidator
from ytdl_sub.validators.string_formatter_validators import UnstructuredDictFormatterValidator


class Overrides(DictFormatterValidator, Scriptable):
class Overrides(UnstructuredDictFormatterValidator, Scriptable):
"""
Allows you to define variables that can be used in any EntryFormatter or OverridesFormatter.
Expand Down Expand Up @@ -51,11 +51,11 @@ class Overrides(DictFormatterValidator, Scriptable):

@classmethod
def partial_validate(cls, name: str, value: Any) -> None:
dict_formatter = DictFormatterValidator(name=name, value=value)
dict_formatter = UnstructuredDictFormatterValidator(name=name, value=value)
_ = [parse(format_string) for format_string in dict_formatter.dict_with_format_strings]

def __init__(self, name, value):
DictFormatterValidator.__init__(self, name, value)
UnstructuredDictFormatterValidator.__init__(self, name, value)
Scriptable.__init__(self, initialize_base_script=True)

for key in self._keys:
Expand Down
2 changes: 2 additions & 0 deletions src/ytdl_sub/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,8 @@ def _parse_map(self) -> UnresolvedMap:
raise MAP_KEY_WITH_NO_VALUE
if isinstance(key, NonHashable):
raise MAP_KEY_NOT_HASHABLE
if isinstance(key, BuiltInFunction) and issubclass(key.output_type(), NonHashable):
raise MAP_KEY_NOT_HASHABLE
if len(value_args) > 1:
raise MAP_KEY_MULTIPLE_VALUES

Expand Down
3 changes: 3 additions & 0 deletions src/ytdl_sub/script/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,6 @@ def resolve(
raise FunctionRuntimeException(
f"Runtime error occurred when executing the function %{self.name}: {str(exc)}"
) from exc

def __hash__(self):
return hash((self.name, *self.args))
87 changes: 85 additions & 2 deletions src/ytdl_sub/utils/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
from typing import Any
from typing import Dict

from ytdl_sub.script.parser import parse
from ytdl_sub.script.script import _is_function
from ytdl_sub.script.types.array import UnresolvedArray
from ytdl_sub.script.types.function import BuiltInFunction
from ytdl_sub.script.types.function import Function
from ytdl_sub.script.types.map import UnresolvedMap
from ytdl_sub.script.types.resolvable import Argument
from ytdl_sub.script.types.resolvable import Boolean
from ytdl_sub.script.types.resolvable import Float
from ytdl_sub.script.types.resolvable import Integer
from ytdl_sub.script.types.resolvable import String
from ytdl_sub.script.types.variable import Variable
from ytdl_sub.script.utils.exceptions import UNREACHABLE

# pylint: disable=too-many-return-statements


class ScriptUtils:
Expand All @@ -28,12 +42,12 @@ def to_script(cls, value: Any) -> str:
out = ""
elif isinstance(value, str):
out = value
elif isinstance(value, bool):
out = f"{{%bool({value})}}"
elif isinstance(value, int):
out = f"{{%int({value})}}"
elif isinstance(value, float):
out = f"{{%float({value})}}"
elif isinstance(value, bool):
out = f"{{%bool({value})}}"
else:
dumped_json = json.dumps(value, ensure_ascii=False, sort_keys=True)
# Remove triple-single-quotes from JSON to avoid parsing issues
Expand All @@ -43,6 +57,75 @@ def to_script(cls, value: Any) -> str:

return out

@classmethod
def _to_script_argument(cls, value: Any) -> Argument:
# Handle simple types as above
if value is None or (isinstance(value, str) and value == ""):
return String("")
if isinstance(value, str):
ast = parse(text=value).ast
if len(ast) == 1:
return ast[0]
return BuiltInFunction(
name="concat", args=[BuiltInFunction(name="string", args=[arg]) for arg in ast]
)
if isinstance(value, bool):
return Boolean(value)
if isinstance(value, int):
return Integer(value)
if isinstance(value, float):
return Float(value)
if isinstance(value, list):
return UnresolvedArray([cls._to_script_argument(val) for val in value])
if isinstance(value, dict):
return UnresolvedMap(
{
cls._to_script_argument(key): cls._to_script_argument(val)
for key, val in value.items()
}
)

raise UNREACHABLE

@classmethod
def _to_script_code(cls, arg: Argument, top_level: bool = False) -> str:
if not top_level and isinstance(arg, (Integer, Boolean, Float)):
return str(arg.native)

if isinstance(arg, String):
if arg.native == "":
return "" if top_level else "''"
return arg.native if top_level else f"'''{arg.native}'''"

if isinstance(arg, Integer):
out = f"%int({arg.native})"
elif isinstance(arg, Boolean):
out = f"%bool({arg.native})"
elif isinstance(arg, Float):
out = f"%float({arg.native})"
elif isinstance(arg, UnresolvedArray):
out = f"[ {', '.join(cls._to_script_code(val) for val in arg.value)} ]"
elif isinstance(arg, UnresolvedMap):
kv_list = (
f"{cls._to_script_code(key)}: {cls._to_script_code(val)}"
for key, val in arg.value.items()
)
out = f"{{ {', '.join(kv_list)} }}"
elif isinstance(arg, Variable):
out = arg.name
elif isinstance(arg, Function):
out = f"%{arg.name}( {', '.join(cls._to_script_code(val) for val in arg.args)} )"
else:
raise UNREACHABLE
return f"{{ {out} }}" if top_level else out

@classmethod
def to_native_script(cls, value: Any) -> str:
"""
Converts any JSON-compatible value into equivalent script syntax
"""
return cls._to_script_code(cls._to_script_argument(value), top_level=True)

@classmethod
def bool_formatter_output(cls, output: str) -> bool:
"""
Expand Down
13 changes: 13 additions & 0 deletions src/ytdl_sub/validators/string_formatter_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ytdl_sub.script.utils.exceptions import ScriptVariableNotResolved
from ytdl_sub.script.utils.exceptions import UserException
from ytdl_sub.utils.exceptions import StringFormattingVariableNotFoundException
from ytdl_sub.utils.script import ScriptUtils
from ytdl_sub.validators.validators import DictValidator
from ytdl_sub.validators.validators import ListValidator
from ytdl_sub.validators.validators import LiteralDictValidator
Expand Down Expand Up @@ -144,6 +145,18 @@ class OverridesDictFormatterValidator(DictFormatterValidator):
_key_validator = OverridesStringFormatterValidator


class UnstructuredDictFormatterValidator(DictFormatterValidator):
def __init__(self, name, value):
# Convert the unstructured-ness into a script
if isinstance(value, dict):
value = {key: ScriptUtils.to_native_script(val) for key, val in value.items()}
super().__init__(name, value)


class UnstructuredOverridesDictFormatterValidator(UnstructuredDictFormatterValidator):
_key_validator = OverridesStringFormatterValidator


def to_variable_dependency_format_string(script: Script, parsed_format_string: SyntaxTree) -> str:
"""
Create a dummy format string that contains all variable deps as a string.
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/youtube/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def single_video_preset_dict(output_directory):
"overrides": {
"music_video_artist": "JMC",
"music_video_directory": output_directory,
"test_override_map": {"{music_video_artist}": "{music_video_directory}"},
"test_override_map_get": "{ %map_get(test_override_map, music_video_artist) }",
},
}

Expand Down
1 change: 0 additions & 1 deletion tests/unit/config/test_config_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def test_config_file_working_dir_home_dir(self):
"preset_dict",
[
{"overrides": "not a dict"},
{"overrides": {"nested": {"dict": "value"}}},
{"overrides": ["list"]},
],
)
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/script/types/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,16 @@ def test_map_key_is_non_hashable_variable(self):
"key_variable": "{['non-hashable']}",
}
).resolve()

def test_map_key_is_function(self):
assert Script(
{
"dict": "{{ %concat('hi', %string(' world')) : 'value' }}",
"key_variable": "hashable",
}
).resolve() == ScriptOutput(
{
"key_variable": String("hashable"),
"dict": Map(value={String(value="hi world"): String(value="value")}),
}
)
28 changes: 28 additions & 0 deletions tests/unit/utils/test_script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import pytest
from unit.script.conftest import single_variable_output

from ytdl_sub.script.parser import parse
from ytdl_sub.script.types.function import BuiltInFunction
from ytdl_sub.script.types.map import UnresolvedMap
from ytdl_sub.script.types.resolvable import String
from ytdl_sub.script.types.syntax_tree import SyntaxTree
from ytdl_sub.script.types.variable import Variable
from ytdl_sub.utils.script import ScriptUtils


Expand Down Expand Up @@ -51,3 +57,25 @@ def test_dict_to_script(self):
)
def test_bool_formatter_output(self, input_str: str, expected_output: bool):
assert ScriptUtils.bool_formatter_output(input_str) == expected_output

def test_to_syntax_tree(self):
out = ScriptUtils.to_native_script(
{"{var_a}": "{var_b}", "static_a": "string with {var_c} in it"}
)
assert parse(out) == SyntaxTree(
ast=[
UnresolvedMap(
value={
Variable(name="var_a"): Variable(name="var_b"),
String(value="static_a"): BuiltInFunction(
name="concat",
args=[
BuiltInFunction(name="string", args=[String(value="string with ")]),
BuiltInFunction(name="string", args=[Variable(name="var_c")]),
BuiltInFunction(name="string", args=[String(value=" in it")]),
],
),
}
)
]
)
49 changes: 49 additions & 0 deletions tests/unit/validators/test_string_formatter_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from ytdl_sub.validators.string_formatter_validators import OverridesDictFormatterValidator
from ytdl_sub.validators.string_formatter_validators import OverridesStringFormatterValidator
from ytdl_sub.validators.string_formatter_validators import StringFormatterValidator
from ytdl_sub.validators.string_formatter_validators import UnstructuredDictFormatterValidator
from ytdl_sub.validators.string_formatter_validators import (
UnstructuredOverridesDictFormatterValidator,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -75,3 +79,48 @@ def test_validates_values(self, dict_validator_class, expected_formatter_class):
"key1": key1_format_string,
"key2": key2_format_string,
}


class TestUnstructuredDictFormatterValidator(object):
@pytest.mark.parametrize(
"dict_validator_class, expected_formatter_class",
[
(UnstructuredDictFormatterValidator, StringFormatterValidator),
(UnstructuredOverridesDictFormatterValidator, OverridesStringFormatterValidator),
],
)
def test_validates_values(self, dict_validator_class, expected_formatter_class):
key1_format_string = "string with {variable}"
key2_format_string = "no variables"
key3_int = 3
key4_float = 4.132
key5_bool = True
key6_map = {"{variable}_key": "value", "static_key": "{variable}_value"}
key7_list = ["list_1", "list_{variable_2}"]
key8_many_vars = "string {variable1} with multiple {variable2}"
validator = dict_validator_class(
name="validator",
value={
"key1": key1_format_string,
"key2": key2_format_string,
"key3": key3_int,
"key4": key4_float,
"key5": key5_bool,
"key6": key6_map,
"key7": key7_list,
"key8": key8_many_vars,
},
)

assert len(validator.dict) == 8
assert all(isinstance(val, expected_formatter_class) for val in validator.dict.values())
assert validator.dict_with_format_strings == {
"key1": "{ %concat( %string( '''string with ''' ), %string( variable ) ) }",
"key2": "no variables",
"key3": "{ %int(3) }",
"key4": "{ %float(4.132) }",
"key5": "{ %bool(True) }",
"key6": "{ { %concat( %string( variable ), %string( '''_key''' ) ): '''value''', '''static_key''': %concat( %string( variable ), %string( '''_value''' ) ) } }",
"key7": "{ [ '''list_1''', %concat( %string( '''list_''' ), %string( variable_2 ) ) ] }",
"key8": "{ %concat( %string( '''string ''' ), %string( variable1 ), %string( ''' with multiple ''' ), %string( variable2 ) ) }",
}

0 comments on commit 5e335b1

Please sign in to comment.