[arrayfire] 215/408: FEAT: replace for all backends
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:12:01 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit 981c5e663f5100438c6ff53f676e05aba8aabaf7
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Sat Aug 8 22:38:49 2015 -0400
FEAT: replace for all backends
---
include/af/data.h | 48 ++++++++++++++++++++++
src/api/c/replace.cpp | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++
src/api/cpp/data.cpp | 9 +++++
3 files changed, 166 insertions(+)
diff --git a/include/af/data.h b/include/af/data.h
index ae90e24..acbc0f6 100644
--- a/include/af/data.h
+++ b/include/af/data.h
@@ -510,6 +510,35 @@ namespace af
AFAPI array select(const array &cond, const array &a, const double &b);
/**
+ \param[in] cond is the conditional array
+ \param[in] a is a scalar assigned to \p out when \p cond is true
+ \param[in] b is the array containing elements from the false part of the condition
+ \return the output containing the value \p a when \p cond is true else elements from \p b
+
+ \ingroup data_func_select
+ */
+ AFAPI array select(const array &cond, const double &a, const array &b);
+
+ /**
+ \param[inout] a is the array whose values are replaced with values from \p b when \p cond is true
+ \param[in] cond is the conditional array
+ \param[in] b is the array containing elements which replace elements in \p a when \p cond is true
+
+ \ingroup data_func_replace
+ */
+ AFAPI void replace(array &a, const array &cond, const array &b);
+
+ /**
+ \param[inout] a is the array whose values are replaced with values from \p b when \p cond is true
+ \param[in] cond is the conditional array
+ \param[in] b is value that replaces elements in \p a when \p cond is true
+
+ \ingroup data_func_replace
+ */
+ AFAPI void replace(array &a, const array &cond, const double &b);
+
+
+ /**
@}
*/
}
@@ -809,6 +838,25 @@ extern "C" {
\ingroup data_func_select
*/
AFAPI af_err af_select_scalar_l(af_array *out, const af_array cond, const double a, const af_array b);
+
+
+ /**
+ \param[inout] a is the array whose values are replaced by \p b when \p cond is true
+ \param[in] cond is the conditional array
+ \param[in] b is the array containing elements that replaces elements of a where \p cond is false
+
+ \ingroup data_func_replace
+ */
+ AFAPI af_err af_replace(af_array a, const af_array cond, const af_array b);
+
+ /**
+ \param[inout] a is the array whose values are replaced by \p b when \p cond is true
+ \param[in] cond is the conditional array
+ \param[in] b is the scalar that replaces the false parts of \p a
+
+ \ingroup data_func_replace
+ */
+ AFAPI af_err af_replace_scalar(af_array a, const af_array cond, const double b);
#ifdef __cplusplus
}
#endif
diff --git a/src/api/c/replace.cpp b/src/api/c/replace.cpp
new file mode 100644
index 0000000..1f37988
--- /dev/null
+++ b/src/api/c/replace.cpp
@@ -0,0 +1,109 @@
+/*******************************************************
+ * Copyright (c) 2015, ArrayFire
+ * All rights reserved.
+ *
+ * This file is distributed under 3-clause BSD license.
+ * The complete license agreement can be obtained at:
+ * http://arrayfire.com/licenses/BSD-3-Clause
+ ********************************************************/
+#include <af/array.h>
+#include <af/defines.h>
+#include <af/arith.h>
+#include <af/data.h>
+#include <ArrayInfo.hpp>
+#include <optypes.hpp>
+#include <implicit.hpp>
+#include <err_common.hpp>
+#include <handle.hpp>
+#include <backend.hpp>
+
+#include <select.hpp>
+
+using namespace detail;
+using af::dim4;
+
+template<typename T>
+void replace(af_array a, const af_array cond, const af_array b)
+{
+ select(getWritableArray<T>(a), getArray<char>(cond), getArray<T>(a), getArray<T>(b));
+}
+
+af_err af_replace(af_array a, const af_array cond, const af_array b)
+{
+ try {
+ ArrayInfo ainfo = getInfo(a);
+ ArrayInfo binfo = getInfo(b);
+ ArrayInfo cinfo = getInfo(cond);
+
+ ARG_ASSERT(2, ainfo.getType() == binfo.getType());
+ ARG_ASSERT(1, cinfo.getType() == b8);
+
+ DIM_ASSERT(1, ainfo.ndims() >= binfo.ndims());
+ DIM_ASSERT(1, cinfo.ndims() == std::min(ainfo.ndims(), binfo.ndims()));
+
+ dim4 adims = ainfo.dims();
+ dim4 bdims = binfo.dims();
+ dim4 cdims = cinfo.dims();
+
+ for (int i = 0; i < 4; i++) {
+ DIM_ASSERT(1, cdims[i] == std::min(adims[i], bdims[i]));
+ DIM_ASSERT(2, adims[i] == bdims[i] || bdims[i] == 1);
+ }
+
+ switch (ainfo.getType()) {
+ case f32: replace<float >(a, cond, b); break;
+ case f64: replace<double >(a, cond, b); break;
+ case c32: replace<cfloat >(a, cond, b); break;
+ case c64: replace<cdouble>(a, cond, b); break;
+ case s32: replace<int >(a, cond, b); break;
+ case u32: replace<uint >(a, cond, b); break;
+ case s64: replace<intl >(a, cond, b); break;
+ case u64: replace<uintl >(a, cond, b); break;
+ case u8: replace<uchar >(a, cond, b); break;
+ case b8: replace<char >(a, cond, b); break;
+ default: TYPE_ERROR(2, ainfo.getType());
+ }
+
+ } CATCHALL;
+ return AF_SUCCESS;
+}
+
+template<typename T>
+void replace_scalar(af_array a, const af_array cond, const double b)
+{
+ select_scalar<T, true>(getWritableArray<T>(a), getArray<char>(cond), getArray<T>(a), b);
+}
+
+af_err af_replace_scalar(af_array a, const af_array cond, const double b)
+{
+ try {
+ ArrayInfo ainfo = getInfo(a);
+ ArrayInfo cinfo = getInfo(cond);
+
+ ARG_ASSERT(1, cinfo.getType() == b8);
+ DIM_ASSERT(1, cinfo.ndims() == ainfo.ndims());
+
+ dim4 adims = ainfo.dims();
+ dim4 cdims = cinfo.dims();
+
+ for (int i = 0; i < 4; i++) {
+ DIM_ASSERT(1, cdims[i] == adims[i]);
+ }
+
+ switch (ainfo.getType()) {
+ case f32: replace_scalar<float >(a, cond, b); break;
+ case f64: replace_scalar<double >(a, cond, b); break;
+ case c32: replace_scalar<cfloat >(a, cond, b); break;
+ case c64: replace_scalar<cdouble>(a, cond, b); break;
+ case s32: replace_scalar<int >(a, cond, b); break;
+ case u32: replace_scalar<uint >(a, cond, b); break;
+ case s64: replace_scalar<intl >(a, cond, b); break;
+ case u64: replace_scalar<uintl >(a, cond, b); break;
+ case u8: replace_scalar<uchar >(a, cond, b); break;
+ case b8: replace_scalar<char >(a, cond, b); break;
+ default: TYPE_ERROR(2, ainfo.getType());
+ }
+
+ } CATCHALL;
+ return AF_SUCCESS;
+}
diff --git a/src/api/cpp/data.cpp b/src/api/cpp/data.cpp
index 87bc20f..196fbf8 100644
--- a/src/api/cpp/data.cpp
+++ b/src/api/cpp/data.cpp
@@ -375,4 +375,13 @@ namespace af
return array(res);
}
+ void replace(array &a, const array &cond, const array &b)
+ {
+ AF_THROW(af_replace(a.get(), cond.get(), b.get()));
+ }
+
+ void replace(array &a, const array &cond, const double &b)
+ {
+ AF_THROW(af_replace_scalar(a.get(), cond.get(), b));
+ }
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list