[mlpack] 38/207: Fix implementation for move constructor to avoid infinite recursion.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:38 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit ab28ecef72fbc7661b234355c27ad77ed431c425
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Jan 11 12:04:34 2017 -0500

    Fix implementation for move constructor to avoid infinite recursion.
---
 .../methods/decision_stump/decision_stump_impl.hpp |  38 +++++++-
 src/mlpack/tests/decision_stump_test.cpp           | 103 ++++++++++++++++++---
 2 files changed, 127 insertions(+), 14 deletions(-)

diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index cea3d96..228b720 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -54,7 +54,9 @@ DecisionStump<MatType, NoRecursion>::DecisionStump() :
   {
     // Make a fake stump by creating two children.  We create two and not one so
     // that we can be guaranteed that splitOrClassProbs has at least one
-    // element.  The children are identical in functionality though.
+    // element.  The children are identical in functionality though.  These fake
+    // children are necessary, because Predict() depends on a stump having
+    // children.
     children.push_back(new DecisionStump(0, 0, std::move(arma::vec("1.0"))));
     children.push_back(new DecisionStump(0, 0, std::move(arma::vec("1.0"))));
   }
@@ -82,7 +84,22 @@ DecisionStump<MatType, NoRecursion>::DecisionStump(DecisionStump&& other) :
     children(std::move(other.children))
 {
   // Reset the other one.
-  other = DecisionStump();
+  other.classes = 1;
+  other.bucketSize = 0;
+  other.splitDimensionOrLabel = 0;
+  other.splitOrClassProbs.ones(1);
+  if (NoRecursion)
+  {
+    // Make a fake stump by creating two children.  We create two and not one so
+    // that we can be guaranteed that splitOrClassProbs has at least one
+    // element.  The children are identical in functionality though.  These fake
+    // children are necessary, because Predict() depends on a stump having
+    // children.
+    other.children.push_back(new DecisionStump(0, 0,
+        std::move(arma::vec("1.0"))));
+    other.children.push_back(new DecisionStump(0, 0,
+        std::move(arma::vec("1.0"))));
+  }
 }
 
 // Copy assignment operator.
@@ -124,7 +141,22 @@ DecisionStump<MatType, NoRecursion>::operator=(DecisionStump&& other)
   children = std::move(other.children);
 
   // Clear and reinitialize other object.
-  other = DecisionStump();
+  other.classes = 1;
+  other.bucketSize = 0;
+  other.splitDimensionOrLabel = 0;
+  other.splitOrClassProbs.ones(1);
+  if (NoRecursion)
+  {
+    // Make a fake stump by creating two children.  We create two and not one so
+    // that we can be guaranteed that splitOrClassProbs has at least one
+    // element.  The children are identical in functionality though.  These fake
+    // children are necessary, because Predict() depends on a stump having
+    // children.
+    other.children.push_back(new DecisionStump(0, 0,
+        std::move(arma::vec("1.0"))));
+    other.children.push_back(new DecisionStump(0, 0,
+        std::move(arma::vec("1.0"))));
+  }
 
   return *this;
 }
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index 2587e27..a243076 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -389,12 +389,37 @@ BOOST_AUTO_TEST_CASE(DecisionStumpCopyConstructorTest)
   // Check the objects for similarity.
   BOOST_REQUIRE_EQUAL(d.Split().n_elem, copy.Split().n_elem);
   BOOST_REQUIRE_EQUAL(d.Split().n_elem, copy2.Split().n_elem);
-  for (size_t i = 0; i < d.Split().n_elem + 1; ++i)
+  BOOST_REQUIRE_EQUAL(d.NumChildren(), copy.NumChildren());
+  BOOST_REQUIRE_EQUAL(d.NumChildren(), copy2.NumChildren());
+  for (size_t i = 0; i < d.NumChildren(); ++i)
   {
     BOOST_REQUIRE_EQUAL(d.Child(i).Label(), copy.Child(i).Label());
-    CheckMatrices(d.Child(i).Split(), copy.Child(i).Split());
+    BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_rows,
+        copy.Child(i).Split().n_rows);
+    BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_cols,
+        copy.Child(i).Split().n_cols);
+    for (size_t j = 0; j < d.Child(i).Split().n_elem; ++j)
+    {
+      if (std::abs(d.Child(i).Split()[j]) < 1e-5)
+        BOOST_REQUIRE_SMALL(copy.Child(i).Split()[j], 1e-5);
+      else
+        BOOST_REQUIRE_CLOSE(copy.Child(i).Split()[j], d.Child(i).Split()[j],
+            1e-5);
+    }
+
     BOOST_REQUIRE_EQUAL(d.Child(i).Label(), copy2.Child(i).Label());
-    CheckMatrices(d.Child(i).Split(), copy2.Child(i).Split());
+    BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_rows,
+        copy2.Child(i).Split().n_rows);
+    BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_cols,
+        copy2.Child(i).Split().n_cols);
+    for (size_t j = 0; j < d.Child(i).Split().n_elem; ++j)
+    {
+      if (std::abs(d.Child(i).Split()[j]) < 1e-5)
+        BOOST_REQUIRE_SMALL(copy2.Child(i).Split()[j], 1e-5);
+      else
+        BOOST_REQUIRE_CLOSE(copy2.Child(i).Split()[j], d.Child(i).Split()[j],
+            1e-5);
+    }
   }
 }
 
