Skip to content

Commit eeb5cb5

Browse files
author
Daniel
committed
the c++17 version, it is deprecated
1 parent 9e60dd2 commit eeb5cb5

38 files changed

+8197
-0
lines changed

include/MAC.hpp

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
#include "./Matrix.hpp"
2+
3+
#define __ARM_CORTEX_M4
4+
5+
#ifdef __ARM_FP16_FORMAT_IEEE
6+
#define MAC_USE_FP16
7+
#endif
8+
9+
#ifdef __ARM_FP16_FORMAT_ALTERNATIVE
10+
#define MAC_USE_FP16
11+
#endif
12+
13+
#ifdef __ARM_CORTEX_M4
14+
#define MAC_USE_REAL_INSTRUCTIONS
15+
#endif
16+
17+
template <Dim_size_t Unrolled, typename InputType, typename WeightType, typename AccumulationType, size_t... UnrollIndexes>
18+
struct MAC {
19+
__attribute__((always_inline)) static inline AccumulationType OP(const InputType input[Unrolled],
20+
const WeightType weights[Unrolled],
21+
AccumulationType acc,
22+
std::index_sequence<UnrollIndexes...>) noexcept {
23+
// acc += ((static_cast<AccumulationType>(input[UnrollIndexes]) * static_cast<AccumulationType>(weights[UnrollIndexes])) + ...);
24+
acc = (acc + ... + (static_cast<const AccumulationType>(input[UnrollIndexes]) * static_cast<const AccumulationType>(weights[UnrollIndexes])));
25+
return acc;
26+
}
27+
};
28+
29+
template <Dim_size_t Unrolled, typename Type, size_t... UnrollIndexes>
30+
struct MAC<Unrolled, Complex<Type>, Complex<Type>, Type, UnrollIndexes...> {
31+
__attribute__((always_inline)) inline static Type OP(const Complex<Type> input[Unrolled], const Complex<Type> weights[Unrolled], Type acc, std::index_sequence<UnrollIndexes...>) noexcept {
32+
acc = (acc + ... + (input[UnrollIndexes].Mul_only_Real_result(weights[UnrollIndexes]))); // Fix computation
33+
return acc;
34+
}
35+
};
36+
37+
template <Dim_size_t Unrolled, typename Type, size_t... UnrollIndexes>
38+
struct MAC<Unrolled, Type, Complex<Type>, Complex<Type>, UnrollIndexes...> {
39+
__attribute__((always_inline)) inline static Complex<Type> OP(const Type input[Unrolled], const Complex<Type> weights[Unrolled], Complex<Type> acc, std::index_sequence<UnrollIndexes...>) noexcept {
40+
acc = (acc + ... + (weights[UnrollIndexes]*input[UnrollIndexes])); // Fix computation
41+
return acc;
42+
}
43+
};
44+
45+
union SMID32_t_int8 {
46+
uint32_t smid;
47+
48+
struct {
49+
int8_t data[4];
50+
} data;
51+
};
52+
53+
union SMID32_t_int16 {
54+
uint32_t smid;
55+
56+
struct {
57+
int16_t data[2];
58+
} data;
59+
};
60+
61+
union SMID32_t_int32 {
62+
uint32_t smid;
63+
64+
struct {
65+
int32_t data;
66+
} data;
67+
};
68+
69+
#ifdef MAC_USE_FP16
70+
union SMID32_t_fp16 {
71+
uint32_t smid;
72+
73+
struct {
74+
__fp16 data[2];
75+
} data;
76+
};
77+
#endif
78+
79+
#ifdef MAC_USE_REAL_INSTRUCTIONS
80+
__attribute__((always_inline)) inline uint32_t __SMLAD(uint32_t op1, uint32_t op2, uint32_t op3) {
81+
uint32_t result;
82+
__asm volatile("smlad %0, %1, %2, %3" : "=r"(result) : "r"(op1), "r"(op2), "r"(op3));
83+
return result;
84+
}
85+
86+
/* Fake Implementation of __SXTB16 to emulate behaviour*/
87+
__attribute__((always_inline)) inline uint32_t __SXTB16_ROR0(const uint32_t op1) {
88+
uint32_t result;
89+
__asm("sxtb16 %0, %1" : "=r"(result) : "r"(op1));
90+
return result;
91+
}
92+
93+
/* Fake Implementation of __SXTB16 to emulate behaviour*/
94+
__attribute__((always_inline)) inline uint32_t __SXTB16_ROR8(const uint32_t op1) {
95+
uint32_t result;
96+
__asm("sxtb16 %0, %1, ror 8" : "=r"(result) : "r"(op1));
97+
return result;
98+
}
99+
#ifdef MAC_USE_FP16
100+
// Bottom half fp16 to fp32
101+
__attribute__((always_inline)) inline float __VCVTB(const uint32_t op1) {
102+
float result;
103+
__asm("vcvtb.f16.f32 %0, %1" : "=w"(result) : "w"(op1));
104+
return result;
105+
}
106+
107+
// Top half fp16 to fp32
108+
__attribute__((always_inline)) inline float __VCVTT(const uint32_t op1) {
109+
float result;
110+
__asm("vcvtt.f16.f32 %0, %1" : "=w"(result) : "w"(op1));
111+
return result;
112+
}
113+
#endif
114+
#else
115+
/* Fake Implementation of __SMLAD to emulate behaviour*/
116+
__attribute__((always_inline)) inline uint32_t __SMLAD(uint32_t op1, uint32_t op2, uint32_t op3) {
117+
// uint32_t result;
118+
// __asm volatile ("smlad %0, %1, %2, %3" : "=r" (result) : "r" (op1), "r" (op2), "r" (op3) );
119+
const SMID32_t_int16 smid_op1{.smid = op1};
120+
const SMID32_t_int16 smid_op2{.smid = op2};
121+
const SMID32_t_int32 smid_op3{.smid = op3};
122+
const int32_t result = static_cast<int32_t>(smid_op1.data.data[0]) * static_cast<int32_t>(smid_op2.data.data[0]) +
123+
static_cast<int32_t>(smid_op1.data.data[1]) * static_cast<int32_t>(smid_op2.data.data[1]) + smid_op3.data.data;
124+
125+
const SMID32_t_int32 smid_result{.data = {result}};
126+
return smid_result.smid;
127+
}
128+
129+
/* Fake Implementation of __SXTB16 to emulate behaviour*/
130+
__attribute__((always_inline)) inline uint32_t __SXTB16_ROR0(const uint32_t op1) {
131+
// uint32_t result;
132+
// __asm ("sxtb16 %0, %1" : "=r" (result) : "r" (op1) );
133+
const SMID32_t_int8 smid_op1{.smid = op1};
134+
const SMID32_t_int16 smid_result{.data = {smid_op1.data.data[0], smid_op1.data.data[2]}};
135+
136+
return smid_result.smid;
137+
}
138+
139+
/* Fake Implementation of __SXTB16 to emulate behaviour*/
140+
__attribute__((always_inline)) inline uint32_t __SXTB16_ROR8(const uint32_t op1) {
141+
// uint32_t result;
142+
// __asm ("sxtb16 %0, %1, ror 8" : "=r" (result) : "r" (op1) );
143+
const SMID32_t_int8 smid_op1{.smid = op1};
144+
const SMID32_t_int16 smid_result{.data = {smid_op1.data.data[1], smid_op1.data.data[3]}};
145+
return smid_result.smid;
146+
}
147+
#ifdef MAC_USE_FP16
148+
// Bottom half fp16 to fp32
149+
__attribute__((always_inline)) inline float __VCVTB(const uint32_t op1) {
150+
// float result;
151+
// __asm("vcvtb.f16.f32 %0, %1" : "=r"(result) : "r"(op1));
152+
SMID32_t_fp16 smid_op1{.smid = op1};
153+
return static_cast<float>(smid_op1.data.data[0]);
154+
}
155+
156+
// Top half fp16 to fp32
157+
__attribute__((always_inline)) inline float __VCVTT(const uint32_t op1) {
158+
// float result;
159+
// __asm("vcvtt.f16.f32 %0, %1" : "=r"(result) : "r"(op1));
160+
SMID32_t_fp16 smid_op1{.smid = op1};
161+
return static_cast<float>(smid_op1.data.data[1]);
162+
}
163+
#endif
164+
#endif
165+
166+
template <size_t... UnrollIndexes>
167+
struct MAC<2, int16_t, int16_t, int32_t, UnrollIndexes...> {
168+
__attribute__((always_inline)) inline static int32_t OP(const int16_t input[2], const int16_t weights[2], int32_t acc, std::index_sequence<UnrollIndexes...>) noexcept {
169+
const SMID32_t_int16 smid_intput{.data = {input[UnrollIndexes]...}};
170+
const SMID32_t_int16 smid_weights{.data = {weights[UnrollIndexes]...}};
171+
SMID32_t_int32 acc_smid{.data = {acc}};
172+
acc_smid.smid = __SMLAD(smid_intput.smid, smid_weights.smid, acc_smid.smid);
173+
return acc_smid.data.data;
174+
}
175+
};
176+
177+
template <size_t... UnrollIndexes>
178+
struct MAC<4, int8_t, int8_t, int32_t, UnrollIndexes...> {
179+
__attribute__((always_inline)) inline static int32_t OP(const int8_t input[4], const int8_t weights[4], int32_t acc, std::index_sequence<UnrollIndexes...>) noexcept {
180+
const SMID32_t_int8 smid_intput{.data = {input[UnrollIndexes]...}};
181+
const SMID32_t_int8 smid_weights{.data = {weights[UnrollIndexes]...}};
182+
SMID32_t_int32 acc_smid{.data = {acc}};
183+
const uint32_t a = __SXTB16_ROR0(smid_intput.smid);
184+
const uint32_t b = __SXTB16_ROR0(smid_weights.smid);
185+
acc_smid.smid = __SMLAD(a, b, acc_smid.smid);
186+
const uint32_t c = __SXTB16_ROR8(smid_intput.smid);
187+
const uint32_t d = __SXTB16_ROR8(smid_weights.smid);
188+
acc_smid.smid = __SMLAD(c, d, acc_smid.smid);
189+
return acc_smid.data.data;
190+
}
191+
};
192+
193+
#ifdef MAC_USE_FP16
194+
// template <Dim_size_t Unrolled, typename InputType, typename WeightType, typename AccumulationType, size_t... UnrollIndexes>
195+
template <size_t... UnrollIndexes>
196+
struct MAC<2, __fp16, __fp16, float, UnrollIndexes...> {
197+
__attribute__((always_inline)) inline static float OP(const __fp16 input[2], const __fp16 weights[2], float acc, std::index_sequence<UnrollIndexes...>) noexcept {
198+
199+
const SMID32_t_fp16 smid_intput{.data = {input[UnrollIndexes]...}};
200+
// const SMID32_t_fp16 smid_intput{.smid = *(uint32_t*)(void*)input};
201+
const SMID32_t_fp16 smid_weights{.data = {weights[UnrollIndexes]...}};
202+
// const SMID32_t_fp16 smid_weights{.smid = *(uint32_t*)(void*)weights};
203+
float a = __VCVTB(smid_intput.smid);
204+
float b = __VCVTB(smid_weights.smid);
205+
acc += a * b;
206+
a = __VCVTT(smid_intput.smid);
207+
b = __VCVTT(smid_weights.smid);
208+
acc += a * b;
209+
210+
return acc;
211+
}
212+
};
213+
214+
// template <Dim_size_t Unrolled, typename InputType, typename WeightType, typename AccumulationType, size_t... UnrollIndexes>
215+
template <size_t... UnrollIndexes>
216+
struct MAC<2, float, __fp16, float, UnrollIndexes...> {
217+
__attribute__((always_inline)) inline static float OP(const float input[2], const __fp16 weights[2], float acc, std::index_sequence<UnrollIndexes...>) noexcept {
218+
219+
// const SMID32_t_fp16 smid_weights{.data = {weights[UnrollIndexes]...}};
220+
// const SMID32_t_fp16 smid_weights{.smid = *(uint32_t*)(void*)weights};
221+
const SMID32_t_fp16 *smid_weights = reinterpret_cast<const SMID32_t_fp16 *>(weights);
222+
223+
const float b = __VCVTB(smid_weights->smid);
224+
acc += input[0] * b;
225+
const float c = __VCVTT(smid_weights->smid);
226+
acc += input[1] * c;
227+
228+
return acc;
229+
}
230+
};
231+
#endif

