[mlpack] 192/207: Refactor for new DatasetMapper API.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:53 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 4121399ad18cbcc3abb12c39e1f6d152fa583779
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Mar 18 15:00:13 2017 -0400

    Refactor for new DatasetMapper API.
---
 src/mlpack/core/data/load_arff_impl.hpp  |   5 +-
 src/mlpack/tests/decision_tree_test.cpp  |  14 +-
 src/mlpack/tests/hoeffding_tree_test.cpp | 322 +++++++++++++++++++++++++++----
 src/mlpack/tests/load_save_test.cpp      |  92 ++++-----
 src/mlpack/tests/serialization_test.cpp  |  40 ++--
 5 files changed, 359 insertions(+), 114 deletions(-)

diff --git a/src/mlpack/core/data/load_arff_impl.hpp b/src/mlpack/core/data/load_arff_impl.hpp
index 8c84ad1..c127f16 100644
--- a/src/mlpack/core/data/load_arff_impl.hpp
+++ b/src/mlpack/core/data/load_arff_impl.hpp
@@ -179,7 +179,10 @@ void LoadARFF(const std::string& filename,
       // What should this token be?
       if (info.Type(col) == Datatype::categorical)
       {
-        matrix(col, row) = info.MapString(*it, col); // We load transposed.
+        // Strip spaces before mapping.
+        std::string token = *it;
+        boost::trim(token);
+        matrix(col, row) = info.template MapString<eT>(token, col); // We load transposed.
       }
       else if (info.Type(col) == Datatype::numeric)
       {
diff --git a/src/mlpack/tests/decision_tree_test.cpp b/src/mlpack/tests/decision_tree_test.cpp
index 719a261..8ba84c8 100644
--- a/src/mlpack/tests/decision_tree_test.cpp
+++ b/src/mlpack/tests/decision_tree_test.cpp
@@ -485,7 +485,7 @@ BOOST_AUTO_TEST_CASE(CategoricalBuildTest)
   arma::Row<size_t> labels(10000);
   for (size_t i = 0; i < 10000; ++i)
   {
-    // One circle every 20000 samples.  Plus some noise. 
+    // One circle every 20000 samples.  Plus some noise.
     const double magnitude = 2.0 + (double(i) / 2000.0) +
         0.5 * mlpack::math::Random();
     const double angle = (i % 2000) * (2 * M_PI) + mlpack::math::Random();
@@ -534,12 +534,12 @@ BOOST_AUTO_TEST_CASE(CategoricalBuildTest)
   di.Type(2) = data::Datatype::categorical;
   di.Type(3) = data::Datatype::categorical;
   // Set mappings.
-  di.MapString("0", 2);
-  di.MapString("1", 2);
-  di.MapString("2", 2);
-  di.MapString("3", 2);
-  di.MapString("0", 3);
-  di.MapString("1", 3);
+  di.MapString<double>("0", 2);
+  di.MapString<double>("1", 2);
+  di.MapString<double>("2", 2);
+  di.MapString<double>("3", 2);
+  di.MapString<double>("0", 3);
+  di.MapString<double>("1", 3);
 
   // Now shuffle the dataset.
   arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0, 9999,
diff --git a/src/mlpack/tests/hoeffding_tree_test.cpp b/src/mlpack/tests/hoeffding_tree_test.cpp
index a5b94bd..9712803 100644
--- a/src/mlpack/tests/hoeffding_tree_test.cpp
+++ b/src/mlpack/tests/hoeffding_tree_test.cpp
@@ -325,7 +325,7 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitSplitTest)
 
   // No training is necessary because we can just call CreateChildren().
   data::DatasetInfo info(3);
-  info.MapString("hello", 0); // Make dimension 0 categorical.
+  info.MapString<size_t>("hello", 0); // Make dimension 0 categorical.
   HoeffdingCategoricalSplit<GiniImpurity>::SplitInfo splitInfo(3);
 
   // Create the children.
@@ -346,15 +346,15 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeNoSplitTest)
 {
   // Make all dimensions categorical.
   data::DatasetInfo info(3);
-  info.MapString("cat1", 0);
-  info.MapString("cat2", 0);
-  info.MapString("cat3", 0);
-  info.MapString("cat4", 0);
-  info.MapString("cat1", 1);
-  info.MapString("cat2", 1);
-  info.MapString("cat3", 1);
-  info.MapString("cat1", 2);
-  info.MapString("cat2", 2);
+  info.MapString<size_t>("cat1", 0);
+  info.MapString<size_t>("cat2", 0);
+  info.MapString<size_t>("cat3", 0);
+  info.MapString<size_t>("cat4", 0);
+  info.MapString<size_t>("cat1", 1);
+  info.MapString<size_t>("cat2", 1);
+  info.MapString<size_t>("cat3", 1);
+  info.MapString<size_t>("cat1", 2);
+  info.MapString<size_t>("cat2", 2);
 
   HoeffdingTree<> split(info, 2, 0.95, 5000, 1);
 
@@ -383,9 +383,9 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeEasySplitTest)
   // will only receive points with class 1.  In the second dimension, all points
   // will have category 0 (so it is useless).
   data::DatasetInfo info(2);
-  info.MapString("cat0", 0);
-  info.MapString("cat1", 0);
-  info.MapString("cat0", 1);
+  info.MapString<size_t>("cat0", 0);
+  info.MapString<size_t>("cat1", 0);
+  info.MapString<size_t>("cat0", 1);
 
   HoeffdingTree<> tree(info, 2, 0.95, 5000, 5000 /* never check for splits */);
 
@@ -411,9 +411,9 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeProbability1SplitTest)
   // will only receive points with class 1.  In the second dimension, all points
   // will have category 0 (so it is useless).
   data::DatasetInfo info(2);
