[mlpack] 75/207: First pass on command-line program.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:42 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit e1e6af85ec1e8baf7dd06dd3d24a5b755fd781a1
Author: Ryan Curtin <ryan at ratml.org>
Date: Thu Aug 4 17:58:05 2016 -0400
First pass on command-line program.
---
src/mlpack/methods/dbscan/CMakeLists.txt | 2 +
src/mlpack/methods/dbscan/dbscan.hpp | 8 +-
src/mlpack/methods/dbscan/dbscan_impl.hpp | 11 ++-
src/mlpack/methods/dbscan/dbscan_main.cpp | 118 ++++++++++++++++++++++++++++++
4 files changed, 133 insertions(+), 6 deletions(-)
diff --git a/src/mlpack/methods/dbscan/CMakeLists.txt b/src/mlpack/methods/dbscan/CMakeLists.txt
index 92ecf71..7c7e67a 100644
--- a/src/mlpack/methods/dbscan/CMakeLists.txt
+++ b/src/mlpack/methods/dbscan/CMakeLists.txt
@@ -14,3 +14,5 @@ endforeach()
# Append sources (with directory name) to list of all mlpack sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
+
+add_cli_executable(dbscan)
diff --git a/src/mlpack/methods/dbscan/dbscan.hpp b/src/mlpack/methods/dbscan/dbscan.hpp
index 94d7e99..abe11d2 100644
--- a/src/mlpack/methods/dbscan/dbscan.hpp
+++ b/src/mlpack/methods/dbscan/dbscan.hpp
@@ -28,7 +28,9 @@ class DBSCAN
* @param minPoints Minimum number of points for each cluster.
*/
DBSCAN(const double epsilon,
- const size_t minPoints);
+ const size_t minPoints,
+ RangeSearchType rangeSearch = RangeSearchType(),
+ PointSelectionPolicy pointSelector = PointSelectionPolicy());
template<typename MatType>
size_t Cluster(const MatType& data,
@@ -46,10 +48,10 @@ class DBSCAN
arma::mat& centroids);
private:
- RangeSearchType rangeSearch;
- PointSelectionPolicy pointSelector;
double epsilon;
size_t minPoints;
+ RangeSearchType rangeSearch;
+ PointSelectionPolicy pointSelector;
template<typename MatType>
size_t ProcessPoint(const MatType& data,
diff --git a/src/mlpack/methods/dbscan/dbscan_impl.hpp b/src/mlpack/methods/dbscan/dbscan_impl.hpp
index ea89450..a2caecc 100644
--- a/src/mlpack/methods/dbscan/dbscan_impl.hpp
+++ b/src/mlpack/methods/dbscan/dbscan_impl.hpp
@@ -13,10 +13,15 @@ namespace mlpack {
namespace dbscan {
template<typename RangeSearchType, typename PointSelectionPolicy>
-DBSCAN<RangeSearchType, PointSelectionPolicy>::DBSCAN(const double epsilon,
- const size_t minPoints) :
+DBSCAN<RangeSearchType, PointSelectionPolicy>::DBSCAN(
+ const double epsilon,
+ const size_t minPoints,
+ RangeSearchType rangeSearch,
+ PointSelectionPolicy pointSelector) :
epsilon(epsilon),
- minPoints(minPoints)
+ minPoints(minPoints),
+ rangeSearch(rangeSearch),
+ pointSelector(pointSelector)
{
// Nothing to do.
}
diff --git a/src/mlpack/methods/dbscan/dbscan_main.cpp b/src/mlpack/methods/dbscan/dbscan_main.cpp
new file mode 100644
index 0000000..bba4a4c
--- /dev/null
+++ b/src/mlpack/methods/dbscan/dbscan_main.cpp
@@ -0,0 +1,118 @@
+/**
+ * @file dbscan_main.cpp
+ * @author Ryan Curtin
+ *
+ * Implementation of program to run DBSCAN.
+ */
+#include "dbscan.hpp"
+
+#include <mlpack/core/tree/binary_space_tree.hpp>
+#include <mlpack/core/tree/rectangle_tree.hpp>
+#include <mlpack/core/tree/cover_tree.hpp>
+
+using namespace mlpack;
+using namespace mlpack::range;
+using namespace mlpack::dbscan;
+using namespace mlpack::metric;
+using namespace mlpack::tree;
+using namespace std;
+
+PROGRAM_INFO("DBSCAN clustering",
+ "This program implements the DBSCAN algorithm for clustering.");
+
+PARAM_STRING_IN_REQ("input_file", "Input dataset to cluster.", "i");
+PARAM_STRING_OUT("assignments_file", "Output file for assignments of each "
+ "point.", "a");
+PARAM_STRING_OUT("centroids_file", "File to save output centroids to.", "C");
+
+PARAM_DOUBLE_IN("epsilon", "Radius of each range search.", "e", 1.0);
+PARAM_INT_IN("min_size", "Minimum number of points for a cluster.", "m", 5);
+
+PARAM_STRING_IN("tree_type", "If using single-tree or dual-tree search, the "
+ "type of tree to use ('kd', 'r', 'r-star', 'x', 'hilbert-r', 'r-plus', "
+ "'r-plus-plus', 'cover', 'ball').", "t", "kd");
+PARAM_FLAG("single", "If set, single-tree range search (not dual-tree) "
+ "will be used.", "S");
+PARAM_FLAG("naive", "If set, brute-force range search (not tree-based) "
+ "will be used.", "N");
+
+// Actually run the clustering, and process the output.
+template<typename RangeSearchType>
+void RunDBSCAN(RangeSearchType rs = RangeSearchType())
+{
+ if (CLI::HasParam("single_mode"))
+ rs.SingleMode() = true;
+
+ // Load dataset.
+ arma::mat dataset;
+ data::Load(CLI::GetParam<string>("input_file"), dataset);
+
+ const double epsilon = CLI::GetParam<double>("epsilon");
+ const size_t minSize = (size_t) CLI::GetParam<size_t>("min_size");
+
+ DBSCAN<RangeSearchType> d(epsilon, minSize, rs);
+
+ // If possible, avoid the overhead of calculating centroids.
+ arma::Row<size_t> assignments;
+ if (CLI::HasParam("centroids_file"))
+ {
+ arma::mat centroids;
+
+ d.Cluster(dataset, assignments, centroids);
+
+ data::Save(CLI::GetParam<string>("centroids_file"), centroids, false);
+ }
+ else
+ {
+ d.Cluster(dataset, assignments);
+ }
+
+ if (CLI::HasParam("assignments_file"))
+ data::Save(CLI::GetParam<string>("assignments_file"), assignments, false,
+ false); // No transpose.
+}
+
+int main(int argc, char** argv)
+{
+ CLI::ParseCommandLine(argc, argv);
+
+ if (!CLI::HasParam("assignments_file") && !CLI::HasParam("centroids_file"))
+ Log::Warn << "Neither --assignments_file nor --centroids_file are "
+ << "specified; no output will be saved!" << endl;
+
+ if (CLI::HasParam("single_mode") && CLI::HasParam("naive"))
+ Log::Warn << "--single_mode ignored because --naive is specified." << endl;
+
+ // Fire off naive search if needed.
+ if (CLI::HasParam("naive"))
+ {
+ RangeSearch<> rs(true);
+ RunDBSCAN(rs);
+ }
+
+ const string treeType = CLI::GetParam<string>("tree_type");
+ if (treeType == "kd")
+ RunDBSCAN<RangeSearch<>>();
+ else if (treeType == "cover")
+ RunDBSCAN<RangeSearch<EuclideanDistance, arma::mat, StandardCoverTree>>();
+ else if (treeType == "r")
+ RunDBSCAN<RangeSearch<EuclideanDistance, arma::mat, RTree>>();
+ else if (treeType == "r-star")
+ RunDBSCAN<RangeSearch<EuclideanDistance, arma::mat, RStarTree>>();
+ else if (treeType == "x")
+ RunDBSCAN<RangeSearch<EuclideanDistance, arma::mat, XTree>>();
+ else if (treeType == "hilbert-r")
+ RunDBSCAN<RangeSearch<EuclideanDistance, arma::mat, HilbertRTree>>();
+ else if (treeType == "r-plus")
+ RunDBSCAN<RangeSearch<EuclideanDistance, arma::mat, RPlusTree>>();
+ else if (treeType == "r-plus-plus")
+ RunDBSCAN<RangeSearch<EuclideanDistance, arma::mat, RPlusPlusTree>>();
+ else if (treeType == "ball")
+ RunDBSCAN<RangeSearch<EuclideanDistance, arma::mat, BallTree>>();
+ else
+ {
+ Log::Fatal << "Unknown tree type specified! Valid choices are 'kd', "
+ << "'cover', 'r', 'r-star', 'x', 'hilbert-r', 'r-plus', 'r-plus-plus',"
+ << " and 'ball'." << endl;
+ }
+}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list