[shark] 01/58: Add functions to export generated model

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Wed Mar 16 10:05:26 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository shark.

commit 8af15e7812b14b97193eae733fa2ddee061b9324
Author: Michele Mastropietro <michele.mastropietro at gmail.com>
Date:   Tue Nov 3 20:19:42 2015 +0100

    Add functions to export generated model
---
 include/shark/Models/Trees/CARTClassifier.h | 16 +++++++++++++++-
 include/shark/Models/Trees/RFClassifier.h   | 26 ++++++++++++++++++++++++++
 2 files changed, 41 insertions(+), 1 deletion(-)

diff --git a/include/shark/Models/Trees/CARTClassifier.h b/include/shark/Models/Trees/CARTClassifier.h
index b9ed396..b602fbf 100644
--- a/include/shark/Models/Trees/CARTClassifier.h
+++ b/include/shark/Models/Trees/CARTClassifier.h
@@ -100,7 +100,16 @@ public:
 	/// Constructor taking the splitMatrix as argument
 	CARTClassifier(SplitMatrixType const& splitMatrix)
 	{
-		setSplitMatrix(splitMatrix);
+		m_splitMatrix=splitMatrix;
+	}
+
+	/// Constructor taking the splitMatrix as argument and optimize it if requested
+	CARTClassifier(SplitMatrixType const& splitMatrix, bool optimize)
+	{
+		if (optimize)
+			setSplitMatrix(splitMatrix);
+		else
+			m_splitMatrix=splitMatrix;
 	}
 
 	/// Constructor taking the splitMatrix as argument as well as maximum number of attributes
@@ -147,6 +156,11 @@ public:
 		optimizeSplitMatrix(m_splitMatrix);
 	}
 	
+	/// Get the model split matrix.
+	SplitMatrixType getSplitMatrix() const {
+		return m_splitMatrix;
+	}
+
 	/// \brief The model does not have any parameters.
 	std::size_t numberOfParameters()const{
 		return 0;
diff --git a/include/shark/Models/Trees/RFClassifier.h b/include/shark/Models/Trees/RFClassifier.h
index eface6f..0354a49 100644
--- a/include/shark/Models/Trees/RFClassifier.h
+++ b/include/shark/Models/Trees/RFClassifier.h
@@ -40,6 +40,8 @@
 
 namespace shark {
 
+typedef CARTClassifier<RealVector>::SplitMatrixType SplitMatrixType;
+typedef std::vector<SplitMatrixType> ForestInfo;
 
 ///
 /// \brief Random Forest Classifier.
@@ -114,6 +116,30 @@ public:
 		m_inputDimension = in;
 	}
 
+	ForestInfo getForestInfo() const {
+		ForestInfo finfo(m_models.size());
+		for (int i=0; i<m_models.size(); ++i)
+			finfo[i]=m_models[i].getSplitMatrix();
+		return finfo;
+	}
+
+	void setForestInfo(ForestInfo const& finfo, std::vector<double> const& weights = std::vector<double>()) {
+		std::size_t n_tree = finfo.size();
+		std::vector<double> we(weights);
+		m_models.resize(n_tree);
+		if (weights.empty()) // set default weights to 1
+			we.resize(n_tree, 1);
+		else if (weights.size() != n_tree)
+			throw SHARKEXCEPTION("Weights must be the same number as trees");
+
+		for (int i=0; i<n_tree; ++i)
+		{
+			m_models[i]=finfo[i];
+			m_weight.push_back(we[i]);
+			m_weightSum+=we[i];
+		}
+	}
+
 protected:
 	// Dimension of label in the regression case, number of classes in the classification case.
 	std::size_t m_labelDimension;

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/shark.git



More information about the debian-science-commits mailing list