-  info.MapString("cat0", 0);
-  info.MapString("cat1", 0);
-  info.MapString("cat0", 1);
+  info.MapString<size_t>("cat0", 0);
+  info.MapString<size_t>("cat1", 0);
+  info.MapString<size_t>("cat0", 1);
 
   HoeffdingTree<> split(info, 2, 1.0, 12000, 1 /* always check for splits */);
 
@@ -438,10 +438,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAlmostPerfectSplit)
 {
   // Two categories and two dimensions.
   data::DatasetInfo info(2);
-  info.MapString("cat0", 0);
-  info.MapString("cat1", 0);
-  info.MapString("cat0", 1);
-  info.MapString("cat1", 1);
+  info.MapString<size_t>("cat0", 0);
+  info.MapString<size_t>("cat1", 0);
+  info.MapString<size_t>("cat0", 1);
+  info.MapString<size_t>("cat1", 1);
 
   HoeffdingTree<> split(info, 2, 0.95, 5000, 5000 /* never check for splits */);
 
@@ -473,10 +473,10 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeEqualSplitTest)
 {
   // Two categories and two dimensions.
   data::DatasetInfo info(2);
-  info.MapString("cat0", 0);
-  info.MapString("cat1", 0);
-  info.MapString("cat0", 1);
-  info.MapString("cat1", 1);
+  info.MapString<size_t>("cat0", 0);
+  info.MapString<size_t>("cat1", 0);
+  info.MapString<size_t>("cat0", 1);
+  info.MapString<size_t>("cat1", 1);
 
   HoeffdingTree<> split(info, 2, 0.95, 5000, 1);
 
@@ -504,18 +504,18 @@ using HoeffdingSizeTNumericSplit = HoeffdingNumericSplit<FitnessFunction,
 BOOST_AUTO_TEST_CASE(HoeffdingTreeSimpleDatasetTest)
 {
   DatasetInfo info(3);
-  info.MapString("cat0", 0);
-  info.MapString("cat1", 0);
-  info.MapString("cat2", 0);
-  info.MapString("cat3", 0);
-  info.MapString("cat4", 0);
-  info.MapString("cat5", 0);
-  info.MapString("cat6", 0);
-  info.MapString("cat0", 1);
-  info.MapString("cat1", 1);
-  info.MapString("cat2", 1);
-  info.MapString("cat0", 2);
-  info.MapString("cat1", 2);
+  info.MapString<size_t>("cat0", 0);
+  info.MapString<size_t>("cat1", 0);
+  info.MapString<size_t>("cat2", 0);
+  info.MapString<size_t>("cat3", 0);
+  info.MapString<size_t>("cat4", 0);
+  info.MapString<size_t>("cat5", 0);
+  info.MapString<size_t>("cat6", 0);
+  info.MapString<size_t>("cat0", 1);
+  info.MapString<size_t>("cat1", 1);
+  info.MapString<size_t>("cat2", 1);
+  info.MapString<size_t>("cat0", 2);
+  info.MapString<size_t>("cat1", 2);
 
   // Now generate data.
   arma::Mat<size_t> dataset(3, 9000);
@@ -798,7 +798,7 @@ BOOST_AUTO_TEST_CASE(BinaryNumericHoeffdingTreeTest)
   arma::mat dataset(4, 9000);
   arma::Row<size_t> labels(9000);
   data::DatasetInfo info(4); // All features are numeric, except the fourth.
-  info.MapString("0", 3);
+  info.MapString<double>("0", 3);
   for (size_t i = 0; i < 9000; i += 3)
   {
     dataset(0, i) = mlpack::math::Random();
@@ -984,7 +984,7 @@ BOOST_AUTO_TEST_CASE(ConfidenceChangeTest)
   arma::mat dataset(4, 9000);
   arma::Row<size_t> labels(9000);
   data::DatasetInfo info(4); // All features are numeric, except the fourth.
-  info.MapString("0", 3);
+  info.MapString<double>("0", 3);
   for (size_t i = 0; i < 9000; i += 3)
   {
     dataset(0, i) = mlpack::math::Random();
@@ -1041,7 +1041,7 @@ BOOST_AUTO_TEST_CASE(ParameterChangeTest)
   arma::mat dataset(4, 9000);
   arma::Row<size_t> labels(9000);
   data::DatasetInfo info(4); // All features are numeric, except the fourth.
-  info.MapString("0", 3);
+  info.MapString<double>("0", 3);
   for (size_t i = 0; i < 9000; i += 3)
   {
     dataset(0, i) = mlpack::math::Random();
@@ -1094,7 +1094,7 @@ BOOST_AUTO_TEST_CASE(MultipleSerializationTest)
   arma::mat dataset(4, 9000);
   arma::Row<size_t> labels(9000);
   data::DatasetInfo info(4); // All features are numeric, except the fourth.
-  info.MapString("0", 3);
+  info.MapString<double>("0", 3);
   for (size_t i = 0; i < 9000; i += 3)
   {
     dataset(0, i) = mlpack::math::Random();
@@ -1145,4 +1145,246 @@ BOOST_AUTO_TEST_CASE(MultipleSerializationTest)
   }
 }
 
+// Test the Hoeffding tree model.
+BOOST_AUTO_TEST_CASE(HoeffdingTreeModelTest)
+{
+  // Generate data.
+  arma::mat dataset(4, 3000);
+  arma::Row<size_t> labels(3000);
+  data::DatasetInfo info(4); // All features are numeric, except the fourth.
+  info.MapString<double>("0", 3);
+  for (size_t i = 0; i < 3000; i += 3)
+  {
+    dataset(0, i) = mlpack::math::Random();
+    dataset(1, i) = mlpack::math::Random();
+    dataset(2, i) = mlpack::math::Random();
+    dataset(3, i) = 0.0;
+    labels[i] = 0;
+
+    dataset(0, i + 1) = mlpack::math::Random();
+    dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+    dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+    dataset(3, i + 1) = 0.0;
+    labels[i + 1] = 2;
+
+    dataset(0, i + 2) = mlpack::math::Random();
+    dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+    dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+    dataset(3, i + 2) = 0.0;
+    labels[i + 2] = 1;
+  }
+
+  // Train a model on a simple dataset, for all four types of models, and make
+  // sure we get reasonable results.
+  for (size_t i = 0; i < 4; ++i)
+  {
+    HoeffdingTreeModel m;
+    switch (i)
+    {
+      case 0:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
+        break;
+
+      case 1:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
+        break;
+
+      case 2:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
+        break;
+
+      case 3:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
+        break;
+    }
+
+    // We'll take 5 passes over the data.
+    m.BuildModel(dataset, info, labels, 3, false, 0.99, 1000, 100, 100, 4, 100);
+    for (size_t j = 0; j < 4; ++j)
+      m.Train(dataset, labels, false);
+
+    // Now make sure the performance is reasonable.
+    arma::Row<size_t> predictions, predictions2;
+    arma::rowvec probabilities;
+    m.Classify(dataset, predictions);
+    m.Classify(dataset, predictions2, probabilities);
+
+    size_t correct = 0;
+    for (size_t i = 0; i < 3000; ++i)
+    {
+      // Check consistency of predictions.
+      BOOST_REQUIRE_EQUAL(predictions[i], predictions2[i]);
+
+      if (labels[i] == predictions[i])
+        ++correct;
+    }
+
+    // Require at least 95% accuracy.
+    BOOST_REQUIRE_GT(correct, 2850);
+  }
+}
+
+// Test the Hoeffding tree model in batch mode.
+BOOST_AUTO_TEST_CASE(HoeffdingTreeModelBatchTest)
+{
+  // Generate data.
+  arma::mat dataset(4, 3000);
+  arma::Row<size_t> labels(3000);
+  data::DatasetInfo info(4); // All features are numeric, except the fourth.
+  info.MapString<double>("0", 3);
+  for (size_t i = 0; i < 3000; i += 3)
+  {
+    dataset(0, i) = mlpack::math::Random();
+    dataset(1, i) = mlpack::math::Random();
+    dataset(2, i) = mlpack::math::Random();
+    dataset(3, i) = 0.0;
+    labels[i] = 0;
+
+    dataset(0, i + 1) = mlpack::math::Random();
+    dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+    dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+    dataset(3, i + 1) = 0.0;
+    labels[i + 1] = 2;
+
+    dataset(0, i + 2) = mlpack::math::Random();
+    dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+    dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+    dataset(3, i + 2) = 0.0;
+    labels[i + 2] = 1;
+  }
+
+  // Train a model on a simple dataset, for all four types of models, and make
+  // sure we get reasonable results.
+  for (size_t i = 0; i < 4; ++i)
+  {
+    HoeffdingTreeModel m;
+    switch (i)
+    {
+      case 0:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
+        break;
+
+      case 1:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
+        break;
+
+      case 2:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
+        break;
+
+      case 3:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
+        break;
+    }
+
+    // Train in batch.
+    m.BuildModel(dataset, info, labels, 3, true, 0.99, 1000, 100, 100, 4, 100);
+
+    // Now make sure the performance is reasonable.
+    arma::Row<size_t> predictions, predictions2;
+    arma::rowvec probabilities;
+    m.Classify(dataset, predictions);
+    m.Classify(dataset, predictions2, probabilities);
+
+    size_t correct = 0;
+    for (size_t i = 0; i < 3000; ++i)
+    {
+      // Check consistency of predictions.
+      BOOST_REQUIRE_EQUAL(predictions[i], predictions2[i]);
+
+      if (labels[i] == predictions[i])
+        ++correct;
+    }
+
+    // Require at least 95% accuracy.
+    BOOST_REQUIRE_GT(correct, 2850);
+  }
+}
+
+BOOST_AUTO_TEST_CASE(HoeffdingTreeModelSerializationTest)
+{
+  // Generate data.
+  arma::mat dataset(4, 3000);
+  arma::Row<size_t> labels(3000);
+  data::DatasetInfo info(4); // All features are numeric, except the fourth.
+  info.MapString<double>("0", 3);
+  for (size_t i = 0; i < 3000; i += 3)
+  {
+    dataset(0, i) = mlpack::math::Random();
+    dataset(1, i) = mlpack::math::Random();
+    dataset(2, i) = mlpack::math::Random();
+    dataset(3, i) = 0.0;
+    labels[i] = 0;
+
+    dataset(0, i + 1) = mlpack::math::Random();
+    dataset(1, i + 1) = mlpack::math::Random() - 1.0;
+    dataset(2, i + 1) = mlpack::math::Random() + 0.5;
+    dataset(3, i + 1) = 0.0;
+    labels[i + 1] = 2;
+
+    dataset(0, i + 2) = mlpack::math::Random();
+    dataset(1, i + 2) = mlpack::math::Random() + 1.0;
+    dataset(2, i + 2) = mlpack::math::Random() + 0.8;
+    dataset(3, i + 2) = 0.0;
+    labels[i + 2] = 1;
+  }
+
+  // Train a model on a simple dataset, for all four types of models, and make
+  // sure we get reasonable results.
+  for (size_t i = 0; i < 4; ++i)
+  {
+    HoeffdingTreeModel m, xmlM, textM, binaryM;
+    switch (i)
+    {
+      case 0:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
+        break;
+
+      case 1:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
+        break;
+
+      case 2:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
+        break;
+
+      case 3:
+        m = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
+        break;
+    }
+
+    // Train in batch.
+    m.BuildModel(dataset, info, labels, 3, true, 0.99, 1000, 100, 100, 4, 100);
+    // False training of XML model.
+    xmlM.BuildModel(dataset, info, labels, 3, false, 0.5, 100, 100, 100, 2,
+        100);
+
+    // Now make sure the performance is reasonable.
+    arma::Row<size_t> predictions, predictionsXml, predictionsText,
+        predictionsBinary;
+    arma::rowvec probabilities, probabilitiesXml, probabilitiesText,
+        probabilitiesBinary;
+
+    SerializeObjectAll(m, xmlM, textM, binaryM);
+
+    // Get predictions for all.
+    m.Classify(dataset, predictions, probabilities);
+    xmlM.Classify(dataset, predictionsXml, probabilitiesXml);
+    textM.Classify(dataset, predictionsText, probabilitiesText);
+    binaryM.Classify(dataset, predictionsBinary, probabilitiesBinary);
+
+    for (size_t i = 0; i < 3000; ++i)
+    {
+      // Check consistency of predictions and probabilities.
+      BOOST_REQUIRE_EQUAL(predictions[i], predictionsXml[i]);
+      BOOST_REQUIRE_EQUAL(predictions[i], predictionsText[i]);
+      BOOST_REQUIRE_EQUAL(predictions[i], predictionsBinary[i]);
+
+      BOOST_REQUIRE_CLOSE(probabilities[i], probabilitiesXml[i], 1e-5);
+      BOOST_REQUIRE_CLOSE(probabilities[i], probabilitiesText[i], 1e-5);
+      BOOST_REQUIRE_CLOSE(probabilities[i], probabilitiesBinary[i], 1e-5);
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END();
diff --git a/src/mlpack/tests/load_save_test.cpp b/src/mlpack/tests/load_save_test.cpp
index 474a6cc..1553179 100644
--- a/src/mlpack/tests/load_save_test.cpp
+++ b/src/mlpack/tests/load_save_test.cpp
@@ -916,9 +916,9 @@ BOOST_AUTO_TEST_CASE(DatasetInfoTest)
   }
 
   // Okay.  Add some mappings for dimension 3.
-  const size_t first = di.MapString("test_mapping_1", 3);
-  const size_t second = di.MapString("test_mapping_2", 3);
-  const size_t third = di.MapString("test_mapping_3", 3);
+  const size_t first = di.MapString<size_t>("test_mapping_1", 3);
+  const size_t second = di.MapString<size_t>("test_mapping_2", 3);
+  const size_t third = di.MapString<size_t>("test_mapping_3", 3);
 
   BOOST_REQUIRE_EQUAL(first, 0);
   BOOST_REQUIRE_EQUAL(second, 1);
@@ -1078,10 +1078,10 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest00)
   BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
   BOOST_REQUIRE(info.Type(2) == Datatype::categorical);
 
-  BOOST_REQUIRE_EQUAL(info.MapString("hello", 2), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("goodbye", 2), 1);
-  BOOST_REQUIRE_EQUAL(info.MapString("coffee", 2), 2);
-  BOOST_REQUIRE_EQUAL(info.MapString("confusion", 2), 3);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("hello", 2), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("goodbye", 2), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("coffee", 2), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("confusion", 2), 3);
 
   BOOST_REQUIRE_EQUAL(info.UnmapString(0, 2), "hello");
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "goodbye");
@@ -1127,8 +1127,8 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest01)
   BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
   BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
 
-  BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("", 0), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 0), 1);
 
   BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "1");
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "");
@@ -1171,8 +1171,8 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest02)
   BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
   BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
 
