@@ -3087,7 +3087,13 @@ def gather_object(
30873087
30883088
30893089@_exception_logger
3090- def send_object_list (object_list , dst , group = None , device = None ):
3090+ def send_object_list (
3091+ object_list : List [Any ],
3092+ dst : Optional [int ] = None ,
3093+ group : Optional [ProcessGroup ] = None ,
3094+ device : Optional [torch .device ] = None ,
3095+ group_dst : Optional [int ] = None ,
3096+ ):
30913097 """
30923098 Sends picklable objects in ``object_list`` synchronously.
30933099
@@ -3105,7 +3111,8 @@ def send_object_list(object_list, dst, group=None, device=None):
31053111 device (``torch.device``, optional): If not None, the objects are
31063112 serialized and converted to tensors which are moved to the
31073113 ``device`` before sending. Default is ``None``.
3108-
3114+ group_dst (int, optional): Destination rank on ``group``.
3115+ Must specify one of ``dst`` and ``group_dst`` but not both
31093116 Returns:
31103117 ``None``.
31113118
@@ -3143,11 +3150,9 @@ def send_object_list(object_list, dst, group=None, device=None):
31433150 >>> objects
31443151 ['foo', 12, {1: 2}]
31453152 """
3146- if get_rank () == dst :
3147- raise ValueError (
3148- "Invalid destination rank: destination rank should not be the same as "
3149- "the rank of the current process."
3150- )
3153+ group = _group_or_default_group (group )
3154+ group_dst = _canonicalize_group_rank (group , dst , group_dst )
3155+ _check_not_self_rank (group , group_dst , "destination" )
31513156
31523157 if _rank_not_in_group (group ):
31533158 _warn_not_in_group ("send_object_list" )
@@ -3167,7 +3172,7 @@ def send_object_list(object_list, dst, group=None, device=None):
31673172 object_sizes_tensor = torch .cat (size_list )
31683173
31693174 # Send object sizes
3170- send (object_sizes_tensor , dst = dst , group = group )
3175+ send (object_sizes_tensor , group_dst = group_dst , group = group )
31713176
31723177 # Concatenate and send serialized object tensors
31733178 # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
@@ -3177,11 +3182,17 @@ def send_object_list(object_list, dst, group=None, device=None):
31773182 else :
31783183 object_tensor = torch .cat (tensor_list )
31793184
3180- send (object_tensor , dst = dst , group = group )
3185+ send (object_tensor , group_dst = group_dst , group = group )
31813186
31823187
31833188@_exception_logger
3184- def recv_object_list (object_list , src = None , group = None , device = None ):
3189+ def recv_object_list (
3190+ object_list : List [Any ],
3191+ src : Optional [int ] = None ,
3192+ group : Optional [ProcessGroup ] = None ,
3193+ device : Optional [torch .device ] = None ,
3194+ group_src : Optional [int ] = None ,
3195+ ):
31853196 """
31863197 Receives picklable objects in ``object_list`` synchronously.
31873198
@@ -3197,6 +3208,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
31973208 the default process group will be used. Default is ``None``.
31983209 device (``torch.device``, optional): If not None, receives on this device.
31993210 Default is ``None``.
3211+ group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
32003212
32013213 Returns:
32023214 Sender rank. -1 if rank is not part of the group. If rank is part of the group,
@@ -3252,7 +3264,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
32523264 )
32533265
32543266 # Receive object sizes
3255- rank_sizes = recv (object_sizes_tensor , src = src , group = group )
3267+ rank_sizes = recv (object_sizes_tensor , src = src , group = group , group_src = group_src )
32563268
32573269 # Tensor to receive serialized objects into.
32583270 object_tensor = torch .empty ( # type: ignore[call-overload]
@@ -3261,7 +3273,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
32613273 device = current_device ,
32623274 )
32633275
3264- rank_objects = recv (object_tensor , src = src , group = group )
3276+ rank_objects = recv (object_tensor , src = src , group = group , group_src = group_src )
32653277 assert (
32663278 rank_sizes == rank_objects
32673279 ), "Mismatch in return ranks for object sizes and objects."
0 commit comments