include/MAC_old.hpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#include "./Matrix.hpp"
2+
3+
template <Dim_size_t Unrolled, typename InputType, typename WeightType, typename AccumulationType, size_t... UnrollIndexes>
4+
struct MAC {
5+
__attribute__((always_inline)) static inline AccumulationType OP(const InputType input[Unrolled],
6+
const WeightType weights[Unrolled],
7+
AccumulationType acc,
8+
std::index_sequence<UnrollIndexes...>) noexcept {
9+
// const InputType input[Unrolled] = {LambdaInput(UnrollIndexes)...};
10+
// const WeightType weights[Unrolled] = {LambdaWeights(UnrollIndexes)...};
11+
acc += ((static_cast<AccumulationType>(input[UnrollIndexes]) * static_cast<AccumulationType>(weights[UnrollIndexes]))+...);
12+
return acc;
13+
}
14+
};
15+
16+
17+
__attribute__((always_inline)) inline uint32_t __SMLAD (const uint32_t op1,const uint32_t op2, const uint32_t op3)
18+
{
19+
uint32_t result;
20+
__asm volatile ("smlad %0, %1, %2, %3" : "=r" (result) : "r" (op1), "r" (op2), "r" (op3) );
21+
return(result);
22+
}
23+
24+
__attribute__((always_inline)) inline uint32_t __SXTB16_ROR0 (const uint32_t op1)
25+
{
26+
uint32_t result;
27+
__asm ("sxtb16 %0, %1" : "=r" (result) : "r" (op1) );
28+
return(result);
29+
}
30+
31+
__attribute__((always_inline)) inline uint32_t __SXTB16_ROR8 (const uint32_t op1)
32+
{
33+
uint32_t result;
34+
__asm ("sxtb16 %0, %1, ror 8" : "=r" (result) : "r" (op1) );
35+
return(result);
36+
}
37+
38+
39+
template <size_t... UnrollIndexes>
40+
struct MAC<2, int16_t, int16_t, int32_t, UnrollIndexes...> {
41+
__attribute__((always_inline)) inline static int32_t OP(const int16_t input[2], const int16_t weights[2], int32_t acc, std::index_sequence<UnrollIndexes...>) noexcept {
42+
union SMID32_t_16 {
43+
uint32_t smid;
44+
struct {
45+
int16_t data[2];
46+
} data;
47+
};
48+
49+
union SMID32_t_32 {
50+
uint32_t smid;
51+
struct {
52+
int32_t data;
53+
} data;
54+
};
55+
56+
const SMID32_t_16 smid_intput{.data = {input[UnrollIndexes]...}};
57+
const SMID32_t_16 smid_weights{.data = {weights[UnrollIndexes]...}};
58+
SMID32_t_32 acc_smid{.data={acc}};
59+
acc_smid.smid = __SMLAD(smid_intput.smid,smid_weights.smid,acc_smid.smid);
60+
return acc_smid.data.data;
61+
}
62+
};
63+
64+
template <size_t... UnrollIndexes>
65+
struct MAC<4, int8_t, int8_t, int32_t, UnrollIndexes...> {
66+
__attribute__((always_inline)) inline static int32_t OP(const int8_t input[4], const int8_t weights[4], int32_t acc, std::index_sequence<UnrollIndexes...>) noexcept {
67+
union SMID32_t_8 {
68+
uint32_t smid;
69+
struct {
70+
int8_t data[4];
71+
} data;
72+
};
73+
74+
union SMID32_t_32 {
75+
uint32_t smid;
76+
struct {
77+
int32_t data;
78+
} data;
79+
};
80+
const SMID32_t_8 smid_intput{.data = {input[UnrollIndexes]...}};
81+
const SMID32_t_8 smid_weights{.data = {weights[UnrollIndexes]...}};
82+
SMID32_t_32 acc_smid{.data={acc}};
83+
uint32_t a = __SXTB16_ROR0(smid_intput.smid);
84+
uint32_t b = __SXTB16_ROR0(smid_weights.smid);
85+
acc_smid.smid = __SMLAD(a,b,acc_smid.smid);
86+
a = __SXTB16_ROR8(smid_intput.smid);
87+
b = __SXTB16_ROR8(smid_weights.smid);
88+
acc_smid.smid = __SMLAD(a,b,acc_smid.smid);
89+
return acc_smid.data.data;
90+
}
91+
};

0 commit comments

Comments
 (0)