iterative-solver 0.0
ArrayHandler.h
1#ifndef LINEARALGEBRA_SRC_MOLPRO_LINALG_ARRAY_ARRAYHANDLER_H
2#define LINEARALGEBRA_SRC_MOLPRO_LINALG_ARRAY_ARRAYHANDLER_H
3#include <algorithm>
4#include <functional>
5#include <list>
6#include <map>
7#include <memory>
8#include <numeric>
9#include <set>
10#include <stdexcept>
11#include <string>
12#include <vector>
13
14#include <molpro/linalg/array/type_traits.h>
15#include <molpro/linalg/itsolv/subspace/Matrix.h>
16#include <molpro/linalg/itsolv/wrap_util.h>
17
21
23namespace util {
24
25struct ArrayHandlerError : public std::logic_error {
26 using std::logic_error::logic_error;
27};
28
31template <typename... Args>
33 using OP = std::tuple<Args...>;
34 std::list<std::tuple<Args...>> m_register;
35
40 template <int N, class ArgEqual>
41 void push(const Args &...args, ArgEqual equal) {
42 auto &&new_op = OP{args...};
43 auto &ref = std::get<N>(new_op);
44 auto rend = std::find_if(m_register.rbegin(), m_register.rend(), [&ref, &equal](const auto &el) {
45 auto xx = std::get<N>(el);
46 return equal(ref, xx);
47 });
48 auto end_of_group = m_register.end();
49 if (rend != m_register.rend())
50 end_of_group = rend.base();
51 m_register.insert(end_of_group, std::forward<OP>(new_op));
52 }
53
55 void push(const Args &...args) { m_register.push_back({args...}); }
56
57 bool empty() { return m_register.empty(); }
58 void clear() { m_register.clear(); }
59};
60
72template <typename X, typename Y, typename Z, class EqualX, class EqualY, class EqualZ>
73std::tuple<std::vector<std::tuple<size_t, size_t, size_t>>, std::vector<X>, std::vector<Y>, std::vector<Z>>
74remove_duplicates(const std::list<std::tuple<X, Y, Z>> &reg, EqualX equal_x, EqualY equal_y, EqualZ equal_z) {
75 auto n_op = reg.size();
76 std::vector<std::tuple<size_t, size_t, size_t>> op_register;
77 op_register.reserve(n_op);
78 std::vector<X> xx;
79 std::vector<Y> yy;
80 std::vector<Z> zz;
81 for (const auto &op : reg) {
82 auto x = std::get<0>(op);
83 auto y = std::get<1>(op);
84 auto z = std::get<2>(op);
85 auto it_x = std::find_if(cbegin(xx), cend(xx), [&x, &equal_x](const auto &el) { return equal_x(x, el); });
86 auto it_y = std::find_if(cbegin(yy), cend(yy), [&y, &equal_y](const auto &el) { return equal_y(y, el); });
87 auto it_z = std::find_if(cbegin(zz), cend(zz), [&z, &equal_z](const auto &el) { return equal_z(z, el); });
88 auto ix = distance(cbegin(xx), it_x);
89 auto iy = distance(cbegin(yy), it_y);
90 auto iz = distance(cbegin(zz), it_z);
91 if (it_x == cend(xx))
92 xx.push_back(x);
93 if (it_y == cend(yy))
94 yy.push_back(y);
95 if (it_z == cend(zz))
96 zz.push_back(z);
97 op_register.emplace_back(ix, iy, iz);
98 }
99 return {op_register, xx, yy, zz};
100}
101
103template <typename T = int>
104struct RefEqual {
105 bool operator()(const std::reference_wrapper<T> &l, const std::reference_wrapper<T> &r) {
106 return std::addressof(l.get()) == std::addressof(r.get());
107 }
108};
109} // namespace util
110
161template <class AL, class AR = AL>
163protected:
164 ArrayHandler() : m_counter(std::make_unique<Counter>()){};
165 ArrayHandler(const ArrayHandler &) = default;
166
167 struct Counter {
168 int scal = 0;
169 int dot = 0;
170 int axpy = 0;
171 int copy = 0;
172 int gemm_inner = 0;
173 int gemm_outer = 0;
174 };
175
176 std::unique_ptr<Counter> m_counter;
177
178public:
181 using value_type = decltype(value_type_L{} * value_type_R{});
182 using value_type_abs = decltype(check_abs<value_type>());
183
184 virtual AL copy(const AR &source) = 0;
186 virtual void copy(AL &x, const AR &y) = 0;
187 virtual void scal(value_type alpha, AL &x) = 0;
188 virtual void fill(value_type alpha, AL &x) = 0;
189 virtual void axpy(value_type alpha, const AR &x, AL &y) = 0;
190 virtual value_type dot(const AL &x, const AR &y) = 0;
191
195 virtual void gemm_outer(const Matrix<value_type> alphas, const CVecRef<AR> &xx, const VecRef<AL> &yy) = 0;
196
200 virtual Matrix<value_type> gemm_inner(const CVecRef<AL> &xx, const CVecRef<AR> &yy) = 0;
201
212 virtual std::map<size_t, value_type_abs> select_max_dot(size_t n, const AL &x, const AR &y) = 0;
213
222 virtual std::map<size_t, value_type> select(size_t n, const AL &x, bool max = false, bool ignore_sign = false) = 0;
223
224 const Counter &counter() const { return *m_counter; }
225
226 std::string counter_to_string(std::string L, std::string R) {
227 std::string output = "";
228 if (m_counter->scal > 0)
229 output.append(std::to_string(m_counter->scal) + " scaling operations of the " + L + " vectors, ");
230 if (m_counter->copy > 0)
231 output.append(std::to_string(m_counter->copy) + " " + L + "<-" + R + " copy operations, ");
232 if (m_counter->dot > 0)
233 output.append(std::to_string(m_counter->dot) + " dot product operations between the " + L + " and " + R +
234 " vectors, ");
235 if (m_counter->axpy > 0)
236 output.append(std::to_string(m_counter->axpy) + " axpy (" + R + " = a*" + L + " + " + R + ") operations, ");
237 if (m_counter->gemm_inner > 0)
238 output.append(std::to_string(m_counter->gemm_inner) + " gemm_inner operations between the " + L + " and " + R +
239 " vectors, ");
240 if (m_counter->gemm_outer > 0)
241 output.append(std::to_string(m_counter->gemm_outer) + " gemm_outer operations between the " + L + " and " + R +
242 " vectors, ");
243 return output;
244 };
245
247 m_counter->scal = 0;
248 m_counter->copy = 0;
249 m_counter->dot = 0;
250 m_counter->axpy = 0;
251 m_counter->gemm_inner = 0;
252 m_counter->gemm_outer = 0;
253 }
254
256 virtual ~ArrayHandler() {
257 std::for_each(m_lazy_handles.begin(), m_lazy_handles.end(), [](auto &el) {
258 if (auto handle = el.lock())
259 handle->invalidate();
260 });
261 }
262
263protected:
268 virtual void error(const std::string &message) { throw util::ArrayHandlerError{message}; };
269
271 virtual void fused_axpy(const std::vector<std::tuple<size_t, size_t, size_t>> &reg,
272 const std::vector<value_type> &alphas,
273 const std::vector<std::reference_wrapper<const AR>> &xx,
274 std::vector<std::reference_wrapper<AL>> &yy) {
275 for (const auto &i : reg) {
276 size_t ai, xi, yi;
277 std::tie(ai, xi, yi) = i;
278 axpy(alphas[ai], xx[xi].get(), yy[yi].get());
279 }
280 }
281
283 virtual void fused_dot(const std::vector<std::tuple<size_t, size_t, size_t>> &reg,
284 const std::vector<std::reference_wrapper<const AL>> &xx,
285 const std::vector<std::reference_wrapper<const AR>> &yy,
286 std::vector<std::reference_wrapper<value_type>> &out) {
287 for (const auto &i : reg) {
288 size_t xi, yi, zi;
289 std::tie(xi, yi, zi) = i;
290 out[zi].get() = dot(xx[xi].get(), yy[yi].get());
291 }
292 }
293
299 public:
301 template <typename T>
302 using ref_wrap = std::reference_wrapper<T>;
303
304 protected:
306 std::set<std::string> m_op_types;
311
312 void error(std::string message) { m_handler.error(message); };
313
322 virtual bool register_op_type(const std::string &type) {
323 if (m_op_types.count(type) == 0 && !m_op_types.empty())
324 return false;
325 m_op_types.insert(type);
326 return true;
327 }
328
330 void clear() {
331 m_op_types.clear();
332 m_axpy.clear();
333 m_dot.clear();
334 }
335
336 public:
337 explicit LazyHandle(ArrayHandler<AL, AR> &handler) : m_handler{handler} {}
339
340 virtual void axpy(value_type alpha, const AR &x, AL &y) {
341 if (register_op_type("axpy"))
342 m_axpy.push(alpha, std::cref(x), std::ref(y));
343 else
344 error("Failed to register operation type axpy with the current state of the LazyHandle");
345 }
346 virtual void dot(const AL &x, const AR &y, value_type &out) {
347 if (register_op_type("dotLR"))
348 m_dot.push(std::cref(x), std::cref(y), std::ref(out));
349 else
350 error("Failed to register operation type dot with the current state of the LazyHandle");
351 }
352
354 virtual void eval() {
355 if (m_invalid)
356 return;
357 if (!m_axpy.empty()) {
358 auto reg = util::remove_duplicates<value_type, ref_wrap<const AR>, ref_wrap<AL>, std::equal_to<value_type>,
360 m_handler.fused_axpy(std::get<0>(reg), std::get<1>(reg), std::get<2>(reg), std::get<3>(reg));
361 }
362 if (!m_dot.empty()) {
363 auto reg =
364 util::remove_duplicates<ref_wrap<const AL>, ref_wrap<const AR>, ref_wrap<value_type>,
366 m_dot.m_register, {}, {}, {});
367 m_handler.fused_dot(std::get<0>(reg), std::get<1>(reg), std::get<2>(reg), std::get<3>(reg));
368 }
369 clear();
370 }
371
373 void invalidate() { m_invalid = true; }
376 bool invalid() { return m_invalid; }
377
378 protected:
380 bool m_invalid = false;
381 };
382
385 public:
386 ProxyHandle(std::shared_ptr<LazyHandle> handle) : m_lazy_handle{std::move(handle)} {}
387
388 template <typename... Args>
389 void axpy(Args &&...args) {
390 m_lazy_handle->axpy(std::forward<Args>(args)...);
391 if (m_off)
392 eval();
393 }
394 template <typename... Args>
395 void dot(Args &&...args) {
396 m_lazy_handle->dot(std::forward<Args>(args)...);
397 if (m_off)
398 eval();
399 }
400 void eval() { m_lazy_handle->eval(); }
401 void invalidate() { m_lazy_handle->invalidate(); }
402 bool invalid() { return m_lazy_handle->invalid(); }
403
405 void off() { m_off = true; };
407 void on() { m_off = false; };
409 bool is_off() { return m_off; }
410
411 protected:
412 std::shared_ptr<LazyHandle> m_lazy_handle;
413 bool m_off = false;
414 };
415
416 std::vector<std::weak_ptr<LazyHandle>> m_lazy_handles;
417
419 void save_handle(const std::shared_ptr<LazyHandle> &handle) {
420 auto empty_handle =
421 std::find_if(m_lazy_handles.begin(), m_lazy_handles.end(), [](const auto &el) { return el.expired(); });
422 if (empty_handle == m_lazy_handles.end())
423 m_lazy_handles.push_back(handle);
424 else
425 *empty_handle = handle;
426 }
427
429 auto handle = std::make_shared<typename ArrayHandler<AL, AR>::LazyHandle>(handler);
430 save_handle(handle);
431 return handle;
432 };
433
434public:
437};
438
439} // namespace molpro::linalg::array
440
441#endif // LINEARALGEBRA_SRC_MOLPRO_LINALG_ARRAY_ARRAYHANDLER_H
Registers operations for lazy evaluation. Evaluation is triggered by calling eval() or on destruction...
Definition: ArrayHandler.h:298
virtual void axpy(value_type alpha, const AR &x, AL &y)
Definition: ArrayHandler.h:340
virtual void eval()
Calls handler to evaluate the registered operations.
Definition: ArrayHandler.h:354
void clear()
Clear the registry.
Definition: ArrayHandler.h:330
virtual void dot(const AL &x, const AR &y, value_type &out)
Definition: ArrayHandler.h:346
virtual ~LazyHandle()
Definition: ArrayHandler.h:338
void error(std::string message)
Definition: ArrayHandler.h:312
void invalidate()
Flag the handler as invalid so that no new operations are registered operations eval() does nothing.
Definition: ArrayHandler.h:373
util::OperationRegister< ref_wrap< const AL >, ref_wrap< const AR >, ref_wrap< value_type > > m_dot
register of dot operations
Definition: ArrayHandler.h:310
LazyHandle(ArrayHandler< AL, AR > &handler)
Definition: ArrayHandler.h:337
ArrayHandler< AL, AR >::value_type value_type
Definition: ArrayHandler.h:300
std::reference_wrapper< T > ref_wrap
Definition: ArrayHandler.h:302
ArrayHandler< AL, AR > & m_handler
all operations are still done through the handler
Definition: ArrayHandler.h:379
std::set< std::string > m_op_types
Types of operations currently registered. Types are strings, because derived classes might add new op...
Definition: ArrayHandler.h:306
bool m_invalid
flags if the handler has been destroyed and LazyHandle is now invalid
Definition: ArrayHandler.h:380
util::OperationRegister< value_type, ref_wrap< const AR >, ref_wrap< AL > > m_axpy
register of axpy operations
Definition: ArrayHandler.h:308
bool invalid()
Definition: ArrayHandler.h:376
virtual bool register_op_type(const std::string &type)
Register an operation type.
Definition: ArrayHandler.h:322
A convenience wrapper around a pointer to the LazyHandle.
Definition: ArrayHandler.h:384
bool is_off()
Returns true if lazy evaluation is off.
Definition: ArrayHandler.h:409
ProxyHandle(std::shared_ptr< LazyHandle > handle)
Definition: ArrayHandler.h:386
void eval()
Definition: ArrayHandler.h:400
void axpy(Args &&...args)
Definition: ArrayHandler.h:389
void invalidate()
Definition: ArrayHandler.h:401
bool invalid()
Definition: ArrayHandler.h:402
bool m_off
whether lazy evaluation is on or off
Definition: ArrayHandler.h:413
void dot(Args &&...args)
Definition: ArrayHandler.h:395
void on()
Turn on lazy evaluation.
Definition: ArrayHandler.h:407
void off()
Turn off lazy evaluation. Next operation will evaluate without delay.
Definition: ArrayHandler.h:405
std::shared_ptr< LazyHandle > m_lazy_handle
Definition: ArrayHandler.h:412
Enhances various operations between pairs of arrays and allows dynamic code injection with uniform in...
Definition: ArrayHandler.h:162
virtual value_type dot(const AL &x, const AR &y)=0
std::unique_ptr< Counter > m_counter
Definition: ArrayHandler.h:176
virtual void fused_dot(const std::vector< std::tuple< size_t, size_t, size_t > > &reg, const std::vector< std::reference_wrapper< const AL > > &xx, const std::vector< std::reference_wrapper< const AR > > &yy, std::vector< std::reference_wrapper< value_type > > &out)
Default implementation of fused_dot without any simplification.
Definition: ArrayHandler.h:283
virtual std::map< size_t, value_type > select(size_t n, const AL &x, bool max=false, bool ignore_sign=false)=0
Select n indices with largest (or smallest) actual (or absolute) value.
typename array::mapped_or_value_type_t< AR > value_type_R
Definition: ArrayHandler.h:180
std::vector< std::weak_ptr< LazyHandle > > m_lazy_handles
keeps track of all created lazy handles
Definition: ArrayHandler.h:416
virtual void scal(value_type alpha, AL &x)=0
typename array::mapped_or_value_type_t< AL > value_type_L
Definition: ArrayHandler.h:179
virtual void gemm_outer(const Matrix< value_type > alphas, const CVecRef< AR > &xx, const VecRef< AL > &yy)=0
decltype(value_type_L{} *value_type_R{}) value_type
Definition: ArrayHandler.h:181
virtual AL copy(const AR &source)=0
ArrayHandler(const ArrayHandler &)=default
virtual void copy(AL &x, const AR &y)=0
Copy content of y into x.
virtual ProxyHandle lazy_handle()=0
Returns a lazy handle. Most implementations simply need to call the overload: return lazy_handle(*thi...
virtual void axpy(value_type alpha, const AR &x, AL &y)=0
virtual Matrix< value_type > gemm_inner(const CVecRef< AL > &xx, const CVecRef< AR > &yy)=0
virtual void fill(value_type alpha, AL &x)=0
virtual void fused_axpy(const std::vector< std::tuple< size_t, size_t, size_t > > &reg, const std::vector< value_type > &alphas, const std::vector< std::reference_wrapper< const AR > > &xx, std::vector< std::reference_wrapper< AL > > &yy)
Default implementation of fused_axpy without any simplification.
Definition: ArrayHandler.h:271
virtual std::map< size_t, value_type_abs > select_max_dot(size_t n, const AL &x, const AR &y)=0
Select n indices with largest by absolute value contributions to the dot product.
virtual ~ArrayHandler()
Destroys ArrayHandler instance and invalidates any LazyHandler it created. Invalidated handler will n...
Definition: ArrayHandler.h:256
ArrayHandler()
Definition: ArrayHandler.h:164
std::string counter_to_string(std::string L, std::string R)
Definition: ArrayHandler.h:226
decltype(check_abs< value_type >()) value_type_abs
Definition: ArrayHandler.h:182
const Counter & counter() const
Definition: ArrayHandler.h:224
ProxyHandle lazy_handle(ArrayHandler< AL, AR > &handler)
Definition: ArrayHandler.h:428
virtual void error(const std::string &message)
Throws an error.
Definition: ArrayHandler.h:268
void save_handle(const std::shared_ptr< LazyHandle > &handle)
Save weak ptr to a lazy handle.
Definition: ArrayHandler.h:419
void clear_counter()
Definition: ArrayHandler.h:246
Matrix container that allows simple data access, slicing, copying and resizing without loosing data.
Definition: Matrix.h:28
std::tuple< std::vector< std::tuple< size_t, size_t, size_t > >, std::vector< X >, std::vector< Y >, std::vector< Z > > remove_duplicates(const std::list< std::tuple< X, Y, Z > > &reg, EqualX equal_x, EqualY equal_y, EqualZ equal_z)
Find duplicates references to x and y arrays and store unique elements in a separate vector.
Definition: ArrayHandler.h:74
Definition: ArrayHandler.h:22
typename mapped_or_value_type< A >::type mapped_or_value_type_t
Definition: type_traits.h:37
std::vector< std::reference_wrapper< const A > > CVecRef
Definition: wrap.h:14
std::vector< std::reference_wrapper< A > > VecRef
Definition: wrap.h:11
Definition: ArrayHandler.h:167
int dot
Definition: ArrayHandler.h:169
int gemm_inner
Definition: ArrayHandler.h:172
int copy
Definition: ArrayHandler.h:171
int axpy
Definition: ArrayHandler.h:170
int gemm_outer
Definition: ArrayHandler.h:173
int scal
Definition: ArrayHandler.h:168
std::list< std::tuple< Args... > > m_register
ordered register of operations
Definition: ArrayHandler.h:34
void push(const Args &...args, ArgEqual equal)
Definition: ArrayHandler.h:41
void clear()
Definition: ArrayHandler.h:58
bool empty()
Definition: ArrayHandler.h:57
std::tuple< Args... > OP
Definition: ArrayHandler.h:33
void push(const Args &...args)
Register each operation as it comes with no reordering.
Definition: ArrayHandler.h:55
When called returns true if addresses of two references are the same.
Definition: ArrayHandler.h:104
bool operator()(const std::reference_wrapper< T > &l, const std::reference_wrapper< T > &r)
Definition: ArrayHandler.h:105