[mlpack] 193/207: Significant refactoring to fix ImputationTest.

Barak A. Pearlmutter barak+git at pearlmutter.net
Thu Mar 23 17:53:53 UTC 2017


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 0f14f37601ce69f75ab4417427eadf24dc396a8e
Author: Ryan Curtin <ryan at ratml.org>
Date:   Sat Mar 18 15:00:27 2017 -0400

    Significant refactoring to fix ImputationTest.
    
    There are a few major changes here:
    
     * We now take two passes through the file: the first to get the full size (and
       get information for the DatasetMapper).
    
     * We don't use the numeric rule, because the DatasetMapper must see all
       possible tokens that could be mapped.
    
     * We don't terminate parsing when we have more rows... I think this will
       provide some acceleration but I have not tested it.
---
 src/mlpack/core/data/load_csv.hpp | 248 +++++++++++++++++++++++++-------------
 1 file changed, 166 insertions(+), 82 deletions(-)

diff --git a/src/mlpack/core/data/load_csv.hpp b/src/mlpack/core/data/load_csv.hpp
index f9e3fb3..6dbd2f9 100644
--- a/src/mlpack/core/data/load_csv.hpp
+++ b/src/mlpack/core/data/load_csv.hpp
@@ -13,6 +13,7 @@
 #define MLPACK_CORE_DATA_LOAD_CSV_HPP
 
 #include <boost/spirit/include/qi.hpp>
+#include <boost/algorithm/string/trim.hpp>
 
 #include <mlpack/core.hpp>
 #include <mlpack/core/util/log.hpp>
@@ -58,6 +59,129 @@ public:
   size_t ColSize();
   size_t RowSize();
 
+  /**
+   * Peek at the file to determine the number of rows and columns in the matrix,
+   * assuming a non-transposed matrix.  This will also take a first pass over
+   * the data for DatasetMapper, if MapPolicy::NeedsFirstPass is true.  The info
+   * object will be re-initialized with the correct dimensionality.
+   *
+   * @param rows Variable to be filled with the number of rows.
+   * @param cols Variable to be filled with the number of columns.
+   * @param info DatasetMapper object to use for first pass.
+   */
+  template<typename T, typename MapPolicy>
+  void GetMatrixSize(size_t& rows, size_t& cols, DatasetMapper<MapPolicy>& info)
+  {
+    // Take a pass through the file.  If the DatasetMapper policy requires it,
+    // we will pass everything string through MapString().  This might be useful
+    // if, e.g., the MapPolicy needs to find which dimensions are numeric or
+    // categorical.
+
+    // Reset to the start of the file.
+    inFile.clear();
+    inFile.seekg(0, std::ios::beg);
+    rows = 0;
+    cols = 0;
+
+    // First, count the number of rows in the file (this is the dimensionality).
+    std::string line;
+    while (std::getline(inFile, line))
+    {
+      ++rows;
+    }
+    info = DatasetMapper<MapPolicy>(rows);
+
+    // Now, jump back to the beginning of the file.
+    inFile.clear();
+    inFile.seekg(0, std::ios::beg);
+    rows = 0;
+    while (std::getline(inFile, line))
+    {
+      ++rows;
+
+      if (rows == 1)
+      {
+        // Extract the number of columns.
+        auto findColSize = [&cols](iter_type) { ++cols; };
+        boost::spirit::qi::phrase_parse(line.begin(), line.end(),
+            CreateCharRule()[findColSize] % ",", boost::spirit::ascii::space);
+      }
+
+      // I guess this is technically a second pass, but that's ok... still the
+      // same idea...
+      if (MapPolicy::NeedsFirstPass)
+      {
+        // In this case we must pass everything we parse to the MapPolicy.
+        auto firstPassMap = [&](const iter_type& iter)
+        {
+          std::string str(iter.begin(), iter.end());
+          if (str == "\t")
+            str.clear();
+          boost::trim(str);
+
+          info.template MapFirstPass<T>(std::move(str), rows - 1);
+        };
+
+        // Now parse the line.
+        boost::spirit::qi::phrase_parse(line.begin(), line.end(),
+            CreateCharRule()[firstPassMap] % ",", boost::spirit::ascii::space);
+      }
+    }
+  }
+
+  template<typename T, typename MapPolicy>
+  void GetTransposeMatrixSize(size_t& rows, size_t& cols, DatasetMapper<MapPolicy>& info)
+  {
+    // Take a pass through the file.  If the DatasetMapper policy requires it,
+    // we will pass everything string through MapString().  This might be useful
+    // if, e.g., the MapPolicy needs to find which dimensions are numeric or
+    // categorical.
+
+    // Reset to the start of the file.
+    inFile.clear();
+    inFile.seekg(0, std::ios::beg);
+    rows = 0;
+    cols = 0;
+
+    std::string line;
+    while (std::getline(inFile, line))
+    {
+      ++cols;
+
+      if (cols == 1)
+      {
+        // Extract the number of dimensions.
+        auto findRowSize = [&rows](iter_type) { ++rows; };
+        boost::spirit::qi::phrase_parse(line.begin(), line.end(),
+            CreateCharRule()[findRowSize] % ",", boost::spirit::ascii::space);
+
+        // Now that we know the dimensionality, initialize the DatasetMapper.
+        info = DatasetMapper<MapPolicy>(rows);
+      }
+
+      // If we need to do a first pass for the DatasetMapper, do it.
+      if (MapPolicy::NeedsFirstPass)
+      {
+        size_t dim = 0;
+
+        // In this case we must pass everything we parse to the MapPolicy.
+        auto firstPassMap = [&](const iter_type& iter)
+        {
+          std::string str(iter.begin(), iter.end());
+          if (str == "\t")
+            str.clear();
+          boost::trim(str);
+
+          info.template MapFirstPass<T>(std::move(str), dim++);
+        };
+
+        // Now parse the line.
+        boost::spirit::qi::phrase_parse(line.begin(), line.end(),
+            CreateCharRule()[firstPassMap] % ",", boost::spirit::ascii::space);
+      }
+    }
+  }
+
 private:
   using iter_type = boost::iterator_range<std::string::iterator>;
 
