[arrayfire] 212/408: FEAT: Adding select for CPU backend

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:12:00 UTC 2015


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.

commit 910feb738fec5f69f9edeeca04122938dc4c0da6
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date:   Sat Aug 8 01:48:07 2015 -0400

    FEAT: Adding select for CPU backend
    
    - Also added stubs for CUDA and OpenCL backends
---
 include/af/data.h             |  50 ++++++++++++++
 src/api/c/select.cpp          | 156 ++++++++++++++++++++++++++++++++++++++++++
 src/api/cpp/data.cpp          |  23 +++++++
 src/backend/cpu/select.cpp    | 143 ++++++++++++++++++++++++++++++++++++++
 src/backend/cpu/select.hpp    |  19 +++++
 src/backend/cuda/select.cu    |  50 ++++++++++++++
 src/backend/cuda/select.hpp   |  19 +++++
 src/backend/opencl/select.cpp |  51 ++++++++++++++
 src/backend/opencl/select.hpp |  19 +++++
 9 files changed, 530 insertions(+)

diff --git a/include/af/data.h b/include/af/data.h
index 64f95f5..ae90e24 100644
--- a/include/af/data.h
+++ b/include/af/data.h
@@ -490,6 +490,26 @@ namespace af
     AFAPI array upper(const array &in, bool is_unit_diag=false);
 
     /**
+       \param[in]  cond is the conditional array
+       \param[in]  a is the array containing elements from the true part of the condition
+       \param[in]  b is the array containing elements from the false part of the condition
+       \return  the output containing elements of \p a when \p cond is true else elements from \p b
+
+       \ingroup data_func_select
+    */
+    AFAPI array select(const array &cond, const array  &a, const array  &b);
+
+    /**
+       \param[in]  cond is the conditional array
+       \param[in]  a is the array containing elements from the true part of the condition
+       \param[in]  b is a scalar assigned to \p out when \p cond is false
+       \return  the output containing elements of \p a when \p cond is true else the value \p b
+
+       \ingroup data_func_select
+    */
+    AFAPI array select(const array &cond, const array  &a, const double &b);
+
+    /**
       @}
     */
 }
@@ -759,6 +779,36 @@ extern "C" {
     /**
       @}
     */
+
+    /**
+       \param[out] out is the output containing elements of \p a when \p cond is true else elements from \p b
+       \param[in]  cond is the conditional array
+       \param[in]  a is the array containing elements from the true part of the condition
+       \param[in]  b is the array containing elements from the false part of the condition
+
+       \ingroup data_func_select
+    */
+    AFAPI af_err af_select(af_array *out, const af_array cond, const af_array a, const af_array b);
+
+    /**
+       \param[out] out is the output containing elements of \p a when \p cond is true else elements from \p b
+       \param[in]  cond is the conditional array
+       \param[in]  a is the array containing elements from the true part of the condition
+       \param[in]  b is a scalar assigned to \p out when \p cond is false
+
+       \ingroup data_func_select
+    */
+    AFAPI af_err af_select_scalar_r(af_array *out, const af_array cond, const af_array a, const double b);
+
+    /**
+       \param[out] out is the output containing elements of \p a when \p cond is true else elements from \p b
+       \param[in]  cond is the conditional array
+       \param[in]  a is a scalar assigned to \p out when \p cond is true
+       \param[in]  b is the array containing elements from the false part of the condition
+
+       \ingroup data_func_select
+    */
+    AFAPI af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b);
 #ifdef __cplusplus
 }
 #endif
