Program Listing for File gmmtmap.hpp

Return to documentation for file (include/mod/gmmtmap.hpp)

/*
 *   Copyright (c) Chittaranjan Srinivas Swaminathan
 *   This file is part of mod.
 *
 *   mod is free software: you can redistribute it and/or
 *   modify it under the terms of the GNU Lesser General Public License as
 *   published by the Free Software Foundation, either version 3 of the License,
 *   or (at your option) any later version.
 *
 *   mod is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 *   GNU Lesser General Public License for more details.
 *
 *   You should have received a copy of the GNU Lesser General Public License
 *   along with mod.  If not, see
 *   <https://www.gnu.org/licenses/>.
 */

#pragma once

#include <Eigen/Core>
#include <array>
#include <boost/chrono.hpp>
#include <boost/geometry.hpp>
#include <boost/geometry/geometries/point_xy.hpp>
#include <boost/geometry/index/rtree.hpp>
#include <boost/log/trivial.hpp>
#include <mod/base.hpp>
#include <vector>

namespace MoD {

namespace bg = boost::geometry;
namespace bgi = boost::geometry::index;

typedef bg::model::d2::point_xy<double> Point2D;
typedef bg::model::box<Point2D> Box;
typedef std::pair<Point2D, std::array<size_t, 2>> TreeValue;

struct GMMTMapCluster {
  double mixing_factor;

  std::vector<std::array<double, 2>> mean;

  std::vector<double> heading;

  inline GMMTMapCluster() = default;

  inline GMMTMapCluster(double pi, const std::vector<std::array<double, 2>> &mean, std::vector<double> heading) {
    this->mixing_factor = pi;
    this->mean = mean;
    this->heading = heading;
  }
};

class GMMTMap : public Base {
 public:
  explicit GMMTMap(const std::string &fileName) { readFromXML(fileName); }

  void readFromXML(const std::string &fileName);

  void computeHeadingAndConstructRTree();

  std::vector<TreeValue> getNearestNeighbors(double x, double y) const;

  inline std::vector<TreeValue> operator()(double x, double y) const { return this->getNearestNeighbors(x, y); };

  inline int getM() const { return M_; }

  inline int getK() const { return K_; }

  inline double getStdDev() const { return stddev_; }

  inline double getMixingFactorByClusterID(size_t cluster_idx) {
    if (cluster_idx >= this->clusters_.size()) {
      BOOST_LOG_TRIVIAL(error) << "getMixingFactorByClusterID() called with "
                                  "cluster_id >= number of clusters.";
      return 1.0;
    }

    return this->clusters_[cluster_idx].mixing_factor;
  }

  inline double getHeadingAtDist(size_t cluster_idx, size_t mean_idx) {
    if (cluster_idx >= this->clusters_.size()) {
      BOOST_LOG_TRIVIAL(error) << "getHeadingAtDist() called with cluster_idx "
                                  ">= number of clusters.";
      BOOST_LOG_TRIVIAL(error) << "Total clusters: " << this->clusters_.size() << ", Cluster ID: " << cluster_idx;
      return 0.0;
    }

    if (mean_idx >= this->clusters_[cluster_idx].heading.size()) {
      BOOST_LOG_TRIVIAL(error) << "getHeadingAtDist() called with mean_idx >= "
                                  "number of traj-means in cluster.";
      BOOST_LOG_TRIVIAL(error) << "Total means: " << this->clusters_[cluster_idx].heading.size()
                               << ", Cluster ID and Mean ID: " << cluster_idx << ", " << mean_idx;
      return 0.0;
    }

    return this->clusters_[cluster_idx].heading[mean_idx];
  }

 protected:
  int M_;

  int K_;

  double stddev_;

  std::vector<GMMTMapCluster> clusters_;

  bgi::rtree<TreeValue, bgi::quadratic<16>> rtree_;
};

typedef std::shared_ptr<GMMTMap> GMMTMapPtr;
typedef std::shared_ptr<const GMMTMap> GMMTMapConstPtr;

}  // namespace MoD