Skip to content

Commit

Permalink
Merge pull request #795 from slayoo/fix_coalescence_regression
Browse files Browse the repository at this point in the history
introducing RandomCommon with asserts on seed and size types, removing seed from dynamics ctors, hopefully fixing #794
  • Loading branch information
slayoo authored Mar 3, 2022
2 parents 47425d7 + 3c676dc commit 80684b9
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 15 deletions.
5 changes: 3 additions & 2 deletions PySDM/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def _cuda_is_available():
import numpy as np

from PySDM.backends.thrust_rtc import ThrustRTC # pylint: disable=ungrouped-imports
from PySDM.backends.impl_common.random_common import RandomCommon # pylint: disable=ungrouped-imports
ThrustRTC.ENABLE = False

class Random: # pylint: disable=too-few-public-methods
class Random(RandomCommon): # pylint: disable=too-few-public-methods
def __init__(self, size, seed):
self.size = size
super().__init__(size, seed)
self.generator = np.random.default_rng(seed)

def __call__(self, storage):
Expand Down
10 changes: 10 additions & 0 deletions PySDM/backends/impl_common/random_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
common base class for random number generation abstraction layer
"""


class RandomCommon:
def __init__(self, size: int, seed: int):
assert isinstance(size, int)
assert isinstance(seed, int)
self.size = size
5 changes: 3 additions & 2 deletions PySDM/backends/impl_numba/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
random number generator class for Numba backend
"""
import numpy as np
from ..impl_common.random_common import RandomCommon


# TIP: can be called asynchronously
# TIP: sometimes only half array is needed

class Random: # pylint: disable=too-few-public-methods
class Random(RandomCommon): # pylint: disable=too-few-public-methods
def __init__(self, size, seed):
self.size = size
super().__init__(size, seed)
self.generator = np.random.default_rng(seed)

def __call__(self, storage):
Expand Down
5 changes: 3 additions & 2 deletions PySDM/backends/impl_thrust_rtc/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
"""
from PySDM.backends.impl_thrust_rtc.nice_thrust import nice_thrust
from PySDM.backends.impl_thrust_rtc.conf import NICE_THRUST_FLAGS
from ..impl_common.random_common import RandomCommon
from .conf import trtc, rndrtc


# TIP: sometimes only half array is needed

class Random: # pylint: disable=too-few-public-methods
class Random(RandomCommon): # pylint: disable=too-few-public-methods
__urand_init_rng_state_body = trtc.For(['rng', 'states', 'seed'], 'i', '''
rng.state_init(seed, i, 0, states[i]);
''')
Expand All @@ -18,9 +19,9 @@ class Random: # pylint: disable=too-few-public-methods
''')

def __init__(self, size, seed):
super().__init__(size, seed)
rng = rndrtc.DVRNG()
self.generator = trtc.device_vector('RNGState', size)
self.size = size
dseed = trtc.DVInt64(seed)
Random.__urand_init_rng_state_body.launch_n(size, [rng, self.generator, dseed])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, number):

@staticmethod
def DVInt64(number: int): # pylint: disable=invalid-name
assert isinstance(number, int)
return FakeThrustRTC.Number(number)

@staticmethod
Expand All @@ -67,6 +68,7 @@ def DVDouble(number: float): # pylint: disable=invalid-name

@staticmethod
def DVBool(number: bool): # pylint: disable=invalid-name
assert isinstance(number, bool)
return FakeThrustRTC.Number(number)

@staticmethod
Expand Down
8 changes: 1 addition & 7 deletions PySDM/dynamics/collisions/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(self,
coalescence_efficiency,
breakup_efficiency,
fragmentation_function,
seed=None,
croupier=None,
optimized_random=False,
substeps: int = 1,
Expand All @@ -49,7 +48,6 @@ def __init__(self,
self.compute_breakup_efficiency = breakup_efficiency
self.compute_number_of_fragments = fragmentation_function

self.seed = seed
self.rnd_opt_frag = None
self.rnd_opt_coll = None
self.rnd_opt_proc = None
Expand Down Expand Up @@ -83,7 +81,7 @@ def register(self, builder):
rnd_args = {
'optimized_random': self.optimized_random,
'dt_min': self.dt_coal_range[0],
'seed': self.seed
'seed': builder.formulae.seed
}
self.rnd_opt_coll = RandomGeneratorOptimizer(**rnd_args)
if self.enable_breakup:
Expand Down Expand Up @@ -239,7 +237,6 @@ class Coalescence(Collision):
def __init__(self,
collision_kernel,
coalescence_efficiency=ConstEc(Ec=1),
seed=None,
croupier=None,
optimized_random=False,
substeps: int = 1,
Expand All @@ -253,7 +250,6 @@ def __init__(self,
coalescence_efficiency,
breakup_efficiency,
fragmentation_function,
seed=seed,
croupier=croupier,
optimized_random=optimized_random,
substeps=substeps,
Expand All @@ -267,7 +263,6 @@ class Breakup(Collision):
def __init__(self,
collision_kernel,
fragmentation_function,
seed=None,
croupier=None,
optimized_random=False,
substeps: int = 1,
Expand All @@ -281,7 +276,6 @@ def __init__(self,
coalescence_efficiency,
breakup_efficiency,
fragmentation_function,
seed=seed,
croupier=croupier,
optimized_random=optimized_random,
substeps=substeps,
Expand Down
2 changes: 1 addition & 1 deletion test-time-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ ghapi
pytest

# note: if cloning both PySDM and PySDM examples, consider "pip install -e"
PySDM-examples @ git+git://github.com/atmos-cloud-sim-uj/PySDM-examples@a578e2c#egg=PySDM-examples
PySDM-examples @ git+git://github.com/atmos-cloud-sim-uj/PySDM-examples@0c46ee4#egg=PySDM-examples
PyMPDATA @ git+https://github.com/atmos-cloud-sim-uj/PyMPDATA@e7b73a7#egg=PyMPDATA
4 changes: 3 additions & 1 deletion tests/unit_tests/dynamics/collisions/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from PySDM.dynamics.collisions import Collision, Breakup, Coalescence


def get_default_args(func):
signature = inspect.signature(func)
return {
Expand All @@ -11,7 +12,8 @@ def get_default_args(func):
if v.default is not inspect.Parameter.empty
}

class Test_defaults:

class TestDefaults:
@staticmethod
@pytest.mark.parametrize("dynamic_class", (Collision, Breakup, Coalescence))
def test_collision_adaptive_default(dynamic_class):
Expand Down

0 comments on commit 80684b9

Please sign in to comment.