-  BOOST_REQUIRE_EQUAL(info.MapString("", 0), 1);
-  BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 0), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 0);
 
   BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "1");
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "");
@@ -1215,8 +1215,8 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest03)
   BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
   BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
 
-  BOOST_REQUIRE_EQUAL(info.MapString("", 0), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 0), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 1);
 
   BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "");
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "1");
@@ -1259,8 +1259,8 @@ BOOST_AUTO_TEST_CASE(CategoricalCSVLoadTest04)
   BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
   BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
 
-  BOOST_REQUIRE_EQUAL(info.MapString("200-DM", 0), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("200-DM", 0), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 1);
 
   BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "200-DM");
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "1");
@@ -1319,24 +1319,24 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest00)
   BOOST_REQUIRE(info.Type(5) == Datatype::numeric);
   BOOST_REQUIRE(info.Type(6) == Datatype::categorical);
 
-  BOOST_REQUIRE_EQUAL(info.MapString("1", 0), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("2", 0), 1);
-  BOOST_REQUIRE_EQUAL(info.MapString("hello", 0), 2);
-  BOOST_REQUIRE_EQUAL(info.MapString("3", 1), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("4", 1), 1);
-  BOOST_REQUIRE_EQUAL(info.MapString("goodbye", 1), 2);
-  BOOST_REQUIRE_EQUAL(info.MapString("5", 2), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("6", 2), 1);
-  BOOST_REQUIRE_EQUAL(info.MapString("coffee", 2), 2);
-  BOOST_REQUIRE_EQUAL(info.MapString("7", 3), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("8", 3), 1);
-  BOOST_REQUIRE_EQUAL(info.MapString("confusion", 3), 2);
-  BOOST_REQUIRE_EQUAL(info.MapString("9", 4), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("10", 4), 1);
-  BOOST_REQUIRE_EQUAL(info.MapString("hello", 4), 2);
-  BOOST_REQUIRE_EQUAL(info.MapString("13", 6), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("14", 6), 1);
-  BOOST_REQUIRE_EQUAL(info.MapString("confusion", 6), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 0), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("2", 0), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("hello", 0), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("3", 1), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("4", 1), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("goodbye", 1), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("5", 2), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("6", 2), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("coffee", 2), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("7", 3), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("8", 3), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("confusion", 3), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("9", 4), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("10", 4), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("hello", 4), 2);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("13", 6), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("14", 6), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("confusion", 6), 2);
 
   BOOST_REQUIRE_EQUAL(info.UnmapString(0, 0), "1");
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 0), "2");
@@ -1396,8 +1396,8 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest01)
   BOOST_REQUIRE(info.Type(2) == Datatype::categorical);
   BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
 
