[arrayfire] 88/248: statistics functions fixes
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Tue Nov 17 15:54:07 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch dfsg-clean
in repository arrayfire.
commit c19058f9fbc4a6c38dd9640951314ccf6b58d0f4
Author: pradeep <pradeep at arrayfire.com>
Date: Tue Sep 29 18:51:10 2015 -0400
statistics functions fixes
Earlier, some of the stats functions where casting the input
array to float/double always. After this commit, this is done
only in few cases where there is no alternative.
---
src/api/c/covariance.cpp | 10 ++++----
src/api/c/mean.cpp | 62 ++++++++++++++++++++++++------------------------
src/api/c/stats.h | 45 ++++++++++++++++++++++-------------
src/api/c/stdev.cpp | 20 +++++++---------
src/api/c/var.cpp | 10 ++++----
5 files changed, 81 insertions(+), 66 deletions(-)
diff --git a/src/api/c/covariance.cpp b/src/api/c/covariance.cpp
index 1050b72..f8bb9c4 100644
--- a/src/api/c/covariance.cpp
+++ b/src/api/c/covariance.cpp
@@ -27,14 +27,16 @@ using namespace detail;
template<typename T, typename cType>
static af_array cov(const af_array& X, const af_array& Y, const bool isbiased)
{
- Array<cType> xArr = cast<cType>(getArray<T>(X));
- Array<cType> yArr = cast<cType>(getArray<T>(Y));
+ Array<T> _x = getArray<T>(X);
+ Array<T> _y = getArray<T>(Y);
+ Array<cType> xArr = cast<cType>(_x);
+ Array<cType> yArr = cast<cType>(_y);
dim4 xDims = xArr.dims();
dim_t N = isbiased ? xDims[0] : xDims[0]-1;
- Array<cType> xmArr = createValueArray<cType>(xDims, mean<cType>(xArr));
- Array<cType> ymArr = createValueArray<cType>(xDims, mean<cType>(yArr));
+ Array<cType> xmArr = createValueArray<cType>(xDims, mean<T, cType>(_x));
+ Array<cType> ymArr = createValueArray<cType>(xDims, mean<T, cType>(_y));
Array<cType> nArr = createValueArray<cType>(xDims, scalar<cType>(N));
Array<cType> diffX = detail::arithOp<cType, af_sub_t>(xArr, xmArr, xDims);
diff --git a/src/api/c/mean.cpp b/src/api/c/mean.cpp
index 09844ef..1cbee32 100644
--- a/src/api/c/mean.cpp
+++ b/src/api/c/mean.cpp
@@ -22,26 +22,26 @@
using namespace detail;
-template<typename T>
-static T mean(const af_array &in)
+template<typename Ti, typename To>
+static To mean(const af_array &in)
{
/* following function is defined in stats.h */
- return mean<T>(castArray<T>(in)); /* defined in stats.h */
+ return mean<Ti, To>(getArray<Ti>(in)); /* defined in stats.h */
}
template<typename T>
static T mean(const af_array &in, const af_array &weights)
{
- typedef typename baseOutType<T>::type bType;
+ typedef typename baseOutType<T>::type Tw;
/* following function is defined in stats.h */
- return mean<T, bType>(castArray<T>(in), castArray<bType>(weights));
+ return mean<T, Tw>(castArray<T>(in), castArray<Tw>(weights));
}
-template<typename T>
+template<typename Ti, typename To>
static af_array mean(const af_array &in, const dim_t dim)
{
/* following function is defined in stats.h */
- return getHandle<T>(mean<T>(castArray<T>(in), dim));
+ return getHandle<To>(mean<Ti, To>(getArray<Ti>(in), dim));
}
template<typename T>
@@ -60,18 +60,18 @@ af_err af_mean(af_array *out, const af_array in, const dim_t dim)
ArrayInfo info = getInfo(in);
af_dtype type = info.getType();
switch(type) {
- case f64: output = mean< double>(in, dim); break;
- case f32: output = mean< float >(in, dim); break;
- case s32: output = mean< float >(in, dim); break;
- case u32: output = mean< float >(in, dim); break;
- case s64: output = mean< double>(in, dim); break;
- case u64: output = mean< double>(in, dim); break;
- case s16: output = mean< float >(in, dim); break;
- case u16: output = mean< float >(in, dim); break;
- case u8: output = mean< float >(in, dim); break;
- case b8: output = mean< float >(in, dim); break;
- case c32: output = mean< cfloat>(in, dim); break;
- case c64: output = mean<cdouble>(in, dim); break;
+ case f64: output = mean<double , double>(in, dim); break;
+ case f32: output = mean<float , float >(in, dim); break;
+ case s32: output = mean<int , float >(in, dim); break;
+ case u32: output = mean<unsigned, float >(in, dim); break;
+ case s64: output = mean<intl , double>(in, dim); break;
+ case u64: output = mean<uintl , double>(in, dim); break;
+ case s16: output = mean<short , float >(in, dim); break;
+ case u16: output = mean<ushort , float >(in, dim); break;
+ case u8: output = mean<uchar , float >(in, dim); break;
+ case b8: output = mean<char , float >(in, dim); break;
+ case c32: output = mean<cfloat , cfloat>(in, dim); break;
+ case c64: output = mean<cdouble , cdouble>(in, dim); break;
default : TYPE_ERROR(1, type);
}
std::swap(*out, output);
@@ -120,23 +120,23 @@ af_err af_mean_all(double *realVal, double *imagVal, const af_array in)
ArrayInfo info = getInfo(in);
af_dtype type = info.getType();
switch(type) {
- case f64: *realVal = mean<double>(in); break;
- case f32: *realVal = mean< float>(in); break;
- case s32: *realVal = mean< float>(in); break;
- case u32: *realVal = mean< float>(in); break;
- case s64: *realVal = mean<double>(in); break;
- case u64: *realVal = mean<double>(in); break;
- case s16: *realVal = mean< float>(in); break;
- case u16: *realVal = mean< float>(in); break;
- case u8: *realVal = mean< float>(in); break;
- case b8: *realVal = mean< float>(in); break;
+ case f64: *realVal = mean<double , double>(in); break;
+ case f32: *realVal = mean<float , float >(in); break;
+ case s32: *realVal = mean<int , float >(in); break;
+ case u32: *realVal = mean<unsigned, float >(in); break;
+ case s64: *realVal = mean<intl , double>(in); break;
+ case u64: *realVal = mean<uintl , double>(in); break;
+ case s16: *realVal = mean<short , float >(in); break;
+ case u16: *realVal = mean<ushort , float >(in); break;
+ case u8: *realVal = mean<uchar , float >(in); break;
+ case b8: *realVal = mean<char , float >(in); break;
case c32: {
- cfloat tmp = mean<cfloat>(in);
+ cfloat tmp = mean<cfloat, cfloat>(in);
*realVal = real(tmp);
*imagVal = imag(tmp);
} break;
case c64: {
- cdouble tmp = mean<cdouble>(in);
+ cdouble tmp = mean<cdouble, cdouble>(in);
*realVal = real(tmp);
*imagVal = imag(tmp);
} break;
diff --git a/src/api/c/stats.h b/src/api/c/stats.h
index 0e74942..6dca227 100644
--- a/src/api/c/stats.h
+++ b/src/api/c/stats.h
@@ -40,39 +40,52 @@ struct baseOutType {
float>::type type;
};
-template<typename T>
-inline T mean(const Array<T>& in)
+template<typename Ti, typename To>
+inline To mean(const Array<Ti>& in)
{
- T out = reduce_all<af_add_t, T, T>(in);
- T result = division(out, in.elements());
+ To out = reduce_all<af_add_t, Ti, To>(in);
+ To result = division(out, in.elements());
return result;
}
-template<typename T, typename wType>
-inline T mean(const Array<T>& in, const Array<wType>& weights)
+template<typename T, typename Tw>
+static T mean(const Array<T>& input, const Array<Tw>& weights)
{
- Array<T> wts = cast<T>(weights);
-
- dim4 iDims = in.dims();
+ dim4 iDims = input.dims();
- Array<T> wtdInput = arithOp<T, af_mul_t>(in, wts, iDims);
+ Array<T> wtdInput = arithOp<T, af_mul_t>(input, weights, iDims);
T wtdSum = reduce_all<af_add_t, T, T>(wtdInput);
- wType wtsSum = reduce_all<af_add_t, wType, wType>(weights);
+ T wtsSum = reduce_all<af_add_t, T, T>(weights);
return division(wtdSum, wtsSum);
}
-template<typename T>
-inline Array<T> mean(const Array<T>& in, dim_t dim)
+#define COMPLEX_TYPE_SPECILIZATION(T, Tw) \
+template<>\
+T mean<T, Tw>(const Array<T>& input, const Array<Tw>& weights)\
+{\
+ Array<T> wts = cast<T, Tw>(weights);\
+ dim4 iDims = input.dims();\
+ Array<T> wtdInput = arithOp<T, af_mul_t>(input, wts, iDims);\
+ T wtdSum = reduce_all<af_add_t, T, T>(wtdInput);\
+ Tw wtsSum = reduce_all<af_add_t, Tw, Tw>(weights);\
+ return division(wtdSum, wtsSum);\
+}
+
+COMPLEX_TYPE_SPECILIZATION(cfloat, float)
+COMPLEX_TYPE_SPECILIZATION(cdouble, double)
+
+template<typename Ti, typename To>
+inline Array<To> mean(const Array<Ti>& in, dim_t dim)
{
- Array<T> redArr = reduce<af_add_t, T, T>(in, dim);
+ Array<To> redArr = reduce<af_add_t, Ti, To>(in, dim);
dim4 iDims = in.dims();
dim4 oDims = redArr.dims();
- Array<T> cnstArr = createValueArray<T>(oDims, scalar<T>(iDims[dim]));
- Array<T> result = arithOp<T, af_div_t>(redArr, cnstArr, oDims);
+ Array<To> cnstArr = createValueArray<To>(oDims, scalar<To>(iDims[dim]));
+ Array<To> result = arithOp<To, af_div_t>(redArr, cnstArr, oDims);
return result;
}
diff --git a/src/api/c/stdev.cpp b/src/api/c/stdev.cpp
index cf871bd..59c9653 100644
--- a/src/api/c/stdev.cpp
+++ b/src/api/c/stdev.cpp
@@ -28,15 +28,12 @@ using namespace detail;
template<typename inType, typename outType>
static outType stdev(const af_array& in)
{
- Array<outType> input = cast<outType>(getArray<inType>(in));
-
- Array<outType> meanCnst= createValueArray<outType>(input.dims(), mean<outType>(input));
-
- Array<outType> diff = detail::arithOp<outType, af_sub_t>(input, meanCnst, input.dims());
-
- Array<outType> diffSq = detail::arithOp<outType, af_mul_t>(diff, diff, diff.dims());
-
- outType result = division(reduce_all<af_add_t, outType, outType>(diffSq), input.elements());
+ Array<inType> _in = getArray<inType>(in);
+ Array<outType> input = cast<outType>(_in);
+ Array<outType> meanCnst = createValueArray<outType>(input.dims(), mean<inType, outType>(_in));
+ Array<outType> diff = detail::arithOp<outType, af_sub_t>(input, meanCnst, input.dims());
+ Array<outType> diffSq = detail::arithOp<outType, af_mul_t>(diff, diff, diff.dims());
+ outType result = division(reduce_all<af_add_t, outType, outType>(diffSq), input.elements());
return sqrt(result);
}
@@ -44,10 +41,11 @@ static outType stdev(const af_array& in)
template<typename inType, typename outType>
static af_array stdev(const af_array& in, int dim)
{
- Array<outType> input = cast<outType>(getArray<inType>(in));
+ Array<inType> _in = getArray<inType>(in);
+ Array<outType> input = cast<outType>(_in);
dim4 iDims = input.dims();
- Array<outType> meanArr = mean<outType>(input, dim);
+ Array<outType> meanArr = mean<inType, outType>(_in, dim);
/* now tile meanArr along dim and use it for variance computation */
dim4 tileDims(1);
diff --git a/src/api/c/var.cpp b/src/api/c/var.cpp
index a6bf434..59a651b 100644
--- a/src/api/c/var.cpp
+++ b/src/api/c/var.cpp
@@ -26,9 +26,10 @@ using namespace detail;
template<typename inType, typename outType>
static outType varAll(const af_array& in, const bool isbiased)
{
- Array<outType> input = cast<outType>(getArray<inType>(in));
+ Array<inType> inArr = getArray<inType>(in);
+ Array<outType> input = cast<outType>(inArr);
- Array<outType> meanCnst= createValueArray<outType>(input.dims(), mean<outType>(input));
+ Array<outType> meanCnst= createValueArray<outType>(input.dims(), mean<inType, outType>(inArr));
Array<outType> diff = arithOp<outType, af_sub_t>(input, meanCnst, input.dims());
@@ -65,10 +66,11 @@ static outType varAll(const af_array& in, const af_array weights)
template<typename inType, typename outType>
static af_array var(const af_array& in, const bool isbiased, int dim)
{
- Array<outType> input = cast<outType>(getArray<inType>(in));
+ Array<inType> _in = getArray<inType>(in);
+ Array<outType> input = cast<outType>(_in);
dim4 iDims = input.dims();
- Array<outType> meanArr = mean<outType>(input, dim);
+ Array<outType> meanArr = mean<inType, outType>(_in, dim);
/* now tile meanArr along dim and use it for variance computation */
dim4 tileDims(1);
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list