Skip to content

Commit 03c8d10

Browse files
authored
Merge pull request #5476 from ChipKerchner/fasterGEMVNRISCV
Tranverse matrix data in a cache friendly manner for GEMV_N (RISCV)
2 parents 3eed188 + 36f9cb8 commit 03c8d10

File tree

1 file changed

+47
-210
lines changed

1 file changed

+47
-210
lines changed

kernel/riscv64/gemv_n_vector.c

Lines changed: 47 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -26,230 +26,67 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
*****************************************************************************/
2727

2828
#include "common.h"
29+
2930
#if !defined(DOUBLE)
30-
#define VSETVL(n) RISCV_RVV(vsetvl_e32m8)(n)
31-
#define FLOAT_V_T vfloat32m8_t
32-
#define VLEV_FLOAT RISCV_RVV(vle32_v_f32m8)
33-
#define VLSEV_FLOAT RISCV_RVV(vlse32_v_f32m8)
34-
#define VSEV_FLOAT RISCV_RVV(vse32_v_f32m8)
35-
#define VSSEV_FLOAT RISCV_RVV(vsse32_v_f32m8)
36-
#define VFMACCVF_FLOAT RISCV_RVV(vfmacc_vf_f32m8)
37-
#define VFMUL_VF_FLOAT RISCV_RVV(vfmul_vf_f32m8)
38-
#define VREINTERPRET_FLOAT RISCV_RVV(vreinterpret_v_i32m8_f32m8)
39-
#define VFILL_INT RISCV_RVV(vmv_v_x_i32m8)
31+
#define VSETVL(n) RISCV_RVV(vsetvl_e32m8)(n)
32+
#define FLOAT_V_T vfloat32m8_t
33+
#define VLEV_FLOAT RISCV_RVV(vle32_v_f32m8)
34+
#define VLSEV_FLOAT RISCV_RVV(vlse32_v_f32m8)
35+
#define VSEV_FLOAT RISCV_RVV(vse32_v_f32m8)
36+
#define VSSEV_FLOAT RISCV_RVV(vsse32_v_f32m8)
37+
#define VFMACCVF_FLOAT RISCV_RVV(vfmacc_vf_f32m8)
4038
#else
41-
#define VSETVL(n) RISCV_RVV(vsetvl_e64m4)(n)
42-
#define FLOAT_V_T vfloat64m4_t
43-
#define VLEV_FLOAT RISCV_RVV(vle64_v_f64m4)
44-
#define VLSEV_FLOAT RISCV_RVV(vlse64_v_f64m4)
45-
#define VSEV_FLOAT RISCV_RVV(vse64_v_f64m4)
46-
#define VSSEV_FLOAT RISCV_RVV(vsse64_v_f64m4)
47-
#define VFMACCVF_FLOAT RISCV_RVV(vfmacc_vf_f64m4)
48-
#define VFMUL_VF_FLOAT RISCV_RVV(vfmul_vf_f64m4)
49-
#define VREINTERPRET_FLOAT RISCV_RVV(vreinterpret_v_i64m4_f64m4)
50-
#define VFILL_INT RISCV_RVV(vmv_v_x_i64m4)
39+
#define VSETVL(n) RISCV_RVV(vsetvl_e64m8)(n)
40+
#define FLOAT_V_T vfloat64m8_t
41+
#define VLEV_FLOAT RISCV_RVV(vle64_v_f64m8)
42+
#define VLSEV_FLOAT RISCV_RVV(vlse64_v_f64m8)
43+
#define VSEV_FLOAT RISCV_RVV(vse64_v_f64m8)
44+
#define VSSEV_FLOAT RISCV_RVV(vsse64_v_f64m8)
45+
#define VFMACCVF_FLOAT RISCV_RVV(vfmacc_vf_f64m8)
5146
#endif
5247