-  BOOST_REQUIRE_EQUAL(info.MapString("", 2), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("1", 2), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 2), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 2), 1);
 
   BOOST_REQUIRE_EQUAL(info.UnmapString(0, 2), "");
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 2), "1");
@@ -1441,8 +1441,8 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest02)
   BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
   BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
 
-  BOOST_REQUIRE_EQUAL(info.MapString("", 1), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("1", 1), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 1), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 1), 1);
 
   BOOST_REQUIRE_EQUAL(info.UnmapString(0, 1), "");
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 1), "1");
@@ -1486,8 +1486,8 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest03)
   BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
   BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
 
-  BOOST_REQUIRE_EQUAL(info.MapString("", 1), 0);
-  BOOST_REQUIRE_EQUAL(info.MapString("1", 1), 1);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("", 1), 0);
+  BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 1), 1);
 
   BOOST_REQUIRE_EQUAL(info.UnmapString(0, 1), "");
   BOOST_REQUIRE_EQUAL(info.UnmapString(1, 1), "1");
@@ -1513,6 +1513,11 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest04)
     BOOST_REQUIRE_EQUAL(matrix.n_cols, 3);
     BOOST_REQUIRE_EQUAL(matrix.n_rows, 4);
 
+    BOOST_REQUIRE(info.Type(0) == Datatype::categorical);
+    BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
+    BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
+    BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
+
     BOOST_REQUIRE_EQUAL(matrix(0, 0), 0);
     BOOST_REQUIRE_EQUAL(matrix(0, 1), 1);
     BOOST_REQUIRE_EQUAL(matrix(0, 2), 1);