@@ -89,51 +213,45 @@ private:
   {
     using namespace boost::spirit;
 
+    // Get the size of the matrix.
+    size_t rows, cols;
+    GetMatrixSize<T>(rows, cols, infoSet);
+
+    // Set up output matrix.
+    inout.set_size(rows, cols);
     size_t row = 0;
     size_t col = 0;
-    infoSet = DatasetMapper<PolicyType>(RowSize());
+
+    // Reset file position.
     std::string line;
-    inout.set_size(infoSet.Dimensionality(), ColSize());
     inFile.clear();
     inFile.seekg(0, std::ios::beg);
 
-    auto setNum = [&](T val)
-    {
-      inout(row, col++) = val;
-    };
     auto setCharClass = [&](iter_type const &iter)
     {
       std::string str(iter.begin(), iter.end());
-      if(str == "\t")
+      if (str == "\t")
       {
         str.clear();
       }
-      inout(row, col++) =
-          static_cast<T>(infoSet.MapString(std::move(str),
-                                           row));
+      boost::trim(str);
+
+      inout(row, col++) = infoSet.template MapString<T>(std::move(str), row);
     };
 
-    auto numRule = CreateNumRule<T>();
     auto charRule = CreateCharRule();
-    while(std::getline(inFile, line))
+    while (std::getline(inFile, line))
     {
-      auto begin = line.begin();
       //parse the numbers from a line(ex : 1,2,3,4), if the parser find the number
       //it will execute the setNum function
-      qi::phrase_parse(begin, line.end(), numRule[setNum] % ",", ascii::space);
-      if(col != inout.n_cols)
+      const bool canParse = qi::phrase_parse(line.begin(), line.end(),
+          charRule[setCharClass] % ",", ascii::space);
+
+      if (!canParse)
       {
-        begin = line.begin();
-        col = 0;
-        const bool canParse = qi::phrase_parse(begin, line.end(),
-                                               charRule[setCharClass] % ",",
-                                               ascii::space);
-        if(!canParse)
-        {
-          throw std::runtime_error("LoadCSV cannot parse categories");
-          break;
-        }
+        throw std::runtime_error("LoadCSV cannot parse categories");
       }
+
       ++row; col = 0;
     }
   }
