[shark] 01/58: Add functions to export generated model
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Wed Mar 16 10:05:26 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch master
in repository shark.
commit 8af15e7812b14b97193eae733fa2ddee061b9324
Author: Michele Mastropietro <michele.mastropietro at gmail.com>
Date: Tue Nov 3 20:19:42 2015 +0100
Add functions to export generated model
---
include/shark/Models/Trees/CARTClassifier.h | 16 +++++++++++++++-
include/shark/Models/Trees/RFClassifier.h | 26 ++++++++++++++++++++++++++
2 files changed, 41 insertions(+), 1 deletion(-)
diff --git a/include/shark/Models/Trees/CARTClassifier.h b/include/shark/Models/Trees/CARTClassifier.h
index b9ed396..b602fbf 100644
--- a/include/shark/Models/Trees/CARTClassifier.h
+++ b/include/shark/Models/Trees/CARTClassifier.h
@@ -100,7 +100,16 @@ public:
/// Constructor taking the splitMatrix as argument
CARTClassifier(SplitMatrixType const& splitMatrix)
{
- setSplitMatrix(splitMatrix);
+ m_splitMatrix=splitMatrix;
+ }
+
+ /// Constructor taking the splitMatrix as argument and optimize it if requested
+ CARTClassifier(SplitMatrixType const& splitMatrix, bool optimize)
+ {
+ if (optimize)
+ setSplitMatrix(splitMatrix);
+ else
+ m_splitMatrix=splitMatrix;
}
/// Constructor taking the splitMatrix as argument as well as maximum number of attributes
@@ -147,6 +156,11 @@ public:
optimizeSplitMatrix(m_splitMatrix);
}
+ /// Get the model split matrix.
+ SplitMatrixType getSplitMatrix() const {
+ return m_splitMatrix;
+ }
+
/// \brief The model does not have any parameters.
std::size_t numberOfParameters()const{
return 0;
diff --git a/include/shark/Models/Trees/RFClassifier.h b/include/shark/Models/Trees/RFClassifier.h
index eface6f..0354a49 100644
--- a/include/shark/Models/Trees/RFClassifier.h
+++ b/include/shark/Models/Trees/RFClassifier.h
@@ -40,6 +40,8 @@
namespace shark {
+typedef CARTClassifier<RealVector>::SplitMatrixType SplitMatrixType;
+typedef std::vector<SplitMatrixType> ForestInfo;
///
/// \brief Random Forest Classifier.
@@ -114,6 +116,30 @@ public:
m_inputDimension = in;
}
+ ForestInfo getForestInfo() const {
+ ForestInfo finfo(m_models.size());
+ for (int i=0; i<m_models.size(); ++i)
+ finfo[i]=m_models[i].getSplitMatrix();
+ return finfo;
+ }
+
+ void setForestInfo(ForestInfo const& finfo, std::vector<double> const& weights = std::vector<double>()) {
+ std::size_t n_tree = finfo.size();
+ std::vector<double> we(weights);
+ m_models.resize(n_tree);
+ if (weights.empty()) // set default weights to 1
+ we.resize(n_tree, 1);
+ else if (weights.size() != n_tree)
+ throw SHARKEXCEPTION("Weights must be the same number as trees");
+
+ for (int i=0; i<n_tree; ++i)
+ {
+ m_models[i]=finfo[i];
+ m_weight.push_back(we[i]);
+ m_weightSum+=we[i];
+ }
+ }
+
protected:
// Dimension of label in the regression case, number of classes in the classification case.
std::size_t m_labelDimension;
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/shark.git
More information about the debian-science-commits
mailing list