@@ -1526,13 +1531,8 @@ BOOST_AUTO_TEST_CASE(CategoricalNontransposedCSVLoadTest04)
     BOOST_REQUIRE_EQUAL(matrix(3, 1), 1);
     BOOST_REQUIRE_EQUAL(matrix(3, 2), 1);
 
-    BOOST_REQUIRE(info.Type(0) == Datatype::categorical);
-    BOOST_REQUIRE(info.Type(1) == Datatype::numeric);
-    BOOST_REQUIRE(info.Type(2) == Datatype::numeric);
-    BOOST_REQUIRE(info.Type(3) == Datatype::numeric);
-
-    BOOST_REQUIRE_EQUAL(info.MapString("200-DM", 1), 0);
-    BOOST_REQUIRE_EQUAL(info.MapString("1", 1), 1);
+    BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("200-DM", 1), 0);
+    BOOST_REQUIRE_EQUAL(info.MapString<arma::uword>("1", 1), 1);
 
     BOOST_REQUIRE_EQUAL(info.UnmapString(0, 1), "200-DM");
     BOOST_REQUIRE_EQUAL(info.UnmapString(1, 1), "1");
diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp
index 0f27bcc..c42d598 100644
--- a/src/mlpack/tests/serialization_test.cpp
+++ b/src/mlpack/tests/serialization_test.cpp
@@ -1500,8 +1500,8 @@ BOOST_AUTO_TEST_CASE(HoeffdingCategoricalSplitTest)
 BOOST_AUTO_TEST_CASE(HoeffdingTreeBeforeSplitTest)
 {
   data::DatasetInfo info(5);
-  info.MapString("0", 2); // Dimension 1 is categorical.
-  info.MapString("1", 2);
+  info.MapString<double>("0", 2); // Dimension 1 is categorical.
+  info.MapString<double>("1", 2);
   HoeffdingTree<> split(info, 2, 0.99, 15000, 1);
 
   // Train for 2 samples.
@@ -1509,14 +1509,14 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeBeforeSplitTest)
   split.Train(arma::vec("-0.3 0.0 0 0.7 0.8"), 1);
 
   data::DatasetInfo wrongInfo(3);
