Skip to content

Commit

Permalink
Fix _whichmodule with multiprocessing (#529)
Browse files Browse the repository at this point in the history

Co-authored-by: Olivier Grisel <[email protected]>
  • Loading branch information
hendrikmakait and ogrisel authored Apr 8, 2024
1 parent d003266 commit f111f7a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
dynamic functions and classes.
([PR #524](https://github.com/cloudpipe/cloudpickle/pull/524))

- Fix a problem with the joint usage of cloudpickle's `_whichmodule` and
`multiprocessing`.
([PR #529](https://github.com/cloudpipe/cloudpickle/pull/529))

3.0.0
=====
Expand Down
1 change: 1 addition & 0 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def _whichmodule(obj, name):
# sys.modules
if (
module_name == "__main__"
or module_name == "__mp_main__"
or module is None
or not isinstance(module, types.ModuleType)
):
Expand Down
24 changes: 24 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pickle

import pytest
from pathlib import Path

try:
# try importing numpy and scipy. These are not hard dependencies and
Expand Down Expand Up @@ -1479,6 +1480,29 @@ def __getattr__(self, name):
finally:
sys.modules.pop("NonModuleObject")

def test_importing_multiprocessing_does_not_impact_whichmodule(self):
# non-regression test for #528
pytest.importorskip("numpy")
script = textwrap.dedent("""
import multiprocessing
import cloudpickle
from numpy import exp
print(cloudpickle.cloudpickle._whichmodule(exp, exp.__name__))
""")
script_path = Path(self.tmpdir) / "whichmodule_and_multiprocessing.py"
with open(script_path, mode="w") as f:
f.write(script)

proc = subprocess.Popen(
[sys.executable, str(script_path)],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
out, _ = proc.communicate()
self.assertEqual(proc.wait(), 0)
self.assertEqual(out, b"numpy.core._multiarray_umath\n")

def test_unrelated_faulty_module(self):
# Check that pickling a dynamically defined function or class does not
# fail when introspecting the currently loaded modules in sys.modules
Expand Down

0 comments on commit f111f7a

Please sign in to comment.