diff --git a/src/api/c/select.cpp b/src/api/c/select.cpp
new file mode 100644
index 0000000..06eef2a
--- /dev/null
+++ b/src/api/c/select.cpp
@@ -0,0 +1,156 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <af/array.h>
+#include <af/defines.h>
+#include <af/arith.h>
+#include <af/data.h>
+#include <ArrayInfo.hpp>
+#include <optypes.hpp>
+#include <implicit.hpp>
+#include <err_common.hpp>
+#include <handle.hpp>
+#include <backend.hpp>
+#include <select.hpp>
+
+using namespace detail;
+using af::dim4;
+
+template<typename T>
+af_array select(const af_array cond, const af_array a, const af_array b, const dim4 &odims)
+{
+    Array<T> out = createEmptyArray<T>(odims);
+    select(out, getArray<char>(cond), getArray<T>(a), getArray<T>(b));
+    return getHandle<T>(out);
+}
+
+af_err af_select(af_array *out, const af_array cond, const af_array a, const af_array b)
+{
+    try {
+        ArrayInfo ainfo = getInfo(a);
+        ArrayInfo binfo = getInfo(b);
+        ArrayInfo cinfo = getInfo(cond);
+
+        ARG_ASSERT(2, ainfo.getType() == binfo.getType());
+        ARG_ASSERT(1, cinfo.getType() == b8);
+
+        DIM_ASSERT(1, cinfo.ndims() == std::min(ainfo.ndims(), binfo.ndims()));
+
+        dim4 adims = ainfo.dims();
+        dim4 bdims = binfo.dims();
+        dim4 cdims = cinfo.dims();
+        dim4 odims(1, 1, 1, 1);
+
+        for (int i = 0; i < 4; i++) {
+            DIM_ASSERT(1, cdims[i] == std::min(adims[i], bdims[i]));
+            DIM_ASSERT(2, adims[i] == bdims[i] || adims[i] == 1 || bdims[i] == 1);
+            odims[i] = std::max(adims[i], bdims[i]);
+        }
+
+        af_array res;
+
+        switch (ainfo.getType()) {
+        case f32: res = select<float  >(cond, a, b, odims); break;
+        case f64: res = select<double >(cond, a, b, odims); break;
+        case c32: res = select<cfloat >(cond, a, b, odims); break;
+        case c64: res = select<cdouble>(cond, a, b, odims); break;
+        case s32: res = select<int    >(cond, a, b, odims); break;
+        case u32: res = select<uint   >(cond, a, b, odims); break;
+        case s64: res = select<intl   >(cond, a, b, odims); break;
+        case u64: res = select<uintl  >(cond, a, b, odims); break;
+        case u8:  res = select<uchar  >(cond, a, b, odims); break;
+        case b8:  res = select<char   >(cond, a, b, odims); break;
+        default:  TYPE_ERROR(2, ainfo.getType());
+        }
+
+        std::swap(*out, res);
+    } CATCHALL;
+    return AF_SUCCESS;
+}
+
+template<typename T, bool flip>
+af_array select_scalar(const af_array cond, const af_array a, const double b, const dim4 &odims)
+{
+    Array<T> out = createEmptyArray<T>(odims);
+    select_scalar<T, flip>(out, getArray<char>(cond), getArray<T>(a), b);
+    return getHandle<T>(out);
+}
+
+af_err af_select_scalar_r(af_array *out, const af_array cond, const af_array a, const double b)
+{
+    try {
+        ArrayInfo ainfo = getInfo(a);
+        ArrayInfo cinfo = getInfo(cond);
+
+        ARG_ASSERT(1, cinfo.getType() == b8);
+        DIM_ASSERT(1, cinfo.ndims() == ainfo.ndims());
+
+        dim4 adims = ainfo.dims();
+        dim4 cdims = cinfo.dims();
+
+        for (int i = 0; i < 4; i++) {
+            DIM_ASSERT(1, cdims[i] == adims[i]);
+        }
+
+        af_array res;
+
+        switch (ainfo.getType()) {
+        case f32: res = select_scalar<float  , false>(cond, a, b, adims); break;
+        case f64: res = select_scalar<double , false>(cond, a, b, adims); break;
+        case c32: res = select_scalar<cfloat , false>(cond, a, b, adims); break;
+        case c64: res = select_scalar<cdouble, false>(cond, a, b, adims); break;
+        case s32: res = select_scalar<int    , false>(cond, a, b, adims); break;
+        case u32: res = select_scalar<uint   , false>(cond, a, b, adims); break;
+        case s64: res = select_scalar<intl   , false>(cond, a, b, adims); break;
+        case u64: res = select_scalar<uintl  , false>(cond, a, b, adims); break;
+        case u8:  res = select_scalar<uchar  , false>(cond, a, b, adims); break;
+        case b8:  res = select_scalar<char   , false>(cond, a, b, adims); break;
+        default:  TYPE_ERROR(2, ainfo.getType());
+        }
+
+        std::swap(*out, res);
+    } CATCHALL;
+    return AF_SUCCESS;
+}
+
+af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b)
+{
+    try {
+        ArrayInfo binfo = getInfo(b);
+        ArrayInfo cinfo = getInfo(cond);
+
+        ARG_ASSERT(1, cinfo.getType() == b8);
+        DIM_ASSERT(1, cinfo.ndims() == binfo.ndims());
+
+        dim4 bdims = binfo.dims();
+        dim4 cdims = cinfo.dims();
+
+        for (int i = 0; i < 4; i++) {
+            DIM_ASSERT(1, cdims[i] == bdims[i]);
+        }
+
+        af_array res;
+
+        switch (binfo.getType()) {
+        case f32: res = select_scalar<float  , true >(cond, b, a, bdims); break;
+        case f64: res = select_scalar<double , true >(cond, b, a, bdims); break;
+        case c32: res = select_scalar<cfloat , true >(cond, b, a, bdims); break;
+        case c64: res = select_scalar<cdouble, true >(cond, b, a, bdims); break;
+        case s32: res = select_scalar<int    , true >(cond, b, a, bdims); break;
+        case u32: res = select_scalar<uint   , true >(cond, b, a, bdims); break;
+        case s64: res = select_scalar<intl   , true >(cond, b, a, bdims); break;
+        case u64: res = select_scalar<uintl  , true >(cond, b, a, bdims); break;
+        case u8:  res = select_scalar<uchar  , true >(cond, b, a, bdims); break;
+        case b8:  res = select_scalar<char   , true >(cond, b, a, bdims); break;
+        default:  TYPE_ERROR(2, binfo.getType());
+        }
+
+        std::swap(*out, res);
+    } CATCHALL;
+    return AF_SUCCESS;
+}
diff --git a/src/api/cpp/data.cpp b/src/api/cpp/data.cpp
index bd9ac18..87bc20f 100644
--- a/src/api/cpp/data.cpp
+++ b/src/api/cpp/data.cpp
@@ -12,6 +12,7 @@
 #include <af/arith.h>
 #include <af/data.h>
 #include <af/traits.hpp>