-  wrongInfo.MapString("1", 1);
+  wrongInfo.MapString<double>("1", 1);
   HoeffdingTree<> xmlSplit(wrongInfo, 7, 0.1, 10, 1);
 
   // Force the binarySplit to split.
   data::DatasetInfo binaryInfo(2);
-  binaryInfo.MapString("cat0", 0);
-  binaryInfo.MapString("cat1", 0);
-  binaryInfo.MapString("cat0", 1);
+  binaryInfo.MapString<double>("cat0", 0);
+  binaryInfo.MapString<double>("cat1", 0);
+  binaryInfo.MapString<double>("cat0", 1);
 
   HoeffdingTree<> binarySplit(info, 2, 0.95, 5000, 1);
 
@@ -1552,9 +1552,9 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAfterSplitTest)
 {
   // Force the split to split.
   data::DatasetInfo info(2);
-  info.MapString("cat0", 0);
-  info.MapString("cat1", 0);
-  info.MapString("cat0", 1);
+  info.MapString<double>("cat0", 0);
+  info.MapString<double>("cat1", 0);
+  info.MapString<double>("cat0", 1);
 
   HoeffdingTree<> split(info, 2, 0.95, 5000, 1);
 
@@ -1568,12 +1568,12 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeAfterSplitTest)
   BOOST_REQUIRE_NE(split.SplitDimension(), size_t(-1));
 
   data::DatasetInfo wrongInfo(3);
