1#ifndef LINEARALGEBRA_SRC_MOLPRO_LINALG_ARRAY_UTIL_GEMM_H
2#define LINEARALGEBRA_SRC_MOLPRO_LINALG_ARRAY_UTIL_GEMM_H
6#include "BufferManager.h"
9#include <molpro/Options.h>
10#include <molpro/Profiler.h>
11#include <molpro/cblas.h>
12#include <molpro/linalg/array/DistrArrayFile.h>
13#include <molpro/linalg/array/type_traits.h>
14#include <molpro/linalg/itsolv/subspace/Matrix.h>
15#include <molpro/linalg/itsolv/wrap.h>
16#include <molpro/linalg/itsolv/wrap_util.h>
17#include <molpro/linalg/options.h>
33 const CVecRef<DistrArrayFile>& xx) {
36 prof += xx.size() * yy.size() * yy[0].get().local_buffer()->size() * 2;
41 auto alphadata =
const_cast<value_type*
>(alphas.data().data());
44 MPI_Allreduce(MPI_IN_PLACE, alphadata, alphas.size(), MPI_DOUBLE, MPI_SUM, molpro::mpi::comm_global());
49template <
class AL,
typename = std::enable_if_t<!std::is_same_v<std::decay_t<AL>, DistrArrayFile>>>
51 const CVecRef<AL>& yy) {
60 const CVecRef<DistrArrayFile>& xx,
const VecRef<AL>& yy) {
61 if (yy.empty() or xx.empty())
65 prof += xx.size() * yy.size() * yy[0].get().local_buffer()->size() * 2;
66 if (alphas.rows() != xx.size())
67 throw std::out_of_range(std::string{
"gemm_outer_distr_distr: dimensions of xx and alphas are different: "} +
68 std::to_string(alphas.rows()) +
" " + std::to_string(xx.size()));
69 if (alphas.cols() != yy.size())
70 throw std::out_of_range(std::string{
"gemm_outer_distr_distr: dimensions of yy and alphas are different: "} +
71 std::to_string(alphas.cols()) +
" " + std::to_string(yy.size()));
81 if (xx.size() == 0 || yy.size() == 0) {
85 bool yy_constant_stride =
true;
86 int previous_stride = 0;
87 int yy_stride = yy.front().get().local_buffer()->size();
88 for (
size_t j = 0; j < std::max((
size_t)1, yy.size()) - 1; ++j) {
89 auto unique_ptr_j = yy.at(j).get().local_buffer()->data();
90 auto unique_ptr_jp1 = yy.at(j + 1).get().local_buffer()->data();
91 yy_stride = unique_ptr_jp1 - unique_ptr_j;
94 yy_constant_stride = yy_constant_stride && (yy_stride == previous_stride);
95 previous_stride = yy_stride;
97 yy_constant_stride = yy_constant_stride && (yy_stride > 0);
100 auto number_of_buffers =
options->parameter(
"GEMM_BUFFERS", 2);
102 std::min(
int(yy.front().get().local_buffer()->size()),
options->parameter(
"GEMM_PAGESIZE", 8192)) *
109 for (
auto buffer_iterator = buffer.
begin(); buffer_iterator != buffer.
end(); ++buffer_iterator) {
114 if (yy_constant_stride and not yy.empty()) {
117 std::to_string(yy.size()) +
", " + std::to_string(current_buf_size));
119 cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, current_buf_size, yy.size(), xx.size(), 1,
120 buffer_iterator->data(), buffer.
buffer_stride(), alphadata, yy.size(), 1,
121 yy[0].get().local_buffer()->data() + container_offset, yy_stride);
125 std::to_string(yy.size()) +
", " + std::to_string(current_buf_size));
127 for (
size_t i = 0; i < yy.size(); ++i) {
128 cblas_dgemv(CblasColMajor, CblasNoTrans, current_buf_size, xx.size(), 1, buffer_iterator->data(), buffer.
buffer_stride(),
129 alphadata + i, yy.size(), 1, yy[i].get().local_buffer()->data() + container_offset, 1);
132 }
else if (
gemm_type == gemm_type::inner) {
133 if (yy_constant_stride and not yy.empty()) {
136 std::to_string(yy.size()) +
", " + std::to_string(current_buf_size));
138 cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, xx.size(), yy.size(), current_buf_size, 1,
139 buffer_iterator->data(), buffer.
buffer_stride(), yy[0].get().local_buffer()->data() + container_offset, yy_stride,
140 1, alphadata, xx.size());
144 std::to_string(yy.size()) +
", " + std::to_string(current_buf_size));
146 for (
size_t k = 0; k < yy.size(); ++k) {
147 cblas_dgemv(CblasColMajor, CblasTrans, current_buf_size, xx.size(), 1, buffer_iterator->data(), buffer.
buffer_stride(),
148 yy[k].get().local_buffer()->data() + container_offset, 1, 1, alphadata + k * xx.size(), 1);
157template <
class AL,
class AR = AL>
159 const CVecRef<AR>& yy) {
160 if (std::is_same<AL, DistrArrayFile>::value) {
161 throw std::runtime_error(
"gemm_inner_distr_distr (unbuffered) called with DistrArrayFile (should never happen!)");
166 if (xx.size() == 0 || yy.size() == 0)
170 prof += mat.cols() * mat.rows() * xx.at(0).get().local_buffer()->size() * 2;
171 for (
size_t j = 0; j < mat.cols(); ++j) {
172 auto loc_y = yy.at(j).get().local_buffer();
173 for (
size_t i = 0; i < mat.rows(); ++i) {
174 auto loc_x = xx.at(i).get().local_buffer();
175 mat(i, j) = std::inner_product(
begin(*loc_x),
end(*loc_x),
begin(*loc_y), (value_type)0);
180 MPI_Allreduce(MPI_IN_PLACE,
const_cast<value_type*
>(mat.data().data()), mat.size(), MPI_DOUBLE, MPI_SUM,
181 xx.at(0).get().communicator());
186template <
class AL,
class AR = AL>
188 const VecRef<AL>& yy) {
189 if (std::is_same<AL, DistrArrayFile>::value) {
190 throw std::runtime_error(
"gemm_outer_distr_distr (unbuffered) called with DistrArrayFile (should never happen!)");
194 prof += alphas.rows() * alphas.cols() * yy[0].get().local_buffer()->size() * 2;
195 for (
size_t ii = 0; ii < alphas.rows(); ++ii) {
196 auto loc_x = xx.at(ii).get().local_buffer();
197 for (
size_t jj = 0; jj < alphas.cols(); ++jj) {
198 auto loc_y = yy[jj].get().local_buffer();
199 for (
size_t i = 0; i < loc_y->size(); ++i)
200 (*loc_y)[i] += alphas(ii, jj) * (*loc_x)[i];
207template <
class AL,
class AR = AL>
209 const VecRef<AL>& yy) {
210 for (
size_t ii = 0; ii < alphas.cols(); ++ii) {
211 auto loc_y = yy[ii].get().local_buffer();
212 for (
size_t jj = 0; jj < alphas.rows(); ++jj) {
213 if (loc_y->size() > 0) {
216 for (
auto it = xx.at(jj).get().lower_bound(loc_y->start());
217 it != xx.at(jj).get().upper_bound(loc_y->start() + loc_y->size() - 1); ++it) {
218 std::tie(i, v) = *it;
219 (*loc_y)[i - loc_y->start()] += alphas(jj, ii) * v;
226template <
class AL,
class AR = AL>
228 const CVecRef<AR>& yy) {
231 if (xx.size() == 0 || yy.size() == 0)
233 for (
size_t i = 0; i < mat.rows(); ++i) {
234 auto loc_x = xx.at(i).get().local_buffer();
235 for (
size_t j = 0; j < mat.cols(); ++j) {
237 if (loc_x->size() > 0) {
240 for (
auto it = yy.at(j).get().lower_bound(loc_x->start());
241 it != yy.at(j).get().upper_bound(loc_x->start() + loc_x->size() - 1); ++it) {
242 std::tie(k, v) = *it;
243 mat(i, j) += (*loc_x)[k - loc_x->start()] * v;
249 MPI_Allreduce(MPI_IN_PLACE,
const_cast<value_type*
>(mat.data().data()), mat.size(), MPI_DOUBLE, MPI_SUM,
250 xx.at(0).get().communicator());
257template <
class Handler,
class AL,
class AR = AL>
259 const VecRef<AL>& yy) {
260 for (
size_t ii = 0; ii < alphas.
rows(); ++ii) {
261 for (
size_t jj = 0; jj < alphas.
cols(); ++jj) {
262 handler.axpy(alphas(ii, jj), xx.at(ii).get(), yy[jj].get());
267template <
class Handler,
class AL,
class AR = AL>
269 const CVecRef<AR>& yy) {
271 if (xx.size() == 0 || yy.size() == 0)
273 for (
size_t ii = 0; ii < mat.rows(); ++ii) {
274 for (
size_t jj = 0; jj < mat.cols(); ++jj) {
275 mat(ii, jj) = handler.dot(xx.at(ii).get(), yy.at(jj).get());
BufferManager provides single-buffered or asynchronous double-buffered read access to the data in a c...
Definition: BufferManager.h:20
size_t buffer_offset() const
Definition: BufferManager.h:66
size_t buffer_size() const
Definition: BufferManager.h:56
size_t buffer_stride() const
Definition: BufferManager.h:61
Iterator end()
Definition: BufferManager.h:128
Iterator begin()
Definition: BufferManager.h:126
Matrix container that allows simple data access, slicing, copying and resizing without loosing data.
Definition: Matrix.h:28
void fill(T value)
Sets all elements of matrix to value.
Definition: Matrix.h:89
index_type rows() const
Definition: Matrix.h:165
index_type cols() const
Definition: Matrix.h:166
static std::shared_ptr< Profiler > single()
auto begin(Span< T > &x)
Definition: Span.h:84
auto end(Span< T > &x)
Definition: Span.h:94
Definition: ArrayHandler.h:23
void gemm_outer_distr_distr(const Matrix< typename array::mapped_or_value_type_t< AL > > alphas, const CVecRef< DistrArrayFile > &xx, const VecRef< AL > &yy)
Definition: gemm.h:59
Matrix< typename array::mapped_or_value_type_t< AL > > gemm_inner_distr_distr(const CVecRef< AL > &yy, const CVecRef< DistrArrayFile > &xx)
Definition: gemm.h:32
Matrix< typename array::mapped_or_value_type_t< AL > > gemm_inner_distr_sparse(const CVecRef< AL > &xx, const CVecRef< AR > &yy)
Definition: gemm.h:227
void gemm_distr_distr(array::mapped_or_value_type_t< AL > *alphadata, const CVecRef< DistrArrayFile > &xx, const VecRef< AL > &yy, gemm_type gemm_type)
Definition: gemm.h:77
void gemm_outer_distr_sparse(const Matrix< typename array::mapped_or_value_type_t< AL > > alphas, const CVecRef< AR > &xx, const VecRef< AL > &yy)
Definition: gemm.h:208
void gemm_outer_default(Handler &handler, const Matrix< typename Handler::value_type > alphas, const CVecRef< AR > &xx, const VecRef< AL > &yy)
Definition: gemm.h:258
gemm_type
Definition: gemm.h:27
@ outer
Definition: gemm.h:27
@ inner
Definition: gemm.h:27
Matrix< typename Handler::value_type > gemm_inner_default(Handler &handler, const CVecRef< AL > &xx, const CVecRef< AR > &yy)
Definition: gemm.h:268
typename mapped_or_value_type< A >::type mapped_or_value_type_t
Definition: type_traits.h:37
void transpose_copy(ML &&ml, const MR &mr)
Definition: Matrix.h:281
std::vector< std::reference_wrapper< const A > > CVecRef
Definition: wrap.h:14
auto const_cast_wrap(ForwardIt begin, ForwardIt end)
Takes a begin and end iterators and returns a vector of const-casted references to each element.
Definition: wrap.h:42
std::vector< std::reference_wrapper< A > > VecRef
Definition: wrap.h:11
const std::shared_ptr< const molpro::Options > options()
Get the Options object associated with iterative-solver.
Definition: options.cpp:4