Skip to content

src/core/vector_base.h

Namespaces

Name
dakku
dakku namespace

Classes

Name
class dakku::VectorBase
vector base

Source code

#ifndef DAKKU_CORE_VECTOR_BASE_H_
#define DAKKU_CORE_VECTOR_BASE_H_
#include <core/logger.h>
#include <core/lua.h>

#include <array>
#include <numeric>
#include <span>

namespace dakku {

template <ArithmeticType T, size_t S, typename D>
class VectorBase {
 public:
  VectorBase() : _data() {}

  template <ArithmeticType Arg>
  VectorBase(Arg value) {
    set(value);
    DAKKU_CHECK(!has_nans(), "has nan");
  }

  VectorBase(const sol::table &table) {
    for (size_t i = 0; i < S; ++i) _data[i] = table.get_or(i + 1, T{0});
    DAKKU_CHECK(!has_nans(), "has nan");
  }

  template <ArithmeticType... Args>
  requires(sizeof...(Args) == S) VectorBase(Args &&...args) {
    set(std::forward<Args>(args)...);
  }

  template <ArithmeticType Other, typename OtherDerived>
  explicit VectorBase(const VectorBase<Other, S, OtherDerived> &other) {
    set(other);
    DAKKU_CHECK(!has_nans(), "has nan");
  }

  VectorBase(const VectorBase &other) : _data(other._data) {
    DAKKU_CHECK(!has_nans(), "has nan");
  }
  VectorBase(VectorBase &&other) noexcept : _data(std::move(other._data)) {
    DAKKU_CHECK(!has_nans(), "has nan");
  }
  VectorBase &operator=(const VectorBase &other) {
    if (this == &other) return *this;
    _data = other._data;
    DAKKU_CHECK(!has_nans(), "has nan");
    return *this;
  }
  VectorBase &operator=(VectorBase &&other) noexcept {
    if (this == &other) return *this;
    _data = std::move(other._data);
    DAKKU_CHECK(!has_nans(), "has nan");
    return *this;
  }

  const D &derived() const { return static_cast<const D &>(*this); }

  D &derived() {
    return const_cast<D &>(static_cast<const VectorBase &>(*this).derived());
  }

  template <ArithmeticType Arg>
  void set(Arg value) {
    _data.fill(static_cast<T>(value));
  }

  template <ArithmeticType Arg>
  void set_by_index(size_t index, Arg value) {
    DAKKU_CHECK(0 <= index && index < S, "index out of range: {} >= {}", index,
                S);
    _data[index] = static_cast<T>(value);
  }

  template <ArithmeticType... Args, size_t... Is>
  requires(sizeof...(Args) == S) void set(std::index_sequence<Is...>,
                                          Args &&...args) {
    (set_by_index(Is, std::forward<Args>(args)), ...);
  }

  template <ArithmeticType... Args>
  requires(sizeof...(Args) == S) void set(Args &&...args) {
    set(std::index_sequence_for<Args...>{}, std::forward<Args>(args)...);
  }

  template <ArithmeticType Other, typename OtherDerived>
  void set(const VectorBase<Other, S, OtherDerived> &rhs) {
    for (size_t i = 0; i < S; ++i) _data[i] = static_cast<T>(rhs[i]);
  }

  const T &get(size_t i) const {
    DAKKU_CHECK(0 <= i && i < S, "index out of range {} >= {}", i, S);
    return _data[i];
  }

  [[nodiscard]] size_t size() const { return S; }

  [[nodiscard]] std::string to_string() const {
    std::string ret{"["};
    for (size_t i = 0; i < _data.size(); ++i) {
      ret += std::to_string(_data[i]);
      if (i + 1 != _data.size()) ret += ", ";
    }
    return ret + "]";
  }

  [[nodiscard]] bool has_nans() const {
    return std::any_of(std::begin(_data), std::end(_data),
                       [](T x) { return isnan(x); });
  }

