[mlpack] 212/324: Contribution from Yash to solve #250 and make BallBound usable.
Barak A. Pearlmutter
barak+git at cs.nuim.ie
Sun Aug 17 08:22:11 UTC 2014
This is an automated email from the git hooks/post-receive script.
bap pushed a commit to branch svn-trunk
in repository mlpack.
commit 9cb84b0f85fc790b79401fcb04060dec33d5040d
Author: rcurtin <rcurtin at 9d5b8971-822b-0410-80eb-d18c1038ef23>
Date: Fri Jul 25 15:52:46 2014 +0000
Contribution from Yash to solve #250 and make BallBound usable.
git-svn-id: http://svn.cc.gatech.edu/fastlab/mlpack/trunk@16854 9d5b8971-822b-0410-80eb-d18c1038ef23
---
src/mlpack/core/tree/ballbound.hpp | 72 ++++++++++---
src/mlpack/core/tree/ballbound_impl.hpp | 173 +++++++++++++++++++++-----------
2 files changed, 174 insertions(+), 71 deletions(-)
diff --git a/src/mlpack/core/tree/ballbound.hpp b/src/mlpack/core/tree/ballbound.hpp
index ac579d0..4a83461 100644
--- a/src/mlpack/core/tree/ballbound.hpp
+++ b/src/mlpack/core/tree/ballbound.hpp
@@ -15,29 +15,52 @@ namespace mlpack {
namespace bound {
/**
- * Ball bound that works in the regular Euclidean metric space.
+ * Ball bound encloses a set of points at a specific distance (radius) from a
+ * specific point (center). TMetricType is the custom metric type that defaults
+ * to the Euclidean (L2) distance.
*
* @tparam VecType Type of vector (arma::vec or arma::sp_vec).
+ * @tparam TMetricType metric type used in the distance measure.
*/
-template<typename VecType = arma::vec>
+template<typename VecType = arma::vec,
+ typename TMetricType = metric::LMetric<2, true> >
class BallBound
{
public:
typedef VecType Vec;
+ //! Need this for Binary Space Partion Tree
+ typedef TMetricType MetricType;
private:
+
+ //! The radius of the ball bound.
double radius;
+
+ //! The center of the ball bound.
VecType center;
+ //! The metric used in this bound.
+ TMetricType* metric;
+
+ /**
+ * To know whether this object allocated memory to the metric member
+ * variable. This will be true except in the copy constructor and the
+ * overloaded assignment operator. We need this to know whether we should
+ * delete the metric member variable in the destructor.
+ */
+ bool ownsMetric;
+
public:
- BallBound() : radius(0) { }
+
+ //! Empty Constructor.
+ BallBound();
/**
* Create the ball bound with the specified dimensionality.
*
* @param dimension Dimensionality of ball bound.
*/
- BallBound(const size_t dimension) : radius(0), center(dimension) { }
+ BallBound(const size_t dimension);
/**
* Create the ball bound with the specified radius and center.
@@ -45,8 +68,16 @@ class BallBound
* @param radius Radius of ball bound.
* @param center Center of ball bound.
*/
- BallBound(const double radius, const VecType& center) :
- radius(radius), center(center) { }
+ BallBound(const double radius, const VecType& center);
+
+ //! Copy constructor. To prevent memory leaks.
+ BallBound(const BallBound& other);
+
+ //! For the same reason as the Copy Constructor. To prevent memory leaks.
+ BallBound& operator=(const BallBound& other);
+
+ //! Destructor to release allocated memory.
+ ~BallBound();
//! Get the radius of the ball.
double Radius() const { return radius; }
@@ -58,7 +89,16 @@ class BallBound
//! Modify the center point of the ball.
VecType& Center() { return center; }
- // Get the range in a certain dimension.
+ //! Get the dimensionality of the ball.
+ double Dim() const { return center.n_elem; }
+
+ /**
+ * Get the minimum width of the bound (this is same as the diameter).
+ * For ball bounds, width along all dimensions remain same.
+ */
+ double MinWidth() const { return radius * 2.0; }
+
+ //! Get the range in a certain dimension.
math::Range operator[](const size_t i) const;
/**
@@ -67,13 +107,11 @@ class BallBound
bool Contains(const VecType& point) const;
/**
- * Gets the center.
+ * Place the centroid of BallBound into the given vector.
*
- * Don't really use this directly. This is only here for consistency
- * with DHrectBound, so it can plug in more directly if a "centroid"
- * is needed.
+ * @param centroid Vector which the centroid will be written to.
*/
- void CalculateMidpoint(VecType& centroid) const;
+ void Centroid(VecType& centroid) const { centroid = center; }
/**
* Calculates minimum bound-to-point squared distance.
@@ -133,6 +171,16 @@ class BallBound
const BallBound& operator|=(const MatType& data);
/**
+ * Returns the diameter of the ballbound.
+ */
+ double Diameter() const { return 2 * radius; }
+
+ /**
+ * Returns the distance metric used in this bound.
+ */
+ TMetricType Metric() const { return *metric; }
+
+ /**
* Returns a string representation of this object.
*/
std::string ToString() const;
diff --git a/src/mlpack/core/tree/ballbound_impl.hpp b/src/mlpack/core/tree/ballbound_impl.hpp
index 57e3587..3f64908 100644
--- a/src/mlpack/core/tree/ballbound_impl.hpp
+++ b/src/mlpack/core/tree/ballbound_impl.hpp
@@ -17,9 +17,73 @@
namespace mlpack {
namespace bound {
+//! Empty Constructor.
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::BallBound() :
+ radius(-DBL_MAX),
+ metric(new TMetricType()),
+ ownsMetric(true)
+{ /* Nothing to do. */ }
+
+/**
+ * Create the ball bound with the specified dimensionality.
+ *
+ * @param dimension Dimensionality of ball bound.
+ */
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::BallBound(const size_t dimension) :
+ radius(-DBL_MAX),
+ center(dimension),
+ metric(new TMetricType()),
+ ownsMetric(true)
+{ /* Nothing to do. */ }
+
+/**
+ * Create the ball bound with the specified radius and center.
+ *
+ * @param radius Radius of ball bound.
+ * @param center Center of ball bound.
+ */
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::BallBound(const double radius,
+ const VecType& center) :
+ radius(radius),
+ center(center),
+ metric(new TMetricType()),
+ ownsMetric(true)
+{ /* Nothing to do. */ }
+
+//! Copy Constructor. To prevent memory leaks.
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::BallBound(const BallBound& other) :
+ radius(other.radius),
+ center(other.center),
+ metric(other.metric),
+ ownsMetric(false)
+{ /* Nothing to do. */ }
+
+//! For the same reason as the Copy Constructor. To prevent memory leaks.
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>& BallBound<VecType, TMetricType>::operator=(
+ const BallBound& other)
+{
+ radius = other.radius;
+ center = other.center;
+ metric = other.metric;
+ ownsMetric = false;
+}
+
+//! Destructor to release allocated memory.
+template<typename VecType, typename TMetricType>
+BallBound<VecType, TMetricType>::~BallBound()
+{
+ if (ownsMetric)
+ delete metric;
+}
+
//! Get the range in a certain dimension.
-template<typename VecType>
-math::Range BallBound<VecType>::operator[](const size_t i) const
+template<typename VecType, typename TMetricType>
+math::Range BallBound<VecType, TMetricType>::operator[](const size_t i) const
{
if (radius < 0)
return math::Range();
@@ -30,56 +94,42 @@ math::Range BallBound<VecType>::operator[](const size_t i) const
/**
* Determines if a point is within the bound.
*/
-template<typename VecType>
-bool BallBound<VecType>::Contains(const VecType& point) const
+template<typename VecType, typename TMetricType>
+bool BallBound<VecType, TMetricType>::Contains(const VecType& point) const
{
if (radius < 0)
return false;
else
- return metric::EuclideanDistance::Evaluate(center, point) <= radius;
-}
-
-/**
- * Gets the center.
- *
- * Don't really use this directly. This is only here for consistency
- * with DHrectBound, so it can plug in more directly if a "centroid"
- * is needed.
- */
-template<typename VecType>
-void BallBound<VecType>::CalculateMidpoint(VecType& centroid) const
-{
- centroid = center;
+ return metric->Evaluate(center, point) <= radius;
}
/**
* Calculates minimum bound-to-point squared distance.
*/
-template<typename VecType>
+template<typename VecType, typename TMetricType>
template<typename OtherVecType>
-double BallBound<VecType>::MinDistance(
+double BallBound<VecType, TMetricType>::MinDistance(
const OtherVecType& point,
typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
{
if (radius < 0)
return DBL_MAX;
else
- return math::ClampNonNegative(metric::EuclideanDistance::Evaluate(point,
- center) - radius);
+ return math::ClampNonNegative(metric->Evaluate(point, center) - radius);
}
/**
* Calculates minimum bound-to-bound squared distance.
*/
-template<typename VecType>
-double BallBound<VecType>::MinDistance(const BallBound& other) const
+template<typename VecType, typename TMetricType>
+double BallBound<VecType, TMetricType>::MinDistance(const BallBound& other) const
{
if (radius < 0)
return DBL_MAX;
else
{
- double delta = metric::EuclideanDistance::Evaluate(center, other.center)
- - radius - other.radius;
+ const double delta = metric->Evaluate(center, other.center) - radius -
+ other.radius;
return math::ClampNonNegative(delta);
}
}
@@ -87,29 +137,29 @@ double BallBound<VecType>::MinDistance(const BallBound& other) const
/**
* Computes maximum distance.
*/
-template<typename VecType>
+template<typename VecType, typename TMetricType>
template<typename OtherVecType>
-double BallBound<VecType>::MaxDistance(
+double BallBound<VecType, TMetricType>::MaxDistance(
const OtherVecType& point,
typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
{
if (radius < 0)
return DBL_MAX;
else
- return metric::EuclideanDistance::Evaluate(point, center) + radius;
+ return metric->Evaluate(point, center) + radius;
}
/**
* Computes maximum distance.
*/
-template<typename VecType>
-double BallBound<VecType>::MaxDistance(const BallBound& other) const
+template<typename VecType, typename TMetricType>
+double BallBound<VecType, TMetricType>::MaxDistance(const BallBound& other)
+ const
{
if (radius < 0)
return DBL_MAX;
else
- return metric::EuclideanDistance::Evaluate(other.center, center) + radius
- + other.radius;
+ return metric->Evaluate(other.center, center) + radius + other.radius;
}
/**
@@ -117,9 +167,9 @@ double BallBound<VecType>::MaxDistance(const BallBound& other) const
*
* Example: bound1.MinDistanceSq(other) for minimum squared distance.
*/
-template<typename VecType>
+template<typename VecType, typename TMetricType>
template<typename OtherVecType>
-math::Range BallBound<VecType>::RangeDistance(
+math::Range BallBound<VecType, TMetricType>::RangeDistance(
const OtherVecType& point,
typename boost::enable_if<IsVector<OtherVecType> >* /* junk */) const
{
@@ -127,22 +177,22 @@ math::Range BallBound<VecType>::RangeDistance(
return math::Range(DBL_MAX, DBL_MAX);
else
{
- double dist = metric::EuclideanDistance::Evaluate(center, point);
+ const double dist = metric->Evaluate(center, point);
return math::Range(math::ClampNonNegative(dist - radius),
dist + radius);
}
}
-template<typename VecType>
-math::Range BallBound<VecType>::RangeDistance(
+template<typename VecType, typename TMetricType>
+math::Range BallBound<VecType, TMetricType>::RangeDistance(
const BallBound& other) const
{
if (radius < 0)
return math::Range(DBL_MAX, DBL_MAX);
else
{
- double dist = metric::EuclideanDistance::Evaluate(center, other.center);
- double sumradius = radius + other.radius;
+ const double dist = metric->Evaluate(center, other.center);
+ const double sumradius = radius + other.radius;
return math::Range(math::ClampNonNegative(dist - sumradius),
dist + sumradius);
}
@@ -151,12 +201,12 @@ math::Range BallBound<VecType>::RangeDistance(
/**
* Expand the bound to include the given bound.
*
-template<typename VecType>
+template<typename VecType, typename TMetricType>
const BallBound<VecType>&
-BallBound<VecType>::operator|=(
+BallBound<VecType, TMetricType>::operator|=(
const BallBound<VecType>& other)
{
- double dist = metric::EuclideanDistance::Evaluate(center, other);
+ double dist = metric->Evaluate(center, other);
// Now expand the radius as necessary.
if (dist > radius)
@@ -166,12 +216,15 @@ BallBound<VecType>::operator|=(
}*/
/**
- * Expand the bound to include the given point.
+ * Expand the bound to include the given point. Algorithm adapted from
+ * Jack Ritter, "An Efficient Bounding Sphere" in Graphics Gems (1990).
+ * The difference lies in the way we initialize the ball bound. The way we
+ * expand the bound is same.
*/
-template<typename VecType>
+template<typename VecType, typename TMetricType>
template<typename MatType>
-const BallBound<VecType>&
-BallBound<VecType>::operator|=(const MatType& data)
+const BallBound<VecType, TMetricType>&
+BallBound<VecType, TMetricType>::operator|=(const MatType& data)
{
if (radius < 0)
{
@@ -179,35 +232,37 @@ BallBound<VecType>::operator|=(const MatType& data)
radius = 0;
}
- // Now iteratively add points. There is probably a closed-form solution to
- // find the minimum bounding circle, and it is probably faster.
- for (size_t i = 1; i < data.n_cols; ++i)
+ // Now iteratively add points.
+ for (size_t i = 0; i < data.n_cols; ++i)
{
- double dist = metric::EuclideanDistance::Evaluate(center, (VecType)
- data.col(i)) - radius;
+ const double dist = metric->Evaluate(center, (VecType) data.col(i));
- if (dist > 0)
+ // See if the new point lies outside the bound.
+ if (dist > radius)
{
- // Move (dist / 2) towards the new point and increase radius by
- // (dist / 2).
+ // Move towards the new point and increase the radius just enough to
+ // accomodate the new point.
arma::vec diff = data.col(i) - center;
- center += 0.5 * diff;
- radius += 0.5 * dist;
+ center += ((dist - radius) / (2 * dist)) * diff;
+ radius = 0.5 * (dist + radius);
}
}
return *this;
}
+
/**
* Returns a string representation of this object.
*/
-template<typename VecType>
-std::string BallBound<VecType>::ToString() const
+template<typename VecType, typename TMetricType>
+std::string BallBound<VecType, TMetricType>::ToString() const
{
std::ostringstream convert;
convert << "BallBound [" << this << "]" << std::endl;
convert << " Radius: " << radius << std::endl;
convert << " Center:" << std::endl << center;
+ convert << " ownsMetric: " << ownsMetric << std::endl;
+ convert << " Metric:" << std::endl << metric->ToString();
return convert.str();
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git
More information about the debian-science-commits
mailing list