Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 125 additions & 56 deletions source/source_lcao/module_gint/gint_atom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,94 +115,163 @@ void GintAtom::set_phi(const std::vector<Vec3d>& coords, const int stride, T* ph
}
}
}

template <typename T>
void GintAtom::set_phi_dphi(
const std::vector<Vec3d>& coords, const int stride,
T* phi, T* dphi_x, T* dphi_y, T* dphi_z) const
const std::vector<Vec3d>& coords,
const int stride,
T* phi,
T* dphi_x,
T* dphi_y,
T* dphi_z) const
{
const int num_mgrids = coords.size();

// orb_ does not have the member variable dr_uniform

const double dr_uniform = orb_->PhiLN(0, 0).dr_uniform;

const int nylm = std::pow(atom_->nwl + 1, 2);
std::vector<double> rly(nylm);
std::vector<double> grly(nylm * 3);

for(int im = 0; im < num_mgrids; im++)
const double rcut = orb_->getRcut();

const int nylm = (atom_->nwl + 1) * (atom_->nwl + 1);

#pragma omp parallel
{
const Vec3d& coord = coords[im];
// 1e-9 is to avoid division by zero
const double dist = coord.norm() < 1e-9 ? 1e-9 : coord.norm();
std::vector<double> rly(nylm);
std::vector<double> grly(nylm * 3);

if(dist > orb_->getRcut())
#pragma omp for schedule(static)
for(int im = 0; im < num_mgrids; ++im)
{
// if the distance is larger than the cutoff radius,
// the wave function values are all zeros
if(phi != nullptr)
const Vec3d& coord = coords[im];

double dist = coord.norm();
if(dist < 1e-9)
{
ModuleBase::GlobalFunc::ZEROS(phi + im * stride, atom_->nw);
dist = 1e-9;
}
ModuleBase::GlobalFunc::ZEROS(dphi_x + im * stride, atom_->nw);
ModuleBase::GlobalFunc::ZEROS(dphi_y + im * stride, atom_->nw);
ModuleBase::GlobalFunc::ZEROS(dphi_z + im * stride, atom_->nw);
}
else
{
// spherical harmonics
// TODO: vectorize the sph_harm function,
// the vectorized function can be called once for all meshgrids in a biggrid
ModuleBase::Ylm::grad_rl_sph_harm(atom_->nwl, coord.x, coord.y, coord.z, rly.data(), grly.data());

// interpolation
if(dist > rcut)
{
if(phi != nullptr)
{
ModuleBase::GlobalFunc::ZEROS(
phi + im * stride,
atom_->nw);
}

ModuleBase::GlobalFunc::ZEROS(
dphi_x + im * stride,
atom_->nw);

ModuleBase::GlobalFunc::ZEROS(
dphi_y + im * stride,
atom_->nw);

ModuleBase::GlobalFunc::ZEROS(
dphi_z + im * stride,
atom_->nw);

continue;
}

ModuleBase::Ylm::grad_rl_sph_harm(
atom_->nwl,
coord.x,
coord.y,
coord.z,
rly.data(),
grly.data());

const double position = dist / dr_uniform;

const int ip = static_cast<int>(position);

const double x0 = position - ip;
const double x1 = 1.0 - x0;
const double x2 = 2.0 - x0;
const double x3 = 3.0 - x0;
const double x12 = x1 * x2 / 6;
const double x03 = x0 * x3 / 2;

double tmp, dtmp;
const double x12 = x1 * x2 / 6.0;
const double x03 = x0 * x3 / 2.0;
const double dist2 = dist * dist;
const double dist3 = dist2 * dist;

for(int iw = 0; iw < atom_->nw; ++iw)
{
// this is a new 'l', we need 1D orbital wave
// function from interpolation method.
double tmp_iw = 0.0;
double dtmp_iw = 0.0;

if(atom_->iw2_new[iw])
{
auto psi_uniform = p_psi_uniform_[iw];
auto dpsi_uniform = p_dpsi_uniform_[iw];
// use Polynomia Interpolation method to get the
// wave functions
const auto psi_uniform = p_psi_uniform_[iw];
const auto dpsi_uniform = p_dpsi_uniform_[iw];

tmp = x12 * (psi_uniform[ip] * x3 + psi_uniform[ip + 3] * x0)
+ x03 * (psi_uniform[ip + 1] * x2 - psi_uniform[ip + 2] * x1);
tmp_iw =
x12 * (psi_uniform[ip] * x3
+ psi_uniform[ip + 3] * x0)
+ x03 * (psi_uniform[ip + 1] * x2
- psi_uniform[ip + 2] * x1);

dtmp = x12 * (dpsi_uniform[ip] * x3 + dpsi_uniform[ip + 3] * x0)
+ x03 * (dpsi_uniform[ip + 1] * x2 - dpsi_uniform[ip + 2] * x1);
} // new l is used.
dtmp_iw =
x12 * (dpsi_uniform[ip] * x3
+ dpsi_uniform[ip + 3] * x0)
+ x03 * (dpsi_uniform[ip + 1] * x2
- dpsi_uniform[ip + 2] * x1);
}
else
{
continue;
}

// get the 'l' of this localized wave function
const int ll = atom_->iw2l[iw];
const int idx_lm = atom_->iw2_ylm[iw];

const double rl = pow_int(dist, ll);
const double tmprl = tmp / rl;
double rl;

// 3D wave functions
if(phi != nullptr)
switch(ll)
{
phi[im * stride + iw] = tmprl * rly[idx_lm];
case 0:
rl = 1.0;
break;

case 1:
rl = dist;
break;

case 2:
rl = dist2;
break;

case 3:
rl = dist3;
break;

default:
rl = pow_int(dist, ll);
break;
}

// derivative of wave functions with respect to atom positions.
const double tmpdphi_rly = (dtmp - tmp * ll / dist) / rl * rly[idx_lm] / dist;
const double tmprl = tmp_iw / rl;

if(phi != nullptr)
{
phi[im * stride + iw]
= tmprl * rly[idx_lm];
}

const double tmpdphi_rly =
(dtmp_iw - tmp_iw * ll / dist)
/ rl
* rly[idx_lm]
/ dist;

dphi_x[im * stride + iw]
= tmpdphi_rly * coord.x
+ tmprl * grly[idx_lm * 3];

dphi_y[im * stride + iw]
= tmpdphi_rly * coord.y
+ tmprl * grly[idx_lm * 3 + 1];

dphi_x[im * stride + iw] = tmpdphi_rly * coord.x + tmprl * grly[idx_lm*3];
dphi_y[im * stride + iw] = tmpdphi_rly * coord.y + tmprl * grly[idx_lm*3 + 1];
dphi_z[im * stride + iw] = tmpdphi_rly * coord.z + tmprl * grly[idx_lm*3 + 2];
dphi_z[im * stride + iw]
= tmpdphi_rly * coord.z
+ tmprl * grly[idx_lm * 3 + 2];
}
}
}
Expand Down
Binary file not shown.
114 changes: 71 additions & 43 deletions source/source_lcao/module_gint/kernel/phi_operator_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -235,57 +235,85 @@ void PhiOperatorGpu<Real>::phi_mul_phi(
double* hr_d) const
{
// ap_num means number of atom pairs
int ap_num = 0;
int max_m = 0;
int max_n = 0;
int max_k = mgrids_num_;
CHECK_CUDA(cudaEventSynchronize(event_));
for (int i = 0; i < bgrid_batch_->get_batch_size(); i++)

int ap_num = 0;
int max_m = 0;
int max_n = 0;
int max_k = mgrids_num_;
CHECK_CUDA(cudaEventSynchronize(event_));
const int batch_size = bgrid_batch_->get_batch_size();
const auto atom_info_host = atoms_num_info_.get_host_ptr();
const auto phi_start_host = atom_phi_start_.get_host_ptr();

std::unordered_map<long long, int> offset_cache;

for (int i = 0; i < batch_size; i++)
{
auto bgrid = bgrid_batch_->get_bgrids()[i];
const int phi_len_mgrid = bgrid->get_phi_len();
const int mg_num = bgrid->get_mgrids_num();
const int nat_bgrid = bgrid->get_atoms_num();
auto* const atoms = bgrid->get_atoms();
const int pre_atoms = atom_info_host[i].y;

for (int ia_1 = 0; ia_1 < nat_bgrid; ia_1++)
{
auto bgrid = bgrid_batch_->get_bgrids()[i];
// the length of phi on a mesh grid
const int phi_len_mgrid = bgrid->get_phi_len();
const int pre_atoms = atoms_num_info_.get_host_ptr()[i].y;
for (int ia_1 = 0; ia_1 < bgrid->get_atoms_num(); ia_1++)
auto atom_1 = atoms[ia_1];
const int iat_1 = atom_1->get_iat();
const int nw1 = atom_1->get_nw();
atom_1->get_R();

for (int ia_2 = 0; ia_2 < nat_bgrid; ia_2++)
{
auto atom_1 = bgrid->get_atoms()[ia_1];
const int iat_1 = atom_1->get_iat();
const auto& r_1 = atom_1->get_R();
const int nw1 = atom_1->get_nw();
const int phi_1_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_1];
auto atom_2 = atoms[ia_2];
const int iat_2 = atom_2->get_iat();
if(iat_1 > iat_2)
{ continue; }

for (int ia_2 = 0; ia_2 < bgrid->get_atoms_num(); ia_2++)
{
auto atom_2 = bgrid->get_atoms()[ia_2];
const int iat_2 = atom_2->get_iat();
const auto& r_2 = atom_2->get_R();
const int nw2 = atom_2->get_nw();
const int nw2 = atom_2->get_nw();
const int phi_2_offset = phi_start_host[pre_atoms + ia_2];
const auto& r_2 = atom_2->get_R();

if(iat_1 > iat_2)
{ continue; }

int hr_offset = hRGint.find_matrix_offset(iat_1, iat_2, r_1 - r_2);
if (hr_offset == -1)
{ continue; }

long long key = (long long)iat_1 * 10000 + iat_2;
int hr_offset = -1;
if (offset_cache.count(key))
{
hr_offset = offset_cache[key];
}
else
{
hr_offset = hRGint.find_matrix_offset(iat_1, iat_2, r_1 - r_2);
offset_cache[key] = hr_offset;
}

const int phi_2_offset = atom_phi_start_.get_host_ptr()[pre_atoms + ia_2];
if (hr_offset == -1)
{ continue; }

gemm_A_.get_host_ptr()[ap_num] = phi_d + phi_1_offset;
gemm_B_.get_host_ptr()[ap_num] = phi_vldr3_d + phi_2_offset;
gemm_C_.get_host_ptr()[ap_num] = hr_d + hr_offset;
gemm_lda_.get_host_ptr()[ap_num] = phi_len_mgrid;
gemm_ldb_.get_host_ptr()[ap_num] = phi_len_mgrid;
gemm_ldc_.get_host_ptr()[ap_num] = nw2;
gemm_m_.get_host_ptr()[ap_num] = nw1;
gemm_n_.get_host_ptr()[ap_num] = nw2;
gemm_k_.get_host_ptr()[ap_num] = bgrid->get_mgrids_num();
ap_num++;
gemm_A_.gem_B_.get_host_ptr()[ap_num] = phi_vldr3_d + phi_2_offset;
gemm_C_.get_host_ptr()[ap_num] = hr_d + hr_offset;
gemm_lda_.get_host_ptr()[ap_num] = phi_len_mgrid;
gemm_ldb_.get_host_ptr()[ap_num] = phi_len_mgrid;
gemm_ldc_.get_host_ptr()[ap_num] = nw2;
gemm_m_.get_host_ptr()[ap_num] = nw1;
gemm_n_.get_host_ptr()[ap_num] = nw2;
gemm_k_.get_host_ptr()[ap_num] = mg_num;
ap_num++;

max_m = std::max(max_m, nw1);
max_n = std::max(max_n, nw2);
}
max_m = std::max(max_m, nw1);
max_n = std::max(max_n, nw2);
}
}
}
gemm_A_.reserve_host(ap_num);
gemm_B_.reserve_host(ap_num);
gemm_C_.reserve_host(ap_num);
gemm_lda_.reserve_host(ap_num);
gemm_ldb_.reserve_host(ap_num);
gemm_ldc_.reserve_host(ap_num);
gemm_m_.reserve_host(ap_num);
gemm_n_.reserve_host(ap_num);
gemm_k_.reserve_host(ap_num);

gemm_A_.copy_host_to_device_async(ap_num);
gemm_B_.copy_host_to_device_async(ap_num);
Expand Down Expand Up @@ -491,4 +519,4 @@ void PhiOperatorGpu<Real>::phi_dot_dphi_r(
template class PhiOperatorGpu<double>;
template class PhiOperatorGpu<float>;

}
}
Loading