Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 229 additions & 44 deletions source/source_cell/module_neighlist/neighbor_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,44 @@
#include <algorithm>
#include <limits>
#include <cassert>
#include "source_base/timer.h"
#include "source_base/tool_quit.h"

#ifdef _OPENMP
#include <omp.h>
#endif

namespace
{
const int neighbor_setup_openmp_threshold = 100000;

int checked_positive_int_count(long long count, const std::string& name)
{
if (count <= 0 || count > std::numeric_limits<int>::max())
{
ModuleBase::WARNING_QUIT(
"NeighborSearch",
name + " is outside the supported int range for neighbor-search indexing"
);
}
return static_cast<int>(count);
}

int checked_product_to_int(long long lhs, long long rhs, const std::string& name)
{
checked_positive_int_count(lhs, name);
checked_positive_int_count(rhs, name);

if (lhs > std::numeric_limits<int>::max() / rhs)
{
ModuleBase::WARNING_QUIT(
"NeighborSearch",
name + " exceeds the supported int range for neighbor-search indexing"
);
}
return static_cast<int>(lhs * rhs);
}
}

// ========== Getter methods ==========

Expand Down Expand Up @@ -171,50 +209,132 @@ void NeighborSearch::check_expand_condition(const AtomProvider& ucell)

void NeighborSearch::set_member_variables(const AtomProvider& ucell)
{
ModuleBase::timer::start("NeighborSearch", "set_member_variables");
all_atoms_.clear();

ModuleBase::Vector3<double> vec1(ucell.get_latvec().e11, ucell.get_latvec().e12, ucell.get_latvec().e13);
ModuleBase::Vector3<double> vec2(ucell.get_latvec().e21, ucell.get_latvec().e22, ucell.get_latvec().e23);
ModuleBase::Vector3<double> vec3(ucell.get_latvec().e31, ucell.get_latvec().e32, ucell.get_latvec().e33);

int atom_count = 0;
std::vector<NeighborAtom> base_atoms;
base_atoms.reserve(ucell.get_natom());

for (int ix = -glayerX_minus_; ix < glayerX_; ix++)
for (int i = 0; i < ucell.get_ntype(); i++)
{
for (int iy = -glayerY_minus_; iy < glayerY_; iy++)
for (int j = 0; j < ucell.get_na(i); j++)
{
for (int iz = -glayerZ_minus_; iz < glayerZ_; iz++)
base_atoms.push_back(NeighborAtom(
ucell.get_tau(i,j).x,
ucell.get_tau(i,j).y,
ucell.get_tau(i,j).z,
i,
j,
static_cast<int>(base_atoms.size())
));
}
}

const int atoms_per_image = checked_positive_int_count(
static_cast<long long>(base_atoms.size()),
"atoms_per_image"
);
const int nimage_x = checked_positive_int_count(
static_cast<long long>(glayerX_) + static_cast<long long>(glayerX_minus_),
"nimage_x"
);
const int nimage_y = checked_positive_int_count(
static_cast<long long>(glayerY_) + static_cast<long long>(glayerY_minus_),
"nimage_y"
);
const int nimage_z = checked_positive_int_count(
static_cast<long long>(glayerZ_) + static_cast<long long>(glayerZ_minus_),
"nimage_z"
);
const int nimage_xy = checked_product_to_int(nimage_x, nimage_y, "nimage_x * nimage_y");
const int nimages = checked_product_to_int(nimage_xy, nimage_z, "nimages");
const int total_atoms = checked_product_to_int(nimages, atoms_per_image, "total_atoms");

#ifdef _OPENMP
const bool use_parallel = total_atoms >= neighbor_setup_openmp_threshold && omp_get_max_threads() > 1;
if (use_parallel)
{
all_atoms_.assign(total_atoms, NeighborAtom(0.0, 0.0, 0.0, 0, 0, 0));

#pragma omp parallel for schedule(static)
for (int image = 0; image < nimages; image++)
{
const int image_z = image % nimage_z;
const int image_y = (image / nimage_z) % nimage_y;
const int image_x = image / (nimage_y * nimage_z);
const int ix = image_x - glayerX_minus_;
const int iy = image_y - glayerY_minus_;
const int iz = image_z - glayerZ_minus_;
const bool is_inside = ix == 0 && iy == 0 && iz == 0;
const double shift_x = vec1[0] * ix + vec2[0] * iy + vec3[0] * iz;
const double shift_y = vec1[1] * ix + vec2[1] * iy + vec3[1] * iz;
const double shift_z = vec1[2] * ix + vec2[2] * iy + vec3[2] * iz;
const int atom_offset = image * atoms_per_image;

for (int atom = 0; atom < atoms_per_image; atom++)
{
for (int i = 0; i < ucell.get_ntype(); i++)
{
for (int j = 0; j < ucell.get_na(i); j++)
{
double atom_x = ucell.get_tau(i,j).x + vec1[0] * ix + vec2[0] * iy + vec3[0] * iz;
double atom_y = ucell.get_tau(i,j).y + vec1[1] * ix + vec2[1] * iy + vec3[1] * iz;
double atom_z = ucell.get_tau(i,j).z + vec1[2] * ix + vec2[2] * iy + vec3[2] * iz;

NeighborAtom atom(atom_x, atom_y, atom_z, i, j, atom_count);
if(ix==0 && iy==0 && iz==0)
{
atom.is_inside = true;
}
else
{
atom.is_inside = false;
}
all_atoms_.push_back(atom);
atom_count++;
}
}
const NeighborAtom& base_atom = base_atoms[atom];
const int atom_id = atom_offset + atom;

all_atoms_[atom_id] = NeighborAtom(
base_atom.position_x + shift_x,
base_atom.position_y + shift_y,
base_atom.position_z + shift_z,
base_atom.atom_type,
base_atom.atom_index,
atom_id
);
all_atoms_[atom_id].is_inside = is_inside;
}
}
ModuleBase::timer::end("NeighborSearch", "set_member_variables");
return;
}
#endif

all_atoms_.reserve(total_atoms);
for (int image = 0; image < nimages; image++)
{
const int image_z = image % nimage_z;
const int image_y = (image / nimage_z) % nimage_y;
const int image_x = image / (nimage_y * nimage_z);
const int ix = image_x - glayerX_minus_;
const int iy = image_y - glayerY_minus_;
const int iz = image_z - glayerZ_minus_;
const bool is_inside = ix == 0 && iy == 0 && iz == 0;
const double shift_x = vec1[0] * ix + vec2[0] * iy + vec3[0] * iz;
const double shift_y = vec1[1] * ix + vec2[1] * iy + vec3[1] * iz;
const double shift_z = vec1[2] * ix + vec2[2] * iy + vec3[2] * iz;
const int atom_offset = image * atoms_per_image;

for (int atom = 0; atom < atoms_per_image; atom++)
{
const NeighborAtom& base_atom = base_atoms[atom];
const int atom_id = atom_offset + atom;

all_atoms_.push_back(NeighborAtom(
base_atom.position_x + shift_x,
base_atom.position_y + shift_y,
base_atom.position_z + shift_z,
base_atom.atom_type,
base_atom.atom_index,
atom_id
));
all_atoms_.back().is_inside = is_inside;
}
}
ModuleBase::timer::end("NeighborSearch", "set_member_variables");
}