  D clone() const { return D{derived()}; }

  const T &operator[](size_t i) const { return _data[i]; }

  T &operator[](size_t i) {
    return const_cast<T &>(static_cast<const VectorBase &>(*this)[i]);
  }

  D &operator+=(const D &rhs) {
    for (size_t i = 0; i < S; ++i) _data[i] += rhs[i];
    return derived();
  }

  template <ArithmeticType V>
  D &operator+=(V rhs) {
    for (size_t i = 0; i < S; ++i) _data[i] += rhs;
    return derived();
  }

  D operator+(const D &rhs) const {
    D ret = derived();
    ret += rhs;
    return ret;
  }

  template <ArithmeticType V>
  D operator+(V rhs) const {
    D ret = derived();
    ret += rhs;
    return ret;
  }

  template <ArithmeticType U>
  friend D operator+(U a, const D &b) {
    return b + a;
  }

  D &operator-=(const D &rhs) {
    for (size_t i = 0; i < S; ++i) _data[i] -= rhs[i];
    return derived();
  }

  template <ArithmeticType V>
  D &operator-=(V rhs) {
    for (size_t i = 0; i < S; ++i) _data[i] -= rhs;
    return derived();
  }

  template <ArithmeticType U>
  friend D operator-(U a, const D &b) {
    D ret = -b;
    ret += a;
    return ret;
  }

  D &operator*=(const D &rhs) {
    for (size_t i = 0; i < S; ++i) _data[i] *= rhs[i];
    return derived();
  }

  template <ArithmeticType V>
  D &operator*=(V rhs) {
    for (size_t i = 0; i < S; ++i) _data[i] *= rhs;
    return derived();
  }

  D operator*(const D &rhs) const {
    D ret = derived();
    ret *= rhs;
    return ret;
  }

  template <ArithmeticType V>
  D operator*(V rhs) const {
    D ret = derived();
    ret *= rhs;
    return ret;
  }

  template <ArithmeticType U>
  friend D operator*(U a, const D &b) {
    return b * a;
  }

  D &operator/=(const D &rhs) {
    for (size_t i = 0; i < S; ++i) _data[i] /= rhs[i];
    return derived();
  }

  template <ArithmeticType V>
  D &operator/=(V rhs) {
    return derived() *= static_cast<T>(T{1} / rhs);
  }

  D operator/(const D &rhs) const {
    D ret = derived();
    ret /= rhs;
    return ret;
  }

  template <ArithmeticType U>
  D operator/(U rhs) const {
    D ret = derived();
    ret /= rhs;
    return ret;
  }

  template <ArithmeticType U>
  friend D operator/(U a, const D &b) {
    return D(a) / b;
  }

  friend std::ostream &operator<<(std::ostream &os, const VectorBase &vec) {
    return os << vec.to_string();
  }

  bool operator==(const VectorBase &rhs) const { return _data == rhs._data; }

  bool operator!=(const VectorBase &rhs) const { return _data != rhs._data; }

  decltype(auto) x() const {
    static_assert(S >= 1, "not enough size to get x");
    return _data[0];
  }

  decltype(auto) x() {
    static_assert(S >= 1, "not enough size to get x");
    return _data[0];
  }

  decltype(auto) y() const {
    static_assert(S >= 2, "not enough size to get y");
    return _data[1];
  }

  decltype(auto) y() {
    static_assert(S >= 2, "not enough size to get y");
    return _data[1];
  }

  decltype(auto) z() const {
    static_assert(S >= 3, "not enough size to get z");
    return _data[2];
  }

  decltype(auto) z() {
    static_assert(S >= 3, "not enough size to get w");
    return _data[2];
  }

  decltype(auto) w() const {
    static_assert(S >= 4, "not enough size to get w");
    return _data[3];
  }

  decltype(auto) w() {
    static_assert(S >= 4, "not enough size to get z");
    return _data[3];
  }

