diff --git a/src/main/java/org/apache/datasketches/hll/HllUnion.java b/src/main/java/org/apache/datasketches/hll/HllUnion.java index dca2cd7da..2738d68d7 100644 --- a/src/main/java/org/apache/datasketches/hll/HllUnion.java +++ b/src/main/java/org/apache/datasketches/hll/HllUnion.java @@ -176,7 +176,7 @@ public double getEstimate() { checkRebuildCurMinNumKxQ(gadget); return gadget.getEstimate(); } - + /** * Gets the effective lgConfigK for the HllUnion operator, which may be less than * lgMaxK. @@ -320,6 +320,14 @@ public void update(final HllSketch sketch) { gadget.hllSketchImpl = unionImpl(sketch, gadget, lgMaxK); } + /** + * Update this HllUnion operator with the given HllUnion. + * @param union the given HllUnion. + */ + public void update(final HllUnion union) { + gadget.hllSketchImpl = unionImpl(union.gadget, gadget, lgMaxK); + } + @Override void couponUpdate(final int coupon) { if (coupon == EMPTY) { return; } diff --git a/src/test/java/org/apache/datasketches/hll/UnionCaseTest.java b/src/test/java/org/apache/datasketches/hll/UnionCaseTest.java index 84ec7ce9a..608585f5e 100644 --- a/src/test/java/org/apache/datasketches/hll/UnionCaseTest.java +++ b/src/test/java/org/apache/datasketches/hll/UnionCaseTest.java @@ -31,6 +31,8 @@ import static org.testng.Assert.assertTrue; import java.lang.foreign.MemorySegment; +//import java.lang.invoke.MethodHandles; +//import java.lang.invoke.VarHandle; import org.apache.datasketches.common.SketchesStateException; import org.testng.annotations.Test; @@ -42,8 +44,8 @@ public class UnionCaseTest { private static final String LS = System.getProperty("line.separator"); long v = 0; final static int maxLgK = 12; - HllSketch source; - //HllUnion union; + HllSketch skSource; + HllUnion uSource; String hfmt = "%10s%10s%10s%10s%10s%10s%10s%10s%10s%10s%10s" + LS; String hdr = String.format(hfmt, "caseNum","srcLgKStr","gdtLgKStr","srcType","gdtType", "srcSeg","gdtSeg","srcMode","gdtMode","srcOoof","gdtOoof"); @@ -52,48 +54,68 @@ public class UnionCaseTest { public void checkAllCases() { print(hdr); for (int i = 0; i < 24; i++) { - checkCase(i, HLL_4, false); + checkCase(i, HLL_4, false, false); } println(""); print(hdr); for (int i = 0; i < 24; i++) { - checkCase(i, HLL_6, false); + checkCase(i, HLL_6, false, false); } println(""); print(hdr); for (int i = 0; i < 24; i++) { - checkCase(i, HLL_8, false); + checkCase(i, HLL_8, false, false); } println(""); print(hdr); for (int i = 0; i < 24; i++) { - checkCase(i, HLL_4, true); + checkCase(i, HLL_8, false, true); //srcUnion + } + println(""); + + print(hdr); + for (int i = 0; i < 24; i++) { + checkCase(i, HLL_4, true, false); } println(""); print(hdr); for (int i = 0; i < 24; i++) { - checkCase(i, HLL_6, true); + checkCase(i, HLL_6, true, false); } println(""); print(hdr); for (int i = 0; i < 24; i++) { - checkCase(i, HLL_8, true); + checkCase(i, HLL_8, true, false); + } + println(""); + + print(hdr); + for (int i = 0; i < 24; i++) { + checkCase(i, HLL_8, true, true); //srcUnion } println(""); } - private void checkCase(final int caseNum, final TgtHllType srcType, final boolean srcSeg) { - source = getSource(caseNum, srcType, srcSeg); + private void checkCase(final int caseNum, final TgtHllType srcType, final boolean srcSeg, final boolean srcUnion) { + if (srcUnion) { + uSource = getUnionSrc(caseNum); + } else { + skSource = getSkSource(caseNum, srcType, srcSeg); + } final boolean gdtSeg = (caseNum & 1) > 0; final HllUnion union = getUnion(caseNum, gdtSeg); - union.update(source); + if (srcUnion) { + union.update(uSource); + } else { + union.update(skSource); + } final int totalU = getSrcCount(caseNum, maxLgK) + getUnionCount(caseNum); - output(caseNum, source, union, totalU); + output(caseNum, skSource, union, totalU); } private void output(final int caseNum, final HllSketch source, final HllUnion union, final int totalU) { @@ -121,7 +143,7 @@ private void output(final int caseNum, final HllSketch source, final HllUnion un assertTrue(err < rse, "Err: " + err + ", RSE: " + rse); } - private HllSketch getSource(final int caseNum, final TgtHllType tgtHllType, final boolean useMemorySegment) { + private HllSketch getSkSource(final int caseNum, final TgtHllType tgtHllType, final boolean useMemorySegment) { final int srcLgK = getSrcLgK(caseNum, maxLgK); final int srcU = getSrcCount(caseNum, maxLgK); if (useMemorySegment) { @@ -131,9 +153,18 @@ private HllSketch getSource(final int caseNum, final TgtHllType tgtHllType, fina } } + private HllUnion getUnionSrc(final int caseNum) { + final int srcLgK = getSrcLgK(caseNum, maxLgK); + final int srcU = getSrcCount(caseNum, maxLgK); + final HllSketch sk = buildHeapSketch(srcLgK, HLL_8, srcU); + final HllUnion u = new HllUnion(maxLgK); + u.update(sk); + return u; + } + private HllUnion getUnion(final int caseNum, final boolean useMemorySegment) { final int unionU = getUnionCount(caseNum); - return (useMemorySegment) ? buildMemorSegmentUnion(maxLgK, unionU) : buildHeapUnion(maxLgK, unionU); + return (useMemorySegment) ? buildMemorySegmentUnion(maxLgK, unionU) : buildHeapUnion(maxLgK, unionU); } private static int getUnionCount(final int caseNum) { @@ -394,7 +425,7 @@ private HllUnion buildHeapUnion(final int lgMaxK, final int n) { return u; } - private HllUnion buildMemorSegmentUnion(final int lgMaxK, final int n) { + private HllUnion buildMemorySegmentUnion(final int lgMaxK, final int n) { final int bytes = HllSketch.getMaxUpdatableSerializationBytes(lgMaxK, TgtHllType.HLL_8); final MemorySegment wseg = MemorySegment.ofArray(new byte[bytes]); final HllUnion u = new HllUnion(lgMaxK, wseg);