@@ -427,17 +452,41 @@ BOOST_AUTO_TEST_CASE(DecisionStumpMoveConstructorTest)
   DecisionStump<> empty; // An empty object to compare against.
 
   BOOST_REQUIRE_EQUAL(d.Split().n_elem, empty.Split().n_elem);
-  for (size_t i = 0; i < d.Split().n_elem + 1; ++i)
+  BOOST_REQUIRE_EQUAL(d.NumChildren(), empty.NumChildren());
+  for (size_t i = 0; i < d.NumChildren(); ++i)
   {
     BOOST_REQUIRE_EQUAL(d.Child(i).Label(), empty.Child(i).Label());
-    CheckMatrices(d.Child(i).Split(), empty.Child(i).Split());
+    BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_rows,
+        empty.Child(i).Split().n_rows);
+    BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_cols,
+        empty.Child(i).Split().n_cols);
+    for (size_t j = 0; j < d.Child(i).Split().n_elem; ++j)
+    {
+      if (std::abs(d.Child(i).Split()[j]) < 1e-5)
+        BOOST_REQUIRE_SMALL(empty.Child(i).Split()[j], 1e-5);
+      else
+        BOOST_REQUIRE_CLOSE(empty.Child(i).Split()[j], d.Child(i).Split()[j],
+            1e-5);
+    }
   }
 
   BOOST_REQUIRE_EQUAL(move.Split().n_elem, copy.Split().n_elem);
-  for (size_t i = 0; i < move.Split().n_elem + 1; ++i)
+  BOOST_REQUIRE_EQUAL(move.NumChildren(), copy.NumChildren());
+  for (size_t i = 0; i < move.NumChildren(); ++i)
   {
     BOOST_REQUIRE_EQUAL(move.Child(i).Label(), copy.Child(i).Label());
-    CheckMatrices(move.Child(i).Split(), copy.Child(i).Split());
+    BOOST_REQUIRE_EQUAL(move.Child(i).Split().n_rows,
+        copy.Child(i).Split().n_rows);
+    BOOST_REQUIRE_EQUAL(move.Child(i).Split().n_cols,
+        copy.Child(i).Split().n_cols);
+    for (size_t j = 0; j < move.Child(i).Split().n_elem; ++j)
+    {
+      if (std::abs(move.Child(i).Split()[j]) < 1e-5)
+        BOOST_REQUIRE_SMALL(copy.Child(i).Split()[j], 1e-5);
+      else
+        BOOST_REQUIRE_CLOSE(copy.Child(i).Split()[j], move.Child(i).Split()[j],
+            1e-5);
+    }
   }
 }
 
@@ -470,18 +519,50 @@ BOOST_AUTO_TEST_CASE(DecisionStumpMoveOperatorTest)
   DecisionStump<> empty; // An empty object to compare against.
 
   BOOST_REQUIRE_EQUAL(d.Split().n_elem, empty.Split().n_elem);
-  for (size_t i = 0; i < d.Split().n_elem + 1; ++i)
+  BOOST_REQUIRE_EQUAL(d.NumChildren(), empty.NumChildren());
+  for (size_t i = 0; i < d.NumChildren(); ++i)
   {
     BOOST_REQUIRE_EQUAL(d.Child(i).Label(), empty.Child(i).Label());
-    CheckMatrices(d.Child(i).Split(), empty.Child(i).Split());
+    BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_rows,
+        empty.Child(i).Split().n_rows);
+    BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_cols,
+        empty.Child(i).Split().n_cols);
+    for (size_t j = 0; j < d.Child(i).Split().n_elem; ++j)
+    {
+      if (std::abs(d.Child(i).Split()[j]) < 1e-5)
+        BOOST_REQUIRE_SMALL(empty.Child(i).Split()[j], 1e-5);
+      else
+        BOOST_REQUIRE_CLOSE(empty.Child(i).Split()[j], d.Child(i).Split()[j],
+            1e-5);
+    }
   }
 
   BOOST_REQUIRE_EQUAL(move.Split().n_elem, copy.Split().n_elem);
-  for (size_t i = 0; i < move.Split().n_elem + 1; ++i)
+  BOOST_REQUIRE_EQUAL(move.NumChildren(), copy.NumChildren());
+  for (size_t i = 0; i < move.NumChildren(); ++i)
   {
     BOOST_REQUIRE_EQUAL(move.Child(i).Label(), copy.Child(i).Label());
-    CheckMatrices(move.Child(i).Split(), copy.Child(i).Split());
+    BOOST_REQUIRE_EQUAL(move.Child(i).Split().n_rows,
+        copy.Child(i).Split().n_rows);
+    BOOST_REQUIRE_EQUAL(move.Child(i).Split().n_cols,
+        copy.Child(i).Split().n_cols);
+    for (size_t j = 0; j < move.Child(i).Split().n_elem; ++j)
+    {
+      if (std::abs(move.Child(i).Split()[j]) < 1e-5)
+        BOOST_REQUIRE_SMALL(copy.Child(i).Split()[j], 1e-5);
+      else
+        BOOST_REQUIRE_CLOSE(copy.Child(i).Split()[j], move.Child(i).Split()[j],
+            1e-5);
+    }
   }
 }
 
+/**
+ * Test that the decision tree outperforms the decision stump.
+ */
+BOOST_AUTO_TEST_CASE(DecisionTreeVsStumpTest)
+{
+  
+}
+
 BOOST_AUTO_TEST_SUITE_END();

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list