@@ -141,88 +259,54 @@ private:
   template<typename T, typename PolicyType>
   void TranposeParse(arma::Mat<T> &inout, DatasetMapper<PolicyType> &infoSet)
   {
-    infoSet = DatasetMapper<PolicyType>(ColSize());
-    inout.set_size(infoSet.Dimensionality(), RowSize());
-    size_t parseTime = 0;
-    std::set<size_t> mapCols;
-    while(!TranposeParseImpl(inout, infoSet, mapCols))
-    {      
-      ++parseTime; //avoid infinite loop
-      infoSet = DatasetMapper<PolicyType>(inout.n_rows);
-      if(parseTime == inout.n_rows)
-      {
-        return;
-      }
-    }
+    // Get matrix size.  This also initializes infoSet correctly.
+    size_t rows, cols;
+    GetTransposeMatrixSize<T>(rows, cols, infoSet);
+
+    // Set the matrix size.
+    inout.set_size(rows, cols);
+    TranposeParseImpl(inout, infoSet);
   }
 
   template<typename T, typename PolicyType>
-  bool TranposeParseImpl(arma::Mat<T> &inout, DatasetMapper<PolicyType> &infoSet,
-                         std::set<size_t> &mapCols)
+  bool TranposeParseImpl(arma::Mat<T>& inout,
+                         DatasetMapper<PolicyType>& infoSet)
   {
     using namespace boost::spirit;
 
     size_t row = 0;
     size_t col = 0;
-    size_t progress = 0;
     std::string line;
     inFile.clear();
     inFile.seekg(0, std::ios::beg);
-    auto setNum = [&](T val)
-    {
-      if(mapCols.find(progress) != std::end(mapCols))
-      {
-        inout(row, col) =
-            static_cast<T>(infoSet.MapString(std::to_string(val),
-                                             progress));
-      }
-      else
-      {
-        inout(row, col) = val;
-      }
-      ++progress; ++row;
-    };
+
     auto setCharClass = [&](iter_type const &iter)
     {
-      if(mapCols.find(progress) != std::end(mapCols))
-      {
-        std::string str(iter.begin(), iter.end());
-        if(str == "\t")
-        {
-          str.clear();
-        }
-        inout(row, col) =
-            static_cast<T>(infoSet.MapString(std::move(str),
-                                             progress));
-      }
-      else
-      {
-        mapCols.insert(progress);
-      }
-      ++progress; ++row;
+      // All parsed values must be mapped.
+      std::string str(iter.begin(), iter.end());
+      if (str == "\t")
+        str.clear();
+      boost::trim(str);
+
+      inout(row, col) = infoSet.template MapString<T>(std::move(str), row);
+      ++row;
     };
 
-    auto numRule = CreateNumRule<T>();
     auto charRule = CreateCharRule();
-    while(std::getline(inFile, line))
+    while (std::getline(inFile, line))
     {
-      auto begin = line.begin();
       row = 0;
-      progress = 0;
-      const size_t oldSize = mapCols.size();
       //parse number of characters from a line, it will execute setNum if it is number,
       //else execute setCharClass, "|" means "if not a, then b"
-      const bool canParse = qi::phrase_parse(begin, line.end(),
-                                             (numRule[setNum] | charRule[setCharClass]) % ",",
+      // Assemble the rule
+
+      const bool canParse = qi::phrase_parse(line.begin(), line.end(),
+                                             charRule[setCharClass] % ",",
                                              ascii::space);
       if(!canParse)
       {
         throw std::runtime_error("LoadCSV cannot parse categories");
       }
-      if(mapCols.size() > oldSize)
-      {
-        return false;
-      }
       ++col;
     }
 

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list