[arrayfire] 215/408: FEAT: replace for all backends

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:12:01 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.

commit 981c5e663f5100438c6ff53f676e05aba8aabaf7
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Sat Aug 8 22:38:49 2015 -0400

    FEAT: replace for all backends
---
 include/af/data.h     |  48 ++++++++++++++++++++++
 src/api/c/replace.cpp | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++
 src/api/cpp/data.cpp  |   9 +++++
 3 files changed, 166 insertions(+)

diff --git a/include/af/data.h b/include/af/data.h
index ae90e24..acbc0f6 100644
--- a/include/af/data.h
+++ b/include/af/data.h
@@ -510,6 +510,35 @@ namespace af
     AFAPI array select(const array &cond, const array  &a, const double &b);
 
     /**
+       \param[in]  cond is the conditional array
+       \param[in]  a is a scalar assigned to \p out when \p cond is true
+       \param[in]  b is the array containing elements from the false part of the condition
+       \return  the output containing the value \p a when \p cond is true else elements from \p b
+
+       \ingroup data_func_select
+    */
+    AFAPI array select(const array &cond, const double &a, const array  &b);
+
+    /**
+       \param[inout]  a is the array whose values are replaced with values from \p b when \p cond is true
+       \param[in]  cond is the conditional array
+       \param[in]  b is the array containing elements which replace elements in \p a when \p cond is true
+
+       \ingroup data_func_replace
+    */
+    AFAPI void replace(array &a, const array  &cond, const array  &b);
+
+    /**
+       \param[inout]  a is the array whose values are replaced with values from \p b when \p cond is true
+       \param[in]  cond is the conditional array
+       \param[in]  b is value that replaces elements in \p a when \p cond is true
+
+       \ingroup data_func_replace
+    */
+    AFAPI void replace(array &a, const array  &cond, const double &b);
+
+
+    /**
       @}
     */
 }
@@ -809,6 +838,25 @@ extern "C" {
        \ingroup data_func_select
     */
     AFAPI af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b);
+
+
+    /**
+       \param[inout] a is the array whose values are replaced by \p b when \p cond is true
+       \param[in]  cond is the conditional array
+       \param[in]  b is the array containing elements that replaces elements of a where \p cond is false
+
+       \ingroup data_func_replace
+    */
+    AFAPI af_err af_replace(af_array a, const af_array cond, const af_array b);
+
+    /**
+       \param[inout] a is the array whose values are replaced by \p b when \p cond is true
+       \param[in]  cond is the conditional array
+       \param[in]  b is the scalar that replaces the false parts of \p a
+
+       \ingroup data_func_replace
+    */
+    AFAPI af_err af_replace_scalar(af_array a, const af_array cond, const double b);
 #ifdef __cplusplus
 }
 #endif
