[arrayfire] 108/408: FEAT: Adding support for non overlapping batched convolution
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Mon Sep 21 19:11:31 UTC 2015
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch debian/sid
in repository arrayfire.
commit f8f072c6520d683f15f5a4d0baccd27b19e205ea
Author: Pavan Yalamanchili <pavan at arrayfire.com>
Date: Mon Jul 6 14:23:54 2015 -0400
FEAT: Adding support for non overlapping batched convolution
---
src/api/c/convolve.cpp | 14 +++---
src/api/c/fftconvolve.cpp | 111 ++++++++++++++++++++++++++--------------------
2 files changed, 70 insertions(+), 55 deletions(-)
diff --git a/src/api/c/convolve.cpp b/src/api/c/convolve.cpp
index f5f8d12..2886e9f 100644
--- a/src/api/c/convolve.cpp
+++ b/src/api/c/convolve.cpp
@@ -49,13 +49,13 @@ ConvolveBatchKind identifyBatchKind(const dim4 &sDims, const dim4 &fDims)
return CONVOLVE_BATCH_SIGNAL;
else if ((sn>baseDim && sn<=4) && (fn>baseDim && fn<=4)) {
bool doesDimensionsMatch = true;
+ bool isInterleaved = true;
for (dim_t i=baseDim; i<4; i++) {
- if (sDims[i]!=fDims[i]) {
- doesDimensionsMatch = false;
- break;
- }
+ doesDimensionsMatch &= (sDims[i] == fDims[i]);
+ isInterleaved &= (sDims[i] == 1 || fDims[i] == 1);
}
- return (doesDimensionsMatch ? CONVOLVE_BATCH_SAME : CONVOLVE_BATCH_UNSUPPORTED);
+ if (doesDimensionsMatch) return CONVOLVE_BATCH_SAME;
+ return (isInterleaved ? CONVOLVE_BATCH_DIFF : CONVOLVE_BATCH_UNSUPPORTED);
}
else
return CONVOLVE_BATCH_UNSUPPORTED;
@@ -75,7 +75,7 @@ af_err convolve(af_array *out, const af_array signal, const af_array filter)
ConvolveBatchKind convBT = identifyBatchKind<baseDim>(sdims, fdims);
- ARG_ASSERT(1, (convBT != CONVOLVE_BATCH_UNSUPPORTED));
+ ARG_ASSERT(1, (convBT != CONVOLVE_BATCH_UNSUPPORTED && convBT != CONVOLVE_BATCH_DIFF));
af_array output;
switch(stype) {
@@ -144,6 +144,8 @@ bool isFreqDomain(const af_array &signal, const af_array filter, af_conv_domain
dim4 sdims = sInfo.dims();
dim4 fdims = fInfo.dims();
+ if (identifyBatchKind<baseDim>(sdims, fdims) == CONVOLVE_BATCH_DIFF) return true;
+
int batch = 1;
for(int i = 3; i >= baseDim; i--) {
batch *= std::max(fdims[i], sdims[i]);
diff --git a/src/api/c/fftconvolve.cpp b/src/api/c/fftconvolve.cpp
index d915a07..d359119 100644
--- a/src/api/c/fftconvolve.cpp
+++ b/src/api/c/fftconvolve.cpp
@@ -17,52 +17,23 @@
#include <fftconvolve.hpp>
#include <convolve_common.hpp>
#include <dispatch.hpp>
+#include <complex.hpp>
using af::dim4;
using namespace detail;
-template<typename T, typename convT, typename cT, bool isDouble, bool roundOut, dim_t baseDim>
-inline static af_array fftconvolve(const af_array &s, const af_array &f, const bool expand, ConvolveBatchKind kind)
-{
- return getHandle(fftconvolve<T, convT, cT, isDouble, roundOut, baseDim>(getArray<T>(s), castArray<T>(f), expand, kind));
-}
-
-template<dim_t baseDim>
-ConvolveBatchKind identifyBatchKind(const dim4 &sDims, const dim4 &fDims)
-{
- dim_t sn = sDims.ndims();
- dim_t fn = fDims.ndims();
-
- if (sn==baseDim && fn==baseDim)
- return CONVOLVE_BATCH_NONE;
- else if (sn==baseDim && (fn>baseDim && fn<=4))
- return CONVOLVE_BATCH_KERNEL;
- else if ((sn>baseDim && sn<=4) && fn==baseDim)
- return CONVOLVE_BATCH_SIGNAL;
- else if ((sn>baseDim && sn<=4) && (fn>baseDim && fn<=4)) {
- bool doesDimensionsMatch = true;
- for (dim_t i=baseDim; i<4; i++) {
- if (sDims[i]!=fDims[i]) {
- doesDimensionsMatch = false;
- break;
- }
- }
- return (doesDimensionsMatch ? CONVOLVE_BATCH_SAME : CONVOLVE_BATCH_UNSUPPORTED);
- }
- else
- return CONVOLVE_BATCH_UNSUPPORTED;
-}
-
-template<typename T, int baseDim>
+template<typename T, typename convT, typename cT, int baseDim>
static inline
-af_array fftconvcplx(const af_array signal, const af_array filter, bool expand,
- ConvolveBatchKind kind)
+af_array fftconvolve_fallback(const af_array signal, const af_array filter, bool expand)
{
- const Array<T> S = getArray<T>(signal);
- const Array<T> F = castArray<T>(filter);
+ const Array<cT> S = castArray<cT>(signal);
+ const Array<cT> F = castArray<cT>(filter);
const dim4 sdims = S.dims();
const dim4 fdims = F.dims();
- dim4 tdims(1, 1, 1, 1);
+ dim4 odims(1, 1, 1, 1);
+ dim4 psdims(1, 1, 1, 1);
+ dim4 pfdims(1, 1, 1, 1);
+
std::vector<af_seq> index(4);
int count = 1;
@@ -70,10 +41,12 @@ af_array fftconvcplx(const af_array signal, const af_array filter, bool expand,
dim_t tdim_i = sdims[i] + fdims[i] - 1;
// Pad temporary buffers to power of 2 for performance
- tdims[i] = nextpow2(tdim_i);
+ odims[i] = nextpow2(tdim_i);
+ psdims[i] = nextpow2(tdim_i);
+ pfdims[i] = nextpow2(tdim_i);
// The normalization factor
- count *= tdims[i];
+ count *= odims[i];
// Get the indexing params for output
if (expand) {
@@ -87,25 +60,65 @@ af_array fftconvcplx(const af_array signal, const af_array filter, bool expand,
}
for (int i = baseDim; i < 4; i++) {
- tdims[i] = std::max(sdims[i], fdims[i]);
+ odims[i] = std::max(sdims[i], fdims[i]);
+ psdims[i] = sdims[i];
+ pfdims[i] = fdims[i];
index[i] = af_span;
}
// fft(signal)
- Array<T> T1 = fft<T, T, baseDim, false>(S, 1.0, baseDim, tdims.get());
+ Array<cT> T1 = fft<cT, cT, baseDim, false>(S, 1.0, baseDim, psdims.get());
// fft(filter)
- Array<T> T2 = fft<T, T, baseDim, false>(F, 1.0, baseDim, tdims.get());
+ Array<cT> T2 = fft<cT, cT, baseDim, false>(F, 1.0, baseDim, pfdims.get());
// fft(signal) * fft(filter)
- T1 = arithOp<T, af_mul_t>(T1, T2, tdims);
+ T1 = arithOp<cT, af_mul_t>(T1, T2, odims);
// ifft(ffit(signal) * fft(filter))
- T1 = ifft<T, baseDim>(T1, 1.0/(double)count, baseDim, tdims.get());
+ T1 = ifft<cT, baseDim>(T1, 1.0/(double)count, baseDim, odims.get());
// Index to proper offsets
- T1 = createSubArray<T>(T1, index);
- return getHandle(T1);
+ T1 = createSubArray<cT>(T1, index);
+
+ if (getInfo(signal).isComplex() || getInfo(filter).isComplex()) {
+ return getHandle(cast<T>(T1));
+ } else {
+ return getHandle(cast<T>(real<convT>(T1)));
+ }
+}
+
+template<typename T, typename convT, typename cT, bool isDouble, bool roundOut, dim_t baseDim>
+inline static af_array fftconvolve(const af_array &s, const af_array &f, const bool expand, ConvolveBatchKind kind)
+{
+ if (kind == CONVOLVE_BATCH_DIFF) return fftconvolve_fallback<T, convT, cT, baseDim>(s, f, expand);
+ else return getHandle(fftconvolve<T, convT, cT, isDouble, roundOut, baseDim>(getArray<T>(s), castArray<T>(f), expand, kind));
+}
+
+template<dim_t baseDim>
+ConvolveBatchKind identifyBatchKind(const dim4 &sDims, const dim4 &fDims)
+{
+ dim_t sn = sDims.ndims();
+ dim_t fn = fDims.ndims();
+
+ if (sn==baseDim && fn==baseDim)
+ return CONVOLVE_BATCH_NONE;
+ else if (sn==baseDim && (fn>baseDim && fn<=4))
+ return CONVOLVE_BATCH_KERNEL;
+ else if ((sn>baseDim && sn<=4) && fn==baseDim)
+ return CONVOLVE_BATCH_SIGNAL;
+ else if ((sn>baseDim && sn<=4) && (fn>baseDim && fn<=4)) {
+ bool doesDimensionsMatch = true;
+ bool isInterleaved = true;
+ for (dim_t i=baseDim; i<4; i++) {
+ doesDimensionsMatch &= (sDims[i] == fDims[i]);
+ isInterleaved &= (sDims[i] == 1 || fDims[i] == 1);
+ }
+ if (doesDimensionsMatch) return CONVOLVE_BATCH_SAME;
+ return (isInterleaved ? CONVOLVE_BATCH_DIFF : CONVOLVE_BATCH_UNSUPPORTED);
+ }
+ else
+ return CONVOLVE_BATCH_UNSUPPORTED;
}
template<dim_t baseDim>
@@ -132,8 +145,8 @@ af_err fft_convolve(af_array *out, const af_array signal, const af_array filter,
case s32: output = fftconvolve<int , float, cfloat, false, true, baseDim>(signal, filter, expand, convBT); break;
case u8: output = fftconvolve<uchar , float, cfloat, false, true, baseDim>(signal, filter, expand, convBT); break;
case b8: output = fftconvolve<char , float, cfloat, false, true, baseDim>(signal, filter, expand, convBT); break;
- case c32: output = fftconvcplx<cfloat , baseDim>(signal, filter, expand, convBT); break;
- case c64: output = fftconvcplx<cdouble, baseDim>(signal, filter, expand, convBT); break;
+ case c32: output = fftconvolve_fallback<cfloat , cfloat , cfloat , baseDim>(signal, filter, expand); break;
+ case c64: output = fftconvolve_fallback<cdouble, cdouble, cdouble, baseDim>(signal, filter, expand); break;
default: TYPE_ERROR(1, stype);
}
std::swap(*out,output);
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/arrayfire.git
More information about the debian-science-commits
mailing list