From ca361545b531298a17f1013894dac18f06ad295c Mon Sep 17 00:00:00 2001 From: haochong zhang Date: Tue, 23 Jun 2026 03:47:31 +0000 Subject: [PATCH] refactor(pw): optimize bspline structure factor grid --- .../module_pwdft/structure_factor.cpp | 129 +++++++----------- 1 file changed, 50 insertions(+), 79 deletions(-) diff --git a/source/source_pw/module_pwdft/structure_factor.cpp b/source/source_pw/module_pwdft/structure_factor.cpp index f0db1acd22..5dacda1ad8 100644 --- a/source/source_pw/module_pwdft/structure_factor.cpp +++ b/source/source_pw/module_pwdft/structure_factor.cpp @@ -1,5 +1,4 @@ #include "source_base/global_function.h" -#include "source_base/global_variable.h" #include "source_io/module_parameter/parameter.h" #include "structure_factor.h" #include "source_base/constants.h" @@ -8,6 +7,7 @@ #include "source_base/timer.h" #include "source_base/libm/libm.h" +#include #ifdef _OPENMP #include @@ -203,38 +203,38 @@ void Structure_Factor::setup(const UnitCell* Ucell, const Parallel_Grid& pgrid, // norder: the order of Cardinal B-spline base functions //FURTHER OPTIMIZATION: // 1. Use "r2c" fft -// 2. Add parallel algorithm for fftw or na loop // void Structure_Factor::bspline_sf(const int norder, const UnitCell* Ucell, const Parallel_Grid& pgrid, const ModulePW::PW_Basis* rho_basis) { - double *r = new double [rho_basis->nxyz]; - double *tmpr = new double[rho_basis->nrxx]; - double *zpiece = new double[rho_basis->nxy]; - std::complex *b1 = new std::complex [rho_basis->nx]; - std::complex *b2 = new std::complex [rho_basis->ny]; - std::complex *b3 = new std::complex [rho_basis->nz]; + (void)pgrid; + std::vector tmpr(rho_basis->nrxx); + std::vector> b1(rho_basis->nx); + std::vector> b2(rho_basis->ny); + std::vector> b3(rho_basis->nz); + const int nplane = rho_basis->nplane; + const int startz = rho_basis->startz_current; - for (int it=0; itntype; it++) + // Each rank owns the same atoms; populate only its local FFT z slab. + for (int it = 0; it < Ucell->ntype; it++) { - const int na = Ucell->atoms[it].na; - const ModuleBase::Vector3 * const taud = Ucell->atoms[it].taud.data(); - ModuleBase::GlobalFunc::ZEROS(r,rho_basis->nxyz); + const int na = Ucell->atoms[it].na; + const ModuleBase::Vector3* const taud = Ucell->atoms[it].taud.data(); + ModuleBase::GlobalFunc::ZEROS(tmpr.data(), rho_basis->nrxx); - //A parallel algorithm can be added in the future. #ifdef _OPENMP - #pragma omp parallel for +#pragma omp parallel for #endif - for(int ia = 0 ; ia < na ; ++ia) + for (int ia = 0; ia < na; ++ia) { - double gridx = taud[ia].x * rho_basis->nx; - double gridy = taud[ia].y * rho_basis->ny; - double gridz = taud[ia].z * rho_basis->nz; - double dx = gridx - floor(gridx); - double dy = gridy - floor(gridy); - double dz = gridz - floor(gridz); + const double gridx = taud[ia].x * rho_basis->nx; + const double gridy = taud[ia].y * rho_basis->ny; + const double gridz = taud[ia].z * rho_basis->nz; + const double dx = gridx - floor(gridx); + const double dy = gridy - floor(gridy); + const double dz = gridz - floor(gridz); //I'm not sure if there is a mod function for double data ModuleBase::Bspline bsx, bsy, bsz; @@ -245,79 +245,50 @@ void Structure_Factor::bspline_sf(const int norder, bsy.getbspline(dy); bsz.getbspline(dz); - for(int iz = 0 ; iz <= norder ; ++iz) + for (int iz = 0; iz <= norder; ++iz) { - int icz = int(rho_basis->nz*10-iz+floor(gridz))%rho_basis->nz; - for(int iy = 0 ; iy <= norder ; ++iy) + const int icz = int(rho_basis->nz * 10 - iz + floor(gridz)) % rho_basis->nz; + if (icz < startz || icz >= startz + nplane) { - int icy = int(rho_basis->ny*10-iy+floor(gridy))%rho_basis->ny; - for(int ix = 0 ; ix <= norder ; ++ix ) + continue; + } + const int local_z = icz - startz; + for (int iy = 0; iy <= norder; ++iy) + { + const int icy = int(rho_basis->ny * 10 - iy + floor(gridy)) % rho_basis->ny; + for (int ix = 0; ix <= norder; ++ix) { - int icx = int(rho_basis->nx*10-ix+floor(gridx))%rho_basis->nx; + const int icx = int(rho_basis->nx * 10 - ix + floor(gridx)) % rho_basis->nx; #ifdef _OPENMP - #pragma omp atomic +#pragma omp atomic #endif - r[icz*rho_basis->ny*rho_basis->nx + icx*rho_basis->ny + icy] += bsz.bezier_ele(iz) - * bsy.bezier_ele(iy) - * bsx.bezier_ele(ix); + tmpr[(icx * rho_basis->ny + icy) * nplane + local_z] + += bsz.bezier_ele(iz) * bsy.bezier_ele(iy) * bsx.bezier_ele(ix); } } } } - - //distribute data to different processors for UFFT - //--------------------------------------------------- - for(int iz = 0; iz < rho_basis->nz; iz++) - { - if(GlobalV::MY_RANK==0) - { -#ifdef _OPENMP - #pragma omp parallel for schedule(static, 512) -#endif - for(int ir = 0; ir < rho_basis->nxy; ir++) - { - zpiece[ir] = r[iz*rho_basis->nxy + ir]; - } - } - - #ifdef __MPI - pgrid.zpiece_to_all(zpiece, iz, tmpr); - #else - // Serial build: the whole real-space grid is local, so there is no - // pool to scatter to. zpiece_to_all() is MPI-only, which otherwise - // leaves tmpr uninitialized -> garbage structure factor and a wrong - // total energy. Fill tmpr directly, using the SAME real-space layout - // as zpiece_to_all's serial path: rho[ir*nczp + znow], i.e. xy index - // outer and z innermost (nczp == nz, znow == iz when serial). - for(int ir = 0; ir < rho_basis->nxy; ir++) - { - tmpr[ir*rho_basis->nz + iz] = zpiece[ir]; - } - #endif - - } - //--------------------------------------------------- //It should be optimized with r2c - rho_basis->real2recip(tmpr, &strucFac(it,0)); - this->bsplinecoef(b1,b2,b3,rho_basis->nx, rho_basis->ny, rho_basis->nz, norder); + rho_basis->real2recip(tmpr.data(), &strucFac(it, 0)); + this->bsplinecoef(b1.data(), + b2.data(), + b3.data(), + rho_basis->nx, + rho_basis->ny, + rho_basis->nz, + norder); #ifdef _OPENMP - #pragma omp parallel for schedule(static, 128) +#pragma omp parallel for schedule(static, 128) #endif - for(int ig = 0 ; ig < rho_basis->npw ; ++ig) + for (int ig = 0; ig < rho_basis->npw; ++ig) { - int idx = int(rho_basis->gdirect[ig].x+0.1+rho_basis->nx)%rho_basis->nx; - int idy = int(rho_basis->gdirect[ig].y+0.1+rho_basis->ny)%rho_basis->ny; - int idz = int(rho_basis->gdirect[ig].z+0.1+rho_basis->nz)%rho_basis->nz; - strucFac(it,ig) *= ( b1[idx] * b2[idy] * b3[idz] * double(rho_basis->nxyz) ); + const int idx = int(rho_basis->gdirect[ig].x + 0.1 + rho_basis->nx) % rho_basis->nx; + const int idy = int(rho_basis->gdirect[ig].y + 0.1 + rho_basis->ny) % rho_basis->ny; + const int idz = int(rho_basis->gdirect[ig].z + 0.1 + rho_basis->nz) % rho_basis->nz; + strucFac(it, ig) *= (b1[idx] * b2[idy] * b3[idz] * double(rho_basis->nxyz)); } - } - delete[] r; - delete[] tmpr; - delete[] zpiece; - delete[] b1; - delete[] b2; - delete[] b3; + } return; }