-
-
Notifications
You must be signed in to change notification settings - Fork 623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add kwargs to idist.barrier #2310
Changes from 9 commits
ccedbb0
b1cbd1c
596b29e
beca8c9
9b4f70d
31488ce
68393f1
5aa6ade
6419d75
0fa0299
e723e60
d2403d4
d0b85af
9e66d46
7f3a68a
c7d2142
9d87c9d
32515c3
ed2f0ee
5dc455b
c402cd7
b5a9ac6
ab9a74d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
import warnings | ||
from abc import ABCMeta, abstractmethod | ||
from inspect import signature | ||
from numbers import Number | ||
from typing import Any, Callable, List, Optional, Union, cast | ||
from typing import Any, Callable, Dict, List, Optional, Union, cast | ||
|
||
import torch | ||
|
||
|
@@ -275,8 +277,24 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: | |
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: | ||
pass | ||
|
||
def _check_barrier_fn_kwargs(self, barrier_fn: Callable, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
fn_params_name = set( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’m wondering whether comprehension list would be more simple. I think the method |
||
map( | ||
lambda param: param.name, | ||
filter( | ||
lambda param: param.kind == param.POSITIONAL_OR_KEYWORD, signature(barrier_fn).parameters.values() | ||
), | ||
) | ||
) | ||
extra_keys = kwargs_dict.keys() - fn_params_name | ||
if extra_keys: | ||
warnings.warn(f"Extra keys : {extra_keys} will not be used by {self._backend}.") | ||
for k in extra_keys: | ||
del kwargs_dict[k] | ||
return kwargs_dict | ||
|
||
@abstractmethod | ||
def barrier(self) -> None: | ||
def barrier(self, **kwargs: Any) -> None: | ||
pass | ||
|
||
|
||
|
@@ -358,5 +376,5 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: | |
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: | ||
return tensor | ||
|
||
def barrier(self) -> None: | ||
def barrier(self, **kwargs: Any) -> None: | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,7 +194,12 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: | |
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: | ||
return hvd.broadcast(tensor, root_rank=src) | ||
|
||
def barrier(self) -> None: | ||
def barrier(self, **kwargs: Any) -> None: | ||
kwargs = self._check_barrier_fn_kwargs(barrier_fn=hvd.allreduce, kwargs_dict=kwargs) | ||
if "tensor" in kwargs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why redefining these arguments and not simply avoid it ? IMO it’s confusing. |
||
del kwargs["tensor"] | ||
if "name" in kwargs: | ||
del kwargs["name"] | ||
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603 | ||
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier") | ||
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier") | ||
hvd.allreduce(tensor=torch.tensor(0, device="cpu"), name="barrier", **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that a barrier function was recently introduced. We could use it in a next PR (and handle older version). |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,5 +160,8 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: | |
xm.all_reduce("sum", [tensor,]) | ||
return tensor | ||
|
||
def barrier(self) -> None: | ||
xm.rendezvous("barrier") | ||
def barrier(self, **kwargs: Any) -> None: | ||
kwargs = self._check_barrier_fn_kwargs(barrier_fn=xm.rendezvous, kwargs_dict=kwargs) | ||
if "tag" in kwargs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same remark than above |
||
del kwargs["tag"] | ||
xm.rendezvous(tag="barrier", **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -421,13 +421,30 @@ def broadcast( | |
return _model.broadcast(tensor, src=src, safe_mode=safe_mode) | ||
|
||
|
||
def barrier() -> None: | ||
def barrier(**kwargs: Any) -> None: | ||
"""Helper method to synchronize all processes. | ||
|
||
Args: | ||
kwargs: acceptable kwargs according to provided backend: | ||
|
||
- | "nccl" or "gloo" : ``group`` (default, GroupMember.WORLD), ``async_op`` (default, False), | ||
| ``device_ids`` (default, None). | ||
|
||
- | "horovod" : ``average`` (default, None), ``compression`` (default, Compression.none), | ||
| ``op`` (default, None), ``prescale_factor`` (default, 1.0), ``postscale_factor`` (default, 1.0), | ||
| ``process_set`` (default, global_process_set). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redefinition is confusing IMO |
||
| Arguments ``tensor=torch.tensor(0, device="cpu")`` and ``name="barrier"`` are redefined. | ||
|
||
- | "xla-tpu" : ``payload`` (default, b""), ``replicas`` (default, []). | ||
| Argument ``tag="barrier"`` is redefined. | ||
|
||
.. versionchanged:: 0.5.1 | ||
Method now accepts ``kwargs`` for all supported backends. | ||
""" | ||
if _need_to_sync and isinstance(_model, _SerialModel): | ||
sync(temporary=True) | ||
|
||
_model.barrier() | ||
_model.barrier(**kwargs) | ||
|
||
|
||
def set_local_rank(index: int) -> None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is more general than its name suggests. You did a signature checker. It could be called
check_method_args
and note that type could be also checked.