+#include <af/gfor.h>
 #include "error.hpp"
 
 namespace af
@@ -352,4 +353,26 @@ namespace af
         AF_THROW(af_upper(&res, in.get(), is_unit_diag));
         return array(res);
     }
+
+    array select(const array &cond, const array &a, const array &b)
+    {
+        af_array res;
+        AF_THROW(af_select(&res, cond.get(), a.get(), b.get()));
+        return array(res);
+    }
+
+    array select(const array &cond, const array &a, const double &b)
+    {
+        af_array res;
+        AF_THROW(af_select_scalar_r(&res, cond.get(), a.get(), b));
+        return array(res);
+    }
+
+    array select(const array &cond, const double &a, const array &b)
+    {
+        af_array res;
+        AF_THROW(af_select_scalar_l(&res, cond.get(), a, b.get()));
+        return array(res);
+    }
+
 }
diff --git a/src/backend/cpu/select.cpp b/src/backend/cpu/select.cpp
new file mode 100644
index 0000000..286e884
--- /dev/null
+++ b/src/backend/cpu/select.cpp
@@ -0,0 +1,143 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <ArrayInfo.hpp>
+#include <Array.hpp>
+#include <select.hpp>
+#include <err_cpu.hpp>
+
+using af::dim4;
+
+namespace cpu
+{
+    template<typename T>
+    void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b)
+    {
+        dim4 adims = a.dims();
+        dim4 astrides = a.strides();
+        dim4 bdims = b.dims();
+        dim4 bstrides = b.strides();
+
+        dim4 cdims = cond.dims();
+        dim4 cstrides = cond.strides();
+
+        dim4 odims = out.dims();
+        dim4 ostrides = out.strides();
+
+        bool is_a_same[] = {adims[0] == odims[0], adims[1] == odims[1],
+                            adims[2] == odims[2], adims[3] == odims[3]};
+
+        bool is_b_same[] = {bdims[0] == odims[0], bdims[1] == odims[1],
+                            bdims[2] == odims[2], bdims[3] == odims[3]};
+
+        bool is_c_same[] = {cdims[0] == odims[0], cdims[1] == odims[1],
+                            cdims[2] == odims[2], cdims[3] == odims[3]};
+
+        const T *aptr = a.get();
+        const T *bptr = b.get();
+        T *optr = out.get();
+        const char *cptr = cond.get();
+
+        for (int l = 0; l < odims[3]; l++) {
+
+            int o_off3   = ostrides[3] * l;
+            int a_off3   = astrides[3] * is_a_same[3] * l;
+            int b_off3   = bstrides[3] * is_b_same[3] * l;
+            int c_off3   = cstrides[3] * is_c_same[3] * l;
+
+            for (int k = 0; k < odims[2]; k++) {
+
+                int o_off2   = ostrides[2] * k + o_off3;
+                int a_off2   = astrides[2] * is_a_same[2] * k + a_off3;
+                int b_off2   = bstrides[2] * is_b_same[2] * k + b_off3;
+                int c_off2   = cstrides[2] * is_c_same[2] * k + c_off3;
+
+                for (int j = 0; j < odims[1]; j++) {
+
+                    int o_off1   = ostrides[1] * j + o_off2;
+                    int a_off1   = astrides[1] * is_a_same[1] * j + a_off2;
+                    int b_off1   = bstrides[1] * is_b_same[1] * j + b_off2;
+                    int c_off1   = cstrides[1] * is_c_same[1] * j + c_off2;
+
+                    for (int i = 0; i < odims[0]; i++) {
+
+                        bool cval = is_c_same[0] ? cptr[c_off1 + i] : cptr[c_off1];
+                        T    aval = is_a_same[0] ? aptr[a_off1 + i] : aptr[a_off1];
+                        T    bval = is_b_same[0] ? bptr[b_off1 + i] : bptr[b_off1];
+                        T    oval = cval ? aval : bval;
+                        optr[o_off1 + i] = oval;
+                    }
+                }
+            }
+        }
+    }
+
+    template<typename T, bool flip>
+    void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b)
+    {
+        dim4 astrides = a.strides();
+        dim4 cstrides = cond.strides();
+
+        dim4 odims = out.dims();
+        dim4 ostrides = out.strides();
+
+        const T *aptr = a.get();
+        T *optr = out.get();
+        const char *cptr = cond.get();
+
+        for (int l = 0; l < odims[3]; l++) {
+
+            int o_off3 = ostrides[3] * l;
+            int a_off3 = astrides[3] * l;
+            int c_off3 = cstrides[3] * l;
+
+            for (int k = 0; k < odims[2]; k++) {
+
+                int o_off2 = ostrides[2] * k + o_off3;
+                int a_off2 = astrides[2] * k + a_off3;
+                int c_off2 = cstrides[2] * k + c_off3;
+
+                for (int j = 0; j < odims[1]; j++) {
+
+                    int o_off1 = ostrides[1] * j + o_off2;
+                    int a_off1 = astrides[1] * j + a_off2;
+                    int c_off1 = cstrides[1] * j + c_off2;
+
+                    for (int i = 0; i < odims[0]; i++) {
+
+                        optr[o_off1 + i] = (flip ^ cptr[c_off1 + i]) ? aptr[a_off1 + i] : b;
+                    }
+                }
+            }
+        }
+    }
+
+
+#define INSTANTIATE(T)                                              \
+    template void select<T>(Array<T> &out, const Array<char> &cond, \
+                            const Array<T> &a, const Array<T> &b);  \
+    template void select_scalar<T, true >(Array<T> &out,            \
+                                          const Array<char> &cond,  \
+                                          const Array<T> &a,        \
+                                          const double &b);         \
+    template void select_scalar<T, false>(Array<T> &out, const      \
+                                          Array<char> &cond,        \
+                                          const Array<T> &a,        \
+                                          const double &b);         \
+
+    INSTANTIATE(float  )
+    INSTANTIATE(double )
+    INSTANTIATE(cfloat )
+    INSTANTIATE(cdouble)
+    INSTANTIATE(int    )
+    INSTANTIATE(uint   )
+    INSTANTIATE(intl   )
+    INSTANTIATE(uintl  )
+    INSTANTIATE(char   )
+    INSTANTIATE(uchar  )
+}
diff --git a/src/backend/cpu/select.hpp b/src/backend/cpu/select.hpp
new file mode 100644
index 0000000..0d725ac
--- /dev/null
+++ b/src/backend/cpu/select.hpp
@@ -0,0 +1,19 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#pragma once
+#include <Array.hpp>
+
+namespace cpu
+{
+    template<typename T>
+    void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b);
+
+    template<typename T, bool flip>
+    void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b);
+}
diff --git a/src/backend/cuda/select.cu b/src/backend/cuda/select.cu
new file mode 100644
index 0000000..5204057
--- /dev/null
+++ b/src/backend/cuda/select.cu
@@ -0,0 +1,50 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <ArrayInfo.hpp>
+#include <Array.hpp>
+#include <select.hpp>
+#include <err_cuda.hpp>
+
+namespace cuda
+{
+    template<typename T>
+    void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b)
+    {
+        CUDA_NOT_SUPPORTED();
+    }
+
+    template<typename T, bool flip>
+    void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b)
+    {
+        CUDA_NOT_SUPPORTED();
+    }
+
+#define INSTANTIATE(T)                                              \
+    template void select<T>(Array<T> &out, const Array<char> &cond, \
+                            const Array<T> &a, const Array<T> &b);  \
+    template void select_scalar<T, true >(Array<T> &out,            \
+                                          const Array<char> &cond,  \
+                                          const Array<T> &a,        \
+                                          const double &b);         \
+    template void select_scalar<T, false>(Array<T> &out, const      \
+                                          Array<char> &cond,        \
+                                          const Array<T> &a,        \
+                                          const double &b);         \
+
+    INSTANTIATE(float  )
+    INSTANTIATE(double )
+    INSTANTIATE(cfloat )
+    INSTANTIATE(cdouble)
+    INSTANTIATE(int    )
+    INSTANTIATE(uint   )
+    INSTANTIATE(intl   )
+    INSTANTIATE(uintl  )
+    INSTANTIATE(char   )
+    INSTANTIATE(uchar  )
+}
diff --git a/src/backend/cuda/select.hpp b/src/backend/cuda/select.hpp
new file mode 100644
index 0000000..872fe25
--- /dev/null
+++ b/src/backend/cuda/select.hpp
@@ -0,0 +1,19 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#pragma once
+#include <Array.hpp>
+
+namespace cuda
+{
+    template<typename T>
+    void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b);
+
+    template<typename T, bool flip>
+    void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b);
+}
diff --git a/src/backend/opencl/select.cpp b/src/backend/opencl/select.cpp
new file mode 100644
index 0000000..92bcc2b
--- /dev/null
+++ b/src/backend/opencl/select.cpp
@@ -0,0 +1,51 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <ArrayInfo.hpp>
+#include <Array.hpp>
+#include <select.hpp>
+#include <err_opencl.hpp>
+
+namespace opencl
+{
+    template<typename T>
+    void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b)
+    {
+        OPENCL_NOT_SUPPORTED();
+    }
+
+    template<typename T, bool flip>
+    void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b)
+    {
+        OPENCL_NOT_SUPPORTED();
+    }
+
+
+#define INSTANTIATE(T)                                              \
+    template void select<T>(Array<T> &out, const Array<char> &cond, \
+                            const Array<T> &a, const Array<T> &b);  \
+    template void select_scalar<T, true >(Array<T> &out,            \
+                                          const Array<char> &cond,  \
+                                          const Array<T> &a,        \
+                                          const double &b);         \
+    template void select_scalar<T, false>(Array<T> &out, const      \
+                                          Array<char> &cond,        \
+                                          const Array<T> &a,        \
+                                          const double &b);         \
+
+    INSTANTIATE(float  )
+    INSTANTIATE(double )
+    INSTANTIATE(cfloat )
+    INSTANTIATE(cdouble)
+    INSTANTIATE(int    )
+    INSTANTIATE(uint   )
+    INSTANTIATE(intl   )
+    INSTANTIATE(uintl  )
+    INSTANTIATE(char   )
+    INSTANTIATE(uchar  )
+}
diff --git a/src/backend/opencl/select.hpp b/src/backend/opencl/select.hpp
new file mode 100644
index 0000000..5bc2f60
--- /dev/null
+++ b/src/backend/opencl/select.hpp
@@ -0,0 +1,19 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#pragma once
+#include <Array.hpp>
+
+namespace opencl
+{
+    template<typename T>
+    void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b);
+
+    template<typename T, bool flip>
+    void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b);
+}

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git



More information about the debian-science-commits mailing list