// ========== Main public interface ==========

void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank)
{
ModuleBase::timer::start("NeighborSearch", "init");
// clear possible residual data from previous runs
inside_atoms_.clear();
ghost_atoms_.clear();
Expand Down Expand Up @@ -242,13 +362,15 @@ void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank)
assert(wide_y_ >= 0);
assert(wide_z_ >= 0);

int in_x, in_y, in_z;

for (size_t i = 0; i < all_atoms_.size(); i++)
auto classify_atom = [&](const NeighborAtom& atom)
{
int in_x;
int in_y;
int in_z;

if(wide_x_ < coord_tolerance)
{
if(std::abs(all_atoms_[i].position_x - atoms.x_low) < coord_tolerance)
if(std::abs(atom.position_x - atoms.x_low) < coord_tolerance)
{
in_x = x_;
}
Expand All @@ -260,13 +382,13 @@ void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank)
else
{
in_x = std::min(
static_cast<int>(std::floor((all_atoms_[i].position_x - atoms.x_low) / wide_x_)),
static_cast<int>(std::floor((atom.position_x - atoms.x_low) / wide_x_)),
nx - 1
);
}
if(wide_y_ < coord_tolerance)
{
if(std::abs(all_atoms_[i].position_y - atoms.y_low) < coord_tolerance)
if(std::abs(atom.position_y - atoms.y_low) < coord_tolerance)
{
in_y = y_;
}
Expand All @@ -278,13 +400,13 @@ void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank)
else
{
in_y = std::min(
static_cast<int>(std::floor((all_atoms_[i].position_y - atoms.y_low) / wide_y_)),
static_cast<int>(std::floor((atom.position_y - atoms.y_low) / wide_y_)),
ny - 1
);
}
if(wide_z_ < coord_tolerance)
{
if(std::abs(all_atoms_[i].position_z - atoms.z_low) < coord_tolerance)
if(std::abs(atom.position_z - atoms.z_low) < coord_tolerance)
{
in_z = z_;
}
Expand All @@ -296,39 +418,102 @@ void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank)
else
{
in_z = std::min(
static_cast<int>(std::floor((all_atoms_[i].position_z - atoms.z_low) / wide_z_)),
static_cast<int>(std::floor((atom.position_z - atoms.z_low) / wide_z_)),
nz - 1
);
}

if (in_x == x_ && in_y == y_ && in_z == z_ &&
all_atoms_[i].position_x <= atoms.x_high &&
all_atoms_[i].position_y <= atoms.y_high &&
all_atoms_[i].position_z <= atoms.z_high &&
all_atoms_[i].is_inside)
atom.position_x <= atoms.x_high &&
atom.position_y <= atoms.y_high &&
atom.position_z <= atoms.z_high &&
atom.is_inside)
{
inside_atoms_.push_back(all_atoms_[i]);
return 1;
}
else if (distance(
all_atoms_[i].position_x,
all_atoms_[i].position_y,
all_atoms_[i].position_z,
atom.position_x,
atom.position_y,
atom.position_z,
atoms.x_low,
atoms.y_low,
atoms.z_low) <= search_radius_ * search_radius_)
{
ghost_atoms_.push_back(all_atoms_[i]);
return 2;
}

