[mlpack] 192/207: Refactor for new DatasetMapper API.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:53 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit 4121399ad18cbcc3abb12c39e1f6d152fa583779
Author: Ryan Curtin <ryan at ratml.org>
Date: Sat Mar 18 15:00:13 2017 -0400
Refactor for new DatasetMapper API.
---
src/mlpack/core/data/load_arff_impl.hpp | 5 +-
src/mlpack/tests/decision_tree_test.cpp | 14 +-
src/mlpack/tests/hoeffding_tree_test.cpp | 322 +++++++++++++++++++++++++++----
src/mlpack/tests/load_save_test.cpp | 92 ++++-----
src/mlpack/tests/serialization_test.cpp | 40 ++--
5 files changed, 359 insertions(+), 114 deletions(-)
diff --git a/src/mlpack/core/data/load_arff_impl.hpp b/src/mlpack/core/data/load_arff_impl.hpp
index 8c84ad1..c127f16 100644
--- a/src/mlpack/core/data/load_arff_impl.hpp
+++ b/src/mlpack/core/data/load_arff_impl.hpp
@@ -179,7 +179,10 @@ void LoadARFF(const std::string& filename,
// What should this token be?
if (info.Type(col) == Datatype::categorical)
{
- matrix(col, row) = info.MapString(*it, col); // We load transposed.
+ // Strip spaces before mapping.
+ std::string token = *it;
+ boost::trim(token);
+ matrix(col, row) = info.template MapString<eT>(token, col); // We load transposed.
}
else if (info.Type(col) == Datatype::numeric)
{
diff --git a/src/mlpack/tests/decision_tree_test.cpp b/src/mlpack/tests/decision_tree_test.cpp
index 719a261..8ba84c8 100644
--- a/src/mlpack/tests/decision_tree_test.cpp
+++ b/src/mlpack/tests/decision_tree_test.cpp
@@ -485,7 +485,7 @@ BOOST_AUTO_TEST_CASE(CategoricalBuildTest)
arma::Row<size_t> labels(10000);
for (size_t i = 0; i < 10000; ++i)
{
- // One circle every 20000 samples. Plus some noise.
+ // One circle every 20000 samples. Plus some noise.
const double magnitude = 2.0 + (double(i) / 2000.0) +
0.5 * mlpack::math::Random();
const double angle = (i % 2000) * (2 * M_PI) + mlpack::math::Random();
@@ -534,12 +534,12 @@ BOOST_AUTO_TEST_CASE(CategoricalBuildTest)
di.Type(2) = data::Datatype::categorical;
di.Type(3) = data::Datatype::categorical;
// Set mappings.
- di.MapString("0", 2);
- di.MapString("1", 2);
- di.MapString("2", 2);
- di.MapString("3", 2);
- di.MapString("0", 3);
- di.MapString("1", 3);
+ di.MapString<double>("0", 2);
+ di.MapString<double>("1", 2);
+ di.MapString<double>("2", 2);
+ di.MapString<double>("3", 2);
+ di.MapString<double>("0", 3);
+ di.MapString<double>("1", 3);
// Now shuffle the dataset.
arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0, 9999,
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index a5b94bd..9712803 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -325,7 +325,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitSplitTest)
// No training is necessary because we can just call CreateChildren().
data::DatasetInfo info(3);
- info.MapString("hello", 0); // Make dimension 0 categorical.
+ info.MapString<size_t>("hello", 0); // Make dimension 0 categorical.
HoeffdingCategoricalSplit<GiniImpurity>::SplitInfo splitInfo(3);
// Create the children.
@@ -346,15 +346,15 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeNoSplitTest)
{
// Make all dimensions categorical.
data::DatasetInfo info(3);
- info.MapString("cat1", 0);
- info.MapString("cat2", 0);
- info.MapString("cat3", 0);
- info.MapString("cat4", 0);
- info.MapString("cat1", 1);
- info.MapString("cat2", 1);
- info.MapString("cat3", 1);
- info.MapString("cat1", 2);
- info.MapString("cat2", 2);
+ info.MapString<size_t>("cat1", 0);
+ info.MapString<size_t>("cat2", 0);
+ info.MapString<size_t>("cat3", 0);
+ info.MapString<size_t>("cat4", 0);
+ info.MapString<size_t>("cat1", 1);
+ info.MapString<size_t>("cat2", 1);
+ info.MapString<size_t>("cat3", 1);
+ info.MapString<size_t>("cat1", 2);
+ info.MapString<size_t>("cat2", 2);
HoeffdingTree<> split(info, 2, 0.95, 5000, 1);
@@ -383,9 +383,9 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeEasySplitTest)
// will only receive points with class 1. In the second dimension, all points
// will have category 0 (so it is useless).
data::DatasetInfo info(2);
- info.MapString("cat0", 0);
- info.MapString("cat1", 0);
- info.MapString("cat0", 1);
+ info.MapString<size_t>("cat0", 0);
+ info.MapString<size_t>("cat1", 0);
+ info.MapString<size_t>("cat0", 1);
HoeffdingTree<> tree(info, 2, 0.95, 5000, 5000 /* never check for splits */);
@@ -411,9 +411,9 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeProbability1SplitTest)
// will only receive points with class 1. In the second dimension, all points
// will have category 0 (so it is useless).
data::DatasetInfo info(2);
- info.MapString("cat0", 0);
- info.MapString("cat1", 0);
- info.MapString("cat0", 1);
+ info.MapString<size_t>("cat0", 0);
+ info.MapString<size_t>("cat1", 0);
+ info.MapString<size_t>("cat0", 1);
HoeffdingTree<> split(info, 2, 1.0, 12000, 1 /* always check for splits */);
@@ -438,10 +438,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAlmostPerfectSplit)
{
// Two categories and two dimensions.
data::DatasetInfo info(2);
- info.MapString("cat0", 0);
- info.MapString("cat1", 0);
- info.MapString("cat0", 1);
- info.MapString("cat1", 1);
+ info.MapString<size_t>("cat0", 0);
+ info.MapString<size_t>("cat1", 0);
+ info.MapString<size_t>("cat0", 1);
+ info.MapString<size_t>("cat1", 1);
HoeffdingTree<> split(info, 2, 0.95, 5000, 5000 /* never check for splits */);
@@ -473,10 +473,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeEqualSplitTest)
{
// Two categories and two dimensions.
data::DatasetInfo info(2);
- info.MapString("cat0", 0);
- info.MapString("cat1", 0);
- info.MapString("cat0", 1);
- info.MapString("cat1", 1);
+ info.MapString<size_t>("cat0", 0);
+ info.MapString<size_t>("cat1", 0);
+ info.MapString<size_t>("cat0", 1);
+ info.MapString<size_t>("cat1", 1);
HoeffdingTree<> split(info, 2, 0.95, 5000, 1);
@@ -504,18 +504,18 @@ using HoeffdingSizeTNumericSplit = HoeffdingNumericSplit<FitnessFunction,
BOOST_AUTO_TEST_CASE(HoeffdingTreeSimpleDatasetTest)
{
DatasetInfo info(3);
- info.MapString("cat0", 0);
- info.MapString("cat1", 0);
- info.MapString("cat2", 0);
- info.MapString("cat3", 0);
- info.MapString("cat4", 0);
- info.MapString("cat5", 0);
- info.MapString("cat6", 0);
- info.MapString("cat0", 1);
- info.MapString("cat1", 1);
- info.MapString("cat2", 1);
- info.MapString("cat0", 2);
- info.MapString("cat1", 2);
+ info.MapString<size_t>("cat0", 0);
+ info.MapString<size_t>("cat1", 0);
+ info.MapString<size_t>("cat2", 0);
+ info.MapString<size_t>("cat3", 0);
+ info.MapString<size_t>("cat4", 0);
+ info.MapString<size_t>("cat5", 0);
+ info.MapString<size_t>("cat6", 0);
+ info.MapString<size_t>("cat0", 1);
+ info.MapString<size_t>("cat1", 1);
+ info.MapString<size_t>("cat2", 1);
+ info.MapString<size_t>("cat0", 2);
+ info.MapString<size_t>("cat1", 2);
// Now generate data.
arma::Mat<size_t> dataset(3, 9000);
@@ -798,7 +798,7 @@ BOOST_AUTO_TEST_CASE(BinaryNumericHoeffdingTreeTest)
arma::mat dataset(4, 9000);
arma::Row<size_t> labels(9000);
data::DatasetInfo info(4); // All features are numeric, except the fourth.
- info.MapString("0", 3);
+ info.MapString<double>("0", 3);
for (size_t i = 0; i < 9000; i += 3)
{
dataset(0, i) = mlpack::math::Random();
@@ -984,7 +984,7 @@ BOOST_AUTO_TEST_CASE(ConfidenceChangeTest)
arma::mat dataset(4, 9000);
arma::Row<size_t> labels(9000);
data::DatasetInfo info(4); // All features are numeric, except the fourth.
- info.MapString("0", 3);
+ info.MapString<double>("0", 3);
for (size_t i = 0; i < 9000; i += 3)
{
dataset(0, i) = mlpack::math::Random();
@@ -1041,7 +1041,7 @@ BOOST_AUTO_TEST_CASE(ParameterChangeTest)
arma::mat dataset(4, 9000);
arma::Row<size_t> labels(9000);
data::DatasetInfo info(4); // All features are numeric, except the fourth.
- info.MapString("0", 3);
+ info.MapString<double>("0", 3);
for (size_t i = 0; i < 9000; i += 3)
{
dataset(0, i) = mlpack::math::Random();
@@ -1094,7 +1094,7 @@ BOOST_AUTO_TEST_CASE(MultipleSerializationTest)
arma::mat dataset(4, 9000);
arma::Row<size_t> labels(9000);
data::DatasetInfo info(4); // All features are numeric, except the fourth.
- info.MapString("0", 3);
+ info.MapString<double>("0", 3);
for (size_t i = 0; i < 9000; i += 3)
{
dataset(0, i) = mlpack::math::Random();
@@ -1145,4 +1145,246 @@ BOOST_AUTO_TEST_CASE(MultipleSerializationTest)
}
}
+// Test the Hoeffding tree model.
+BOOST_AUTO_TEST_CASE(HoeffdingTreeModelTest)
+{
+ // Generate data.
+ arma::mat dataset(4, 3000);
+ arma::Row<size_t> labels(3000);
+ data::DatasetInfo info(4); // All features are numeric, except the fourth.
+ info.MapString<double>("0", 3);
+ for (size_t i = 0; i < 3000; i += 3)
+ {
+ dataset(0, i) = mlpack::math::Random();
+ dataset(1, i) = mlpack::math::Random();
+ dataset(2, i) = mlpack::math::Random();
+ dataset(3, i) = 0.0;
+ labels[i] = 0;
+
+ dataset(0, i + 1) = mlpack::math::Random();
+ dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+ dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+ dataset(3, i + 1) = 0.0;
+ labels[i + 1] = 2;
+
+ dataset(0, i + 2) = mlpack::math::Random();
+ dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+ dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+ dataset(3, i + 2) = 0.0;
+ labels[i + 2] = 1;
+ }
+
+ // Train a model on a simple dataset, for all four types of models, and make
+ // sure we get reasonable results.
+ for (size_t i = 0; i < 4; ++i)
+ {
+ HoeffdingTreeModel m;
+ switch (i)
+ {
+ case 0:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
+ break;
+
+ case 1:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
+ break;
+
+ case 2:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
+ break;
+
+ case 3:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
+ break;
+ }
+
+ // We'll take 5 passes over the data.
+ m.BuildModel(dataset, info, labels, 3, false, 0.99, 1000, 100, 100, 4, 100);
+ for (size_t j = 0; j < 4; ++j)
+ m.Train(dataset, labels, false);
+
+ // Now make sure the performance is reasonable.
+ arma::Row<size_t> predictions, predictions2;
+ arma::rowvec probabilities;
+ m.Classify(dataset, predictions);
+ m.Classify(dataset, predictions2, probabilities);
+
+ size_t correct = 0;
+ for (size_t i = 0; i < 3000; ++i)
+ {
+ // Check consistency of predictions.
+ BOOST_REQUIRE_EQUAL(predictions[i], predictions2[i]);
+
+ if (labels[i] == predictions[i])
+ ++correct;
+ }
+
+ // Require at least 95% accuracy.
+ BOOST_REQUIRE_GT(correct, 2850);
+ }
+}
+
+// Test the Hoeffding tree model in batch mode.
+BOOST_AUTO_TEST_CASE(HoeffdingTreeModelBatchTest)
+{
+ // Generate data.
+ arma::mat dataset(4, 3000);
+ arma::Row<size_t> labels(3000);
+ data::DatasetInfo info(4); // All features are numeric, except the fourth.
+ info.MapString<double>("0", 3);
+ for (size_t i = 0; i < 3000; i += 3)
+ {
+ dataset(0, i) = mlpack::math::Random();
+ dataset(1, i) = mlpack::math::Random();
+ dataset(2, i) = mlpack::math::Random();
+ dataset(3, i) = 0.0;
+ labels[i] = 0;
+
+ dataset(0, i + 1) = mlpack::math::Random();
+ dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+ dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+ dataset(3, i + 1) = 0.0;
+ labels[i + 1] = 2;
+
+ dataset(0, i + 2) = mlpack::math::Random();
+ dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+ dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+ dataset(3, i + 2) = 0.0;
+ labels[i + 2] = 1;
+ }
+
+ // Train a model on a simple dataset, for all four types of models, and make
+ // sure we get reasonable results.
+ for (size_t i = 0; i < 4; ++i)
+ {
+ HoeffdingTreeModel m;
+ switch (i)
+ {
+ case 0:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
+ break;
+
+ case 1:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
+ break;
+
+ case 2:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
+ break;
+
+ case 3:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
+ break;
+ }
+
+ // Train in batch.
+ m.BuildModel(dataset, info, labels, 3, true, 0.99, 1000, 100, 100, 4, 100);
+
+ // Now make sure the performance is reasonable.
+ arma::Row<size_t> predictions, predictions2;
+ arma::rowvec probabilities;
+ m.Classify(dataset, predictions);
+ m.Classify(dataset, predictions2, probabilities);
+
+ size_t correct = 0;
+ for (size_t i = 0; i < 3000; ++i)
+ {
+ // Check consistency of predictions.
+ BOOST_REQUIRE_EQUAL(predictions[i], predictions2[i]);
+
+ if (labels[i] == predictions[i])
+ ++correct;
+ }
+
+ // Require at least 95% accuracy.
+ BOOST_REQUIRE_GT(correct, 2850);
+ }
+}
+
+BOOST_AUTO_TEST_CASE(HoeffdingTreeModelSerializationTest)
+{
+ // Generate data.
+ arma::mat dataset(4, 3000);
+ arma::Row<size_t> labels(3000);
+ data::DatasetInfo info(4); // All features are numeric, except the fourth.
+ info.MapString<double>("0", 3);
+ for (size_t i = 0; i < 3000; i += 3)
+ {
+ dataset(0, i) = mlpack::math::Random();
+ dataset(1, i) = mlpack::math::Random();
+ dataset(2, i) = mlpack::math::Random();
+ dataset(3, i) = 0.0;
+ labels[i] = 0;
+
+ dataset(0, i + 1) = mlpack::math::Random();
+ dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+ dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+ dataset(3, i + 1) = 0.0;
+ labels[i + 1] = 2;
+
+ dataset(0, i + 2) = mlpack::math::Random();
+ dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+ dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+ dataset(3, i + 2) = 0.0;
+ labels[i + 2] = 1;
+ }
+
+ // Train a model on a simple dataset, for all four types of models, and make
+ // sure we get reasonable results.
+ for (size_t i = 0; i < 4; ++i)
+ {
+ HoeffdingTreeModel m, xmlM, textM, binaryM;
+ switch (i)
+ {
+ case 0:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
+ break;
+
+ case 1:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
+ break;
+
+ case 2:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
+ break;
+
+ case 3:
+ m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
+ break;
+ }
+
+ // Train in batch.
+ m.BuildModel(dataset, info, labels, 3, true, 0.99, 1000, 100, 100, 4, 100);
+ // False training of XML model.
+ xmlM.BuildModel(dataset, info, labels, 3, false, 0.5, 100, 100, 100, 2,
+ 100);
+
+ // Now make sure the performance is reasonable.
+ arma::Row<size_t> predictions, predictionsXml, predictionsText,
+ predictionsBinary;
+ arma::rowvec probabilities, probabilitiesXml, probabilitiesText,
+ probabilitiesBinary;
+
+ SerializeObjectAll(m, xmlM, textM, binaryM);
+
+ // Get predictions for all.
+ m.Classify(dataset, predictions, probabilities);
+ xmlM.Classify(dataset, predictionsXml, probabilitiesXml);
+ textM.Classify(dataset, predictionsText, probabilitiesText);
+ binaryM.Classify(dataset, predictionsBinary, probabilitiesBinary);
+
+ for (size_t i = 0; i < 3000; ++i)
+ {
+ // Check consistency of predictions and probabilities.
+ BOOST_REQUIRE_EQUAL(predictions[i], predictionsXml[i]);
+ BOOST_REQUIRE_EQUAL(predictions[i], predictionsText[i]);
+ BOOST_REQUIRE_EQUAL(predictions[i], predictionsBinary[i]);
+
+ BOOST_REQUIRE_CLOSE(probabilities[i], probabilitiesXml[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(probabilities[i], probabilitiesText[i], 1e-5);
+ BOOST_REQUIRE_CLOSE(probabilities[i], probabilitiesBinary[i], 1e-5);
+ }
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index 474a6cc..1553179 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -916,9 +916,9 @@ BOOST_AUTO_TEST_CASE(DatasetInfoTest)
}
// Okay. Add some mappings for dimension 3.
- const size_t first = di.MapString("test_mapping_1", 3);
- const size_t second = di.MapString("test_mapping_2", 3);
- const size_t third = di.MapString("test_mapping_3", 3);
+ const size_t first = di.MapString<size_t>("test_mapping_1", 3);
+ const size_t second = di.MapString<size_t>("test_mapping_2", 3);
+ const size_t third = di.MapString<size_t>("test_mapping_3", 3);
BOOST_REQUIRE_EQUAL(first, 0);
BOOST_REQUIRE_EQUAL(second, 1);
@@ -1078,10 +1078,10 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest00)
BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
BOOST_REQUIRE(info.Type(2) == Datatype::categorical);
- BOOST_REQUIRE_EQUAL(info.MapString("hello", 2), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("goodbye", 2), 1);
- BOOST_REQUIRE_EQUAL(info.MapString("coffee", 2), 2);
- BOOST_REQUIRE_EQUAL(info.MapString("confusion", 2), 3);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("hello", 2), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("goodbye", 2), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("coffee", 2), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("confusion", 2), 3);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 2), "hello");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "goodbye");
@@ -1127,8 +1127,8 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest01)
BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
- BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("", 0), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 0), 1);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "1");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "");
@@ -1171,8 +1171,8 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest02)
BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
- BOOST_REQUIRE_EQUAL(info.MapString("", 0), 1);
- BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 0), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 0);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "1");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "");
@@ -1215,8 +1215,8 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest03)
BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
- BOOST_REQUIRE_EQUAL(info.MapString("", 0), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 0), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 1);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "1");
@@ -1259,8 +1259,8 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest04)
BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
- BOOST_REQUIRE_EQUAL(info.MapString("200-DM", 0), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("200-DM", 0), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 1);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "200-DM");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "1");
@@ -1319,24 +1319,24 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest00)
BOOST_REQUIRE(info.Type(5) == Datatype::numeric);
BOOST_REQUIRE(info.Type(6) == Datatype::categorical);
- BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("2", 0), 1);
- BOOST_REQUIRE_EQUAL(info.MapString("hello", 0), 2);
- BOOST_REQUIRE_EQUAL(info.MapString("3", 1), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("4", 1), 1);
- BOOST_REQUIRE_EQUAL(info.MapString("goodbye", 1), 2);
- BOOST_REQUIRE_EQUAL(info.MapString("5", 2), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("6", 2), 1);
- BOOST_REQUIRE_EQUAL(info.MapString("coffee", 2), 2);
- BOOST_REQUIRE_EQUAL(info.MapString("7", 3), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("8", 3), 1);
- BOOST_REQUIRE_EQUAL(info.MapString("confusion", 3), 2);
- BOOST_REQUIRE_EQUAL(info.MapString("9", 4), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("10", 4), 1);
- BOOST_REQUIRE_EQUAL(info.MapString("hello", 4), 2);
- BOOST_REQUIRE_EQUAL(info.MapString("13", 6), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("14", 6), 1);
- BOOST_REQUIRE_EQUAL(info.MapString("confusion", 6), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("2", 0), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("hello", 0), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("3", 1), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("4", 1), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("goodbye", 1), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("5", 2), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("6", 2), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("coffee", 2), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("7", 3), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("8", 3), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("confusion", 3), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("9", 4), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("10", 4), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("hello", 4), 2);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("13", 6), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("14", 6), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("confusion", 6), 2);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "1");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "2");
@@ -1396,8 +1396,8 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest01)
BOOST_REQUIRE(info.Type(2) == Datatype::categorical);
BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
- BOOST_REQUIRE_EQUAL(info.MapString("", 2), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("1", 2), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 2), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 2), 1);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 2), "");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "1");
@@ -1441,8 +1441,8 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest02)
BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
- BOOST_REQUIRE_EQUAL(info.MapString("", 1), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("1", 1), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 1), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 1), 1);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 1), "");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 1), "1");
@@ -1486,8 +1486,8 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest03)
BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
- BOOST_REQUIRE_EQUAL(info.MapString("", 1), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("1", 1), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 1), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 1), 1);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 1), "");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 1), "1");
@@ -1513,6 +1513,11 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest04)
BOOST_REQUIRE_EQUAL(matrix.n_cols, 3);
BOOST_REQUIRE_EQUAL(matrix.n_rows, 4);
+ BOOST_REQUIRE(info.Type(0) == Datatype::categorical);
+ BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
+ BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
+ BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
+
BOOST_REQUIRE_EQUAL(matrix(0, 0), 0);
BOOST_REQUIRE_EQUAL(matrix(0, 1), 1);
BOOST_REQUIRE_EQUAL(matrix(0, 2), 1);
@@ -1526,13 +1531,8 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest04)
BOOST_REQUIRE_EQUAL(matrix(3, 1), 1);
BOOST_REQUIRE_EQUAL(matrix(3, 2), 1);
- BOOST_REQUIRE(info.Type(0) == Datatype::categorical);
- BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
- BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
- BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
-
- BOOST_REQUIRE_EQUAL(info.MapString("200-DM", 1), 0);
- BOOST_REQUIRE_EQUAL(info.MapString("1", 1), 1);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("200-DM", 1), 0);
+ BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 1), 1);
BOOST_REQUIRE_EQUAL(info.UnmapString(0, 1), "200-DM");
BOOST_REQUIRE_EQUAL(info.UnmapString(1, 1), "1");
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 0f27bcc..c42d598 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -1500,8 +1500,8 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitTest)
BOOST_AUTO_TEST_CASE(HoeffdingTreeBeforeSplitTest)
{
data::DatasetInfo info(5);
- info.MapString("0", 2); // Dimension 1 is categorical.
- info.MapString("1", 2);
+ info.MapString<double>("0", 2); // Dimension 1 is categorical.
+ info.MapString<double>("1", 2);
HoeffdingTree<> split(info, 2, 0.99, 15000, 1);
// Train for 2 samples.
@@ -1509,14 +1509,14 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeBeforeSplitTest)
split.Train(arma::vec("-0.3 0.0 0 0.7 0.8"), 1);
data::DatasetInfo wrongInfo(3);
- wrongInfo.MapString("1", 1);
+ wrongInfo.MapString<double>("1", 1);
HoeffdingTree<> xmlSplit(wrongInfo, 7, 0.1, 10, 1);
// Force the binarySplit to split.
data::DatasetInfo binaryInfo(2);
- binaryInfo.MapString("cat0", 0);
- binaryInfo.MapString("cat1", 0);
- binaryInfo.MapString("cat0", 1);
+ binaryInfo.MapString<double>("cat0", 0);
+ binaryInfo.MapString<double>("cat1", 0);
+ binaryInfo.MapString<double>("cat0", 1);
HoeffdingTree<> binarySplit(info, 2, 0.95, 5000, 1);
@@ -1552,9 +1552,9 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAfterSplitTest)
{
// Force the split to split.
data::DatasetInfo info(2);
- info.MapString("cat0", 0);
- info.MapString("cat1", 0);
- info.MapString("cat0", 1);
+ info.MapString<double>("cat0", 0);
+ info.MapString<double>("cat1", 0);
+ info.MapString<double>("cat0", 1);
HoeffdingTree<> split(info, 2, 0.95, 5000, 1);
@@ -1568,12 +1568,12 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAfterSplitTest)
BOOST_REQUIRE_NE(split.SplitDimension(), size_t(-1));
data::DatasetInfo wrongInfo(3);
- wrongInfo.MapString("1", 1);
+ wrongInfo.MapString<double>("1", 1);
HoeffdingTree<> xmlSplit(wrongInfo, 7, 0.1, 10, 1);
data::DatasetInfo binaryInfo(5);
- binaryInfo.MapString("0", 2); // Dimension 2 is categorical.
- binaryInfo.MapString("1", 2);
+ binaryInfo.MapString<double>("0", 2); // Dimension 2 is categorical.
+ binaryInfo.MapString<double>("1", 2);
HoeffdingTree<> binarySplit(binaryInfo, 2, 0.99, 15000, 1);
// Train for 2 samples.
@@ -1645,14 +1645,14 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeTest)
}
// Make the features categorical.
data::DatasetInfo info(2);
- info.MapString("a", 0);
- info.MapString("b", 0);
- info.MapString("c", 0);
- info.MapString("d", 0);
- info.MapString("a", 1);
- info.MapString("b", 1);
- info.MapString("c", 1);
- info.MapString("d", 1);
+ info.MapString<double>("a", 0);
+ info.MapString<double>("b", 0);
+ info.MapString<double>("c", 0);
+ info.MapString<double>("d", 0);
+ info.MapString<double>("a", 1);
+ info.MapString<double>("b", 1);
+ info.MapString<double>("c", 1);
+ info.MapString<double>("d", 1);
HoeffdingTree<> tree(dataset, info, labels, 2, false /* no batch mode */);
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list