1#ifndef LINEARALGEBRA_SRC_MOLPRO_LINALG_ARRAY_ARRAYHANDLER_H
2#define LINEARALGEBRA_SRC_MOLPRO_LINALG_ARRAY_ARRAYHANDLER_H
14#include <molpro/linalg/array/type_traits.h>
15#include <molpro/linalg/itsolv/subspace/Matrix.h>
16#include <molpro/linalg/itsolv/wrap_util.h>
26 using std::logic_error::logic_error;
31template <
typename... Args>
33 using OP = std::tuple<Args...>;
40 template <
int N,
class ArgEqual>
41 void push(
const Args &...args, ArgEqual equal) {
42 auto &&new_op =
OP{args...};
43 auto &ref = std::get<N>(new_op);
44 auto rend = std::find_if(
m_register.rbegin(),
m_register.rend(), [&ref, &equal](
const auto &el) {
45 auto xx = std::get<N>(el);
46 return equal(ref, xx);
50 end_of_group = rend.base();
51 m_register.insert(end_of_group, std::forward<OP>(new_op));
72template <
typename X,
typename Y,
typename Z,
class EqualX,
class EqualY,
class EqualZ>
73std::tuple<std::vector<std::tuple<size_t, size_t, size_t>>, std::vector<X>, std::vector<Y>, std::vector<Z>>
74remove_duplicates(
const std::list<std::tuple<X, Y, Z>> ®, EqualX equal_x, EqualY equal_y, EqualZ equal_z) {
75 auto n_op = reg.size();
76 std::vector<std::tuple<size_t, size_t, size_t>> op_register;
77 op_register.reserve(n_op);
81 for (
const auto &op : reg) {
82 auto x = std::get<0>(op);
83 auto y = std::get<1>(op);
84 auto z = std::get<2>(op);
85 auto it_x = std::find_if(cbegin(xx), cend(xx), [&x, &equal_x](
const auto &el) {
return equal_x(x, el); });
86 auto it_y = std::find_if(cbegin(yy), cend(yy), [&y, &equal_y](
const auto &el) {
return equal_y(y, el); });
87 auto it_z = std::find_if(cbegin(zz), cend(zz), [&z, &equal_z](
const auto &el) {
return equal_z(z, el); });
88 auto ix = distance(cbegin(xx), it_x);
89 auto iy = distance(cbegin(yy), it_y);
90 auto iz = distance(cbegin(zz), it_z);
97 op_register.emplace_back(ix, iy, iz);
99 return {op_register, xx, yy, zz};
103template <
typename T =
int>
105 bool operator()(
const std::reference_wrapper<T> &l,
const std::reference_wrapper<T> &r) {
106 return std::addressof(l.get()) == std::addressof(r.get());
161template <
class AL,
class AR = AL>
184 virtual AL
copy(
const AR &source) = 0;
186 virtual void copy(AL &x,
const AR &y) = 0;
212 virtual std::map<size_t, value_type_abs>
select_max_dot(
size_t n,
const AL &x,
const AR &y) = 0;
222 virtual std::map<size_t, value_type>
select(
size_t n,
const AL &x,
bool max =
false,
bool ignore_sign =
false) = 0;
227 std::string output =
"";
229 output.append(std::to_string(
m_counter->scal) +
" scaling operations of the " + L +
" vectors, ");
231 output.append(std::to_string(
m_counter->copy) +
" " + L +
"<-" + R +
" copy operations, ");
233 output.append(std::to_string(
m_counter->dot) +
" dot product operations between the " + L +
" and " + R +
236 output.append(std::to_string(
m_counter->axpy) +
" axpy (" + R +
" = a*" + L +
" + " + R +
") operations, ");
238 output.append(std::to_string(
m_counter->gemm_inner) +
" gemm_inner operations between the " + L +
" and " + R +
241 output.append(std::to_string(
m_counter->gemm_outer) +
" gemm_outer operations between the " + L +
" and " + R +
258 if (auto handle = el.lock())
259 handle->invalidate();
271 virtual void fused_axpy(
const std::vector<std::tuple<size_t, size_t, size_t>> ®,
272 const std::vector<value_type> &alphas,
273 const std::vector<std::reference_wrapper<const AR>> &xx,
274 std::vector<std::reference_wrapper<AL>> &yy) {
275 for (
const auto &i : reg) {
277 std::tie(ai, xi, yi) = i;
278 axpy(alphas[ai], xx[xi].get(), yy[yi].get());
283 virtual void fused_dot(
const std::vector<std::tuple<size_t, size_t, size_t>> ®,
284 const std::vector<std::reference_wrapper<const AL>> &xx,
285 const std::vector<std::reference_wrapper<const AR>> &yy,
286 std::vector<std::reference_wrapper<value_type>> &out) {
287 for (
const auto &i : reg) {
289 std::tie(xi, yi, zi) = i;
290 out[zi].get() =
dot(xx[xi].get(), yy[yi].get());
301 template <
typename T>
342 m_axpy.
push(alpha, std::cref(x), std::ref(y));
344 error(
"Failed to register operation type axpy with the current state of the LazyHandle");
348 m_dot.
push(std::cref(x), std::cref(y), std::ref(out));
350 error(
"Failed to register operation type dot with the current state of the LazyHandle");
358 auto reg = util::remove_duplicates<value_type, ref_wrap<const AR>,
ref_wrap<AL>, std::equal_to<value_type>,
360 m_handler.
fused_axpy(std::get<0>(reg), std::get<1>(reg), std::get<2>(reg), std::get<3>(reg));
367 m_handler.
fused_dot(std::get<0>(reg), std::get<1>(reg), std::get<2>(reg), std::get<3>(reg));
388 template <
typename... Args>
394 template <
typename... Args>
395 void dot(Args &&...args) {
425 *empty_handle = handle;
429 auto handle = std::make_shared<typename ArrayHandler<AL, AR>::LazyHandle>(handler);
Registers operations for lazy evaluation. Evaluation is triggered by calling eval() or on destruction...
Definition: ArrayHandler.h:298
virtual void axpy(value_type alpha, const AR &x, AL &y)
Definition: ArrayHandler.h:340
virtual void eval()
Calls handler to evaluate the registered operations.
Definition: ArrayHandler.h:354
void clear()
Clear the registry.
Definition: ArrayHandler.h:330
virtual void dot(const AL &x, const AR &y, value_type &out)
Definition: ArrayHandler.h:346
virtual ~LazyHandle()
Definition: ArrayHandler.h:338
void error(std::string message)
Definition: ArrayHandler.h:312
void invalidate()
Flag the handler as invalid so that no new operations are registered operations eval() does nothing.
Definition: ArrayHandler.h:373
util::OperationRegister< ref_wrap< const AL >, ref_wrap< const AR >, ref_wrap< value_type > > m_dot
register of dot operations
Definition: ArrayHandler.h:310
LazyHandle(ArrayHandler< AL, AR > &handler)
Definition: ArrayHandler.h:337
ArrayHandler< AL, AR >::value_type value_type
Definition: ArrayHandler.h:300
std::reference_wrapper< T > ref_wrap
Definition: ArrayHandler.h:302
ArrayHandler< AL, AR > & m_handler
all operations are still done through the handler
Definition: ArrayHandler.h:379
std::set< std::string > m_op_types
Types of operations currently registered. Types are strings, because derived classes might add new op...
Definition: ArrayHandler.h:306
bool m_invalid
flags if the handler has been destroyed and LazyHandle is now invalid
Definition: ArrayHandler.h:380
util::OperationRegister< value_type, ref_wrap< const AR >, ref_wrap< AL > > m_axpy
register of axpy operations
Definition: ArrayHandler.h:308
bool invalid()
Definition: ArrayHandler.h:376
virtual bool register_op_type(const std::string &type)
Register an operation type.
Definition: ArrayHandler.h:322
A convenience wrapper around a pointer to the LazyHandle.
Definition: ArrayHandler.h:384
bool is_off()
Returns true if lazy evaluation is off.
Definition: ArrayHandler.h:409
ProxyHandle(std::shared_ptr< LazyHandle > handle)
Definition: ArrayHandler.h:386
void eval()
Definition: ArrayHandler.h:400
void axpy(Args &&...args)
Definition: ArrayHandler.h:389
void invalidate()
Definition: ArrayHandler.h:401
bool invalid()
Definition: ArrayHandler.h:402
bool m_off
whether lazy evaluation is on or off
Definition: ArrayHandler.h:413
void dot(Args &&...args)
Definition: ArrayHandler.h:395
void on()
Turn on lazy evaluation.
Definition: ArrayHandler.h:407
void off()
Turn off lazy evaluation. Next operation will evaluate without delay.
Definition: ArrayHandler.h:405
std::shared_ptr< LazyHandle > m_lazy_handle
Definition: ArrayHandler.h:412
Enhances various operations between pairs of arrays and allows dynamic code injection with uniform in...
Definition: ArrayHandler.h:162
virtual value_type dot(const AL &x, const AR &y)=0
std::unique_ptr< Counter > m_counter
Definition: ArrayHandler.h:176
virtual void fused_dot(const std::vector< std::tuple< size_t, size_t, size_t > > ®, const std::vector< std::reference_wrapper< const AL > > &xx, const std::vector< std::reference_wrapper< const AR > > &yy, std::vector< std::reference_wrapper< value_type > > &out)
Default implementation of fused_dot without any simplification.
Definition: ArrayHandler.h:283
virtual std::map< size_t, value_type > select(size_t n, const AL &x, bool max=false, bool ignore_sign=false)=0
Select n indices with largest (or smallest) actual (or absolute) value.
typename array::mapped_or_value_type_t< AR > value_type_R
Definition: ArrayHandler.h:180
std::vector< std::weak_ptr< LazyHandle > > m_lazy_handles
keeps track of all created lazy handles
Definition: ArrayHandler.h:416
virtual void scal(value_type alpha, AL &x)=0
typename array::mapped_or_value_type_t< AL > value_type_L
Definition: ArrayHandler.h:179
virtual void gemm_outer(const Matrix< value_type > alphas, const CVecRef< AR > &xx, const VecRef< AL > &yy)=0
decltype(value_type_L{} *value_type_R{}) value_type
Definition: ArrayHandler.h:181
virtual AL copy(const AR &source)=0
ArrayHandler(const ArrayHandler &)=default
virtual void copy(AL &x, const AR &y)=0
Copy content of y into x.
virtual ProxyHandle lazy_handle()=0
Returns a lazy handle. Most implementations simply need to call the overload: return lazy_handle(*thi...
virtual void axpy(value_type alpha, const AR &x, AL &y)=0
virtual Matrix< value_type > gemm_inner(const CVecRef< AL > &xx, const CVecRef< AR > &yy)=0
virtual void fill(value_type alpha, AL &x)=0
virtual void fused_axpy(const std::vector< std::tuple< size_t, size_t, size_t > > ®, const std::vector< value_type > &alphas, const std::vector< std::reference_wrapper< const AR > > &xx, std::vector< std::reference_wrapper< AL > > &yy)
Default implementation of fused_axpy without any simplification.
Definition: ArrayHandler.h:271
virtual std::map< size_t, value_type_abs > select_max_dot(size_t n, const AL &x, const AR &y)=0
Select n indices with largest by absolute value contributions to the dot product.
virtual ~ArrayHandler()
Destroys ArrayHandler instance and invalidates any LazyHandler it created. Invalidated handler will n...
Definition: ArrayHandler.h:256
ArrayHandler()
Definition: ArrayHandler.h:164
std::string counter_to_string(std::string L, std::string R)
Definition: ArrayHandler.h:226
decltype(check_abs< value_type >()) value_type_abs
Definition: ArrayHandler.h:182
const Counter & counter() const
Definition: ArrayHandler.h:224
ProxyHandle lazy_handle(ArrayHandler< AL, AR > &handler)
Definition: ArrayHandler.h:428
virtual void error(const std::string &message)
Throws an error.
Definition: ArrayHandler.h:268
void save_handle(const std::shared_ptr< LazyHandle > &handle)
Save weak ptr to a lazy handle.
Definition: ArrayHandler.h:419
void clear_counter()
Definition: ArrayHandler.h:246
Matrix container that allows simple data access, slicing, copying and resizing without loosing data.
Definition: Matrix.h:28
std::tuple< std::vector< std::tuple< size_t, size_t, size_t > >, std::vector< X >, std::vector< Y >, std::vector< Z > > remove_duplicates(const std::list< std::tuple< X, Y, Z > > ®, EqualX equal_x, EqualY equal_y, EqualZ equal_z)
Find duplicates references to x and y arrays and store unique elements in a separate vector.
Definition: ArrayHandler.h:74
Definition: ArrayHandler.h:22
typename mapped_or_value_type< A >::type mapped_or_value_type_t
Definition: type_traits.h:37
std::vector< std::reference_wrapper< const A > > CVecRef
Definition: wrap.h:14
std::vector< std::reference_wrapper< A > > VecRef
Definition: wrap.h:11
Definition: ArrayHandler.h:167
int dot
Definition: ArrayHandler.h:169
int gemm_inner
Definition: ArrayHandler.h:172
int copy
Definition: ArrayHandler.h:171
int axpy
Definition: ArrayHandler.h:170
int gemm_outer
Definition: ArrayHandler.h:173
int scal
Definition: ArrayHandler.h:168
Definition: ArrayHandler.h:25
Definition: ArrayHandler.h:32
std::list< std::tuple< Args... > > m_register
ordered register of operations
Definition: ArrayHandler.h:34
void push(const Args &...args, ArgEqual equal)
Definition: ArrayHandler.h:41
void clear()
Definition: ArrayHandler.h:58
bool empty()
Definition: ArrayHandler.h:57
std::tuple< Args... > OP
Definition: ArrayHandler.h:33
void push(const Args &...args)
Register each operation as it comes with no reordering.
Definition: ArrayHandler.h:55
When called returns true if addresses of two references are the same.
Definition: ArrayHandler.h:104
bool operator()(const std::reference_wrapper< T > &l, const std::reference_wrapper< T > &r)
Definition: ArrayHandler.h:105