[mlpack] 38/207: Fix implementation for move constructor to avoid infinite recursion.
Barak A. Pearlmutter
barak+git at pearlmutter.net
Thu Mar 23 17:53:38 UTC 2017
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch master
in repository mlpack.
commit ab28ecef72fbc7661b234355c27ad77ed431c425
Author: Ryan Curtin <ryan at ratml.org>
Date: Wed Jan 11 12:04:34 2017 -0500
Fix implementation for move constructor to avoid infinite recursion.
---
.../methods/decision_stump/decision_stump_impl.hpp | 38 +++++++-
src/mlpack/tests/decision_stump_test.cpp | 103 ++++++++++++++++++---
2 files changed, 127 insertions(+), 14 deletions(-)
diff --git a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
index cea3d96..228b720 100644
--- a/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
+++ b/src/mlpack/methods/decision_stump/decision_stump_impl.hpp
@@ -54,7 +54,9 @@ DecisionStump<MatType, NoRecursion>::DecisionStump() :
{
// Make a fake stump by creating two children. We create two and not one so
// that we can be guaranteed that splitOrClassProbs has at least one
- // element. The children are identical in functionality though.
+ // element. The children are identical in functionality though. These fake
+ // children are necessary, because Predict() depends on a stump having
+ // children.
children.push_back(new DecisionStump(0, 0, std::move(arma::vec("1.0"))));
children.push_back(new DecisionStump(0, 0, std::move(arma::vec("1.0"))));
}
@@ -82,7 +84,22 @@ DecisionStump<MatType, NoRecursion>::DecisionStump(DecisionStump&& other) :
children(std::move(other.children))
{
// Reset the other one.
- other = DecisionStump();
+ other.classes = 1;
+ other.bucketSize = 0;
+ other.splitDimensionOrLabel = 0;
+ other.splitOrClassProbs.ones(1);
+ if (NoRecursion)
+ {
+ // Make a fake stump by creating two children. We create two and not one so
+ // that we can be guaranteed that splitOrClassProbs has at least one
+ // element. The children are identical in functionality though. These fake
+ // children are necessary, because Predict() depends on a stump having
+ // children.
+ other.children.push_back(new DecisionStump(0, 0,
+ std::move(arma::vec("1.0"))));
+ other.children.push_back(new DecisionStump(0, 0,
+ std::move(arma::vec("1.0"))));
+ }
}
// Copy assignment operator.
@@ -124,7 +141,22 @@ DecisionStump<MatType, NoRecursion>::operator=(DecisionStump&& other)
children = std::move(other.children);
// Clear and reinitialize other object.
- other = DecisionStump();
+ other.classes = 1;
+ other.bucketSize = 0;
+ other.splitDimensionOrLabel = 0;
+ other.splitOrClassProbs.ones(1);
+ if (NoRecursion)
+ {
+ // Make a fake stump by creating two children. We create two and not one so
+ // that we can be guaranteed that splitOrClassProbs has at least one
+ // element. The children are identical in functionality though. These fake
+ // children are necessary, because Predict() depends on a stump having
+ // children.
+ other.children.push_back(new DecisionStump(0, 0,
+ std::move(arma::vec("1.0"))));
+ other.children.push_back(new DecisionStump(0, 0,
+ std::move(arma::vec("1.0"))));
+ }
return *this;
}
diff --git a/src/mlpack/tests/decision_stump_test.cpp b/src/mlpack/tests/decision_stump_test.cpp
index 2587e27..a243076 100644
--- a/src/mlpack/tests/decision_stump_test.cpp
+++ b/src/mlpack/tests/decision_stump_test.cpp
@@ -389,12 +389,37 @@ BOOST_AUTO_TEST_CASE(DecisionStumpCopyConstructorTest)
// Check the objects for similarity.
BOOST_REQUIRE_EQUAL(d.Split().n_elem, copy.Split().n_elem);
BOOST_REQUIRE_EQUAL(d.Split().n_elem, copy2.Split().n_elem);
- for (size_t i = 0; i < d.Split().n_elem + 1; ++i)
+ BOOST_REQUIRE_EQUAL(d.NumChildren(), copy.NumChildren());
+ BOOST_REQUIRE_EQUAL(d.NumChildren(), copy2.NumChildren());
+ for (size_t i = 0; i < d.NumChildren(); ++i)
{
BOOST_REQUIRE_EQUAL(d.Child(i).Label(), copy.Child(i).Label());
- CheckMatrices(d.Child(i).Split(), copy.Child(i).Split());
+ BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_rows,
+ copy.Child(i).Split().n_rows);
+ BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_cols,
+ copy.Child(i).Split().n_cols);
+ for (size_t j = 0; j < d.Child(i).Split().n_elem; ++j)
+ {
+ if (std::abs(d.Child(i).Split()[j]) < 1e-5)
+ BOOST_REQUIRE_SMALL(copy.Child(i).Split()[j], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(copy.Child(i).Split()[j], d.Child(i).Split()[j],
+ 1e-5);
+ }
+
BOOST_REQUIRE_EQUAL(d.Child(i).Label(), copy2.Child(i).Label());
- CheckMatrices(d.Child(i).Split(), copy2.Child(i).Split());
+ BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_rows,
+ copy2.Child(i).Split().n_rows);
+ BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_cols,
+ copy2.Child(i).Split().n_cols);
+ for (size_t j = 0; j < d.Child(i).Split().n_elem; ++j)
+ {
+ if (std::abs(d.Child(i).Split()[j]) < 1e-5)
+ BOOST_REQUIRE_SMALL(copy2.Child(i).Split()[j], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(copy2.Child(i).Split()[j], d.Child(i).Split()[j],
+ 1e-5);
+ }
}
}
@@ -427,17 +452,41 @@ BOOST_AUTO_TEST_CASE(DecisionStumpMoveConstructorTest)
DecisionStump<> empty; // An empty object to compare against.
BOOST_REQUIRE_EQUAL(d.Split().n_elem, empty.Split().n_elem);
- for (size_t i = 0; i < d.Split().n_elem + 1; ++i)
+ BOOST_REQUIRE_EQUAL(d.NumChildren(), empty.NumChildren());
+ for (size_t i = 0; i < d.NumChildren(); ++i)
{
BOOST_REQUIRE_EQUAL(d.Child(i).Label(), empty.Child(i).Label());
- CheckMatrices(d.Child(i).Split(), empty.Child(i).Split());
+ BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_rows,
+ empty.Child(i).Split().n_rows);
+ BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_cols,
+ empty.Child(i).Split().n_cols);
+ for (size_t j = 0; j < d.Child(i).Split().n_elem; ++j)
+ {
+ if (std::abs(d.Child(i).Split()[j]) < 1e-5)
+ BOOST_REQUIRE_SMALL(empty.Child(i).Split()[j], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(empty.Child(i).Split()[j], d.Child(i).Split()[j],
+ 1e-5);
+ }
}
BOOST_REQUIRE_EQUAL(move.Split().n_elem, copy.Split().n_elem);
- for (size_t i = 0; i < move.Split().n_elem + 1; ++i)
+ BOOST_REQUIRE_EQUAL(move.NumChildren(), copy.NumChildren());
+ for (size_t i = 0; i < move.NumChildren(); ++i)
{
BOOST_REQUIRE_EQUAL(move.Child(i).Label(), copy.Child(i).Label());
- CheckMatrices(move.Child(i).Split(), copy.Child(i).Split());
+ BOOST_REQUIRE_EQUAL(move.Child(i).Split().n_rows,
+ copy.Child(i).Split().n_rows);
+ BOOST_REQUIRE_EQUAL(move.Child(i).Split().n_cols,
+ copy.Child(i).Split().n_cols);
+ for (size_t j = 0; j < move.Child(i).Split().n_elem; ++j)
+ {
+ if (std::abs(move.Child(i).Split()[j]) < 1e-5)
+ BOOST_REQUIRE_SMALL(copy.Child(i).Split()[j], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(copy.Child(i).Split()[j], move.Child(i).Split()[j],
+ 1e-5);
+ }
}
}
@@ -470,18 +519,50 @@ BOOST_AUTO_TEST_CASE(DecisionStumpMoveOperatorTest)
DecisionStump<> empty; // An empty object to compare against.
BOOST_REQUIRE_EQUAL(d.Split().n_elem, empty.Split().n_elem);
- for (size_t i = 0; i < d.Split().n_elem + 1; ++i)
+ BOOST_REQUIRE_EQUAL(d.NumChildren(), empty.NumChildren());
+ for (size_t i = 0; i < d.NumChildren(); ++i)
{
BOOST_REQUIRE_EQUAL(d.Child(i).Label(), empty.Child(i).Label());
- CheckMatrices(d.Child(i).Split(), empty.Child(i).Split());
+ BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_rows,
+ empty.Child(i).Split().n_rows);
+ BOOST_REQUIRE_EQUAL(d.Child(i).Split().n_cols,
+ empty.Child(i).Split().n_cols);
+ for (size_t j = 0; j < d.Child(i).Split().n_elem; ++j)
+ {
+ if (std::abs(d.Child(i).Split()[j]) < 1e-5)
+ BOOST_REQUIRE_SMALL(empty.Child(i).Split()[j], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(empty.Child(i).Split()[j], d.Child(i).Split()[j],
+ 1e-5);
+ }
}
BOOST_REQUIRE_EQUAL(move.Split().n_elem, copy.Split().n_elem);
- for (size_t i = 0; i < move.Split().n_elem + 1; ++i)
+ BOOST_REQUIRE_EQUAL(move.NumChildren(), copy.NumChildren());
+ for (size_t i = 0; i < move.NumChildren(); ++i)
{
BOOST_REQUIRE_EQUAL(move.Child(i).Label(), copy.Child(i).Label());
- CheckMatrices(move.Child(i).Split(), copy.Child(i).Split());
+ BOOST_REQUIRE_EQUAL(move.Child(i).Split().n_rows,
+ copy.Child(i).Split().n_rows);
+ BOOST_REQUIRE_EQUAL(move.Child(i).Split().n_cols,
+ copy.Child(i).Split().n_cols);
+ for (size_t j = 0; j < move.Child(i).Split().n_elem; ++j)
+ {
+ if (std::abs(move.Child(i).Split()[j]) < 1e-5)
+ BOOST_REQUIRE_SMALL(copy.Child(i).Split()[j], 1e-5);
+ else
+ BOOST_REQUIRE_CLOSE(copy.Child(i).Split()[j], move.Child(i).Split()[j],
+ 1e-5);
+ }
}
}
+/**
+ * Test that the decision tree outperforms the decision stump.
+ */
+BOOST_AUTO_TEST_CASE(DecisionTreeVsStumpTest)
+{
+
+}
+
BOOST_AUTO_TEST_SUITE_END();
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list