[arrayfire] 212/408: FEAT: Adding select for CPU backend
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:12:00 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit 910feb738fec5f69f9edeeca04122938dc4c0da6
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Sat Aug 8 01:48:07 2015 -0400
FEAT: Adding select for CPU backend
- Also added stubs for CUDA and OpenCL backends
---
include/af/data.h | 50 ++++++++++++++
src/api/c/select.cpp | 156 ++++++++++++++++++++++++++++++++++++++++++
src/api/cpp/data.cpp | 23 +++++++
src/backend/cpu/select.cpp | 143 ++++++++++++++++++++++++++++++++++++++
src/backend/cpu/select.hpp | 19 +++++
src/backend/cuda/select.cu | 50 ++++++++++++++
src/backend/cuda/select.hpp | 19 +++++
src/backend/opencl/select.cpp | 51 ++++++++++++++
src/backend/opencl/select.hpp | 19 +++++
9 files changed, 530 insertions(+)
diff --git a/include/af/data.h b/include/af/data.h
index 64f95f5..ae90e24 100644
--- a/include/af/data.h
+++ b/include/af/data.h
@@ -490,6 +490,26 @@ namespace af
AFAPI array upper(const array &in, bool is_unit_diag=false);
/**
+ \param[in] cond is the conditional array
+ \param[in] a is the array containing elements from the true part of the condition
+ \param[in] b is the array containing elements from the false part of the condition
+ \return the output containing elements of \p a when \p cond is true else elements from \p b
+
+ \ingroup data_func_select
+ */
+ AFAPI array select(const array &cond, const array &a, const array &b);
+
+ /**
+ \param[in] cond is the conditional array
+ \param[in] a is the array containing elements from the true part of the condition
+ \param[in] b is a scalar assigned to \p out when \p cond is false
+ \return the output containing elements of \p a when \p cond is true else the value \p b
+
+ \ingroup data_func_select
+ */
+ AFAPI array select(const array &cond, const array &a, const double &b);
+
+ /**
@}
*/
}
@@ -759,6 +779,36 @@ extern "C" {
/**
@}
*/
+
+ /**
+ \param[out] out is the output containing elements of \p a when \p cond is true else elements from \p b
+ \param[in] cond is the conditional array
+ \param[in] a is the array containing elements from the true part of the condition
+ \param[in] b is the array containing elements from the false part of the condition
+
+ \ingroup data_func_select
+ */
+ AFAPI af_err af_select(af_array *out, const af_array cond, const af_array a, const af_array b);
+
+ /**
+ \param[out] out is the output containing elements of \p a when \p cond is true else elements from \p b
+ \param[in] cond is the conditional array
+ \param[in] a is the array containing elements from the true part of the condition
+ \param[in] b is a scalar assigned to \p out when \p cond is false
+
+ \ingroup data_func_select
+ */
+ AFAPI af_err af_select_scalar_r(af_array *out, const af_array cond, const af_array a, const double b);
+
+ /**
+ \param[out] out is the output containing elements of \p a when \p cond is true else elements from \p b
+ \param[in] cond is the conditional array
+ \param[in] a is a scalar assigned to \p out when \p cond is true
+ \param[in] b is the array containing elements from the false part of the condition
+
+ \ingroup data_func_select
+ */
+ AFAPI af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b);
#ifdef __cplusplus
}
#endif
diff --git a/src/api/c/select.cpp b/src/api/c/select.cpp
new file mode 100644
index 0000000..06eef2a
--- /dev/null
+++ b/src/api/c/select.cpp
@@ -0,0 +1,156 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <af/array.h>
+#include <af/defines.h>
+#include <af/arith.h>
+#include <af/data.h>
+#include <ArrayInfo.hpp>
+#include <optypes.hpp>
+#include <implicit.hpp>
+#include <err_common.hpp>
+#include <handle.hpp>
+#include <backend.hpp>
+#include <select.hpp>
+
+using namespace detail;
+using af::dim4;
+
+template<typename T>
+af_array select(const af_array cond, const af_array a, const af_array b, const dim4 &odims)
+{
+ Array<T> out = createEmptyArray<T>(odims);
+ select(out, getArray<char>(cond), getArray<T>(a), getArray<T>(b));
+ return getHandle<T>(out);
+}
+
+af_err af_select(af_array *out, const af_array cond, const af_array a, const af_array b)
+{
+ try {
+ ArrayInfo ainfo = getInfo(a);
+ ArrayInfo binfo = getInfo(b);
+ ArrayInfo cinfo = getInfo(cond);
+
+ ARG_ASSERT(2, ainfo.getType() == binfo.getType());
+ ARG_ASSERT(1, cinfo.getType() == b8);
+
+ DIM_ASSERT(1, cinfo.ndims() == std::min(ainfo.ndims(), binfo.ndims()));
+
+ dim4 adims = ainfo.dims();
+ dim4 bdims = binfo.dims();
+ dim4 cdims = cinfo.dims();
+ dim4 odims(1, 1, 1, 1);
+
+ for (int i = 0; i < 4; i++) {
+ DIM_ASSERT(1, cdims[i] == std::min(adims[i], bdims[i]));
+ DIM_ASSERT(2, adims[i] == bdims[i] || adims[i] == 1 || bdims[i] == 1);
+ odims[i] = std::max(adims[i], bdims[i]);
+ }
+
+ af_array res;
+
+ switch (ainfo.getType()) {
+ case f32: res = select<float >(cond, a, b, odims); break;
+ case f64: res = select<double >(cond, a, b, odims); break;
+ case c32: res = select<cfloat >(cond, a, b, odims); break;
+ case c64: res = select<cdouble>(cond, a, b, odims); break;
+ case s32: res = select<int >(cond, a, b, odims); break;
+ case u32: res = select<uint >(cond, a, b, odims); break;
+ case s64: res = select<intl >(cond, a, b, odims); break;
+ case u64: res = select<uintl >(cond, a, b, odims); break;
+ case u8: res = select<uchar >(cond, a, b, odims); break;
+ case b8: res = select<char >(cond, a, b, odims); break;
+ default: TYPE_ERROR(2, ainfo.getType());
+ }
+
+ std::swap(*out, res);
+ } CATCHALL;
+ return AF_SUCCESS;
+}
+
+template<typename T, bool flip>
+af_array select_scalar(const af_array cond, const af_array a, const double b, const dim4 &odims)
+{
+ Array<T> out = createEmptyArray<T>(odims);
+ select_scalar<T, flip>(out, getArray<char>(cond), getArray<T>(a), b);
+ return getHandle<T>(out);
+}
+
+af_err af_select_scalar_r(af_array *out, const af_array cond, const af_array a, const double b)
+{
+ try {
+ ArrayInfo ainfo = getInfo(a);
+ ArrayInfo cinfo = getInfo(cond);
+
+ ARG_ASSERT(1, cinfo.getType() == b8);
+ DIM_ASSERT(1, cinfo.ndims() == ainfo.ndims());
+
+ dim4 adims = ainfo.dims();
+ dim4 cdims = cinfo.dims();
+
+ for (int i = 0; i < 4; i++) {
+ DIM_ASSERT(1, cdims[i] == adims[i]);
+ }
+
+ af_array res;
+
+ switch (ainfo.getType()) {
+ case f32: res = select_scalar<float , false>(cond, a, b, adims); break;
+ case f64: res = select_scalar<double , false>(cond, a, b, adims); break;
+ case c32: res = select_scalar<cfloat , false>(cond, a, b, adims); break;
+ case c64: res = select_scalar<cdouble, false>(cond, a, b, adims); break;
+ case s32: res = select_scalar<int , false>(cond, a, b, adims); break;
+ case u32: res = select_scalar<uint , false>(cond, a, b, adims); break;
+ case s64: res = select_scalar<intl , false>(cond, a, b, adims); break;
+ case u64: res = select_scalar<uintl , false>(cond, a, b, adims); break;
+ case u8: res = select_scalar<uchar , false>(cond, a, b, adims); break;
+ case b8: res = select_scalar<char , false>(cond, a, b, adims); break;
+ default: TYPE_ERROR(2, ainfo.getType());
+ }
+
+ std::swap(*out, res);
+ } CATCHALL;
+ return AF_SUCCESS;
+}
+
+af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b)
+{
+ try {
+ ArrayInfo binfo = getInfo(b);
+ ArrayInfo cinfo = getInfo(cond);
+
+ ARG_ASSERT(1, cinfo.getType() == b8);
+ DIM_ASSERT(1, cinfo.ndims() == binfo.ndims());
+
+ dim4 bdims = binfo.dims();
+ dim4 cdims = cinfo.dims();
+
+ for (int i = 0; i < 4; i++) {
+ DIM_ASSERT(1, cdims[i] == bdims[i]);
+ }
+
+ af_array res;
+
+ switch (binfo.getType()) {
+ case f32: res = select_scalar<float , true >(cond, b, a, bdims); break;
+ case f64: res = select_scalar<double , true >(cond, b, a, bdims); break;
+ case c32: res = select_scalar<cfloat , true >(cond, b, a, bdims); break;
+ case c64: res = select_scalar<cdouble, true >(cond, b, a, bdims); break;
+ case s32: res = select_scalar<int , true >(cond, b, a, bdims); break;
+ case u32: res = select_scalar<uint , true >(cond, b, a, bdims); break;
+ case s64: res = select_scalar<intl , true >(cond, b, a, bdims); break;
+ case u64: res = select_scalar<uintl , true >(cond, b, a, bdims); break;
+ case u8: res = select_scalar<uchar , true >(cond, b, a, bdims); break;
+ case b8: res = select_scalar<char , true >(cond, b, a, bdims); break;
+ default: TYPE_ERROR(2, binfo.getType());
+ }
+
+ std::swap(*out, res);
+ } CATCHALL;
+ return AF_SUCCESS;
+}
diff --git a/src/api/cpp/data.cpp b/src/api/cpp/data.cpp
index bd9ac18..87bc20f 100644
--- a/src/api/cpp/data.cpp
+++ b/src/api/cpp/data.cpp
@@ -12,6 +12,7 @@
#include <af/arith.h>
#include <af/data.h>
#include <af/traits.hpp>
+#include <af/gfor.h>
#include "error.hpp"
namespace af
@@ -352,4 +353,26 @@ namespace af
AF_THROW(af_upper(&res, in.get(), is_unit_diag));
return array(res);
}
+
+ array select(const array &cond, const array &a, const array &b)
+ {
+ af_array res;
+ AF_THROW(af_select(&res, cond.get(), a.get(), b.get()));
+ return array(res);
+ }
+
+ array select(const array &cond, const array &a, const double &b)
+ {
+ af_array res;
+ AF_THROW(af_select_scalar_r(&res, cond.get(), a.get(), b));
+ return array(res);
+ }
+
+ array select(const array &cond, const double &a, const array &b)
+ {
+ af_array res;
+ AF_THROW(af_select_scalar_l(&res, cond.get(), a, b.get()));
+ return array(res);
+ }
+
}
diff --git a/src/backend/cpu/select.cpp b/src/backend/cpu/select.cpp
new file mode 100644
index 0000000..286e884
--- /dev/null
+++ b/src/backend/cpu/select.cpp
@@ -0,0 +1,143 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <ArrayInfo.hpp>
+#include <Array.hpp>
+#include <select.hpp>
+#include <err_cpu.hpp>
+
+using af::dim4;
+
+namespace cpu
+{
+ template<typename T>
+ void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b)
+ {
+ dim4 adims = a.dims();
+ dim4 astrides = a.strides();
+ dim4 bdims = b.dims();
+ dim4 bstrides = b.strides();
+
+ dim4 cdims = cond.dims();
+ dim4 cstrides = cond.strides();
+
+ dim4 odims = out.dims();
+ dim4 ostrides = out.strides();
+
+ bool is_a_same[] = {adims[0] == odims[0], adims[1] == odims[1],
+ adims[2] == odims[2], adims[3] == odims[3]};
+
+ bool is_b_same[] = {bdims[0] == odims[0], bdims[1] == odims[1],
+ bdims[2] == odims[2], bdims[3] == odims[3]};
+
+ bool is_c_same[] = {cdims[0] == odims[0], cdims[1] == odims[1],
+ cdims[2] == odims[2], cdims[3] == odims[3]};
+
+ const T *aptr = a.get();
+ const T *bptr = b.get();
+ T *optr = out.get();
+ const char *cptr = cond.get();
+
+ for (int l = 0; l < odims[3]; l++) {
+
+ int o_off3 = ostrides[3] * l;
+ int a_off3 = astrides[3] * is_a_same[3] * l;
+ int b_off3 = bstrides[3] * is_b_same[3] * l;
+ int c_off3 = cstrides[3] * is_c_same[3] * l;
+
+ for (int k = 0; k < odims[2]; k++) {
+
+ int o_off2 = ostrides[2] * k + o_off3;
+ int a_off2 = astrides[2] * is_a_same[2] * k + a_off3;
+ int b_off2 = bstrides[2] * is_b_same[2] * k + b_off3;
+ int c_off2 = cstrides[2] * is_c_same[2] * k + c_off3;
+
+ for (int j = 0; j < odims[1]; j++) {
+
+ int o_off1 = ostrides[1] * j + o_off2;
+ int a_off1 = astrides[1] * is_a_same[1] * j + a_off2;
+ int b_off1 = bstrides[1] * is_b_same[1] * j + b_off2;
+ int c_off1 = cstrides[1] * is_c_same[1] * j + c_off2;
+
+ for (int i = 0; i < odims[0]; i++) {
+
+ bool cval = is_c_same[0] ? cptr[c_off1 + i] : cptr[c_off1];
+ T aval = is_a_same[0] ? aptr[a_off1 + i] : aptr[a_off1];
+ T bval = is_b_same[0] ? bptr[b_off1 + i] : bptr[b_off1];
+ T oval = cval ? aval : bval;
+ optr[o_off1 + i] = oval;
+ }
+ }
+ }
+ }
+ }
+
+ template<typename T, bool flip>
+ void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b)
+ {
+ dim4 astrides = a.strides();
+ dim4 cstrides = cond.strides();
+
+ dim4 odims = out.dims();
+ dim4 ostrides = out.strides();
+
+ const T *aptr = a.get();
+ T *optr = out.get();
+ const char *cptr = cond.get();
+
+ for (int l = 0; l < odims[3]; l++) {
+
+ int o_off3 = ostrides[3] * l;
+ int a_off3 = astrides[3] * l;
+ int c_off3 = cstrides[3] * l;
+
+ for (int k = 0; k < odims[2]; k++) {
+
+ int o_off2 = ostrides[2] * k + o_off3;
+ int a_off2 = astrides[2] * k + a_off3;
+ int c_off2 = cstrides[2] * k + c_off3;
+
+ for (int j = 0; j < odims[1]; j++) {
+
+ int o_off1 = ostrides[1] * j + o_off2;
+ int a_off1 = astrides[1] * j + a_off2;
+ int c_off1 = cstrides[1] * j + c_off2;
+
+ for (int i = 0; i < odims[0]; i++) {
+
+ optr[o_off1 + i] = (flip ^ cptr[c_off1 + i]) ? aptr[a_off1 + i] : b;
+ }
+ }
+ }
+ }
+ }
+
+
+#define INSTANTIATE(T) \
+ template void select<T>(Array<T> &out, const Array<char> &cond, \
+ const Array<T> &a, const Array<T> &b); \
+ template void select_scalar<T, true >(Array<T> &out, \
+ const Array<char> &cond, \
+ const Array<T> &a, \
+ const double &b); \
+ template void select_scalar<T, false>(Array<T> &out, const \
+ Array<char> &cond, \
+ const Array<T> &a, \
+ const double &b); \
+
+ INSTANTIATE(float )
+ INSTANTIATE(double )
+ INSTANTIATE(cfloat )
+ INSTANTIATE(cdouble)
+ INSTANTIATE(int )
+ INSTANTIATE(uint )
+ INSTANTIATE(intl )
+ INSTANTIATE(uintl )
+ INSTANTIATE(char )
+ INSTANTIATE(uchar )
+}
diff --git a/src/backend/cpu/select.hpp b/src/backend/cpu/select.hpp
new file mode 100644
index 0000000..0d725ac
--- /dev/null
+++ b/src/backend/cpu/select.hpp
@@ -0,0 +1,19 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#pragma once
+#include <Array.hpp>
+
+namespace cpu
+{
+ template<typename T>
+ void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b);
+
+ template<typename T, bool flip>
+ void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b);
+}
diff --git a/src/backend/cuda/select.cu b/src/backend/cuda/select.cu
new file mode 100644
index 0000000..5204057
--- /dev/null
+++ b/src/backend/cuda/select.cu
@@ -0,0 +1,50 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <ArrayInfo.hpp>
+#include <Array.hpp>
+#include <select.hpp>
+#include <err_cuda.hpp>
+
+namespace cuda
+{
+ template<typename T>
+ void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b)
+ {
+ CUDA_NOT_SUPPORTED();
+ }
+
+ template<typename T, bool flip>
+ void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b)
+ {
+ CUDA_NOT_SUPPORTED();
+ }
+
+#define INSTANTIATE(T) \
+ template void select<T>(Array<T> &out, const Array<char> &cond, \
+ const Array<T> &a, const Array<T> &b); \
+ template void select_scalar<T, true >(Array<T> &out, \
+ const Array<char> &cond, \
+ const Array<T> &a, \
+ const double &b); \
+ template void select_scalar<T, false>(Array<T> &out, const \
+ Array<char> &cond, \
+ const Array<T> &a, \
+ const double &b); \
+
+ INSTANTIATE(float )
+ INSTANTIATE(double )
+ INSTANTIATE(cfloat )
+ INSTANTIATE(cdouble)
+ INSTANTIATE(int )
+ INSTANTIATE(uint )
+ INSTANTIATE(intl )
+ INSTANTIATE(uintl )
+ INSTANTIATE(char )
+ INSTANTIATE(uchar )
+}
diff --git a/src/backend/cuda/select.hpp b/src/backend/cuda/select.hpp
new file mode 100644
index 0000000..872fe25
--- /dev/null
+++ b/src/backend/cuda/select.hpp
@@ -0,0 +1,19 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#pragma once
+#include <Array.hpp>
+
+namespace cuda
+{
+ template<typename T>
+ void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b);
+
+ template<typename T, bool flip>
+ void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b);
+}
diff --git a/src/backend/opencl/select.cpp b/src/backend/opencl/select.cpp
new file mode 100644
index 0000000..92bcc2b
--- /dev/null
+++ b/src/backend/opencl/select.cpp
@@ -0,0 +1,51 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <ArrayInfo.hpp>
+#include <Array.hpp>
+#include <select.hpp>
+#include <err_opencl.hpp>
+
+namespace opencl
+{
+ template<typename T>
+ void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b)
+ {
+ OPENCL_NOT_SUPPORTED();
+ }
+
+ template<typename T, bool flip>
+ void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b)
+ {
+ OPENCL_NOT_SUPPORTED();
+ }
+
+
+#define INSTANTIATE(T) \
+ template void select<T>(Array<T> &out, const Array<char> &cond, \
+ const Array<T> &a, const Array<T> &b); \
+ template void select_scalar<T, true >(Array<T> &out, \
+ const Array<char> &cond, \
+ const Array<T> &a, \
+ const double &b); \
+ template void select_scalar<T, false>(Array<T> &out, const \
+ Array<char> &cond, \
+ const Array<T> &a, \
+ const double &b); \
+
+ INSTANTIATE(float )
+ INSTANTIATE(double )
+ INSTANTIATE(cfloat )
+ INSTANTIATE(cdouble)
+ INSTANTIATE(int )
+ INSTANTIATE(uint )
+ INSTANTIATE(intl )
+ INSTANTIATE(uintl )
+ INSTANTIATE(char )
+ INSTANTIATE(uchar )
+}
diff --git a/src/backend/opencl/select.hpp b/src/backend/opencl/select.hpp
new file mode 100644
index 0000000..5bc2f60
--- /dev/null
+++ b/src/backend/opencl/select.hpp
@@ -0,0 +1,19 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#pragma once
+#include <Array.hpp>
+
+namespace opencl
+{
+ template<typename T>
+ void select(Array<T> &out, const Array<char> &cond, const Array<T> &a, const Array<T> &b);
+
+ template<typename T, bool flip>
+ void select_scalar(Array<T> &out, const Array<char> &cond, const Array<T> &a, const double &b);
+}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list