src/core/vector_base.h
Namespaces
Name |
---|
dakku dakku namespace |
Classes
Name | |
---|---|
class | dakku::VectorBase vector base |
Source code
#ifndef DAKKU_CORE_VECTOR_BASE_H_
#define DAKKU_CORE_VECTOR_BASE_H_
#include <core/logger.h>
#include <core/lua.h>
#include <array>
#include <numeric>
#include <span>
namespace dakku {
template <ArithmeticType T, size_t S, typename D>
class VectorBase {
public:
VectorBase() : _data() {}
template <ArithmeticType Arg>
VectorBase(Arg value) {
set(value);
DAKKU_CHECK(!has_nans(), "has nan");
}
VectorBase(const sol::table &table) {
for (size_t i = 0; i < S; ++i) _data[i] = table.get_or(i + 1, T{0});
DAKKU_CHECK(!has_nans(), "has nan");
}
template <ArithmeticType... Args>
requires(sizeof...(Args) == S) VectorBase(Args &&...args) {
set(std::forward<Args>(args)...);
}
template <ArithmeticType Other, typename OtherDerived>
explicit VectorBase(const VectorBase<Other, S, OtherDerived> &other) {
set(other);
DAKKU_CHECK(!has_nans(), "has nan");
}
VectorBase(const VectorBase &other) : _data(other._data) {
DAKKU_CHECK(!has_nans(), "has nan");
}
VectorBase(VectorBase &&other) noexcept : _data(std::move(other._data)) {
DAKKU_CHECK(!has_nans(), "has nan");
}
VectorBase &operator=(const VectorBase &other) {
if (this == &other) return *this;
_data = other._data;
DAKKU_CHECK(!has_nans(), "has nan");
return *this;
}
VectorBase &operator=(VectorBase &&other) noexcept {
if (this == &other) return *this;
_data = std::move(other._data);
DAKKU_CHECK(!has_nans(), "has nan");
return *this;
}
const D &derived() const { return static_cast<const D &>(*this); }
D &derived() {
return const_cast<D &>(static_cast<const VectorBase &>(*this).derived());
}
template <ArithmeticType Arg>
void set(Arg value) {
_data.fill(static_cast<T>(value));
}
template <ArithmeticType Arg>
void set_by_index(size_t index, Arg value) {
DAKKU_CHECK(0 <= index && index < S, "index out of range: {} >= {}", index,
S);
_data[index] = static_cast<T>(value);
}
template <ArithmeticType... Args, size_t... Is>
requires(sizeof...(Args) == S) void set(std::index_sequence<Is...>,
Args &&...args) {
(set_by_index(Is, std::forward<Args>(args)), ...);
}
template <ArithmeticType... Args>
requires(sizeof...(Args) == S) void set(Args &&...args) {
set(std::index_sequence_for<Args...>{}, std::forward<Args>(args)...);
}
template <ArithmeticType Other, typename OtherDerived>
void set(const VectorBase<Other, S, OtherDerived> &rhs) {
for (size_t i = 0; i < S; ++i) _data[i] = static_cast<T>(rhs[i]);
}
const T &get(size_t i) const {
DAKKU_CHECK(0 <= i && i < S, "index out of range {} >= {}", i, S);
return _data[i];
}
[[nodiscard]] size_t size() const { return S; }
[[nodiscard]] std::string to_string() const {
std::string ret{"["};
for (size_t i = 0; i < _data.size(); ++i) {
ret += std::to_string(_data[i]);
if (i + 1 != _data.size()) ret += ", ";
}
return ret + "]";
}
[[nodiscard]] bool has_nans() const {
return std::any_of(std::begin(_data), std::end(_data),
[](T x) { return isnan(x); });
}
D clone() const { return D{derived()}; }
const T &operator[](size_t i) const { return _data[i]; }
T &operator[](size_t i) {
return const_cast<T &>(static_cast<const VectorBase &>(*this)[i]);
}
D &operator+=(const D &rhs) {
for (size_t i = 0; i < S; ++i) _data[i] += rhs[i];
return derived();
}
template <ArithmeticType V>
D &operator+=(V rhs) {
for (size_t i = 0; i < S; ++i) _data[i] += rhs;
return derived();
}
D operator+(const D &rhs) const {
D ret = derived();
ret += rhs;
return ret;
}
template <ArithmeticType V>
D operator+(V rhs) const {
D ret = derived();
ret += rhs;
return ret;
}
template <ArithmeticType U>
friend D operator+(U a, const D &b) {
return b + a;
}
D &operator-=(const D &rhs) {
for (size_t i = 0; i < S; ++i) _data[i] -= rhs[i];
return derived();
}
template <ArithmeticType V>
D &operator-=(V rhs) {
for (size_t i = 0; i < S; ++i) _data[i] -= rhs;
return derived();
}
template <ArithmeticType U>
friend D operator-(U a, const D &b) {
D ret = -b;
ret += a;
return ret;
}
D &operator*=(const D &rhs) {
for (size_t i = 0; i < S; ++i) _data[i] *= rhs[i];
return derived();
}
template <ArithmeticType V>
D &operator*=(V rhs) {
for (size_t i = 0; i < S; ++i) _data[i] *= rhs;
return derived();
}
D operator*(const D &rhs) const {
D ret = derived();
ret *= rhs;
return ret;
}
template <ArithmeticType V>
D operator*(V rhs) const {
D ret = derived();
ret *= rhs;
return ret;
}
template <ArithmeticType U>
friend D operator*(U a, const D &b) {
return b * a;
}
D &operator/=(const D &rhs) {
for (size_t i = 0; i < S; ++i) _data[i] /= rhs[i];
return derived();
}
template <ArithmeticType V>
D &operator/=(V rhs) {
return derived() *= static_cast<T>(T{1} / rhs);
}
D operator/(const D &rhs) const {
D ret = derived();
ret /= rhs;
return ret;
}
template <ArithmeticType U>
D operator/(U rhs) const {
D ret = derived();
ret /= rhs;
return ret;
}
template <ArithmeticType U>
friend D operator/(U a, const D &b) {
return D(a) / b;
}
friend std::ostream &operator<<(std::ostream &os, const VectorBase &vec) {
return os << vec.to_string();
}
bool operator==(const VectorBase &rhs) const { return _data == rhs._data; }
bool operator!=(const VectorBase &rhs) const { return _data != rhs._data; }
decltype(auto) x() const {
static_assert(S >= 1, "not enough size to get x");
return _data[0];
}
decltype(auto) x() {
static_assert(S >= 1, "not enough size to get x");
return _data[0];
}
decltype(auto) y() const {
static_assert(S >= 2, "not enough size to get y");
return _data[1];
}
decltype(auto) y() {
static_assert(S >= 2, "not enough size to get y");
return _data[1];
}
decltype(auto) z() const {
static_assert(S >= 3, "not enough size to get z");
return _data[2];
}
decltype(auto) z() {
static_assert(S >= 3, "not enough size to get w");
return _data[2];
}
decltype(auto) w() const {
static_assert(S >= 4, "not enough size to get w");
return _data[3];
}
decltype(auto) w() {
static_assert(S >= 4, "not enough size to get z");
return _data[3];
}
[[nodiscard]] size_t max_element_index() const {
return std::distance(_data.begin(),
std::max_element(_data.begin(), _data.end()));
}
decltype(auto) max_element() const { return _data[max_element_index()]; }
friend D max(const D &v1, const D &v2) {
D ret = v1;
for (size_t i = 0; i < S; ++i) ret[i] = std::max(ret[i], v2[i]);
return ret;
}
D max(const D &rhs) const { return max(derived(), rhs); }
friend D min(const D &v1, const D &v2) {
D ret = v1;
for (size_t i = 0; i < S; ++i) ret[i] = std::min(ret[i], v2[i]);
return ret;
}
D min(const D &rhs) const { return min(derived(), rhs); }
template <typename OtherDerived>
decltype(auto) dot(const VectorBase<T, S, OtherDerived> &rhs) const {
return std::inner_product(_data.begin(), _data.end(), rhs._data.begin(),
T{});
}
decltype(auto) squared_norm() const { return this->dot(*this); }
decltype(auto) norm() const { return std::sqrt(squared_norm()); }
decltype(auto) length() const { return norm(); }
friend decltype(auto) distance(const D &a, const D &b) {
return (a - b).length();
}
decltype(auto) distance(const D &rhs) const {
return distance(derived(), rhs);
}
friend D abs(const D &v) {
D ret = v;
for (size_t i = 0; i < S; ++i) ret[i] = std::abs(ret[i]);
return ret;
}
D abs() const { return abs(derived()); }
D cross(const D &rhs) const {
static_assert(S == 3, "only 3d vector support cross product");
return D{(y() * rhs.z()) - (z() * rhs.y()),
(z() * rhs.x()) - (x() * rhs.z()),
(x() * rhs.y()) - (y() * rhs.x())};
}
[[nodiscard]] bool is_zero() const {
return std::all_of(_data.begin(), _data.end(),
[](const T &v) { return v == 0; });
}
friend D sqrt(const D &v) {
D ret = v;
for (size_t i = 0; i < S; ++i) ret[i] = static_cast<T>(std::sqrt(ret[i]));
return ret;
}
D sqrt() const { return sqrt(derived()); }
template <ArithmeticType E>
friend D pow(const D &v, E e) {
D ret = v;
for (size_t i = 0; i < S; ++i) ret[i] = static_cast<T>(std::pow(ret[i], e));
return ret;
}
template <ArithmeticType E>
D pow(E e) const {
return pow(derived(), e);
}
friend D exp(const D &v) {
D ret = v;
for (size_t i = 0; i < S; ++i) ret[i] = static_cast<T>(std::exp(ret[i]));
return ret;
}
D exp() const { return exp(derived()); }
friend D floor(const D &v) {
D ret = v;
for (size_t i = 0; i < S; ++i) ret[i] = std::floor(ret[i]);
return ret;
}
D floor() const { return floor(derived()); }
friend D ceil(const D &v) {
D ret = v;
for (size_t i = 0; i < S; ++i) ret[i] = std::ceil(ret[i]);
return ret;
}
D ceil() const { return ceil(derived()); }
friend D lerp(const D &a, const D &b, T t) {
D ret;
for (size_t i = 0; i < S; ++i) ret[i] = std::lerp(a[i], b[i], t);
return ret;
}
D lerp(const D &b, T t) const { return lerp(derived(), b, t); }
operator std::span<T, S>() { return std::span{_data}; }
operator std::span<const T, S>() const { return std::span{_data}; }
private:
std::array<T, S> _data;
};
} // namespace dakku
#endif
Updated on 2022-04-30 at 15:46:11 +0000