Jax has canonicalize_dtype and PyTorch also has a notion of default types.
Can we provide canonicalize_dtype for all libraries?
Something like:
def canonicalize_dtype(xp: Namespace, dtype: DType | type[complex]) -> DType:
    if is_jax_namespace(xp):
        from jax.dtypes import canonicalize_dtype
        return canonicalize_dtype(dtype)  # Suppresses warning.
    return xp.empty((), dtype=dtype).dtype