Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion sycl/include/sycl/detail/vector_arith.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,10 @@ template <int NumElements>
class vec_arith<std::byte, NumElements>
: public VecOperators<vec<std::byte, NumElements>>::template CombineImpl<
std::bit_or<void>, std::bit_and<void>, std::bit_xor<void>,
std::bit_not<void>> {
std::bit_not<void>, std::equal_to<void>, std::not_equal_to<void>,
std::less<void>, std::greater<void>, std::less_equal<void>,
std::greater_equal<void>, OpAssign<std::bit_or<void>>,
OpAssign<std::bit_and<void>>, OpAssign<std::bit_xor<void>>> {
protected:
// NumElements can never be zero. Still using the redundant check to avoid
// incomplete type errors.
Expand Down
38 changes: 20 additions & 18 deletions sycl/include/sycl/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,10 @@ template <typename Self> class ScalarConversionOperatorsMixIn {
};

template <typename T>
inline constexpr bool is_fundamental_or_half_or_bfloat16 =
inline constexpr bool is_supported_vector_elem_type =
std::is_fundamental_v<T> || std::is_same_v<std::remove_const_t<T>, half> ||
std::is_same_v<std::remove_const_t<T>, ext::oneapi::bfloat16>;
std::is_same_v<std::remove_const_t<T>, ext::oneapi::bfloat16> ||
std::is_same_v<std::remove_const_t<T>, std::byte>;

// Per SYCL specification sycl::vec has different ctors available based on the
// number of elements. Without C++20's concepts we'd have to use partial
Expand Down Expand Up @@ -585,7 +586,7 @@ class __SYCL_EBO vec :
// when NumElements == 1. The template prevents implicit conversion from
// vec<_, 1> to DataT.
template <typename Ty = DataT>
typename std::enable_if_t<detail::is_fundamental_or_half_or_bfloat16<Ty>,
typename std::enable_if_t<detail::is_supported_vector_elem_type<Ty>,
vec &>
operator=(const DataT &Rhs) {
*this = vec{Rhs};
Expand Down Expand Up @@ -918,12 +919,12 @@ class SwizzleOp : public detail::NamedSwizzlesMixinBoth<
template <typename T>
using EnableIfScalarType =
typename std::enable_if_t<std::is_convertible_v<DataT, T> &&
detail::is_fundamental_or_half_or_bfloat16<T>>;
detail::is_supported_vector_elem_type<T>>;

template <typename T>
using EnableIfNoScalarType =
typename std::enable_if_t<!std::is_convertible_v<DataT, T> ||
!detail::is_fundamental_or_half_or_bfloat16<T>>;
!detail::is_supported_vector_elem_type<T>>;

template <int... Indices>
using Swizzle =
Expand Down Expand Up @@ -1143,12 +1144,12 @@ class SwizzleOp : public detail::NamedSwizzlesMixinBoth<
return Tmp RELLOGOP Rhs; \
}

__SYCL_RELLOGOP(==, (!detail::is_byte_v<T>))
__SYCL_RELLOGOP(!=, (!detail::is_byte_v<T>))
__SYCL_RELLOGOP(>, (!detail::is_byte_v<T>))
__SYCL_RELLOGOP(<, (!detail::is_byte_v<T>))
__SYCL_RELLOGOP(>=, (!detail::is_byte_v<T>))
__SYCL_RELLOGOP(<=, (!detail::is_byte_v<T>))
__SYCL_RELLOGOP(==, true)
__SYCL_RELLOGOP(!=, true)
__SYCL_RELLOGOP(>, true)
__SYCL_RELLOGOP(<, true)
__SYCL_RELLOGOP(>=, true)
__SYCL_RELLOGOP(<=, true)
__SYCL_RELLOGOP(&&, (!detail::is_byte_v<T> && !detail::is_vgenfloat_v<T>))
__SYCL_RELLOGOP(||, (!detail::is_byte_v<T> && !detail::is_vgenfloat_v<T>))
#undef __SYCL_RELLOGOP
Expand Down Expand Up @@ -1527,8 +1528,8 @@ class SwizzleOp : public detail::NamedSwizzlesMixinBoth<
m_RightOperation(std::move(Rhs.m_RightOperation)) {}

// Either performing CurrentOperation on results of left and right operands
// or reading values from actual vector. Perform implicit type conversion when
// the number of elements == 1
// or reading values from actual vector. Always perform explicit type conversion
// because std::byte operators are strongly typed.

template <int IdxNum = size()>
CommonDataT getValue(EnableIfOneIndex<IdxNum, size_t> Index) const {
Expand All @@ -1537,8 +1538,8 @@ class SwizzleOp : public detail::NamedSwizzlesMixinBoth<
return (*m_Vector)[Idxs[Index]];
}
auto Op = OperationCurrentT<CommonDataT>();
return Op(m_LeftOperation.getValue(Index),
m_RightOperation.getValue(Index));
return Op(static_cast<DataT>(m_LeftOperation.getValue(Index)),
static_cast<DataT>(m_RightOperation.getValue(Index)));
}

template <int IdxNum = size()>
Expand All @@ -1548,16 +1549,17 @@ class SwizzleOp : public detail::NamedSwizzlesMixinBoth<
return (*m_Vector)[Idxs[Index]];
}
auto Op = OperationCurrentT<DataT>();
return Op(m_LeftOperation.getValue(Index),
m_RightOperation.getValue(Index));
return Op(static_cast<DataT>(m_LeftOperation.getValue(Index)),
static_cast<DataT>(m_RightOperation.getValue(Index)));
}