diff --git a/src/api/c/replace.cpp b/src/api/c/replace.cpp
new file mode 100644
index 0000000..1f37988
--- /dev/null
+++ b/src/api/c/replace.cpp
@@ -0,0 +1,109 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <af/array.h>
+#include <af/defines.h>
+#include <af/arith.h>
+#include <af/data.h>
+#include <ArrayInfo.hpp>
+#include <optypes.hpp>
+#include <implicit.hpp>
+#include <err_common.hpp>
+#include <handle.hpp>
+#include <backend.hpp>
+
+#include <select.hpp>
+
+using namespace detail;
+using af::dim4;
+
+template<typename T>
+void replace(af_array a, const af_array cond, const af_array b)
+{
+    select(getWritableArray<T>(a), getArray<char>(cond), getArray<T>(a), getArray<T>(b));
+}
+
+af_err af_replace(af_array a, const af_array cond, const af_array b)
+{
+    try {
+        ArrayInfo ainfo = getInfo(a);
+        ArrayInfo binfo = getInfo(b);
+        ArrayInfo cinfo = getInfo(cond);
+
+        ARG_ASSERT(2, ainfo.getType() == binfo.getType());
+        ARG_ASSERT(1, cinfo.getType() == b8);
+
+        DIM_ASSERT(1, ainfo.ndims() >= binfo.ndims());
+        DIM_ASSERT(1, cinfo.ndims() == std::min(ainfo.ndims(), binfo.ndims()));
+
+        dim4 adims = ainfo.dims();
+        dim4 bdims = binfo.dims();
+        dim4 cdims = cinfo.dims();
+
+        for (int i = 0; i < 4; i++) {
+            DIM_ASSERT(1, cdims[i] == std::min(adims[i], bdims[i]));
+            DIM_ASSERT(2, adims[i] == bdims[i] || bdims[i] == 1);
+        }
+
+        switch (ainfo.getType()) {
+        case f32: replace<float  >(a, cond, b); break;
+        case f64: replace<double >(a, cond, b); break;
+        case c32: replace<cfloat >(a, cond, b); break;
+        case c64: replace<cdouble>(a, cond, b); break;
+        case s32: replace<int    >(a, cond, b); break;
+        case u32: replace<uint   >(a, cond, b); break;
+        case s64: replace<intl   >(a, cond, b); break;
+        case u64: replace<uintl  >(a, cond, b); break;
+        case u8:  replace<uchar  >(a, cond, b); break;
+        case b8:  replace<char   >(a, cond, b); break;
+        default:  TYPE_ERROR(2, ainfo.getType());
+        }
+
+    } CATCHALL;
+    return AF_SUCCESS;
+}
+
+template<typename T>
+void replace_scalar(af_array a, const af_array cond, const double b)
+{
+    select_scalar<T, true>(getWritableArray<T>(a), getArray<char>(cond), getArray<T>(a), b);
+}
+
+af_err af_replace_scalar(af_array a, const af_array cond, const double b)
+{
+    try {
+        ArrayInfo ainfo = getInfo(a);
+        ArrayInfo cinfo = getInfo(cond);
+
+        ARG_ASSERT(1, cinfo.getType() == b8);
+        DIM_ASSERT(1, cinfo.ndims() == ainfo.ndims());
+
+        dim4 adims = ainfo.dims();
+        dim4 cdims = cinfo.dims();
+
+        for (int i = 0; i < 4; i++) {
+            DIM_ASSERT(1, cdims[i] == adims[i]);
+        }
+
+        switch (ainfo.getType()) {
+        case f32: replace_scalar<float  >(a, cond, b); break;
+        case f64: replace_scalar<double >(a, cond, b); break;
+        case c32: replace_scalar<cfloat >(a, cond, b); break;
+        case c64: replace_scalar<cdouble>(a, cond, b); break;
+        case s32: replace_scalar<int    >(a, cond, b); break;
+        case u32: replace_scalar<uint   >(a, cond, b); break;
+        case s64: replace_scalar<intl   >(a, cond, b); break;
+        case u64: replace_scalar<uintl  >(a, cond, b); break;
+        case u8:  replace_scalar<uchar  >(a, cond, b); break;
+        case b8:  replace_scalar<char   >(a, cond, b); break;
+        default:  TYPE_ERROR(2, ainfo.getType());
+        }
+
+    } CATCHALL;
+    return AF_SUCCESS;
+}
diff --git a/src/api/cpp/data.cpp b/src/api/cpp/data.cpp
index 87bc20f..196fbf8 100644
--- a/src/api/cpp/data.cpp
+++ b/src/api/cpp/data.cpp
@@ -375,4 +375,13 @@ namespace af
         return array(res);
     }
 
+    void replace(array &a, const array &cond, const array &b)
+    {
+        AF_THROW(af_replace(a.get(), cond.get(), b.get()));
+    }
+
+    void replace(array &a, const array &cond, const double &b)
+    {
+        AF_THROW(af_replace_scalar(a.get(), cond.get(), b));
+    }
 }

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list