5 #if defined (__AVX512F__) && defined (__AVX512DQ__)
10 #define _mm256_fmadd_ps(x,y,z) _mm256_add_ps(z, _mm256_mul_ps(x,y))
11 #define _mm256_fnmadd_ps(x,y,z) _mm256_sub_ps(z, _mm256_mul_ps(x,y))
12 #define _mm512_fmadd_ps(x,y,z) _mm512_add_ps(z, _mm512_mul_ps(x,y))
13 #define _mm512_fnmadd_ps(x,y,z) _mm512_sub_ps(z, _mm512_mul_ps(x,y))
27 float xibuf [
NIMAX/16] [4][16];
28 float accpbuf[
NIMAX/16] [5][16];
30 float xibuf [
NIMAX/8] [4][8];
31 float accpbuf[
NIMAX/8] [5][8];
33 float epjbuf [
NJMAX] [4];
34 float rsearchj[
NJMAX];
35 float spjbuf [
NJMAX] [3][4];
38 static double get_a_NaN(){
39 union{
long l;
double d; } m;
132 this->r_crit2 = _r_crit2;
135 void set_epj_one(
const int addr,
const double x,
const double y,
const double z,
136 const double m,
const double r_search) {
141 rsearchj[addr] = r_search;
145 const double x,
const double y,
const double z,
const double m,
146 const double qxx,
const double qyy,
const double qzz,
147 const double qxy,
const double qyz,
const double qzx){
148 const double tr = qxx + qyy + qzz;
149 spjbuf[addr][0][0] = x;
150 spjbuf[addr][0][1] = y;
151 spjbuf[addr][0][2] = z;
152 spjbuf[addr][0][3] = m;
154 spjbuf[addr][1][0] = 3.0 * qxx - tr;
155 spjbuf[addr][1][1] = 3.0 * qyy - tr;
156 spjbuf[addr][1][2] = 3.0 * qzz - tr;
157 spjbuf[addr][1][3] = m;
159 spjbuf[addr][2][0] = 3.0 * qxy;
160 spjbuf[addr][2][1] = 3.0 * qyz;
161 spjbuf[addr][2][2] = 3.0 * qzx;
162 spjbuf[addr][2][3] = -(eps2 * tr);
165 void set_xi_one(
const int addr,
const double x,
const double y,
const double z,
const double r_search){
167 const int ah = addr / 16;
168 const int al = addr % 16;
170 const int ah = addr / 8;
171 const int al = addr % 8;
173 xibuf[ah][0][al] = x;
174 xibuf[ah][1][al] = y;
175 xibuf[ah][2][al] = z;
176 xibuf[ah][3][al] = r_search;
178 template <
typename real_t>
182 const int ah = addr / 16;
183 const int al = addr % 16;
185 const int ah = addr / 8;
186 const int al = addr % 8;
188 ax += accpbuf[ah][0][al];
189 ay += accpbuf[ah][1][al];
190 az += accpbuf[ah][2][al];
191 pot += accpbuf[ah][3][al];
193 template <
typename real_t>
195 real_t &pot, real_t &nngb){
197 const int ah = addr / 16;
198 const int al = addr % 16;
200 const int ah = addr / 8;
201 const int al = addr % 8;
203 ax += accpbuf[ah][0][al];
204 ay += accpbuf[ah][1][al];
205 az += accpbuf[ah][2][al];
206 pot += accpbuf[ah][3][al];
207 nngb += accpbuf[ah][4][al];
209 template <
typename real_t>
212 const int ah = addr / 16;
213 const int al = addr % 16;
215 const int ah = addr / 8;
216 const int al = addr % 8;
218 nngb += accpbuf[ah][4][al];
220 template <
typename real_t>
221 void get_accp_one(
const int addr, real_t &ax, real_t &ay, real_t &az, real_t &pot,
224 const int ah = addr / 16;
225 const int al = addr % 16;
227 const int ah = addr / 8;
228 const int al = addr % 8;
230 ax = accpbuf[ah][0][al];
231 ay = accpbuf[ah][1][al];
232 az = accpbuf[ah][2][al];
233 pot = accpbuf[ah][3][al];
234 nngb = accpbuf[ah][4][al];
239 std::cout<<
"ni= "<<ni<<
" NIMAX= "<<
NIMAX<<
" nj= "<<nj<<
" NJMAX= "<<
NJMAX<<std::endl;
241 for(
PS::S32 i=0; i<(ni-1)/16+1; i++){
244 std::cout<<
"xibuf[i][k][j]="<<xibuf[i][k][j]<<std::endl;
246 std::cout<<std::endl;
253 std::cout<<
"xibuf[i][k][j]="<<xibuf[i][k][j]<<std::endl;
255 std::cout<<std::endl;
262 kernel_epj_nounroll_for_p3t_with_linear_cutoff(ni, nj);
267 std::cout<<
"ni= "<<ni<<
" NIMAX= "<<
NIMAX<<
" nj= "<<nj<<
" NJMAX= "<<
NJMAX<<std::endl;
269 for(
PS::S32 i=0; i<(ni-1)/16+1; i++){
272 std::cout<<
"xibuf[i][k][j]="<<xibuf[i][k][j]<<std::endl;
274 std::cout<<std::endl;
281 std::cout<<
"xibuf[i][k][j]="<<xibuf[i][k][j]<<std::endl;
283 std::cout<<std::endl;
290 kernel_epj_nounroll_for_neighbor_count(ni, nj);
295 std::cout<<
"ni= "<<ni<<
" NIMAX= "<<
NIMAX<<
" nj= "<<nj<<
" NJMAX= "<<
NJMAX<<std::endl;
297 for(
PS::S32 i=0; i<(ni-1)/16+1; i++){
300 std::cout<<
"i,j,k="<<i<<
" "<<j<<
" "<<k<<std::endl;
301 std::cout<<
"xibuf[i][k][j]="<<xibuf[i][k][j]<<std::endl;
303 std::cout<<std::endl;
307 for(
PS::S32 i=0; i<(ni-1)/8+1; i++){
310 std::cout<<
"i,j,k="<<i<<
" "<<j<<
" "<<k<<std::endl;
311 std::cout<<
"xibuf[i][k][j]="<<xibuf[i][k][j]<<std::endl;
313 std::cout<<std::endl;
321 kernel_epj_nounroll(ni, nj);
326 std::cout<<
"ni= "<<ni<<
" NIMAX= "<<
NIMAX<<
" nj= "<<nj<<
" NJMAX= "<<
NJMAX<<std::endl;
330 kernel_spj_nounroll(ni, nj);
355 typedef __m512 v16sf;
356 typedef __m512d v8df;
360 void kernel_epj_nounroll_for_neighbor_count(
const int ni,
const int nj){
361 for(
int i=0; i<ni; i+=16){
362 const v16sf xi = *(v16sf *)(xibuf[i/16][0]);
363 const v16sf yi = *(v16sf *)(xibuf[i/16][1]);
364 const v16sf zi = *(v16sf *)(xibuf[i/16][2]);
365 const v16sf rsi= *(v16sf *)(xibuf[i/16][3]);
367 v16sf nngb = _mm512_set1_ps(0.0f);
370 v16sf rsjbuf= _mm512_set1_ps(*rsearchj);
371 v16sf jbuf = _mm512_broadcast_f32x4(*(v4sf *)epjbuf);
373 v16sf xj = _mm512_shuffle_ps(jbuf, jbuf, 0x00);
374 v16sf yj = _mm512_shuffle_ps(jbuf, jbuf, 0x55);
375 v16sf zj = _mm512_shuffle_ps(jbuf, jbuf, 0xaa);
380 for(
int j=0; j<nj; j++) {
382 rsjbuf = _mm512_set1_ps(rsearchj[j+1]);
383 jbuf = _mm512_broadcast_f32x4(*(v4sf *)(epjbuf + j+1));
385 v16sf dx = _mm512_sub_ps(xi, xj);
386 v16sf dy = _mm512_sub_ps(yi, yj);
387 v16sf dz = _mm512_sub_ps(zi, zj);
389 v16sf dx = _mm512_sub_ps(xi, _mm512_set1_ps(epjbuf[j][0]));
390 v16sf dy = _mm512_sub_ps(yi, _mm512_set1_ps(epjbuf[j][1]));
391 v16sf dz = _mm512_sub_ps(zi, _mm512_set1_ps(epjbuf[j][2]));
393 v16sf r2_real = _mm512_mul_ps(dx, dx);
398 xj = _mm512_shuffle_ps(jbuf, jbuf, 0x00);
399 yj = _mm512_shuffle_ps(jbuf, jbuf, 0x55);
400 zj = _mm512_shuffle_ps(jbuf, jbuf, 0xaa);
404 v16sf vrcrit = _mm512_max_ps(rsi, rsj);
406 v16sf vrcrit = _mm512_max_ps(rsi, _mm512_set1_ps(rsearchj[j]));
408 v16sf vrcrit2 = _mm512_mul_ps(vrcrit, vrcrit);
409 nngb = _mm512_add_ps(_mm512_mask_blend_ps(
410 _mm512_cmp_ps_mask(vrcrit2, r2_real, 0x01),
411 _mm512_set1_ps(1.0f),
412 _mm512_set1_ps(0.0f)),
418 *(v16sf *)(accpbuf[i/16][4]) = nngb;
423 void kernel_epj_nounroll_for_p3t_with_linear_cutoff(
const int ni,
const int nj){
424 const v16sf veps2 = _mm512_set1_ps((
float)eps2);
427 const v16sf vr_out2 = _mm512_set1_ps((
float)r_crit2);
429 for(
int i=0; i<ni; i+=16){
430 const v16sf xi = *(v16sf *)(xibuf[i/16][0]);
431 const v16sf yi = *(v16sf *)(xibuf[i/16][1]);
432 const v16sf zi = *(v16sf *)(xibuf[i/16][2]);
433 const v16sf rsi= *(v16sf *)(xibuf[i/16][3]);
435 v16sf ax, ay, az, pot, nngb;
436 ax = ay = az = pot = nngb = _mm512_set1_ps(0.0f);
438 v16sf rsjbuf= _mm512_set1_ps(*rsearchj);
439 v16sf jbuf = _mm512_broadcast_f32x4(*(v4sf *)epjbuf);
441 v16sf xj = _mm512_shuffle_ps(jbuf, jbuf, 0x00);
442 v16sf yj = _mm512_shuffle_ps(jbuf, jbuf, 0x55);
443 v16sf zj = _mm512_shuffle_ps(jbuf, jbuf, 0xaa);
444 v16sf mj = _mm512_shuffle_ps(jbuf, jbuf, 0xff);
448 for(
int j=0; j<nj; j++) {
450 rsjbuf = _mm512_set1_ps(rsearchj[j+1]);
451 jbuf = _mm512_broadcast_f32x4(*(v4sf *)(epjbuf + j+1));
453 v16sf dx = _mm512_sub_ps(xi, xj);
454 v16sf dy = _mm512_sub_ps(yi, yj);
455 v16sf dz = _mm512_sub_ps(zi, zj);
457 v16sf dx = _mm512_sub_ps(xi, _mm512_set1_ps(epjbuf[j][0]));
458 v16sf dy = _mm512_sub_ps(yi, _mm512_set1_ps(epjbuf[j][1]));
459 v16sf dz = _mm512_sub_ps(zi, _mm512_set1_ps(epjbuf[j][2]));
464 v16sf r2 = _mm512_max_ps(r2_real, vr_out2);
465 v16sf ri1 = _mm512_rsqrt14_ps(r2);
467 v16sf ri2 = _mm512_mul_ps(ri1, ri1);
468 #ifdef RSQRT_NR_EPJ_X4
470 v16sf v1 = _mm512_set1_ps(1.0f);
473 v16sf v6p0 = _mm512_set1_ps(6.0f);
474 v16sf v5p0 = _mm512_set1_ps(5.0f);
475 v16sf v8p0 = _mm512_set1_ps(8.0f);
476 v16sf v0p0625 = _mm512_set1_ps((
float)1.0/16.0);
480 ri2 = _mm512_mul_ps(h, ri2);
482 ri1 = _mm512_mul_ps(ri2, ri1);
486 ri2 = _mm512_mul_ps(ri1, ri1);
488 #elif defined(RSQRT_NR_EPJ_X2)
490 ri2 = _mm512_mul_ps(ri2, _mm512_set1_ps(0.5f));
491 ri1 = _mm512_mul_ps(ri2, ri1);
494 ri2 = _mm512_mul_ps(ri1, ri1);
498 v16sf mri1 = _mm512_mul_ps(ri1, mj);
500 v16sf mri1 = _mm512_mul_ps(ri1, _mm512_set1_ps(epjbuf[j][3]));
502 v16sf mri3 = _mm512_mul_ps(mri1, ri2);
504 xj = _mm512_shuffle_ps(jbuf, jbuf, 0x00);
505 yj = _mm512_shuffle_ps(jbuf, jbuf, 0x55);
506 zj = _mm512_shuffle_ps(jbuf, jbuf, 0xaa);
507 mj = _mm512_shuffle_ps(jbuf, jbuf, 0xff);
510 pot = _mm512_sub_ps(pot, mri1);
516 v16sf vrcrit = _mm512_max_ps(rsi, rsj);
518 v16sf vrcrit = _mm512_max_ps(rsi, _mm512_set1_ps(rsearchj[j]));
520 v16sf vrcrit2 = _mm512_mul_ps(vrcrit, vrcrit);
521 nngb = _mm512_add_ps(_mm512_mask_blend_ps(
522 _mm512_cmp_ps_mask(vrcrit2, r2_real, 0x01),
523 _mm512_set1_ps(1.0f),
524 _mm512_set1_ps(0.0f)),
530 *(v16sf *)(accpbuf[i/16][0]) = ax;
531 *(v16sf *)(accpbuf[i/16][1]) = ay;
532 *(v16sf *)(accpbuf[i/16][2]) = az;
533 *(v16sf *)(accpbuf[i/16][3]) = pot;
534 *(v16sf *)(accpbuf[i/16][4]) = nngb;
547 typedef __m256d v4df;
550 void kernel_epj_nounroll_for_neighbor_count(
const int ni,
const int nj){
551 const v8sf vone = _mm256_set1_ps(1.0f);
552 const v8sf allbits = _mm256_cmp_ps(vone, vone, 0x00);
554 for(
int i=0; i<ni; i+=8){
555 const v8sf xi = *(v8sf *)(xibuf[i/8][0]);
556 const v8sf yi = *(v8sf *)(xibuf[i/8][1]);
557 const v8sf zi = *(v8sf *)(xibuf[i/8][2]);
558 const v8sf rsi = *(v8sf *)(xibuf[i/8][3]);
560 v8sf nngb = _mm256_set1_ps(0.0f);
562 v8sf jbuf = _mm256_broadcast_ps((v4sf *)epjbuf);
563 v8sf rsjbuf= _mm256_broadcast_ss(rsearchj);
565 v8sf xj = _mm256_shuffle_ps(jbuf, jbuf, 0x00);
566 v8sf yj = _mm256_shuffle_ps(jbuf, jbuf, 0x55);
567 v8sf zj = _mm256_shuffle_ps(jbuf, jbuf, 0xaa);
570 for(
int j=0; j<nj; j++){
571 jbuf = _mm256_broadcast_ps((v4sf *)(epjbuf + j+1));
572 rsjbuf = _mm256_broadcast_ss(rsearchj+j+1);
574 v8sf dx = _mm256_sub_ps(xi, xj);
575 v8sf dy = _mm256_sub_ps(yi, yj);
576 v8sf dz = _mm256_sub_ps(zi, zj);
578 v8sf r2_real = _mm256_mul_ps(dx, dx);
582 xj = _mm256_shuffle_ps(jbuf, jbuf, 0x00);
583 yj = _mm256_shuffle_ps(jbuf, jbuf, 0x55);
584 zj = _mm256_shuffle_ps(jbuf, jbuf, 0xaa);
586 v8sf vrcrit = _mm256_max_ps(rsi, rsj);
587 v8sf vrcrit2 = _mm256_mul_ps(vrcrit, vrcrit);
588 v8sf mask = _mm256_cmp_ps(vrcrit2, r2_real, 0x01);
590 nngb = _mm256_add_ps(_mm256_and_ps( vone, _mm256_xor_ps(mask, allbits) ), nngb);
592 *(v8sf *)(accpbuf[i/8][4]) = nngb;
597 void kernel_epj_nounroll_for_p3t_with_linear_cutoff(
const int ni,
const int nj){
598 const v8sf vone = _mm256_set1_ps(1.0f);
599 const v8sf veps2 = _mm256_set1_ps((
float)eps2);
602 const v8sf vr_out2 = _mm256_set1_ps((
float)r_crit2);
605 const v8sf allbits = _mm256_cmp_ps(vone, vone, 0x00);
606 for(
int i=0; i<ni; i+=8){
608 const v8sf xi = *(v8sf *)(xibuf[i/8][0]);
614 const v8sf yi = *(v8sf *)(xibuf[i/8][1]);
620 const v8sf zi = *(v8sf *)(xibuf[i/8][2]);
625 const v8sf rsi = *(v8sf *)(xibuf[i/8][3]);
627 v8sf ax, ay, az, pot, nngb;
628 ax = ay = az = pot = nngb = _mm256_set1_ps(0.0f);
629 v8sf jbuf = _mm256_broadcast_ps((v4sf *)epjbuf);
631 v8sf rsjbuf= _mm256_broadcast_ss(rsearchj);
633 v8sf xj = _mm256_shuffle_ps(jbuf, jbuf, 0x00);
634 v8sf yj = _mm256_shuffle_ps(jbuf, jbuf, 0x55);
635 v8sf zj = _mm256_shuffle_ps(jbuf, jbuf, 0xaa);
636 v8sf mj = _mm256_shuffle_ps(jbuf, jbuf, 0xff);
638 for(
int j=0; j<nj; j++){
639 jbuf = _mm256_broadcast_ps((v4sf *)(epjbuf + j+1));
640 rsjbuf = _mm256_broadcast_ss(rsearchj+j+1);
642 v8sf dx = _mm256_sub_ps(xi, xj);
643 v8sf dy = _mm256_sub_ps(yi, yj);
644 v8sf dz = _mm256_sub_ps(zi, zj);
649 v8sf r2 = _mm256_max_ps(r2_real, vr_out2);
650 v8sf ri1 = _mm256_rsqrt_ps(r2);
654 v8sf ri2 = _mm256_mul_ps(ri1, ri1);
655 #ifdef RSQRT_NR_EPJ_X4
657 v8sf v1 = _mm256_set1_ps(1.0f);
660 v8sf v6p0 = _mm256_set1_ps(6.0f);
661 v8sf v5p0 = _mm256_set1_ps(5.0f);
662 v8sf v8p0 = _mm256_set1_ps(8.0f);
663 v8sf v0p0625 = _mm256_set1_ps((
float)1.0/16.0);
667 ri2 = _mm256_mul_ps(h, ri2);
669 ri1 = _mm256_mul_ps(ri2, ri1);
671 ri2 = _mm256_mul_ps(ri1, ri1);
672 #elif defined(RSQRT_NR_EPJ_X2)
674 ri2 = _mm256_mul_ps(ri2, _mm256_set1_ps(0.5f));
675 ri1 = _mm256_mul_ps(ri2, ri1);
677 ri2 = _mm256_mul_ps(ri1, ri1);
685 v8sf mri1 = _mm256_mul_ps(mj, ri1);
686 v8sf mri3 = _mm256_mul_ps(mri1, ri2);
692 xj = _mm256_shuffle_ps(jbuf, jbuf, 0x00);
693 yj = _mm256_shuffle_ps(jbuf, jbuf, 0x55);
694 zj = _mm256_shuffle_ps(jbuf, jbuf, 0xaa);
695 mj = _mm256_shuffle_ps(jbuf, jbuf, 0xff);
698 pot = _mm256_sub_ps(pot, mri1);
707 v8sf vrcrit = _mm256_max_ps(rsi, rsj);
708 v8sf vrcrit2 = _mm256_mul_ps(vrcrit, vrcrit);
709 v8sf mask = _mm256_cmp_ps(vrcrit2, r2_real, 0x01);
711 nngb = _mm256_add_ps(_mm256_and_ps( vone, _mm256_xor_ps(mask, allbits) ), nngb);
713 *(v8sf *)(accpbuf[i/8][0]) = ax;
714 *(v8sf *)(accpbuf[i/8][1]) = ay;
715 *(v8sf *)(accpbuf[i/8][2]) = az;
716 *(v8sf *)(accpbuf[i/8][3]) = pot;
717 *(v8sf *)(accpbuf[i/8][4]) = nngb;
732 void kernel_epj_nounroll(
const int ni,
const int nj){
733 const v16sf veps2 = _mm512_set1_ps((
float)eps2);
734 for(
int i=0; i<ni; i+=16){
735 const v16sf xi = *(v16sf *)(xibuf[i/16][0]);
736 const v16sf yi = *(v16sf *)(xibuf[i/16][1]);
737 const v16sf zi = *(v16sf *)(xibuf[i/16][2]);
739 v16sf ax, ay, az, pot;
740 ax = ay = az = pot = _mm512_set1_ps(0.0f);
743 v16sf jbuf = _mm512_broadcast_f32x4(*(v4sf *)epjbuf);
744 v16sf xj = _mm512_shuffle_ps(jbuf, jbuf, 0x00);
745 v16sf yj = _mm512_shuffle_ps(jbuf, jbuf, 0x55);
746 v16sf zj = _mm512_shuffle_ps(jbuf, jbuf, 0xaa);
747 v16sf mj = _mm512_shuffle_ps(jbuf, jbuf, 0xff);
749 for(
int j=0; j<nj; j++){
751 jbuf = _mm512_broadcast_f32x4(*(v4sf *)(epjbuf + j+1));
753 v16sf dx = _mm512_sub_ps(xi, xj);
754 v16sf dy = _mm512_sub_ps(yi, yj);
755 v16sf dz = _mm512_sub_ps(zi, zj);
757 v16sf dx = _mm512_sub_ps(xi, _mm512_set1_ps(epjbuf[j][0]));
758 v16sf dy = _mm512_sub_ps(yi, _mm512_set1_ps(epjbuf[j][1]));
759 v16sf dz = _mm512_sub_ps(zi, _mm512_set1_ps(epjbuf[j][2]));
764 v16sf ri1 = _mm512_rsqrt14_ps(r2);
766 v16sf ri2 = _mm512_mul_ps(ri1, ri1);
767 #ifdef RSQRT_NR_EPJ_X2
769 ri2 = _mm512_mul_ps(ri2, _mm512_set1_ps(0.5f));
770 ri1 = _mm512_mul_ps(ri2, ri1);
772 ri2 = _mm512_mul_ps(ri1, ri1);
776 v16sf mri1 = _mm512_mul_ps(ri1, mj);
778 v16sf mri1 = _mm512_mul_ps(ri1, _mm512_set1_ps(epjbuf[j][3]));
780 v16sf mri3 = _mm512_mul_ps(mri1, ri2);
782 xj = _mm512_shuffle_ps(jbuf, jbuf, 0x00);
783 yj = _mm512_shuffle_ps(jbuf, jbuf, 0x55);
784 zj = _mm512_shuffle_ps(jbuf, jbuf, 0xaa);
785 mj = _mm512_shuffle_ps(jbuf, jbuf, 0xff);
788 pot = _mm512_sub_ps(pot, mri1);
794 *(v16sf *)(accpbuf[i/16][0]) = ax;
795 *(v16sf *)(accpbuf[i/16][1]) = ay;
796 *(v16sf *)(accpbuf[i/16][2]) = az;
797 *(v16sf *)(accpbuf[i/16][3]) = pot;
803 void kernel_epj_nounroll(
const int ni,
const int nj){
804 const v8sf veps2 = _mm256_set1_ps((
float)eps2);
805 for(
int i=0; i<ni; i+=8){
806 const v8sf xi = *(v8sf *)(xibuf[i/8][0]);
807 const v8sf yi = *(v8sf *)(xibuf[i/8][1]);
808 const v8sf zi = *(v8sf *)(xibuf[i/8][2]);
810 v8sf ax, ay, az, pot;
811 ax = ay = az = pot = _mm256_set1_ps(0.0f);
813 v8sf jbuf = _mm256_broadcast_ps((v4sf *)epjbuf);
814 v8sf xj = _mm256_shuffle_ps(jbuf, jbuf, 0x00);
815 v8sf yj = _mm256_shuffle_ps(jbuf, jbuf, 0x55);
816 v8sf zj = _mm256_shuffle_ps(jbuf, jbuf, 0xaa);
817 v8sf mj = _mm256_shuffle_ps(jbuf, jbuf, 0xff);
819 for(
int j=0; j<nj; j++){
820 jbuf = _mm256_broadcast_ps((v4sf *)(epjbuf + j+1));
822 v8sf dx = _mm256_sub_ps(xi, xj);
823 v8sf dy = _mm256_sub_ps(yi, yj);
824 v8sf dz = _mm256_sub_ps(zi, zj);
829 v8sf ri1 = _mm256_rsqrt_ps(r2);
831 v8sf ri2 = _mm256_mul_ps(ri1, ri1);
832 #ifdef RSQRT_NR_EPJ_X2
834 ri2 = _mm256_mul_ps(ri2, _mm256_set1_ps(0.5f));
835 ri1 = _mm256_mul_ps(ri2, ri1);
837 ri2 = _mm256_mul_ps(ri1, ri1);
839 v8sf mri1 = _mm256_mul_ps(mj, ri1);
840 v8sf mri3 = _mm256_mul_ps(mri1, ri2);
842 xj = _mm256_shuffle_ps(jbuf, jbuf, 0x00);
843 yj = _mm256_shuffle_ps(jbuf, jbuf, 0x55);
844 zj = _mm256_shuffle_ps(jbuf, jbuf, 0xaa);
845 mj = _mm256_shuffle_ps(jbuf, jbuf, 0xff);
847 pot = _mm256_sub_ps(pot, mri1);
852 *(v8sf *)(accpbuf[i/8][0]) = ax;
853 *(v8sf *)(accpbuf[i/8][1]) = ay;
854 *(v8sf *)(accpbuf[i/8][2]) = az;
855 *(v8sf *)(accpbuf[i/8][3]) = pot;
869 void kernel_spj_nounroll(
const int ni,
const int nj){
870 const v16sf veps2 = _mm512_set1_ps((
float)eps2);
871 for(
int i=0; i<ni; i+=16){
872 const v16sf xi = *(v16sf *)(xibuf[i/16][0]);
873 const v16sf yi = *(v16sf *)(xibuf[i/16][1]);
874 const v16sf zi = *(v16sf *)(xibuf[i/16][2]);
876 v16sf ax, ay, az, pot;
877 ax = ay = az = pot = _mm512_set1_ps(0.0f);
880 v16sf jbuf0 = _mm512_broadcast_f32x4(*(v4sf *)&spjbuf[0][0]);
881 v16sf jbuf1 = _mm512_broadcast_f32x4(*(v4sf *)&spjbuf[0][1]);
882 v16sf jbuf2 = _mm512_broadcast_f32x4(*(v4sf *)&spjbuf[0][2]);
886 for(
int j=0; j<nj; j++){
891 v16sf xj = _mm512_shuffle_ps(jbuf0, jbuf0, 0x00);
892 v16sf yj = _mm512_shuffle_ps(jbuf0, jbuf0, 0x55);
893 v16sf zj = _mm512_shuffle_ps(jbuf0, jbuf0, 0xaa);
895 jbuf0 = _mm512_broadcast_f32x4(*(v4sf *)&spjbuf[j+1][0]);
899 v16sf qxx = _mm512_shuffle_ps(jbuf1, jbuf1, 0x00);
900 v16sf qyy = _mm512_shuffle_ps(jbuf1, jbuf1, 0x55);
901 v16sf qzz = _mm512_shuffle_ps(jbuf1, jbuf1, 0xaa);
902 v16sf mj = _mm512_shuffle_ps(jbuf1, jbuf1, 0xff);
904 jbuf1 = _mm512_broadcast_f32x4(*(v4sf *)&spjbuf[j+1][1]);
909 v16sf qxy = _mm512_shuffle_ps(jbuf2, jbuf2, 0x00);
910 v16sf qyz = _mm512_shuffle_ps(jbuf2, jbuf2, 0x55);
911 v16sf qzx = _mm512_shuffle_ps(jbuf2, jbuf2, 0xaa);
912 v16sf mtr = _mm512_shuffle_ps(jbuf2, jbuf2, 0xff);
914 jbuf2 = _mm512_broadcast_f32x4(*(v4sf *)&spjbuf[j+1][2]);
916 v16sf dx = _mm512_sub_ps(xi, xj);
917 v16sf dy = _mm512_sub_ps(yi, yj);
918 v16sf dz = _mm512_sub_ps(zi, zj);
920 v16sf dx = _mm512_sub_ps(xi, _mm512_set1_ps(spjbuf[j][0][0]));
921 v16sf dy = _mm512_sub_ps(yi, _mm512_set1_ps(spjbuf[j][0][1]));
922 v16sf dz = _mm512_sub_ps(zi, _mm512_set1_ps(spjbuf[j][0][2]));
928 v16sf ri1 = _mm512_rsqrt14_ps(r2);
929 v16sf ri2 = _mm512_mul_ps(ri1, ri1);
931 #ifdef RSQRT_NR_SPJ_X4
933 v16sf v1 = _mm512_set1_ps(1.0f);
936 v16sf v6p0 = _mm512_set1_ps(6.0f);
937 v16sf v5p0 = _mm512_set1_ps(5.0f);
938 v16sf v8p0 = _mm512_set1_ps(8.0f);
939 v16sf v0p0625 = _mm512_set1_ps((
float)1.0/16.0);
943 ri2 = _mm512_mul_ps(h, ri2);
945 ri1 = _mm512_mul_ps(ri2, ri1);
949 ri2 = _mm512_mul_ps(ri1, ri1);
950 #elif defined(RSQRT_NR_SPJ_X2)
953 ri2 = _mm512_mul_ps(ri2, _mm512_set1_ps(0.5f));
954 ri1 = _mm512_mul_ps(ri2, ri1);
956 ri2 = _mm512_mul_ps(ri1, ri1);
959 v16sf ri3 = _mm512_mul_ps(ri1, ri2);
960 v16sf ri4 = _mm512_mul_ps(ri2, ri2);
961 v16sf ri5 = _mm512_mul_ps(ri2, ri3);
965 v16sf qr_x = _mm512_mul_ps(dx, qxx);
969 v16sf qr_y = _mm512_mul_ps(dy, qyy);
973 v16sf qr_z = _mm512_mul_ps(dz, qzz);
981 v16sf qr_x = _mm512_mul_ps(dx, _mm512_set1_ps(spjbuf[j][1][0]));
985 v16sf qr_y = _mm512_mul_ps(dy, _mm512_set1_ps(spjbuf[j][1][1]));
989 v16sf qr_z = _mm512_mul_ps(dz, _mm512_set1_ps(spjbuf[j][1][2]));
994 v16sf rqr =
_mm512_fmadd_ps(qr_x, dx, _mm512_set1_ps(spjbuf[j][2][3]));
1000 v16sf rqr_ri4 = _mm512_mul_ps(rqr, ri4);
1007 v16sf mj = _mm512_set1_ps(spjbuf[j][1][3]);
1012 meff3 = _mm512_mul_ps(meff3, ri3);
1025 *(v16sf *)(accpbuf[i/16][0]) = ax;
1026 *(v16sf *)(accpbuf[i/16][1]) = ay;
1027 *(v16sf *)(accpbuf[i/16][2]) = az;
1028 *(v16sf *)(accpbuf[i/16][3]) = pot;
1033 void kernel_spj_nounroll(
const int ni,
const int nj){
1034 const v8sf veps2 = _mm256_set1_ps((
float)eps2);
1035 for(
int i=0; i<ni; i+=8){
1036 const v8sf xi = *(v8sf *)(xibuf[i/8][0]);
1037 const v8sf yi = *(v8sf *)(xibuf[i/8][1]);
1038 const v8sf zi = *(v8sf *)(xibuf[i/8][2]);
1040 v8sf ax, ay, az, pot;
1041 ax = ay = az = pot = _mm256_set1_ps(0.0f);
1044 v8sf jbuf0 = _mm256_broadcast_ps((v4sf *)&spjbuf[0][0]);
1045 v8sf jbuf1 = _mm256_broadcast_ps((v4sf *)&spjbuf[0][1]);
1046 v8sf jbuf2 = _mm256_broadcast_ps((v4sf *)&spjbuf[0][2]);
1048 v8sf jbuf0, jbuf1, jbuf2;
1050 for(
int j=0; j<nj; j++){
1052 jbuf0 = _mm256_broadcast_ps((v4sf *)&spjbuf[j+0][0]);
1054 v8sf xj = _mm256_shuffle_ps(jbuf0, jbuf0, 0x00);
1055 v8sf yj = _mm256_shuffle_ps(jbuf0, jbuf0, 0x55);
1056 v8sf zj = _mm256_shuffle_ps(jbuf0, jbuf0, 0xaa);
1058 jbuf0 = _mm256_broadcast_ps((v4sf *)&spjbuf[j+1][0]);
1062 jbuf1 = _mm256_broadcast_ps((v4sf *)&spjbuf[j+0][1]);
1064 v8sf qxx = _mm256_shuffle_ps(jbuf1, jbuf1, 0x00);
1065 v8sf qyy = _mm256_shuffle_ps(jbuf1, jbuf1, 0x55);
1066 v8sf qzz = _mm256_shuffle_ps(jbuf1, jbuf1, 0xaa);
1067 v8sf mj = _mm256_shuffle_ps(jbuf1, jbuf1, 0xff);
1069 jbuf1 = _mm256_broadcast_ps((v4sf *)&spjbuf[j+1][1]);
1073 jbuf2 = _mm256_broadcast_ps((v4sf *)&spjbuf[j+0][2]);
1075 v8sf qxy = _mm256_shuffle_ps(jbuf2, jbuf2, 0x00);
1076 v8sf qyz = _mm256_shuffle_ps(jbuf2, jbuf2, 0x55);
1077 v8sf qzx = _mm256_shuffle_ps(jbuf2, jbuf2, 0xaa);
1078 v8sf mtr = _mm256_shuffle_ps(jbuf2, jbuf2, 0xff);
1080 jbuf2 = _mm256_broadcast_ps((v4sf *)&spjbuf[j+1][2]);
1083 v8sf dx = _mm256_sub_ps(xi, xj);
1084 v8sf dy = _mm256_sub_ps(yi, yj);
1085 v8sf dz = _mm256_sub_ps(zi, zj);
1090 v8sf ri1 = _mm256_rsqrt_ps(r2);
1092 v8sf ri2 = _mm256_mul_ps(ri1, ri1);
1093 #ifdef RSQRT_NR_SPJ_X2
1095 ri2 = _mm256_mul_ps(ri2, _mm256_set1_ps(0.5f));
1096 ri1 = _mm256_mul_ps(ri2, ri1);
1098 ri2 = _mm256_mul_ps(ri1, ri1);
1100 v8sf ri3 = _mm256_mul_ps(ri1, ri2);
1101 v8sf ri4 = _mm256_mul_ps(ri2, ri2);
1102 v8sf ri5 = _mm256_mul_ps(ri2, ri3);
1105 v8sf qr_x = _mm256_mul_ps(dx, qxx);
1109 v8sf qr_y = _mm256_mul_ps(dy, qyy);
1113 v8sf qr_z = _mm256_mul_ps(dz, qzz);
1121 v8sf rqr_ri4 = _mm256_mul_ps(rqr, ri4);
1125 meff3 = _mm256_mul_ps(meff3, ri3);
1136 *(v8sf *)(accpbuf[i/8][0]) = ax;
1137 *(v8sf *)(accpbuf[i/8][1]) = ay;
1138 *(v8sf *)(accpbuf[i/8][2]) = az;
1139 *(v8sf *)(accpbuf[i/8][3]) = pot;
1247 #if defined(CALC_EP_64bit) || defined(CALC_EP_MIX)
1249 #define _mm256_fmadd_pd(x,y,z) _mm256_add_pd(z, _mm256_mul_pd(x,y))
1250 #define _mm256_fnmadd_pd(x,y,z) _mm256_sub_pd(z, _mm256_mul_pd(x,y))
1251 #define _mm512_fmadd_pd(x,y,z) _mm512_add_pd(z, _mm512_mul_pd(x,y))
1252 #define _mm512_fnmadd_pd(x,y,z) _mm512_sub_pd(z, _mm512_mul_pd(x,y))
1254 class PhantomGrapeQuad64Bit{
1261 double xibuf [
NIMAX/8] [4][8];
1262 double accpbuf[
NIMAX/8] [5][8];
1263 double epjbuf [
NJMAX] [4];
1264 double rsearchj[
NJMAX];
1265 double spjbuf [
NJMAX] [3][4];
1268 static double get_a_NaN(){
1269 union{
long l;
double d; } m;
1278 PhantomGrapeQuad64Bit() : eps2(get_a_NaN()) {}
1290 this->r_crit2 = _r_crit2;
1293 void set_epj_one(
const int addr,
const double x,
const double y,
const double z,
1294 const double m,
const double r_search) {
1295 epjbuf[addr][0] = x;
1296 epjbuf[addr][1] = y;
1297 epjbuf[addr][2] = z;
1298 epjbuf[addr][3] = m;
1299 rsearchj[addr] = r_search;
1303 const double x,
const double y,
const double z,
const double m,
1304 const double qxx,
const double qyy,
const double qzz,
1305 const double qxy,
const double qyz,
const double qzx)
1307 const double tr = qxx + qyy + qzz;
1308 spjbuf[addr][0][0] = x;
1309 spjbuf[addr][0][1] = y;
1310 spjbuf[addr][0][2] = z;
1311 spjbuf[addr][0][3] = m;
1313 spjbuf[addr][1][0] = 3.0 * qxx - tr;
1314 spjbuf[addr][1][1] = 3.0 * qyy - tr;
1315 spjbuf[addr][1][2] = 3.0 * qzz - tr;
1316 spjbuf[addr][1][3] = m;
1318 spjbuf[addr][2][0] = 3.0 * qxy;
1319 spjbuf[addr][2][1] = 3.0 * qyz;
1320 spjbuf[addr][2][2] = 3.0 * qzx;
1321 spjbuf[addr][2][3] = -(eps2 * tr);
1324 void set_xi_one(
const int addr,
const double x,
const double y,
const double z,
const double r_search){
1325 const int ah = addr / 8;
1326 const int al = addr % 8;
1327 xibuf[ah][0][al] = x;
1328 xibuf[ah][1][al] = y;
1329 xibuf[ah][2][al] = z;
1330 xibuf[ah][3][al] = r_search;
1333 template <
typename real_t>
1334 void get_accp_one(
const int addr, real_t &ax, real_t &ay, real_t &az, real_t &pot, real_t &nngb){
1335 const int ah = addr / 8;
1336 const int al = addr % 8;
1337 ax = accpbuf[ah][0][al];
1338 ay = accpbuf[ah][1][al];
1339 az = accpbuf[ah][2][al];
1340 pot = accpbuf[ah][3][al];
1341 nngb = accpbuf[ah][4][al];
1344 template <
typename real_t>
1345 void accum_accp_one(
const int addr, real_t &ax, real_t &ay, real_t &az, real_t &pot){
1346 const int ah = addr / 8;
1347 const int al = addr % 8;
1348 ax += accpbuf[ah][0][al];
1349 ay += accpbuf[ah][1][al];
1350 az += accpbuf[ah][2][al];
1351 pot += accpbuf[ah][3][al];
1354 template <
typename real_t>
1355 void accum_accp_one(
const int addr, real_t &ax, real_t &ay, real_t &az, real_t &pot, real_t &nngb){
1356 const int ah = addr / 8;
1357 const int al = addr % 8;
1358 ax += accpbuf[ah][0][al];
1359 ay += accpbuf[ah][1][al];
1360 az += accpbuf[ah][2][al];
1361 pot += accpbuf[ah][3][al];
1362 nngb += accpbuf[ah][4][al];
1365 template <
typename real_t>
1367 const int ah = addr / 8;
1368 const int al = addr % 8;
1369 nngb += accpbuf[ah][4][al];
1374 std::cout<<
"ni= "<<ni<<
" NIMAX= "<<
NIMAX<<
" nj= "<<nj<<
" NJMAX= "<<
NJMAX<<std::endl;
1375 for(
PS::S32 i=0; i<(ni-1)/8+1; i++){
1378 std::cout<<
"xibuf[i][k][j]="<<xibuf[i][k][j]<<std::endl;
1380 std::cout<<std::endl;
1384 assert(ni <=
NIMAX);
1385 assert(nj <=
NJMAX);
1386 kernel_epj_nounroll_for_neighbor_count(ni, nj);
1391 std::cout<<
"ni= "<<ni<<
" NIMAX= "<<
NIMAX<<
" nj= "<<nj<<
" NJMAX= "<<
NJMAX<<std::endl;
1392 for(
PS::S32 i=0; i<(ni-1)/8+1; i++){
1395 std::cout<<
"xibuf[i][k][j]="<<xibuf[i][k][j]<<std::endl;
1397 std::cout<<std::endl;
1401 assert(ni <=
NIMAX);
1402 assert(nj <=
NJMAX);
1403 kernel_epj_nounroll_for_p3t_with_linear_cutoff(ni, nj);
1406 void run_epj(
const int ni,
const int nj){
1408 std::cout<<
"ni= "<<ni<<
" NIMAX= "<<
NIMAX<<
" nj= "<<nj<<
" NJMAX= "<<
NJMAX<<std::endl;
1409 for(
PS::S32 i=0; i<(ni-1)/8+1; i++){
1412 std::cout<<
"xibuf[i][k][j]="<<xibuf[i][k][j]<<std::endl;
1414 std::cout<<std::endl;
1418 assert(ni <=
NIMAX);
1419 assert(nj <=
NJMAX);
1420 kernel_epj_nounroll(ni, nj);
1423 void run_spj(
const int ni,
const int nj){
1425 std::cout<<
"ni= "<<ni<<
" NIMAX= "<<
NIMAX<<
" nj= "<<nj<<
" NJMAX= "<<
NJMAX<<std::endl;
1427 assert(ni <=
NIMAX);
1428 assert(nj <=
NJMAX);
1429 kernel_spj_nounroll(ni, nj);
1439 typedef __m256 v8sf;
1440 typedef __m256d v4df;
1441 typedef __m512 v16sf;
1442 typedef __m512d v8df;
1445 void kernel_epj_nounroll(
const int ni,
const int nj){
1446 const v8df veps2 = _mm512_set1_pd(eps2);
1447 for(
int i=0; i<ni; i+=8){
1448 const v8df xi = *(v8df *)(&xibuf[i/8][0]);
1449 const v8df yi = *(v8df *)(&xibuf[i/8][1]);
1450 const v8df zi = *(v8df *)(&xibuf[i/8][2]);
1452 v8df ax, ay, az, pot;
1453 ax = ay = az = pot = _mm512_set1_pd(0.0);
1456 v4df jbuf = *((v4df*)epjbuf);
1457 v8df jbuf8= _mm512_broadcast_f64x4(jbuf);
1458 v8df xj = _mm512_permutex_pd(jbuf8, 0x00);
1459 v8df yj = _mm512_permutex_pd(jbuf8, 0x55);
1460 v8df zj = _mm512_permutex_pd(jbuf8, 0xaa);
1461 v8df mj = _mm512_permutex_pd(jbuf8, 0xff);
1463 for(
int j=0; j<nj; j++){
1465 jbuf = *((v4df*)(epjbuf+j+1));
1466 v8df dx = _mm512_sub_pd(xi, xj);
1467 v8df dy = _mm512_sub_pd(yi, yj);
1468 v8df dz = _mm512_sub_pd(zi, zj);
1470 v8df dx = _mm512_sub_pd(xi, _mm512_set1_pd(epjbuf[j][0]));
1471 v8df dy = _mm512_sub_pd(yi, _mm512_set1_pd(epjbuf[j][1]));
1472 v8df dz = _mm512_sub_pd(zi, _mm512_set1_pd(epjbuf[j][2]));
1474 v8df r2 = _mm512_fmadd_pd(dx, dx, veps2);
1475 r2 = _mm512_fmadd_pd(dy, dy, r2);
1476 r2 = _mm512_fmadd_pd(dz, dz, r2);
1480 v8df ri1 = _mm512_rsqrt14_pd(r2);
1482 v8df ri2 = _mm512_mul_pd(ri1, ri1);
1484 #ifdef RSQRT_NR_EPJ_X4
1486 v8df v1 = _mm512_set1_pd(1.0f);
1487 v8df h = _mm512_fnmadd_pd(r2, ri2, v1);
1489 v8df v6p0 = _mm512_set1_pd(6.0f);
1490 v8df v5p0 = _mm512_set1_pd(5.0f);
1491 v8df v8p0 = _mm512_set1_pd(8.0f);
1492 v8df v0p0625 = _mm512_set1_pd((
float)1.0/16.0);
1494 ri2 = _mm512_fmadd_pd(h, v5p0, v6p0);
1495 ri2 = _mm512_fmadd_pd(h, ri2, v8p0);
1496 ri2 = _mm512_mul_pd(h, ri2);
1497 ri2 = _mm512_fmadd_pd(ri2, v0p0625, v1);
1498 ri1 = _mm512_mul_pd(ri2, ri1);
1502 ri2 = _mm512_mul_pd(ri1, ri1);
1512 #elif defined(RSQRT_NR_EPJ_X2)
1514 ri2 = _mm512_fnmadd_pd(r2, ri2, _mm512_set1_pd(3.0f));
1515 ri2 = _mm512_mul_pd(ri2, _mm512_set1_pd(0.5f));
1516 ri1 = _mm512_mul_pd(ri2, ri1);
1520 ri2 = _mm512_mul_pd(ri1, ri1);
1524 v8df mri1 = _mm512_mul_pd(ri1, mj);
1526 v8df mri1 = _mm512_mul_pd(ri1, _mm512_set1_pd(epjbuf[j][3]));
1528 v8df mri3 = _mm512_mul_pd(mri1, ri2);
1531 jbuf8=_mm512_broadcast_f64x4(jbuf);
1532 xj = _mm512_permutex_pd(jbuf8, 0x00);
1533 yj = _mm512_permutex_pd(jbuf8, 0x55);
1534 zj = _mm512_permutex_pd(jbuf8, 0xaa);
1535 mj = _mm512_permutex_pd(jbuf8, 0xff);
1537 pot = _mm512_sub_pd(pot, mri1);
1538 ax = _mm512_fnmadd_pd(mri3, dx, ax);
1539 ay = _mm512_fnmadd_pd(mri3, dy, ay);
1540 az = _mm512_fnmadd_pd(mri3, dz, az);
1542 *(v8df *)(&accpbuf[i/8][0]) = ax;
1543 *(v8df *)(&accpbuf[i/8][1]) = ay;
1544 *(v8df *)(&accpbuf[i/8][2]) = az;
1545 *(v8df *)(&accpbuf[i/8][3]) = pot;
1553 typedef __m128 v4sf;
1554 typedef __m256 v8sf;
1555 typedef __m256d v4df;
1558 void kernel_epj_nounroll(
const int ni,
const int nj){
1559 const v4df veps2 = _mm256_set1_pd(eps2);
1560 for(
int i=0; i<ni; i+=4){
1562 const v4df xi = *(v4df *)(&xibuf[i/8][0][il]);
1563 const v4df yi = *(v4df *)(&xibuf[i/8][1][il]);
1564 const v4df zi = *(v4df *)(&xibuf[i/8][2][il]);
1566 v4df ax, ay, az, pot;
1567 ax = ay = az = pot = _mm256_set1_pd(0.0);
1569 v4df jbuf = *((v4df*)epjbuf);
1570 v4df xj = _mm256_permute4x64_pd(jbuf, 0x00);
1571 v4df yj = _mm256_permute4x64_pd(jbuf, 0x55);
1572 v4df zj = _mm256_permute4x64_pd(jbuf, 0xaa);
1573 v4df mj = _mm256_permute4x64_pd(jbuf, 0xff);
1575 for(
int j=0; j<nj; j++){
1576 jbuf = *((v4df*)(epjbuf+j+1));
1577 v4df dx = _mm256_sub_pd(xi, xj);
1578 v4df dy = _mm256_sub_pd(yi, yj);
1579 v4df dz = _mm256_sub_pd(zi, zj);
1582 v4df r2 = _mm256_fmadd_pd(dx, dx, veps2);
1583 r2 = _mm256_fmadd_pd(dy, dy, r2);
1584 r2 = _mm256_fmadd_pd(dz, dz, r2);
1593 v4df mask = _mm256_cmp_pd(veps2, r2, 0x4);
1594 v4df ri1 = _mm256_and_pd(_mm256_cvtps_pd(_mm_rsqrt_ps(_mm256_cvtpd_ps(r2))), mask);
1597 v4df ri2 = _mm256_mul_pd(ri1, ri1);
1598 #ifdef RSQRT_NR_EPJ_X4
1600 v4df v1 = _mm256_set1_pd(1.0f);
1601 v4df h = _mm256_fnmadd_pd(r2, ri2, v1);
1603 v4df v6p0 = _mm256_set1_pd(6.0f);
1604 v4df v5p0 = _mm256_set1_pd(5.0f);
1605 v4df v8p0 = _mm256_set1_pd(8.0f);
1606 v4df v0p0625 = _mm256_set1_pd((
float)1.0/16.0);
1608 ri2 = _mm256_fmadd_pd(h, v5p0, v6p0);
1609 ri2 = _mm256_fmadd_pd(h, ri2, v8p0);
1610 ri2 = _mm256_mul_pd(h, ri2);
1611 ri2 = _mm256_fmadd_pd(ri2, v0p0625, v1);
1612 ri1 = _mm256_mul_pd(ri2, ri1);
1616 ri2 = _mm256_mul_pd(ri1, ri1);
1624 #elif defined(RSQRT_NR_EPJ_X2)
1626 ri2 = _mm256_fnmadd_pd(r2, ri2, _mm256_set1_pd(3.0f));
1627 ri2 = _mm256_mul_pd(ri2, _mm256_set1_pd(0.5f));
1628 ri1 = _mm256_mul_pd(ri2, ri1);
1632 ri2 = _mm256_mul_pd(ri1, ri1);
1638 v4df mri1 = _mm256_mul_pd(ri1, mj);
1639 v4df mri3 = _mm256_mul_pd(mri1, ri2);
1641 xj = _mm256_permute4x64_pd(jbuf, 0x00);
1642 yj = _mm256_permute4x64_pd(jbuf, 0x55);
1643 zj = _mm256_permute4x64_pd(jbuf, 0xaa);
1644 mj = _mm256_permute4x64_pd(jbuf, 0xff);
1646 pot = _mm256_sub_pd(pot, mri1);
1647 ax = _mm256_fnmadd_pd(mri3, dx, ax);
1648 ay = _mm256_fnmadd_pd(mri3, dy, ay);
1649 az = _mm256_fnmadd_pd(mri3, dz, az);
1651 *(v4df *)(&accpbuf[i/8][0][il]) = ax;
1652 *(v4df *)(&accpbuf[i/8][1][il]) = ay;
1653 *(v4df *)(&accpbuf[i/8][2][il]) = az;
1654 *(v4df *)(&accpbuf[i/8][3][il]) = pot;
1661 void kernel_epj_nounroll_for_neighbor_count(
const int ni,
const int nj) {
1662 for(
int i=0; i<ni; i+=8){
1663 const v8df xi = *(v8df *)(&xibuf[i/8][0]);
1664 const v8df yi = *(v8df *)(&xibuf[i/8][1]);
1665 const v8df zi = *(v8df *)(&xibuf[i/8][2]);
1666 const v8df rsi= *(v8df *)(&xibuf[i/8][3]);
1667 v8df nngb = _mm512_set1_pd(0.0);
1670 v4df jbuf = *((v4df*)epjbuf);
1671 v8df jbuf8= _mm512_broadcast_f64x4(jbuf);
1672 v8df rsjbuf= _mm512_set1_pd(*rsearchj);
1673 v8df xj = _mm512_permutex_pd(jbuf8, 0x00);
1674 v8df yj = _mm512_permutex_pd(jbuf8, 0x55);
1675 v8df zj = _mm512_permutex_pd(jbuf8, 0xaa);
1678 for(
int j=0; j<nj; j++){
1680 jbuf = *((v4df*)(epjbuf+j+1));
1681 rsjbuf = _mm512_set1_pd(*(rsearchj+j+1));
1682 v8df dx = _mm512_sub_pd(xi, xj);
1683 v8df dy = _mm512_sub_pd(yi, yj);
1684 v8df dz = _mm512_sub_pd(zi, zj);
1686 v8df dx = _mm512_sub_pd(xi, _mm512_set1_pd(epjbuf[j][0]));
1687 v8df dy = _mm512_sub_pd(yi, _mm512_set1_pd(epjbuf[j][1]));
1688 v8df dz = _mm512_sub_pd(zi, _mm512_set1_pd(epjbuf[j][2]));
1690 v8df r2_real = _mm512_mul_pd(dx, dx);
1691 r2_real = _mm512_fmadd_pd(dy, dy, r2_real);
1692 r2_real = _mm512_fmadd_pd(dz, dz, r2_real);
1695 jbuf8= _mm512_broadcast_f64x4(jbuf);
1696 xj = _mm512_permutex_pd(jbuf8, 0x00);
1697 yj = _mm512_permutex_pd(jbuf8, 0x55);
1698 zj = _mm512_permutex_pd(jbuf8, 0xaa);
1702 v8df vrcrit = _mm512_max_pd(rsi, rsj);
1704 v8df vrcrit = _mm512_max_pd(rsi, _mm512_set1_pd(rsearchj[j]));
1706 v8df vrcrit2 = _mm512_max_pd(vrcrit, vrcrit);
1709 nngb = _mm512_add_pd(
1710 _mm512_mask_blend_pd(_mm512_cmp_pd_mask(vrcrit2, r2_real, 0x01),
1711 _mm512_set1_pd(1.0),
1712 _mm512_set1_pd(0.0)),
1718 *(v8df *)(&accpbuf[i/8][4]) = nngb;
1723 void kernel_epj_nounroll_for_p3t_with_linear_cutoff(
const int ni,
const int nj) {
1724 const v8df veps2 = _mm512_set1_pd(eps2);
1727 const v8df vr_out2 = _mm512_set1_pd(r_crit2);
1729 for(
int i=0; i<ni; i+=8){
1730 const v8df xi = *(v8df *)(&xibuf[i/8][0]);
1731 const v8df yi = *(v8df *)(&xibuf[i/8][1]);
1732 const v8df zi = *(v8df *)(&xibuf[i/8][2]);
1733 const v8df rsi= *(v8df *)(&xibuf[i/8][3]);
1734 v8df ax, ay, az, pot, nngb;
1735 ax = ay = az = pot = nngb = _mm512_set1_pd(0.0);
1738 v4df jbuf = *((v4df*)epjbuf);
1739 v8df jbuf8= _mm512_broadcast_f64x4(jbuf);
1740 v8df rsjbuf= _mm512_set1_pd(*rsearchj);
1741 v8df xj = _mm512_permutex_pd(jbuf8, 0x00);
1742 v8df yj = _mm512_permutex_pd(jbuf8, 0x55);
1743 v8df zj = _mm512_permutex_pd(jbuf8, 0xaa);
1744 v8df mj = _mm512_permutex_pd(jbuf8, 0xff);
1747 for(
int j=0; j<nj; j++){
1749 jbuf = *((v4df*)(epjbuf+j+1));
1750 rsjbuf = _mm512_set1_pd(*(rsearchj+j+1));
1751 v8df dx = _mm512_sub_pd(xi, xj);
1752 v8df dy = _mm512_sub_pd(yi, yj);
1753 v8df dz = _mm512_sub_pd(zi, zj);
1755 v8df dx = _mm512_sub_pd(xi, _mm512_set1_pd(epjbuf[j][0]));
1756 v8df dy = _mm512_sub_pd(yi, _mm512_set1_pd(epjbuf[j][1]));
1757 v8df dz = _mm512_sub_pd(zi, _mm512_set1_pd(epjbuf[j][2]));
1759 v8df r2_real = _mm512_fmadd_pd(dx, dx, veps2);
1760 r2_real = _mm512_fmadd_pd(dy, dy, r2_real);
1761 r2_real = _mm512_fmadd_pd(dz, dz, r2_real);
1763 v8df r2 = _mm512_max_pd( r2_real, vr_out2);
1764 v8df ri1 = _mm512_rsqrt14_pd(r2);
1766 v8df ri2 = _mm512_mul_pd(ri1, ri1);
1768 #ifdef RSQRT_NR_EPJ_X4
1770 v8df v1 = _mm512_set1_pd(1.0f);
1771 v8df h = _mm512_fnmadd_pd(r2, ri2, v1);
1773 v8df v6p0 = _mm512_set1_pd(6.0f);
1774 v8df v5p0 = _mm512_set1_pd(5.0f);
1775 v8df v8p0 = _mm512_set1_pd(8.0f);
1776 v8df v0p0625 = _mm512_set1_pd((
float)1.0/16.0);
1778 ri2 = _mm512_fmadd_pd(h, v5p0, v6p0);
1779 ri2 = _mm512_fmadd_pd(h, ri2, v8p0);
1780 ri2 = _mm512_mul_pd(h, ri2);
1781 ri2 = _mm512_fmadd_pd(ri2, v0p0625, v1);
1782 ri1 = _mm512_mul_pd(ri2, ri1);
1790 ri2 = _mm512_mul_pd(ri1, ri1);
1792 #elif defined(RSQRT_NR_EPJ_X2)
1794 ri2 = _mm512_fnmadd_pd(r2, ri2, _mm512_set1_pd(3.0f));
1795 ri2 = _mm512_mul_pd(ri2, _mm512_set1_pd(0.5f));
1796 ri1 = _mm512_mul_pd(ri2, ri1);
1800 ri2 = _mm512_mul_pd(ri1, ri1);
1804 v8df mri1 = _mm512_mul_pd(ri1, mj);
1806 v8df mri1 = _mm512_mul_pd(ri1, _mm512_set1_pd(epjbuf[j][3]));
1808 v8df mri3 = _mm512_mul_pd(mri1, ri2);
1811 jbuf8= _mm512_broadcast_f64x4(jbuf);
1812 xj = _mm512_permutex_pd(jbuf8, 0x00);
1813 yj = _mm512_permutex_pd(jbuf8, 0x55);
1814 zj = _mm512_permutex_pd(jbuf8, 0xaa);
1815 mj = _mm512_permutex_pd(jbuf8, 0xff);
1818 pot = _mm512_sub_pd(pot, mri1);
1819 ax = _mm512_fnmadd_pd(mri3, dx, ax);
1820 ay = _mm512_fnmadd_pd(mri3, dy, ay);
1821 az = _mm512_fnmadd_pd(mri3, dz, az);
1824 v8df vrcrit = _mm512_max_pd(rsi, rsj);
1826 v8df vrcrit = _mm512_max_pd(rsi, _mm512_set1_pd(rsearchj[j]));
1828 v8df vrcrit2 = _mm512_max_pd(vrcrit, vrcrit);
1831 nngb = _mm512_add_pd(
1832 _mm512_mask_blend_pd(_mm512_cmp_pd_mask(vrcrit2, r2_real, 0x01),
1833 _mm512_set1_pd(1.0),
1834 _mm512_set1_pd(0.0)),
1840 *(v8df *)(&accpbuf[i/8][0]) = ax;
1841 *(v8df *)(&accpbuf[i/8][1]) = ay;
1842 *(v8df *)(&accpbuf[i/8][2]) = az;
1843 *(v8df *)(&accpbuf[i/8][3]) = pot;
1844 *(v8df *)(&accpbuf[i/8][4]) = nngb;
1849 void kernel_epj_nounroll_for_neighbor_count(
const int ni,
const int nj){
1850 const v4df vone = _mm256_set1_pd(1.0);
1851 const v4df allbits = _mm256_cmp_pd(vone, vone, 0x00);
1852 for(
int i=0; i<ni; i+=4){
1854 const v4df xi = *(v4df *)(&xibuf[i/8][0][il]);
1855 const v4df yi = *(v4df *)(&xibuf[i/8][1][il]);
1856 const v4df zi = *(v4df *)(&xibuf[i/8][2][il]);
1857 const v4df rsi= *(v4df *)(&xibuf[i/8][3][il]);
1858 v4df nngb = _mm256_set1_pd(0.0);
1859 v4df jbuf = *((v4df*)epjbuf);
1860 v4df rsjbuf= _mm256_broadcast_sd(rsearchj);
1861 v4df xj = _mm256_permute4x64_pd(jbuf, 0x00);
1862 v4df yj = _mm256_permute4x64_pd(jbuf, 0x55);
1863 v4df zj = _mm256_permute4x64_pd(jbuf, 0xaa);
1865 for(
int j=0; j<nj; j++){
1866 jbuf = *((v4df*)(epjbuf+j+1));
1867 rsjbuf = _mm256_broadcast_sd(rsearchj+j+1);
1868 v4df dx = _mm256_sub_pd(xi, xj);
1869 v4df dy = _mm256_sub_pd(yi, yj);
1870 v4df dz = _mm256_sub_pd(zi, zj);
1872 v4df r2_real = _mm256_mul_pd(dx, dx);
1873 r2_real = _mm256_fmadd_pd(dy, dy, r2_real);
1874 r2_real = _mm256_fmadd_pd(dz, dz, r2_real);
1876 xj = _mm256_permute4x64_pd(jbuf, 0x00);
1877 yj = _mm256_permute4x64_pd(jbuf, 0x55);
1878 zj = _mm256_permute4x64_pd(jbuf, 0xaa);
1880 v4df vrcrit = _mm256_max_pd(rsi, rsj);
1881 v4df vrcrit2 = _mm256_mul_pd(vrcrit, vrcrit);
1882 v4df mask = _mm256_cmp_pd(vrcrit2, r2_real, 0x01);
1884 nngb = _mm256_add_pd(_mm256_and_pd( vone, _mm256_xor_pd(mask, allbits) ), nngb);
1886 *(v4df *)(&accpbuf[i/8][4][il]) = nngb;
1891 void kernel_epj_nounroll_for_p3t_with_linear_cutoff(
const int ni,
const int nj){
1892 const v4df vone = _mm256_set1_pd(1.0);
1893 const v4df veps2 = _mm256_set1_pd(eps2);
1896 const v4df vr_out2 = _mm256_set1_pd(r_crit2);
1897 const v4df allbits = _mm256_cmp_pd(vone, vone, 0x00);
1898 for(
int i=0; i<ni; i+=4){
1900 const v4df xi = *(v4df *)(&xibuf[i/8][0][il]);
1901 const v4df yi = *(v4df *)(&xibuf[i/8][1][il]);
1902 const v4df zi = *(v4df *)(&xibuf[i/8][2][il]);
1903 const v4df rsi= *(v4df *)(&xibuf[i/8][3][il]);
1904 v4df ax, ay, az, pot, nngb;
1905 ax = ay = az = pot = nngb = _mm256_set1_pd(0.0);
1906 v4df jbuf = *((v4df*)epjbuf);
1907 v4df rsjbuf= _mm256_broadcast_sd(rsearchj);
1908 v4df xj = _mm256_permute4x64_pd(jbuf, 0x00);
1909 v4df yj = _mm256_permute4x64_pd(jbuf, 0x55);
1910 v4df zj = _mm256_permute4x64_pd(jbuf, 0xaa);
1911 v4df mj = _mm256_permute4x64_pd(jbuf, 0xff);
1913 for(
int j=0; j<nj; j++){
1914 jbuf = *((v4df*)(epjbuf+j+1));
1915 rsjbuf = _mm256_broadcast_sd(rsearchj+j+1);
1916 v4df dx = _mm256_sub_pd(xi, xj);
1917 v4df dy = _mm256_sub_pd(yi, yj);
1918 v4df dz = _mm256_sub_pd(zi, zj);
1920 v4df r2_real = _mm256_fmadd_pd(dx, dx, veps2);
1921 r2_real = _mm256_fmadd_pd(dy, dy, r2_real);
1922 r2_real = _mm256_fmadd_pd(dz, dz, r2_real);
1924 v4df r2 = _mm256_max_pd( r2_real, vr_out2);
1925 v4df ri1 = _mm256_cvtps_pd(_mm_rsqrt_ps(_mm256_cvtpd_ps(r2)));
1927 v4df ri2 = _mm256_mul_pd(ri1, ri1);
1929 #ifdef RSQRT_NR_EPJ_X4
1931 v4df v1 = _mm256_set1_pd(1.0f);
1932 v4df h = _mm256_fnmadd_pd(r2, ri2, v1);
1934 v4df v6p0 = _mm256_set1_pd(6.0f);
1935 v4df v5p0 = _mm256_set1_pd(5.0f);
1936 v4df v8p0 = _mm256_set1_pd(8.0f);
1937 v4df v0p0625 = _mm256_set1_pd((
float)1.0/16.0);
1939 ri2 = _mm256_fmadd_pd(h, v5p0, v6p0);
1940 ri2 = _mm256_fmadd_pd(h, ri2, v8p0);
1941 ri2 = _mm256_mul_pd(h, ri2);
1942 ri2 = _mm256_fmadd_pd(ri2, v0p0625, v1);
1943 ri1 = _mm256_mul_pd(ri2, ri1);
1947 ri2 = _mm256_mul_pd(ri1, ri1);
1955 #elif defined(RSQRT_NR_EPJ_X2)
1957 ri2 = _mm256_fnmadd_pd(r2, ri2, _mm256_set1_pd(3.0f));
1958 ri2 = _mm256_mul_pd(ri2, _mm256_set1_pd(0.5f));
1959 ri1 = _mm256_mul_pd(ri2, ri1);
1963 ri2 = _mm256_mul_pd(ri1, ri1);
1969 v4df mri1 = _mm256_mul_pd(ri1, mj);
1970 v4df mri3 = _mm256_mul_pd(mri1, ri2);
1972 xj = _mm256_permute4x64_pd(jbuf, 0x00);
1973 yj = _mm256_permute4x64_pd(jbuf, 0x55);
1974 zj = _mm256_permute4x64_pd(jbuf, 0xaa);
1975 mj = _mm256_permute4x64_pd(jbuf, 0xff);
1977 pot = _mm256_sub_pd(pot, mri1);
1978 ax = _mm256_fnmadd_pd(mri3, dx, ax);
1979 ay = _mm256_fnmadd_pd(mri3, dy, ay);
1980 az = _mm256_fnmadd_pd(mri3, dz, az);
1982 v4df vrcrit = _mm256_max_pd(rsi, rsj);
1983 v4df vrcrit2 = _mm256_mul_pd(vrcrit, vrcrit);
1984 v4df mask = _mm256_cmp_pd(vrcrit2, r2_real, 0x01);
1986 nngb = _mm256_add_pd(_mm256_and_pd( vone, _mm256_xor_pd(mask, allbits) ), nngb);
1988 *(v4df *)(&accpbuf[i/8][0][il]) = ax;
1989 *(v4df *)(&accpbuf[i/8][1][il]) = ay;
1990 *(v4df *)(&accpbuf[i/8][2][il]) = az;
1991 *(v4df *)(&accpbuf[i/8][3][il]) = pot;
1992 *(v4df *)(&accpbuf[i/8][4][il]) = nngb;
1998 #undef _mm512_fmadd_pd
1999 #undef _mm512_fnmadd_pd
2000 #undef _mm256_fmadd_pd
2001 #undef _mm256_fnmadd_pd
2006 void kernel_spj_nounroll(
const int ni,
const int nj){
2007 const v8df veps2 = _mm512_set1_pd(eps2);
2008 for(
int i=0; i<ni; i+=8){
2009 const v8df xi = *(v8df *)(&xibuf[i/8][0]);
2010 const v8df yi = *(v8df *)(&xibuf[i/8][1]);
2011 const v8df zi = *(v8df *)(&xibuf[i/8][2]);
2013 v8df ax, ay, az, pot;
2014 ax = ay = az = pot = _mm512_set1_pd(0.0);
2017 v4df jbuf0 = *((v4df*)spjbuf[0][0]);
2018 v4df jbuf1 = *((v4df*)spjbuf[0][1]);
2019 v4df jbuf2 = *((v4df*)spjbuf[0][2]);
2021 for(
int j=0; j<nj; j++){
2023 v8df jp = _mm512_broadcast_f64x4(jbuf0);
2024 v8df xj = _mm512_permutex_pd(jp, 0x00);
2025 v8df yj = _mm512_permutex_pd(jp, 0x55);
2026 v8df zj = _mm512_permutex_pd(jp, 0xaa);
2027 jbuf0 = *((v4df*)spjbuf[j+1][0]);
2029 v8df jq = _mm512_broadcast_f64x4(jbuf1);
2030 v8df qxx = _mm512_permutex_pd(jq, 0x00);
2031 v8df qyy = _mm512_permutex_pd(jq, 0x55);
2032 v8df qzz = _mm512_permutex_pd(jq, 0xaa);
2033 v8df mj = _mm512_permutex_pd(jq, 0xff);
2034 jbuf1 = *((v4df*)spjbuf[j+1][1]);
2037 v8df jq2 = _mm512_broadcast_f64x4(jbuf2);
2038 v8df qxy = _mm512_permutex_pd(jq2, 0x00);
2039 v8df qyz = _mm512_permutex_pd(jq2, 0x55);
2040 v8df qzx = _mm512_permutex_pd(jq2, 0xaa);
2041 v8df mtr = _mm512_permutex_pd(jq2, 0xff);
2042 jbuf2 = *((v4df*)spjbuf[j+1][2]);
2044 v8df dx = _mm512_sub_pd(xi, xj);
2045 v8df dy = _mm512_sub_pd(yi, yj);
2046 v8df dz = _mm512_sub_pd(zi, zj);
2048 v8df dx = _mm512_sub_pd(xi, _mm512_set1_pd(spjbuf[j][0][0]));
2049 v8df dy = _mm512_sub_pd(yi, _mm512_set1_pd(spjbuf[j][0][1]));
2050 v8df dz = _mm512_sub_pd(zi, _mm512_set1_pd(spjbuf[j][0][2]));
2052 v8df r2 = _mm512_fmadd_pd(dx, dx, veps2);
2053 r2 = _mm512_fmadd_pd(dy, dy, r2);
2054 r2 = _mm512_fmadd_pd(dz, dz, r2);
2056 v8df ri1 = _mm512_rsqrt14_pd(r2);
2058 v8df ri2 = _mm512_mul_pd(ri1, ri1);
2060 #ifdef RSQRT_NR_EPJ_X4
2062 v8df v1 = _mm512_set1_pd(1.0f);
2063 v8df h = _mm512_fnmadd_pd(r2, ri2, v1);
2065 v8df v6p0 = _mm512_set1_pd(6.0f);
2066 v8df v5p0 = _mm512_set1_pd(5.0f);
2067 v8df v8p0 = _mm512_set1_pd(8.0f);
2068 v8df v0p0625 = _mm512_set1_pd((
float)1.0/16.0);
2070 ri2 = _mm512_fmadd_pd(h, v5p0, v6p0);
2071 ri2 = _mm512_fmadd_pd(h, ri2, v8p0);
2072 ri2 = _mm512_mul_pd(h, ri2);
2073 ri2 = _mm512_fmadd_pd(ri2, v0p0625, v1);
2074 ri1 = _mm512_mul_pd(ri2, ri1);
2078 ri2 = _mm512_mul_pd(ri1, ri1);
2088 #elif defined(RSQRT_NR_EPJ_X2)
2090 ri2 = _mm512_fnmadd_pd(r2, ri2, _mm512_set1_pd(3.0f));
2091 ri2 = _mm512_mul_pd(ri2, _mm512_set1_pd(0.5f));
2092 ri1 = _mm512_mul_pd(ri2, ri1);
2096 ri2 = _mm512_mul_pd(ri1, ri1);
2099 v8df ri3 = _mm512_mul_pd(ri1, ri2);
2100 v8df ri4 = _mm512_mul_pd(ri2, ri2);
2101 v8df ri5 = _mm512_mul_pd(ri2, ri3);
2105 v8df qr_x = _mm512_mul_pd(dx, qxx);
2106 qr_x = _mm512_fmadd_pd(dy, qxy, qr_x);
2107 qr_x = _mm512_fmadd_pd(dz, qzx, qr_x);
2109 v8df qr_y = _mm512_mul_pd(dy, qyy);
2110 qr_y = _mm512_fmadd_pd(dx, qxy, qr_y);
2111 qr_y = _mm512_fmadd_pd(dz, qyz, qr_y);
2113 v8df qr_z = _mm512_mul_pd(dz, qzz);
2114 qr_z = _mm512_fmadd_pd(dx, qzx, qr_z);
2115 qr_z = _mm512_fmadd_pd(dy, qyz, qr_z);
2118 v8df rqr = _mm512_fmadd_pd(qr_x, dx, mtr);
2121 v8df qr_x = _mm512_mul_pd(dx, _mm512_set1_pd(spjbuf[j][1][0]));
2122 qr_x = _mm512_fmadd_pd(dy, _mm512_set1_pd(spjbuf[j][2][0]), qr_x);
2123 qr_x = _mm512_fmadd_pd(dz, _mm512_set1_pd(spjbuf[j][2][2]), qr_x);
2125 v8df qr_y = _mm512_mul_pd(dy, _mm512_set1_pd(spjbuf[j][1][1]));
2126 qr_y = _mm512_fmadd_pd(dx, _mm512_set1_pd(spjbuf[j][2][0]), qr_y);
2127 qr_y = _mm512_fmadd_pd(dz, _mm512_set1_pd(spjbuf[j][2][1]), qr_y);
2129 v8df qr_z = _mm512_mul_pd(dz, _mm512_set1_pd(spjbuf[j][1][2]));
2130 qr_z = _mm512_fmadd_pd(dx, _mm512_set1_pd(spjbuf[j][2][2]), qr_z);
2131 qr_z = _mm512_fmadd_pd(dy, _mm512_set1_pd(spjbuf[j][2][1]), qr_z);
2134 v8df rqr = _mm512_fmadd_pd(qr_x, dx, _mm512_set1_pd(spjbuf[j][2][3]));
2137 rqr = _mm512_fmadd_pd(qr_y, dy, rqr);
2138 rqr = _mm512_fmadd_pd(qr_z, dz, rqr);
2140 v8df rqr_ri4 = _mm512_mul_pd(rqr, ri4);
2143 v8df mj = _mm512_set1_pd(spjbuf[j][1][3]);
2145 v8df meff = _mm512_fmadd_pd(rqr_ri4, _mm512_set1_pd(0.5f), mj);
2147 v8df meff3 = _mm512_fmadd_pd(rqr_ri4, _mm512_set1_pd(2.5f), mj);
2148 meff3 = _mm512_mul_pd(meff3, ri3);
2151 pot = _mm512_fnmadd_pd(meff, ri1, pot);
2153 ax = _mm512_fmadd_pd(ri5, qr_x, ax);
2154 ax = _mm512_fnmadd_pd(meff3, dx, ax);
2155 ay = _mm512_fmadd_pd(ri5, qr_y, ay);
2156 ay = _mm512_fnmadd_pd(meff3, dy, ay);
2157 az = _mm512_fmadd_pd(ri5, qr_z, az);
2158 az = _mm512_fnmadd_pd(meff3, dz, az);
2174 *(v8df *)(&accpbuf[i/8][0]) = ax;
2175 *(v8df *)(&accpbuf[i/8][1]) = ay;
2176 *(v8df *)(&accpbuf[i/8][2]) = az;
2177 *(v8df *)(&accpbuf[i/8][3]) = pot;
2182 void kernel_spj_nounroll(
const int ni,
const int nj){
2183 const v4df veps2 = _mm256_set1_pd(eps2);
2184 for(
int i=0; i<ni; i+=4){
2186 const v4df xi = *(v4df *)(&xibuf[i/8][0][il]);
2187 const v4df yi = *(v4df *)(&xibuf[i/8][1][il]);
2188 const v4df zi = *(v4df *)(&xibuf[i/8][2][il]);
2190 v4df ax, ay, az, pot;
2191 ax = ay = az = pot = _mm256_set1_pd(0.0);
2193 v4df jbuf0 = *((v4df*)spjbuf[0][0]);
2194 v4df jbuf1 = *((v4df*)spjbuf[0][1]);
2195 v4df jbuf2 = *((v4df*)spjbuf[0][2]);
2197 for(
int j=0; j<nj; j++){
2198 v4df xj = _mm256_permute4x64_pd(jbuf0, 0x00);
2199 v4df yj = _mm256_permute4x64_pd(jbuf0, 0x55);
2200 v4df zj = _mm256_permute4x64_pd(jbuf0, 0xaa);
2201 jbuf0 = *((v4df*)spjbuf[j+1][0]);
2203 v4df qxx = _mm256_permute4x64_pd(jbuf1, 0x00);
2204 v4df qyy = _mm256_permute4x64_pd(jbuf1, 0x55);
2205 v4df qzz = _mm256_permute4x64_pd(jbuf1, 0xaa);
2206 v4df mj = _mm256_permute4x64_pd(jbuf1, 0xff);
2207 jbuf1 = *((v4df*)spjbuf[j+1][1]);
2209 v4df qxy = _mm256_permute4x64_pd(jbuf2, 0x00);
2210 v4df qyz = _mm256_permute4x64_pd(jbuf2, 0x55);
2211 v4df qzx = _mm256_permute4x64_pd(jbuf2, 0xaa);
2212 v4df mtr = _mm256_permute4x64_pd(jbuf2, 0xff);
2213 jbuf2 = *((v4df*)spjbuf[j+1][2]);
2215 v4df dx = _mm256_sub_pd(xi, xj);
2216 v4df dy = _mm256_sub_pd(yi, yj);
2217 v4df dz = _mm256_sub_pd(zi, zj);
2219 v4df r2 = _mm256_fmadd_pd(dx, dx, veps2);
2220 r2 = _mm256_fmadd_pd(dy, dy, r2);
2221 r2 = _mm256_fmadd_pd(dz, dz, r2);
2223 v4df ri1 = _mm256_cvtps_pd(_mm_rsqrt_ps(_mm256_cvtpd_ps(r2)));
2226 v4df ri2 = _mm256_mul_pd(ri1, ri1);
2228 #ifdef RSQRT_NR_EPJ_X4
2230 v4df v1 = _mm256_set1_pd(1.0f);
2231 v4df h = _mm256_fnmadd_pd(r2, ri2, v1);
2233 v4df v6p0 = _mm256_set1_pd(6.0f);
2234 v4df v5p0 = _mm256_set1_pd(5.0f);
2235 v4df v8p0 = _mm256_set1_pd(8.0f);
2236 v4df v0p0625 = _mm256_set1_pd((
float)1.0/16.0);
2238 ri2 = _mm256_fmadd_pd(h, v5p0, v6p0);
2239 ri2 = _mm256_fmadd_pd(h, ri2, v8p0);
2240 ri2 = _mm256_mul_pd(h, ri2);
2241 ri2 = _mm256_fmadd_pd(ri2, v0p0625, v1);
2242 ri1 = _mm256_mul_pd(ri2, ri1);
2246 ri2 = _mm256_mul_pd(ri1, ri1);
2254 #elif defined(RSQRT_NR_EPJ_X2)
2256 ri2 = _mm256_fnmadd_pd(r2, ri2, _mm256_set1_pd(3.0f));
2257 ri2 = _mm256_mul_pd(ri2, _mm256_set1_pd(0.5f));
2258 ri1 = _mm256_mul_pd(ri2, ri1);
2262 ri2 = _mm256_mul_pd(ri1, ri1);
2269 v4df ri3 = _mm256_mul_pd(ri1, ri2);
2270 v4df ri4 = _mm256_mul_pd(ri2, ri2);
2271 v4df ri5 = _mm256_mul_pd(ri2, ri3);
2274 v4df qr_x = _mm256_mul_pd(dx, qxx);
2275 qr_x = _mm256_fmadd_pd(dy, qxy, qr_x);
2276 qr_x = _mm256_fmadd_pd(dz, qzx, qr_x);
2278 v4df qr_y = _mm256_mul_pd(dy, qyy);
2279 qr_y = _mm256_fmadd_pd(dx, qxy, qr_y);
2280 qr_y = _mm256_fmadd_pd(dz, qyz, qr_y);
2282 v4df qr_z = _mm256_mul_pd(dz, qzz);
2283 qr_z = _mm256_fmadd_pd(dx, qzx, qr_z);
2284 qr_z = _mm256_fmadd_pd(dy, qyz, qr_z);
2287 v4df rqr = _mm256_fmadd_pd(qr_x, dx, mtr);
2288 rqr = _mm256_fmadd_pd(qr_y, dy, rqr);
2289 rqr = _mm256_fmadd_pd(qr_z, dz, rqr);
2291 v4df rqr_ri4 = _mm256_mul_pd(rqr, ri4);
2293 v4df meff = _mm256_fmadd_pd(rqr_ri4, _mm256_set1_pd(0.5f), mj);
2294 v4df meff3 = _mm256_fmadd_pd(rqr_ri4, _mm256_set1_pd(2.5f), mj);
2295 meff3 = _mm256_mul_pd(meff3, ri3);
2297 pot = _mm256_fnmadd_pd(meff, ri1, pot);
2299 ax = _mm256_fmadd_pd(ri5, qr_x, ax);
2300 ax = _mm256_fnmadd_pd(meff3, dx, ax);
2301 ay = _mm256_fmadd_pd(ri5, qr_y, ay);
2302 ay = _mm256_fnmadd_pd(meff3, dy, ay);
2303 az = _mm256_fmadd_pd(ri5, qr_z, az);
2304 az = _mm256_fnmadd_pd(meff3, dz, az);
2316 *(v4df *)(&accpbuf[i/8][0][il]) = ax;
2317 *(v4df *)(&accpbuf[i/8][1][il]) = ay;
2318 *(v4df *)(&accpbuf[i/8][2][il]) = az;
2319 *(v4df *)(&accpbuf[i/8][3][il]) = pot;