-  wrongInfo.MapString("1", 1);
+  wrongInfo.MapString<double>("1", 1);
   HoeffdingTree<> xmlSplit(wrongInfo, 7, 0.1, 10, 1);
 
   data::DatasetInfo binaryInfo(5);
-  binaryInfo.MapString("0", 2); // Dimension 2 is categorical.
-  binaryInfo.MapString("1", 2);
+  binaryInfo.MapString<double>("0", 2); // Dimension 2 is categorical.
+  binaryInfo.MapString<double>("1", 2);
   HoeffdingTree<> binarySplit(binaryInfo, 2, 0.99, 15000, 1);
 
   // Train for 2 samples.
@@ -1645,14 +1645,14 @@ BOOST_AUTO_TEST_CASE(HoeffdingTreeTest)
   }
   // Make the features categorical.
   data::DatasetInfo info(2);
-  info.MapString("a", 0);
-  info.MapString("b", 0);
-  info.MapString("c", 0);
-  info.MapString("d", 0);
-  info.MapString("a", 1);
-  info.MapString("b", 1);
-  info.MapString("c", 1);
-  info.MapString("d", 1);
+  info.MapString<double>("a", 0);
+  info.MapString<double>("b", 0);
+  info.MapString<double>("c", 0);
+  info.MapString<double>("d", 0);
+  info.MapString<double>("a", 1);
+  info.MapString<double>("b", 1);
+  info.MapString<double>("c", 1);
+  info.MapString<double>("d", 1);
 
   HoeffdingTree<> tree(dataset, info, labels, 2, false /* no batch mode */);
 

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list