template <template <typename> class Operation, typename RhsOperation>
void operatorHelper(const RhsOperation &Rhs) const {
Operation<DataT> Op;
std::array<int, size()> Idxs{Indexes...};
for (size_t I = 0; I < Idxs.size(); ++I) {
DataT Res = Op((*m_Vector)[Idxs[I]], Rhs.getValue(I));
DataT Res =
Op((*m_Vector)[Idxs[I]], static_cast<DataT>(Rhs.getValue(I)));
(*m_Vector)[Idxs[I]] = Res;
}
}
Expand Down
169 changes: 166 additions & 3 deletions sycl/test-e2e/Basic/vector/byte.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ int main() {
assert(VecByte3Or[1] == (VecByte3A[1] | VecByte3B[1]));
assert(VecByte3Xor[2] == (VecByte3A[2] ^ VecByte3B[2]));

// logical binary assignment op for 2 vec
auto VecByte3ACopy = VecByte3A;
VecByte3ACopy &= VecByte3B;
assert(VecByte3ACopy[0] == (VecByte3A[0] & VecByte3B[0]));
VecByte3ACopy = VecByte3A;
VecByte3ACopy |= VecByte3B;
assert(VecByte3ACopy[1] == (VecByte3A[1] | VecByte3B[1]));
VecByte3ACopy = VecByte3A;
VecByte3ACopy ^= VecByte3B;
assert(VecByte3ACopy[2] == (VecByte3A[2] ^ VecByte3B[2]));

// logical binary op between swizzle and vec.
using SwizType = sycl::vec<std::byte, 2>;
auto SwizByte2And = SwizByte2A & (SwizType)SwizByte2B;
Expand All @@ -150,6 +161,18 @@ int main() {
assert(SwizByte2Or[1] == (VecByte4A[1] | VecByte4A[3]));
assert(SwizByte2Xor[0] == (VecByte4A[0] ^ VecByte4A[2]));

// logical binary assignment op between swizzle and vec.
auto VecByte4ACopy = VecByte4A;
auto SwizByte2ACopy = VecByte4ACopy.lo();
SwizByte2ACopy &= (SwizType)SwizByte2B;
assert(SwizByte2ACopy[0] == (SwizByte2A[0] & SwizByte2B[0]));
VecByte4ACopy = VecByte4A;
SwizByte2ACopy |= (SwizType)SwizByte2B;
assert(SwizByte2ACopy[0] == (SwizByte2A[0] | SwizByte2B[0]));
VecByte4ACopy = VecByte4A;
SwizByte2ACopy ^= (SwizType)SwizByte2B;
assert(SwizByte2ACopy[0] == (SwizByte2A[0] ^ SwizByte2B[0]));

// Check overloads with scalar argument for bitwise operators.
auto BitWiseAnd1 = VecByte3A & std::byte{3};
auto BitWiseOr1 = VecByte3A | std::byte{3};
Expand All @@ -161,6 +184,17 @@ int main() {
assert(BitWiseOr1[1] == BitWiseOr2[1]);
assert(BitWiseXor1[2] == BitWiseXor2[2]);

// Check overloads with scalar argument for bitwise assign operators.
VecByte3ACopy = VecByte3A;
VecByte3ACopy &= std::byte{3};
assert(VecByte3ACopy[0] == (VecByte3A[0] & std::byte{3}));
VecByte3ACopy = VecByte3A;
VecByte3ACopy |= std::byte{3};
assert(VecByte3ACopy[1] == (VecByte3A[1] | std::byte{3}));
VecByte3ACopy = VecByte3A;
VecByte3ACopy ^= std::byte{3};
assert(VecByte3ACopy[2] == (VecByte3A[2] ^ std::byte{3}));

// logical binary op for 1 swizzle
auto SwizByte2AndScalarA = SwizByte2A & std::byte{3};
auto SwizByte2OrScalarA = SwizByte2A | std::byte{3};
Expand All @@ -172,19 +206,148 @@ int main() {
assert(SwizByte2OrScalarA[1] == SwizByte2OrScalarB[1]);
assert(SwizByte2XorScalarA[0] == SwizByte2XorScalarB[0]);

// logical binary assign op for 1 swizzle
VecByte4ACopy = VecByte4A;
SwizByte2ACopy &= std::byte{3};
assert(SwizByte2ACopy[0] == (SwizByte2A[0] & std::byte{3}));
VecByte4ACopy = VecByte4A;
SwizByte2ACopy |= std::byte{3};
assert(SwizByte2ACopy[0] == (SwizByte2A[0] | std::byte{3}));
VecByte4ACopy = VecByte4A;
SwizByte2ACopy ^= std::byte{3};
assert(SwizByte2ACopy[0] == (SwizByte2A[0] ^ std::byte{3}));

// bit-wise negation test
auto VecByte4Neg = ~VecByte4A;
assert(VecByte4Neg[0] == ~VecByte4A[0]);

auto SwizByte2Neg = ~SwizByte2B;
assert(SwizByte2Neg[0] == ~SwizByte2B[0]);
}

// Test comparison operations on vec<std::byte> and swizzles.
{
auto SwizByte2A = VecByte4A.lo();
auto SwizByte2B = VecByte4A.hi();

// comparison op for 2 vec
auto VecByte3Eq = VecByte3A == VecByte3B;
auto VecByte3Neq = VecByte3A != VecByte3B;
auto VecByte3Lt = VecByte3A < VecByte3B;
auto VecByte3Lte = VecByte3A <= VecByte3B;
auto VecByte3Gt = VecByte3A > VecByte3B;
auto VecByte3Gte = VecByte3A >= VecByte3B;
// Cast to bool since the result vector element is defined to be int8_t
assert(static_cast<bool>(VecByte3Eq[0]) ==
(VecByte3A[0] == VecByte3B[0]));
assert(static_cast<bool>(VecByte3Neq[1]) ==
(VecByte3A[1] != VecByte3B[1]));
assert(static_cast<bool>(VecByte3Lt[2]) == (VecByte3A[2] < VecByte3B[2]));
assert(static_cast<bool>(VecByte3Lte[0]) ==
(VecByte3A[0] <= VecByte3B[0]));
assert(static_cast<bool>(VecByte3Gt[1]) == (VecByte3A[1] > VecByte3B[1]));
assert(static_cast<bool>(VecByte3Gte[2]) ==
(VecByte3A[2] >= VecByte3B[2]));

// comparison op between swizzle and vec.
using SwizType = sycl::vec<std::byte, 2>;
auto SwizByte2Eq = SwizByte2A == (SwizType)SwizByte2B;
auto SwizByte2Neq = SwizByte2A != (SwizType)SwizByte2B;
auto SwizByte2Lt = SwizByte2A < (SwizType)SwizByte2B;
auto SwizByte2Lte = SwizByte2A <= (SwizType)SwizByte2B;
auto SwizByte2Gt = SwizByte2A > (SwizType)SwizByte2B;
auto SwizByte2Gte = SwizByte2A >= (SwizType)SwizByte2B;
// Cast to bool since the result vector element is defined to be int8_t
assert(static_cast<bool>(SwizByte2Eq[0]) ==
(VecByte4A[0] == VecByte4A[2]));
assert(static_cast<bool>(SwizByte2Neq[0]) ==
(VecByte4A[0] != VecByte4A[2]));
assert(static_cast<bool>(SwizByte2Lt[0]) ==
(VecByte4A[0] < VecByte4A[2]));
assert(static_cast<bool>(SwizByte2Lte[0]) ==
(VecByte4A[0] <= VecByte4A[2]));
assert(static_cast<bool>(SwizByte2Gt[0]) ==
(VecByte4A[0] > VecByte4A[2]));
assert(static_cast<bool>(SwizByte2Gte[0]) ==
(VecByte4A[0] >= VecByte4A[2]));

// Check overloads with scalar argument for comparison operators.
auto BitWiseEq1 = VecByte3A == std::byte{3};
auto BitWiseNeq1 = VecByte3A != std::byte{3};
auto BitWiseLt1 = VecByte3A < std::byte{3};
auto BitWiseLte1 = VecByte3A <= std::byte{3};
auto BitWiseGt1 = VecByte3A > std::byte{3};
auto BitWiseGte1 = VecByte3A >= std::byte{3};
auto BitWiseEq2 = std::byte{3} == VecByte3A;
auto BitWiseNeq2 = std::byte{3} != VecByte3A;
auto BitWiseLt2 = std::byte{3} < VecByte3A;
auto BitWiseLte2 = std::byte{3} <= VecByte3A;
auto BitWiseGt2 = std::byte{3} > VecByte3A;
auto BitWiseGte2 = std::byte{3} >= VecByte3A;
// Cast to bool since the result vector element is defined to be int8_t
assert(static_cast<bool>(BitWiseEq1[0]) ==
(VecByte3A[0] == std::byte{3}));
assert(static_cast<bool>(BitWiseNeq1[0]) ==
(VecByte3A[0] != std::byte{3}));
assert(static_cast<bool>(BitWiseLt1[0]) == (VecByte3A[0] < std::byte{3}));
assert(static_cast<bool>(BitWiseLte1[0]) ==
(VecByte3A[0] <= std::byte{3}));
assert(static_cast<bool>(BitWiseGt1[0]) == (VecByte3A[0] > std::byte{3}));
assert(static_cast<bool>(BitWiseGte1[0]) ==
(VecByte3A[0] >= std::byte{3}));
assert(static_cast<bool>(BitWiseEq2[0]) ==
(std::byte{3} == VecByte3A[0]));
assert(static_cast<bool>(BitWiseNeq2[0]) ==
(std::byte{3} != VecByte3A[0]));
assert(static_cast<bool>(BitWiseLt2[0]) == (std::byte{3} < VecByte3A[0]));
assert(static_cast<bool>(BitWiseLte2[0]) ==
(std::byte{3} <= VecByte3A[0]));
assert(static_cast<bool>(BitWiseGt2[0]) == (std::byte{3} > VecByte3A[0]));
assert(static_cast<bool>(BitWiseGte2[0]) ==
(std::byte{3} >= VecByte3A[0]));

// logical binary op for 1 swizzle
auto SwizByte2EqScalarA = SwizByte2A == std::byte{3};
auto SwizByte2NeqScalarA = SwizByte2A != std::byte{3};
auto SwizByte2LtScalarA = SwizByte2A < std::byte{3};
auto SwizByte2LteScalarA = SwizByte2A <= std::byte{3};
auto SwizByte2GtScalarA = SwizByte2A > std::byte{3};
auto SwizByte2GteScalarA = SwizByte2A >= std::byte{3};
auto SwizByte2EqScalarB = std::byte{3} == SwizByte2A;
auto SwizByte2NeqScalarB = std::byte{3} != SwizByte2A;
auto SwizByte2LtScalarB = std::byte{3} < SwizByte2A;
auto SwizByte2LteScalarB = std::byte{3} <= SwizByte2A;
auto SwizByte2GtScalarB = std::byte{3} > SwizByte2A;
auto SwizByte2GteScalarB = std::byte{3} >= SwizByte2A;
// Cast to bool since the result vector element is defined to be int8_t
assert(static_cast<bool>(SwizByte2EqScalarA[0]) ==
(SwizByte2A[0] == std::byte{3}));
assert(static_cast<bool>(SwizByte2NeqScalarA[0]) ==
(SwizByte2A[0] != std::byte{3}));
assert(static_cast<bool>(SwizByte2LtScalarA[0]) ==
(SwizByte2A[0] < std::byte{3}));
assert(static_cast<bool>(SwizByte2LteScalarA[0]) ==
(SwizByte2A[0] <= std::byte{3}));
assert(static_cast<bool>(SwizByte2GtScalarA[0]) ==
(SwizByte2A[0] > std::byte{3}));
assert(static_cast<bool>(SwizByte2GteScalarA[0]) ==
(SwizByte2A[0] >= std::byte{3}));
assert(static_cast<bool>(SwizByte2EqScalarB[0]) ==
(std::byte{3} == SwizByte2A[0]));
assert(static_cast<bool>(SwizByte2NeqScalarB[0]) ==
(std::byte{3} != SwizByte2A[0]));
assert(static_cast<bool>(SwizByte2LtScalarB[0]) ==
(std::byte{3} < SwizByte2A[0]));
assert(static_cast<bool>(SwizByte2LteScalarB[0]) ==
(std::byte{3} <= SwizByte2A[0]));
assert(static_cast<bool>(SwizByte2GtScalarB[0]) ==
(std::byte{3} > SwizByte2A[0]));
assert(static_cast<bool>(SwizByte2GteScalarB[0]) ==
(std::byte{3} >= SwizByte2A[0]));
}

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
{
// std::byte is not an arithmetic type and it only supports the following
// overloads of >> and << operators.
//
// 1 template <class IntegerType>
// constexpr std::byte operator<<( std::byte b, IntegerType shift )
// noexcept;
Expand Down