-
Notifications
You must be signed in to change notification settings - Fork 52
Closed
Labels
API extensionAdds new functions or objects to the API.Adds new functions or objects to the API.topic: ManipulationArray manipulation and transformation.Array manipulation and transformation.
Milestone
Description
This RFC proposes adding support to the array API specification for repeating each element of an array.
Overview
Based on array comparison data, the API is available in most array libraries. The main exception is PyTorch which deviates in its naming convention (repeat_interleave vs NumPy et al's repeat).
Prior art
- NumPy: https://numpy.org/doc/stable/reference/generated/numpy.repeat.html
- PyTorch: https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html
- TensorFlow: https://www.tensorflow.org/api_docs/python/tf/repeat
- JAX: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.repeat.html
- CuPy: https://docs.cupy.dev/en/stable/reference/generated/cupy.repeat.html
- Dask: https://docs.dask.org/en/stable/generated/dask.array.repeat.html#dask.array.repeat
Proposal
def repeat(x: array, repeats: Union[int, Sequence[int], array], /, *, axis: Optional[int] = None)-
repeats: the number of repetitions for each element.
If
axisis notNone,- if
repeatsis an array,repeats.shapemust broadcast tox.shape[axis]. - if
repeatsis a sequence of ints,len(repeats)must broadcast tox.shape[axis]. - if
repeatsis an integer,repeatsmust be broadcasted to match the size of a specifiedaxis.
If
axisisNone,- if
repeatsis an array,repeats.shapemust broadcast toprod(x.shape). - if
repeatsis a sequence of ints,len(repeats)must broadcast toprod(x.shape). - if
repeatsis an integer,repeatsmust be broadcasted to match the size of the flattened array.
- if
-
axis: specifies the axis along which to repeat values. If
None, use a flattened input array and return a flat output array.
Questions
- Both PyTorch and JAX support a kwarg for specifying the output size in order to avoid stream synchronization (PyTorch) and to allow compilation (JAX). Without such kwarg support, is this API viable? And what are the reasons for needing this kwarg when other array libraries (TensorFlow) omit such a kwarg?
- When flattening the input array, flatten in row-major order? (precedent:
nonzero) - Is PyTorch okay adding a
repeatfunction in its main namespace, given the divergence in behavior fortorch.Tensor.repeat, which behaves similar tonp.tile? - CuPy only allows
int,List, andTuplefor repeats, not an array. PyTorch may prefer a list ofints(see Unnecessary cuda synchronizations that we should remove in PyTorch pytorch/pytorch#108968).
Related
- Adding tuple argument support to
numpy.repeatto avoid repeated invocations: ENH: Allow tuple arguments fornumpy.repeatnumpy/numpy#21435 and ENH: Introduce multiple pair parameters in the 'repeat' function numpy/numpy#23937. - Mention of xarray's need for
repeat: Common APIs across array libraries (1 year later) #187 (comment)
Metadata
Metadata
Assignees
Labels
API extensionAdds new functions or objects to the API.Adds new functions or objects to the API.topic: ManipulationArray manipulation and transformation.Array manipulation and transformation.