[PATCH] Support reductions over where()
Richard Guenther
rguenth at tat.physik.uni-tuebingen.de
Thu Nov 20 21:22:38 UTC 2003
Hi!
This patch adds support for reductions over two- and three-arg where()
functions like in
double star_mass = sum(where(norm(positions(rh).read(I)) <= star_radius,
(rh * Pooma::cellVolumes(rh))(I)));
which integrates mass over a sphere with radius star_radius inside the
computational domain.
More interesting cases like
int cnt = sum(where(rh.read(I) != 0.0, 1));
still need to be fixed - but they are less interesting for me at the
moment.
Tested by checking old reduction and where functionality. New testcase
passes.
Ok?
Richard.
2003Nov20 Richard Guenther <richard.guenther at uni-tuebingen.de>
* src/Evaluator/WhereProxy.h: add Element_t typedef and
hasRelations enum.
src/Evaluator/OpMask.h: add Unwrap<> and ReductionTraits<>
specialization for OpMask<> operators.
src/Evaluator/Reduction.h: handle WhereProxy<> in main
reduction evaluator by unwrapping the expression. Unwrap op
for final reduction over patch results.
src/Engine/RemoteEngine.h: unwrap op for final reduction over
patch results.
src/Field/tests/WhereTest.cpp: add tests for reduction over
two- and three-arg where.
diff -Nru a/r2/src/Engine/RemoteEngine.h b/r2/src/Engine/RemoteEngine.h
--- a/r2/src/Engine/RemoteEngine.h Thu Nov 20 22:03:32 2003
+++ b/r2/src/Engine/RemoteEngine.h Thu Nov 20 22:03:32 2003
@@ -2069,7 +2069,7 @@
{
ret = vals[0];
for (j = 1; j < n; j++)
- op(ret, vals[j]);
+ Unwrap<Op>::unwrap(op)(ret, vals[j]);
}
delete [] vals;
diff -Nru a/r2/src/Evaluator/OpMask.h b/r2/src/Evaluator/OpMask.h
--- a/r2/src/Evaluator/OpMask.h Thu Nov 20 22:03:32 2003
+++ b/r2/src/Evaluator/OpMask.h Thu Nov 20 22:03:32 2003
@@ -169,6 +169,28 @@
typedef T1 &Type_t;
};
+template <class Op>
+struct Unwrap {
+ typedef Op Op_t;
+ static inline const Op_t& unwrap(const Op &op) { return op; }
+};
+
+template <class Op>
+struct Unwrap<OpMask<Op> > {
+ typedef typename Unwrap<Op>::Op_t Op_t;
+ static inline const Op_t& unwrap(const OpMask<Op> &op) { return Unwrap<Op>::unwrap(op.op_m); }
+};
+
+template <class Op, class T>
+struct ReductionTraits;
+
+template <class Op, class T>
+struct ReductionTraits<OpMask<Op>, T>
+{
+ static T identity() { return ReductionTraits<Op, T>::identity(); }
+};
+
+
//-----------------------------------------------------------------------------
//
//-----------------------------------------------------------------------------
diff -Nru a/r2/src/Evaluator/Reduction.h b/r2/src/Evaluator/Reduction.h
--- a/r2/src/Evaluator/Reduction.h Thu Nov 20 22:03:32 2003
+++ b/r2/src/Evaluator/Reduction.h Thu Nov 20 22:03:32 2003
@@ -53,6 +53,7 @@
#include "Engine/IntersectEngine.h"
#include "Evaluator/ReductionKernel.h"
#include "Evaluator/EvaluatorTags.h"
+#include "Evaluator/WhereProxy.h"
#include "Threads/PoomaCSem.h"
#include <vector>
@@ -109,6 +110,14 @@
return e.centeringSize() == 1 && e.numMaterials() == 1;
}
+ /// Un-wrap where() expression operation and pass on to generic evaluator.
+
+ template<class T, class Op, class Cond, class Expr>
+ void evaluate(T &ret, const Op &op, const WhereProxy<Cond, Expr> &w) const
+ {
+ evaluate(ret, w.opMask(op), w.whereMask());
+ }
+
/// Input an expression and cause it to be reduced.
/// We just pass the buck to a special reduction after updating
/// the expression leafs and checking its validity (we can handle
@@ -249,7 +258,7 @@
ret = vals[0];
for (j = 1; j < n; j++)
- op(ret, vals[j]);
+ Unwrap<Op>::unwrap(op)(ret, vals[j]);
delete [] vals;
}
};
diff -Nru a/r2/src/Evaluator/WhereProxy.h b/r2/src/Evaluator/WhereProxy.h
--- a/r2/src/Evaluator/WhereProxy.h Thu Nov 20 22:03:32 2003
+++ b/r2/src/Evaluator/WhereProxy.h Thu Nov 20 22:03:32 2003
@@ -85,6 +85,10 @@
typedef typename ConvertWhereProxy<ETrait_t,Tree_t>::Make_t MakeFromTree_t;
typedef typename MakeFromTree_t::Expression_t WhereMask_t;
+ typedef typename B::Element_t Element_t;
+
+ enum { hasRelations = B::hasRelations };
+
inline WhereMask_t
whereMask() const
{
diff -Nru a/r2/src/Field/tests/WhereTest.cpp b/r2/src/Field/tests/WhereTest.cpp
--- a/r2/src/Field/tests/WhereTest.cpp Thu Nov 20 22:03:32 2003
+++ b/r2/src/Field/tests/WhereTest.cpp Thu Nov 20 22:03:32 2003
@@ -86,6 +86,7 @@
// Now, we can declare a field.
Centering<2> allFace = canonicalCentering<2>(FaceType, Continuous);
+ Centering<2> allCell = canonicalCentering<2>(CellType, Continuous);
typedef UniformRectilinearMesh<2> Geometry_t;
@@ -103,6 +104,9 @@
Field_t a(allFace, layout, origin, spacings);
Field_t b(allFace, layout, origin, spacings);
Field_t c(allFace, layout, origin, spacings);
+ Field_t d(allCell, layout, origin, spacings);
+ Field_t e(allCell, layout, origin, spacings);
+ Field_t f(allCell, layout, origin, spacings);
PositionsTraits<Geometry_t>::Type_t x = positions(a);
@@ -154,6 +158,21 @@
tester.check("twoarg where result dirtied part, centering one",
all(where(dot(x.subField(0, 1), line) > 8.0,
b.subField(0, 1), c.subField(0, 1)) == a.subField(0, 1)));
+
+ // 2-arg where reduction
+
+ d = 1.0;
+ e = positions(e).read(e.physicalDomain()).comp(0);
+ tester.check("reduction over twoarg where",
+ sum(where(e(e.physicalDomain()) < 4.0, d)) == 4.0*9.0);
+
+ // 3-arg where reduction
+
+ d = 1.0;
+ f = 0.0;
+ e = positions(e).read(e.physicalDomain()).comp(0);
+ tester.check("reduction over twoarg where",
+ sum(where(e(e.physicalDomain()) < 4.0, d, f)) == 4.0*9.0);
int ret = tester.results("WhereTest");
Pooma::finalize();
More information about the pooma-dev
mailing list