5348
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer)
5449
{
55-
BLASLONG i = 0, j = 0, k = 0;
56-
BLASLONG ix = 0, iy = 0;
57-
58-
if(n < 0) return(0);
59-
FLOAT *a_ptr = a;
60-
FLOAT temp[4];
61-
FLOAT_V_T va0, va1, vy0, vy1,vy0_temp, vy1_temp ,va0_0 , va0_1 , va1_0 ,va1_1 ,va2_0 ,va2_1 ,va3_0 ,va3_1 ;
62-
unsigned int gvl = 0;
63-
if(inc_y == 1 && inc_x == 1){
64-
gvl = VSETVL(m);
65-
if(gvl <= m/2){
66-
for(k=0,j=0; k<m/(2*gvl); k++){
67-
a_ptr = a;
68-
ix = 0;
69-
vy0_temp = VLEV_FLOAT(&y[j], gvl);
70-
vy1_temp = VLEV_FLOAT(&y[j+gvl], gvl);
71-
vy0 = VREINTERPRET_FLOAT(VFILL_INT(0, gvl));
72-
vy1 = VREINTERPRET_FLOAT(VFILL_INT(0, gvl));
73-
int i;
74-
75-
int remainder = n % 4;
76-
for(i = 0; i < remainder; i++){
77-
temp[0] = x[ix];
78-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
79-
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
80-
81-
va1 = VLEV_FLOAT(&a_ptr[j+gvl], gvl);
82-
vy1 = VFMACCVF_FLOAT(vy1, temp[0], va1, gvl);
83-
a_ptr += lda;
84-
ix ++;
85-
}
86-
87-
for(i = remainder; i < n; i += 4){
88-
va0_0 = VLEV_FLOAT(&(a_ptr)[j], gvl);
89-
va0_1 = VLEV_FLOAT(&(a_ptr)[j+gvl], gvl);
90-
va1_0 = VLEV_FLOAT(&(a_ptr+lda * 1)[j], gvl);
91-
va1_1 = VLEV_FLOAT(&(a_ptr+lda * 1)[j+gvl], gvl);
92-
va2_0 = VLEV_FLOAT(&(a_ptr+lda * 2)[j], gvl);
93-
va2_1 = VLEV_FLOAT(&(a_ptr+lda * 2)[j+gvl], gvl);
94-
va3_0 = VLEV_FLOAT(&(a_ptr+lda * 3)[j], gvl);
95-
va3_1 = VLEV_FLOAT(&(a_ptr+lda * 3)[j+gvl], gvl);
96-
97-
vy0 = VFMACCVF_FLOAT(vy0, x[ix], va0_0, gvl);
98-
vy1 = VFMACCVF_FLOAT(vy1, x[ix], va0_1, gvl);
99-
100-
vy0 = VFMACCVF_FLOAT(vy0, x[ix+1], va1_0, gvl);
101-
vy1 = VFMACCVF_FLOAT(vy1, x[ix+1], va1_1, gvl);
50+
if (n < 0) return(0);
10251

103-
vy0 = VFMACCVF_FLOAT(vy0, x[ix+2], va2_0, gvl);
104-
vy1 = VFMACCVF_FLOAT(vy1, x[ix+2], va2_1, gvl);
105-
106-
vy0 = VFMACCVF_FLOAT(vy0, x[ix+3], va3_0, gvl);
107-
vy1 = VFMACCVF_FLOAT(vy1, x[ix+3], va3_1, gvl);
108-
a_ptr += 4 * lda;
109-
ix +=4;
110-
}
111-
vy0 = VFMACCVF_FLOAT(vy0_temp, alpha, vy0, gvl);
112-
vy1 = VFMACCVF_FLOAT(vy1_temp, alpha, vy1, gvl);
113-
VSEV_FLOAT(&y[j], vy0, gvl);
114-
VSEV_FLOAT(&y[j+gvl], vy1, gvl);
115-
j += gvl * 2;
116-
}
117-
}
118-
//tail
119-
if(gvl <= m - j ){
120-
a_ptr = a;
121-
ix = 0;
122-
vy0_temp = VLEV_FLOAT(&y[j], gvl);
123-
vy0 = VREINTERPRET_FLOAT(VFILL_INT(0, gvl));
124-
int i;
125-
126-
int remainder = n % 4;
127-
for(i = 0; i < remainder; i++){
128-
temp[0] = x[ix];
129-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
130-
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
131-
a_ptr += lda;
132-
ix ++;
133-
}
134-
135-
for(i = remainder; i < n; i += 4){
136-
va0_0 = VLEV_FLOAT(&(a_ptr)[j], gvl);
137-
va1_0 = VLEV_FLOAT(&(a_ptr+lda * 1)[j], gvl);
138-
va2_0 = VLEV_FLOAT(&(a_ptr+lda * 2)[j], gvl);
139-
va3_0 = VLEV_FLOAT(&(a_ptr+lda * 3)[j], gvl);
140-
vy0 = VFMACCVF_FLOAT(vy0, x[ix], va0_0, gvl);
141-
vy0 = VFMACCVF_FLOAT(vy0, x[ix+1], va1_0, gvl);
142-
vy0 = VFMACCVF_FLOAT(vy0, x[ix+2], va2_0, gvl);
143-
vy0 = VFMACCVF_FLOAT(vy0, x[ix+3], va3_0, gvl);
144-
a_ptr += 4 * lda;
145-
ix +=4;
146-
}
147-
vy0 = VFMACCVF_FLOAT(vy0_temp, alpha, vy0, gvl);
148-
149-
VSEV_FLOAT(&y[j], vy0, gvl);
150-
151-
j += gvl ;
152-
}
153-
52+
FLOAT *a_ptr, *y_ptr, temp;
53+
BLASLONG i, j, vl;
54+
FLOAT_V_T va, vy;
15455

155-
for(;j < m;){
156-
gvl = VSETVL(m-j);
56+
if (inc_y == 1) {
57+
for (j = 0; j < n; j++) {
58+
temp = alpha * x[0];
59+
y_ptr = y;
15760
a_ptr = a;
158-
ix = 0;
159-
vy0 = VLEV_FLOAT(&y[j], gvl);
160-
for(i = 0; i < n; i++){
161-
temp[0] = alpha * x[ix];
162-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
163-
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
164-
165-
a_ptr += lda;
166-
ix += inc_x;
167-
}
168-
VSEV_FLOAT(&y[j], vy0, gvl);
169-
j += gvl;
170-
}
171-
}else if (inc_y == 1 && inc_x !=1) {
172-
gvl = VSETVL(m);
173-
if(gvl <= m/2){
174-
for(k=0,j=0; k<m/(2*gvl); k++){
175-
a_ptr = a;
176-
ix = 0;
177-
vy0 = VLEV_FLOAT(&y[j], gvl);
178-
vy1 = VLEV_FLOAT(&y[j+gvl], gvl);
179-
for(i = 0; i < n; i++){
180-
temp[0] = alpha * x[ix];
181-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
182-
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
183-
184-
va1 = VLEV_FLOAT(&a_ptr[j+gvl], gvl);
185-
vy1 = VFMACCVF_FLOAT(vy1, temp[0], va1, gvl);
186-
a_ptr += lda;
187-
ix += inc_x;
188-
}
189-
VSEV_FLOAT(&y[j], vy0, gvl);
190-
VSEV_FLOAT(&y[j+gvl], vy1, gvl);
191-
j += gvl * 2;
192-
}
193-
}
194-
//tail
195-
for(;j < m;){
196-
gvl = VSETVL(m-j);
197-
a_ptr = a;
198-
ix = 0;
199-
vy0 = VLEV_FLOAT(&y[j], gvl);
200-
for(i = 0; i < n; i++){
201-
temp[0] = alpha * x[ix];
202-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
203-
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
204-
205-
a_ptr += lda;
206-
ix += inc_x;
61+
for (i = m; i > 0; i -= vl) {
62+
vl = VSETVL(i);
63+
vy = VLEV_FLOAT(y_ptr, vl);
64+
va = VLEV_FLOAT(a_ptr, vl);
65+
vy = VFMACCVF_FLOAT(vy, temp, va, vl);
66+
VSEV_FLOAT(y_ptr, vy, vl);
67+
y_ptr += vl;
68+
a_ptr += vl;
20769
}
208-
VSEV_FLOAT(&y[j], vy0, gvl);
209-
j += gvl;
70+
x += inc_x;
71+
a += lda;
21072
}
211-
}else{
73+
} else {
21274
BLASLONG stride_y = inc_y * sizeof(FLOAT);
213-
gvl = VSETVL(m);
214-
if(gvl <= m/2){
215-
BLASLONG inc_yv = inc_y * gvl;
216-
for(k=0,j=0; k<m/(2*gvl); k++){
217-
a_ptr = a;
218-
ix = 0;
219-
vy0 = VLSEV_FLOAT(&y[iy], stride_y, gvl);
220-
vy1 = VLSEV_FLOAT(&y[iy+inc_yv], stride_y, gvl);
221-
for(i = 0; i < n; i++){
222-
temp[0] = alpha * x[ix];
223-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
224-
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
225-
226-
va1 = VLEV_FLOAT(&a_ptr[j+gvl], gvl);
227-
vy1 = VFMACCVF_FLOAT(vy1, temp[0], va1, gvl);
228-
a_ptr += lda;
229-
ix += inc_x;
230-
}
231-
VSSEV_FLOAT(&y[iy], stride_y, vy0, gvl);
232-
VSSEV_FLOAT(&y[iy+inc_yv], stride_y, vy1, gvl);
233-
j += gvl * 2;
234-
iy += inc_yv * 2;
235-
}
236-
}
237-
//tail
238-
for(;j < m;){
239-
gvl = VSETVL(m-j);
75+
for (j = 0; j < n; j++) {
76+
temp = alpha * x[0];
77+
y_ptr = y;
24078
a_ptr = a;
241-
ix = 0;
242-
vy0 = VLSEV_FLOAT(&y[j*inc_y], stride_y, gvl);
243-
for(i = 0; i < n; i++){
244-
temp[0] = alpha * x[ix];
245-
va0 = VLEV_FLOAT(&a_ptr[j], gvl);
246-
vy0 = VFMACCVF_FLOAT(vy0, temp[0], va0, gvl);
247-
248-
a_ptr += lda;
249-
ix += inc_x;
79+
for (i = m; i > 0; i -= vl) {
80+
vl = VSETVL(i);
81+
vy = VLSEV_FLOAT(y_ptr, stride_y, vl);
82+
va = VLEV_FLOAT(a_ptr, vl);
83+
vy = VFMACCVF_FLOAT(vy, temp, va, vl);
84+
VSSEV_FLOAT(y_ptr, stride_y, vy, vl);
85+
y_ptr += vl * inc_y;
86+
a_ptr += vl;
25087
}
251-
VSSEV_FLOAT(&y[j*inc_y], stride_y, vy0, gvl);
252-
j += gvl;
88+
x += inc_x;
89+
a += lda;
25390
}
25491
}
25592
return(0);

0 commit comments

Comments
 (0)