return 0;
};

#ifdef _OPENMP
const bool use_parallel_classification =
static_cast<int>(all_atoms_.size()) >= neighbor_setup_openmp_threshold && omp_get_max_threads() > 1;
if (use_parallel_classification)
{
std::vector<unsigned char> categories(all_atoms_.size(), 0);
int ninside = 0;
int nghost = 0;

#pragma omp parallel for schedule(static) reduction(+:ninside, nghost)
for (int i = 0; i < static_cast<int>(all_atoms_.size()); i++)
{
const int category = classify_atom(all_atoms_[i]);
categories[i] = static_cast<unsigned char>(category);
ninside += category == 1;
nghost += category == 2;
}

inside_atoms_.reserve(ninside);
ghost_atoms_.reserve(nghost);

int inside_index = 0;
int ghost_index = 0;
for (size_t i = 0; i < all_atoms_.size(); i++)
{
if (categories[i] == 1)
{
inside_atoms_.push_back(all_atoms_[i]);
inside_index++;
}
else if (categories[i] == 2)
{
ghost_atoms_.push_back(all_atoms_[i]);
ghost_index++;
}
}
assert(inside_index == ninside);
assert(ghost_index == nghost);
}
else
#endif
{
inside_atoms_.reserve(ucell.get_natom());
ghost_atoms_.reserve(all_atoms_.size());
for (size_t i = 0; i < all_atoms_.size(); i++)
{
const int category = classify_atom(all_atoms_[i]);
if (category == 1)
{
inside_atoms_.push_back(all_atoms_[i]);
}
else if (category == 2)
{
ghost_atoms_.push_back(all_atoms_[i]);
}
}
}

neighbor_list_.initialize(inside_atoms_.size(), all_atoms_.size() * neighbor_reserve_factor);
ModuleBase::timer::end("NeighborSearch", "init");
}

void NeighborSearch::build_neighbors()
{
ModuleBase::timer::start("NeighborSearch", "build_neighbors");
bin_manager_.init_bins(search_radius_, inside_atoms_, ghost_atoms_);
bin_manager_.do_binning(inside_atoms_, ghost_atoms_);
bin_manager_.build_atom_neighbors(neighbor_list_, inside_atoms_);
ModuleBase::timer::end("NeighborSearch", "build_neighbors");
}

// ========== Utility methods ==========
Expand Down Expand Up @@ -374,4 +559,4 @@ void NeighborSearch::decompose(int mpi_size, int &nx, int &ny, int &nz)
break;
}
}
}
}
Loading