  [[nodiscard]] size_t max_element_index() const {
    return std::distance(_data.begin(),
                         std::max_element(_data.begin(), _data.end()));
  }

  decltype(auto) max_element() const { return _data[max_element_index()]; }

  friend D max(const D &v1, const D &v2) {
    D ret = v1;
    for (size_t i = 0; i < S; ++i) ret[i] = std::max(ret[i], v2[i]);
    return ret;
  }

  D max(const D &rhs) const { return max(derived(), rhs); }

  friend D min(const D &v1, const D &v2) {
    D ret = v1;
    for (size_t i = 0; i < S; ++i) ret[i] = std::min(ret[i], v2[i]);
    return ret;
  }

  D min(const D &rhs) const { return min(derived(), rhs); }

  template <typename OtherDerived>
  decltype(auto) dot(const VectorBase<T, S, OtherDerived> &rhs) const {
    return std::inner_product(_data.begin(), _data.end(), rhs._data.begin(),
                              T{});
  }

  decltype(auto) squared_norm() const { return this->dot(*this); }

  decltype(auto) norm() const { return std::sqrt(squared_norm()); }

  decltype(auto) length() const { return norm(); }

  friend decltype(auto) distance(const D &a, const D &b) {
    return (a - b).length();
  }

  decltype(auto) distance(const D &rhs) const {
    return distance(derived(), rhs);
  }

  friend D abs(const D &v) {
    D ret = v;
    for (size_t i = 0; i < S; ++i) ret[i] = std::abs(ret[i]);
    return ret;
  }

  D abs() const { return abs(derived()); }

  D cross(const D &rhs) const {
    static_assert(S == 3, "only 3d vector support cross product");
    return D{(y() * rhs.z()) - (z() * rhs.y()),
             (z() * rhs.x()) - (x() * rhs.z()),
             (x() * rhs.y()) - (y() * rhs.x())};
  }

  [[nodiscard]] bool is_zero() const {
    return std::all_of(_data.begin(), _data.end(),
                       [](const T &v) { return v == 0; });
  }

  friend D sqrt(const D &v) {
    D ret = v;
    for (size_t i = 0; i < S; ++i) ret[i] = static_cast<T>(std::sqrt(ret[i]));
    return ret;
  }

  D sqrt() const { return sqrt(derived()); }

  template <ArithmeticType E>
  friend D pow(const D &v, E e) {
    D ret = v;
    for (size_t i = 0; i < S; ++i) ret[i] = static_cast<T>(std::pow(ret[i], e));
    return ret;
  }

  template <ArithmeticType E>
  D pow(E e) const {
    return pow(derived(), e);
  }

  friend D exp(const D &v) {
    D ret = v;
    for (size_t i = 0; i < S; ++i) ret[i] = static_cast<T>(std::exp(ret[i]));
    return ret;
  }

  D exp() const { return exp(derived()); }

  friend D floor(const D &v) {
    D ret = v;
    for (size_t i = 0; i < S; ++i) ret[i] = std::floor(ret[i]);
    return ret;
  }

  D floor() const { return floor(derived()); }

  friend D ceil(const D &v) {
    D ret = v;
    for (size_t i = 0; i < S; ++i) ret[i] = std::ceil(ret[i]);
    return ret;
  }

  D ceil() const { return ceil(derived()); }

  friend D lerp(const D &a, const D &b, T t) {
    D ret;
    for (size_t i = 0; i < S; ++i) ret[i] = std::lerp(a[i], b[i], t);
    return ret;
  }

  D lerp(const D &b, T t) const { return lerp(derived(), b, t); }

  operator std::span<T, S>() { return std::span{_data}; }
  operator std::span<const T, S>() const { return std::span{_data}; }

 private:
  std::array<T, S> _data;
};
}  // namespace dakku

#endif

Updated on 2022-04-30 at 15:46:11 +0000