From e04971aa8b4396a79b00170f7f0df2c9d8665427 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Mon, 6 Nov 2023 13:15:51 -0800 Subject: [PATCH 1/2] fixies Signed-off-by: Ayush Kamat --- latch_cli/snakemake/serialize.py | 28 ++++++++++++++++++++++------ latch_cli/snakemake/workflow.py | 2 +- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/latch_cli/snakemake/serialize.py b/latch_cli/snakemake/serialize.py index 397cc87b..63354480 100644 --- a/latch_cli/snakemake/serialize.py +++ b/latch_cli/snakemake/serialize.py @@ -1,10 +1,11 @@ +import json import os import sys import textwrap import traceback from pathlib import Path from textwrap import dedent -from typing import Dict, List, Optional, Set, Union, get_args +from typing import Any, Dict, List, Optional, Set, Union, get_args import click from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan @@ -91,12 +92,20 @@ def ensure_snakemake_metadata_exists(): # todo(maximsmol): this needs to run in a subprocess because it pollutes globals class SnakemakeWorkflowExtractor(Workflow): - def __init__(self, pkg_root: Path, snakefile: Path): - super().__init__(snakefile=snakefile) + def __init__( + self, + pkg_root: Path, + snakefile: Path, + non_blob_parameters: Optional[Dict[str, Any]] = None, + ): + super().__init__(snakefile=snakefile, overwrite_config=non_blob_parameters) self.pkg_root = pkg_root self._old_cwd = "" + if non_blob_parameters is not None: + print(f"Config: {json.dumps(non_blob_parameters, indent=2)}") + def extract_dag(self): targets: List[str] = ( [self.default_target] if self.default_target is not None else [] @@ -169,7 +178,9 @@ def __exit__(self, typ, value, tb): def snakemake_workflow_extractor( - pkg_root: Path, snakefile: Path + pkg_root: Path, + snakefile: Path, + non_blob_parameters: Optional[Dict[str, Any]] = None, ) -> SnakemakeWorkflowExtractor: snakefile = snakefile.resolve() @@ -184,6 +195,7 @@ def snakemake_workflow_extractor( extractor = SnakemakeWorkflowExtractor( pkg_root=pkg_root, snakefile=snakefile, + non_blob_parameters=non_blob_parameters, ) with extractor: extractor.include( @@ -201,12 +213,16 @@ def extract_snakemake_workflow( jit_wf_version: str, jit_exec_display_name: str, local_to_remote_path_mapping: Optional[Dict[str, str]] = None, + non_blob_parameters: Optional[Dict[str, Any]] = None, ) -> SnakemakeWorkflow: - extractor = snakemake_workflow_extractor(pkg_root, snakefile) + extractor = snakemake_workflow_extractor(pkg_root, snakefile, non_blob_parameters) with extractor: dag = extractor.extract_dag() wf = SnakemakeWorkflow( - dag, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping + dag, + jit_wf_version, + jit_exec_display_name, + local_to_remote_path_mapping, ) wf.compile() diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 170d69fb..197abb42 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -492,7 +492,7 @@ def get_fn_code( print(f"JIT Workflow Version: {{jit_wf_version}}") print(f"JIT Execution Display Name: {{jit_exec_display_name}}") - wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping) + wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping, non_blob_parameters) wf_name = wf.name generate_snakemake_entrypoint(wf, pkg_root, snakefile, {repr(remote_output_url)}, non_blob_parameters) From 9223fcc8e04a3344f6643a2619ab83659b33dce8 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Mon, 6 Nov 2023 13:47:04 -0800 Subject: [PATCH 2/2] fix single task config handling Signed-off-by: Ayush Kamat --- latch_cli/snakemake/single_task_snakemake.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/latch_cli/snakemake/single_task_snakemake.py b/latch_cli/snakemake/single_task_snakemake.py index 505392af..c12497f7 100644 --- a/latch_cli/snakemake/single_task_snakemake.py +++ b/latch_cli/snakemake/single_task_snakemake.py @@ -21,6 +21,7 @@ Shell, ) from snakemake.rules import Rule as RRule +from snakemake.workflow import Workflow as WWorkflow sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) @@ -40,9 +41,15 @@ def eprint(x: str) -> None: non_blob_parameters = data.get("non_blob_parameters", {}) -# todo(ayush): do this without overwriting globals -sw = sys.modules["snakemake.workflow"] -setattr(sw, "config", non_blob_parameters) +old_workflow_init = WWorkflow.__init__ + + +def new_init(self: WWorkflow, *args, **kwargs): + kwargs["overwrite_config"] = non_blob_parameters + old_workflow_init(self, *args, **kwargs) + + +WWorkflow.__init__ = new_init def eprint_named_list(xs):