Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kwargs to idist.barrier #2310

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ccedbb0
#2213 add kwargs to idist.barrier => needs more tests and documentati…
fco-dv Oct 31, 2021
b1cbd1c
rm Mapping import
fco-dv Oct 31, 2021
596b29e
#2213 pass kwargs in xla and hvdconftest, add tests
fco-dv Nov 2, 2021
beca8c9
#2213 fix tests
fco-dv Nov 2, 2021
9b4f70d
#2213 add tests on warning
fco-dv Nov 4, 2021
31488ce
#2213 fix test
fco-dv Nov 4, 2021
68393f1
#2213 update doc
fco-dv Nov 4, 2021
5aa6ade
#2213 fix docstring
fco-dv Nov 5, 2021
6419d75
#2213 refactor _check_barrier_fn_kwargs
fco-dv Nov 5, 2021
0fa0299
Doctests for `RMSE` and `MeanPairwiseDistance` (#2307)
DevPranjal Nov 1, 2021
e723e60
disabled apex tests (#2308)
KickItLikeShika Nov 1, 2021
d2403d4
refactor metrics - add doctest for psnr (#2311)
sdesrozis Nov 10, 2021
d0b85af
add doctest for nlp metrics (#2317)
sdesrozis Nov 12, 2021
9e66d46
[skip ci] add doctest for `CohenKappa` metric (#2321)
sdesrozis Nov 13, 2021
7f3a68a
[skipci] [Doctest] added contrib regression (``FractionalBias``, ``Fr…
Ishan-Kumar2 Nov 13, 2021
c7d2142
[skip ci] add doctest for regression metrics (#2324)
sdesrozis Nov 13, 2021
9d87c9d
#2313 fix bug in StateParamScheduler attach method (#2316)
fco-dv Nov 15, 2021
32515c3
add run_code_style for windows (#2329)
sdesrozis Nov 16, 2021
ed2f0ee
[skip ci] Add doctest for `LinearCyclicalScheduler` (#2327)
sdesrozis Nov 19, 2021
5dc455b
[skip ci] fix version extraction (#2331)
Priyansi Nov 19, 2021
c402cd7
Paramscheduler emahandler (#2326)
fco-dv Nov 21, 2021
b5a9ac6
#2213 refactor signature checker / update tests
fco-dv Nov 23, 2021
ab9a74d
Merge branch 'master' into kwargs_idist_barrier
fco-dv Nov 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from abc import ABCMeta, abstractmethod
from inspect import signature
from numbers import Number
from typing import Any, Callable, List, Optional, Union, cast
from typing import Any, Callable, Dict, List, Optional, Union, cast

import torch

Expand Down Expand Up @@ -275,8 +277,24 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass

def _check_barrier_fn_kwargs(self, barrier_fn: Callable, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is more general than its name suggests. You did a signature checker. It could be called check_method_args and note that type could be also checked.

fn_params_name = set(
Copy link
Contributor

@sdesrozis sdesrozis Nov 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m wondering whether comprehension list would be more simple.

I think the method bind of the class Signature could help.

map(
lambda param: param.name,
filter(
lambda param: param.kind == param.POSITIONAL_OR_KEYWORD, signature(barrier_fn).parameters.values()
),
)
)
extra_keys = kwargs_dict.keys() - fn_params_name
if extra_keys:
warnings.warn(f"Extra keys : {extra_keys} will not be used by {self._backend}.")
for k in extra_keys:
del kwargs_dict[k]
return kwargs_dict

@abstractmethod
def barrier(self) -> None:
def barrier(self, **kwargs: Any) -> None:
pass


Expand Down Expand Up @@ -358,5 +376,5 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return tensor

def barrier(self) -> None:
def barrier(self, **kwargs: Any) -> None:
pass
9 changes: 7 additions & 2 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return hvd.broadcast(tensor, root_rank=src)

def barrier(self) -> None:
def barrier(self, **kwargs: Any) -> None:
kwargs = self._check_barrier_fn_kwargs(barrier_fn=hvd.allreduce, kwargs_dict=kwargs)
if "tensor" in kwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why redefining these arguments and not simply avoid it ? IMO it’s confusing.

del kwargs["tensor"]
if "name" in kwargs:
del kwargs["name"]
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
hvd.allreduce(tensor=torch.tensor(0, device="cpu"), name="barrier", **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that a barrier function was recently introduced. We could use it in a next PR (and handle older version).

5 changes: 3 additions & 2 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,9 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
dist.broadcast(tensor, src=src)
return tensor

def barrier(self) -> None:
dist.barrier()
def barrier(self, **kwargs: Any) -> None:
kwargs = self._check_barrier_fn_kwargs(barrier_fn=dist.barrier, kwargs_dict=kwargs)
dist.barrier(**kwargs)

def _expand_hostlist(nodelist: str) -> List[str]:
"""Expand a compressed hostlist string and returns all hosts listed.
Expand Down
7 changes: 5 additions & 2 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,8 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
xm.all_reduce("sum", [tensor,])
return tensor

def barrier(self) -> None:
xm.rendezvous("barrier")
def barrier(self, **kwargs: Any) -> None:
kwargs = self._check_barrier_fn_kwargs(barrier_fn=xm.rendezvous, kwargs_dict=kwargs)
if "tag" in kwargs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same remark than above

del kwargs["tag"]
xm.rendezvous(tag="barrier", **kwargs)
21 changes: 19 additions & 2 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,13 +421,30 @@ def broadcast(
return _model.broadcast(tensor, src=src, safe_mode=safe_mode)


def barrier() -> None:
def barrier(**kwargs: Any) -> None:
"""Helper method to synchronize all processes.

Args:
kwargs: acceptable kwargs according to provided backend:

- | "nccl" or "gloo" : ``group`` (default, GroupMember.WORLD), ``async_op`` (default, False),
| ``device_ids`` (default, None).

- | "horovod" : ``average`` (default, None), ``compression`` (default, Compression.none),
| ``op`` (default, None), ``prescale_factor`` (default, 1.0), ``postscale_factor`` (default, 1.0),
| ``process_set`` (default, global_process_set).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redefinition is confusing IMO

| Arguments ``tensor=torch.tensor(0, device="cpu")`` and ``name="barrier"`` are redefined.

- | "xla-tpu" : ``payload`` (default, b""), ``replicas`` (default, []).
| Argument ``tag="barrier"`` is redefined.

.. versionchanged:: 0.5.1
Method now accepts ``kwargs`` for all supported backends.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

_model.barrier()
_model.barrier(**kwargs)


def set_local_rank(index: int) -> None:
Expand Down
5 changes: 3 additions & 2 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,15 @@ def _test(data_src, data_others, safe_mode):
idist.broadcast(None, src=0)


def _test_distrib_barrier(device):
def _test_distrib_barrier(device, kwargs_dict=None):

t = torch.tensor([idist.get_rank()], device=device, dtype=torch.float)
true_res = sum([i for i in range(idist.get_world_size())])

if idist.get_rank() == 0:
t += 10.0
idist.barrier()

idist.barrier(**kwargs_dict) if kwargs_dict else idist.barrier()

tt = idist.all_reduce(t)
assert tt.item() == true_res + 10.0
Expand Down
23 changes: 23 additions & 0 deletions tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,29 @@ def test_idist_barrier_hvd(gloo_hvd_executor):
gloo_hvd_executor(_test_distrib_barrier, (device,), np=np, do_init=True)


@pytest.mark.distributed
@pytest.mark.skipif(not has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_idist_barrier_kwargs_hvd(gloo_hvd_executor):
from horovod.torch.compression import Compression
from horovod.torch.mpi_ops import global_process_set

device = "cpu" if not torch.cuda.is_available() else "cuda"
np = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

kwargs_dict = dict(
tensor=torch.tensor(0, device="cpu"),
average=None,
name=None,
compression=Compression.none,
op=None,
prescale_factor=1.0,
postscale_factor=1.0,
process_set=global_process_set,
)
gloo_hvd_executor(_test_distrib_barrier, (device, kwargs_dict,), np=np, do_init=True)


def _test_idist_methods_overhead(ok_factor, sync_model):
import time

Expand Down
35 changes: 35 additions & 0 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,41 @@ def test_idist_barrier_gloo(distributed_context_single_node_gloo):
_test_distrib_barrier(device)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_idist_barrier_kwargs_nccl(distributed_context_single_node_nccl):

device = idist.device()
from torch.distributed import GroupMember

kwargs_dict = {"group": GroupMember.WORLD, "async_op": False, "device_ids": None}
_test_distrib_barrier(device, kwargs_dict)

kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []})
with pytest.warns(
UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by nccl."
):
_test_distrib_barrier(device, kwargs_dict)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
def test_idist_barrier_kwargs_gloo(distributed_context_single_node_gloo):

device = idist.device()
from torch.distributed import GroupMember

kwargs_dict = {"group": GroupMember.WORLD, "async_op": False, "device_ids": None}
_test_distrib_barrier(device, kwargs_dict)

kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []})
with pytest.warns(
UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by gloo."
):
_test_distrib_barrier(device, kwargs_dict)


def _test_idist_methods_overhead(ok_factor):
import time

Expand Down
32 changes: 30 additions & 2 deletions tests/ignite/distributed/utils/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,28 @@ def test_idist_barrier_xla():
_test_distrib_barrier(device)


def _test_idist_barrier_xla_in_child_proc(index):
def _test_idist_barrier_xla_in_child_proc(index, kwargs_dict=None):
device = idist.device()
_test_distrib_barrier(device)
_test_distrib_barrier(device, kwargs_dict)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package")
def test_idist_barrier_kwargs_xla():

device = idist.device()
kwargs_dict = {"tag": "barrier", "payload": b"", "replicas": []}
_test_distrib_barrier(device, kwargs_dict)

from torch.distributed import GroupMember

kwargs_dict.update({"group": GroupMember.WORLD, "async_op": False, "device_ids": None})
with pytest.warns(
UserWarning,
match=r"Extra keys : \{((, )?('async_op'|'group'|'device_ids')(, )?)+\} will not be used by xla-tpu.",
):
_test_distrib_barrier(device, kwargs_dict)


@pytest.mark.tpu
Expand All @@ -197,6 +216,15 @@ def test_idist_barrier_xla_in_child_proc(xmp_executor):
xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(), nprocs=n)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package")
def test_idist_barrier_kwargs_xla_in_child_proc(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
kwargs_dict = {"tag": "barrier", "payload": b"", "replicas": []}
xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(kwargs_dict,), nprocs=n)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package")
Expand Down