diff options
author | Kazuaki Ishizaki <ishizaki@jp.ibm.com> | 2017-03-09 22:58:52 -0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-03-09 22:58:52 -0800 |
commit | 5949e6c4477fd3cb07a6962dbee48b4416ea65dd (patch) | |
tree | 63c9d88ce70e7c145193228693b24e31102eb26f | |
parent | 82138e09b9ad8d9609d5c64d6c11244b8f230be7 (diff) | |
download | spark-5949e6c4477fd3cb07a6962dbee48b4416ea65dd.tar.gz spark-5949e6c4477fd3cb07a6962dbee48b4416ea65dd.tar.bz2 spark-5949e6c4477fd3cb07a6962dbee48b4416ea65dd.zip |
[SPARK-19008][SQL] Improve performance of Dataset.map by eliminating boxing/unboxing
## What changes were proposed in this pull request?
This PR improve performance of Dataset.map() for primitive types by removing boxing/unbox operations. This is based on [the discussion](https://github.com/apache/spark/pull/16391#discussion_r93788919) with cloud-fan.
Current Catalyst generates a method call to a `apply()` method of an anonymous function written in Scala. The types of an argument and return value are `java.lang.Object`. As a result, each method call for a primitive value involves a pair of unboxing and boxing for calling this `apply()` method and a pair of boxing and unboxing for returning from this `apply()` method.
This PR directly calls a specialized version of a `apply()` method without boxing and unboxing. For example, if types of an arguments ant return value is `int`, this PR generates a method call to `apply$mcII$sp`. This PR supports any combination of `Int`, `Long`, `Float`, and `Double`.
The following is a benchmark result using [this program](https://github.com/apache/spark/pull/16391/files) with 4.7x. Here is a Dataset part of this program.
Without this PR
```
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz
back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
RDD 1923 / 1952 52.0 19.2 1.0X
DataFrame 526 / 548 190.2 5.3 3.7X
Dataset 3094 / 3154 32.3 30.9 0.6X
```
With this PR
```
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz
back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
RDD 1883 / 1892 53.1 18.8 1.0X
DataFrame 502 / 642 199.1 5.0 3.7X
Dataset 657 / 784 152.2 6.6 2.9X
```
```java
def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._
val rdd = spark.sparkContext.range(0, numRows)
val ds = spark.range(0, numRows)
val func = (l: Long) => l + 1
val benchmark = new Benchmark("back-to-back map", numRows)
...
benchmark.addCase("Dataset") { iter =>
var res = ds.as[Long]
var i = 0
while (i < numChains) {
res = res.map(func)
i += 1
}
res.queryExecution.toRdd.foreach(_ => Unit)
}
benchmark
}
```
A motivating example
```java
Seq(1, 2, 3).toDS.map(i => i * 7).show
```
Generated code without this PR
```java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private UnsafeRow deserializetoobject_result;
/* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 012 */ private int mapelements_argValue;
/* 013 */ private UnsafeRow mapelements_result;
/* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 016 */ private UnsafeRow serializefromobject_result;
/* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 019 */
/* 020 */ public GeneratedIterator(Object[] references) {
/* 021 */ this.references = references;
/* 022 */ }
/* 023 */
/* 024 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */ partitionIndex = index;
/* 026 */ this.inputs = inputs;
/* 027 */ inputadapter_input = inputs[0];
/* 028 */ deserializetoobject_result = new UnsafeRow(1);
/* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0);
/* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 031 */
/* 032 */ mapelements_result = new UnsafeRow(1);
/* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0);
/* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 035 */ serializefromobject_result = new UnsafeRow(1);
/* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 038 */
/* 039 */ }
/* 040 */
/* 041 */ protected void processNext() throws java.io.IOException {
/* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 044 */ int inputadapter_value = inputadapter_row.getInt(0);
/* 045 */
/* 046 */ boolean mapelements_isNull = true;
/* 047 */ int mapelements_value = -1;
/* 048 */ if (!false) {
/* 049 */ mapelements_argValue = inputadapter_value;
/* 050 */
/* 051 */ mapelements_isNull = false;
/* 052 */ if (!mapelements_isNull) {
/* 053 */ Object mapelements_funcResult = null;
/* 054 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 055 */ if (mapelements_funcResult == null) {
/* 056 */ mapelements_isNull = true;
/* 057 */ } else {
/* 058 */ mapelements_value = (Integer) mapelements_funcResult;
/* 059 */ }
/* 060 */
/* 061 */ }
/* 062 */
/* 063 */ }
/* 064 */
/* 065 */ serializefromobject_rowWriter.zeroOutNullBytes();
/* 066 */
/* 067 */ if (mapelements_isNull) {
/* 068 */ serializefromobject_rowWriter.setNullAt(0);
/* 069 */ } else {
/* 070 */ serializefromobject_rowWriter.write(0, mapelements_value);
/* 071 */ }
/* 072 */ append(serializefromobject_result);
/* 073 */ if (shouldStop()) return;
/* 074 */ }
/* 075 */ }
/* 076 */ }
```
Generated code with this PR (lines 48-56 are changed)
```java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private UnsafeRow deserializetoobject_result;
/* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 012 */ private int mapelements_argValue;
/* 013 */ private UnsafeRow mapelements_result;
/* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 016 */ private UnsafeRow serializefromobject_result;
/* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 019 */
/* 020 */ public GeneratedIterator(Object[] references) {
/* 021 */ this.references = references;
/* 022 */ }
/* 023 */
/* 024 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */ partitionIndex = index;
/* 026 */ this.inputs = inputs;
/* 027 */ inputadapter_input = inputs[0];
/* 028 */ deserializetoobject_result = new UnsafeRow(1);
/* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0);
/* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 031 */
/* 032 */ mapelements_result = new UnsafeRow(1);
/* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0);
/* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 035 */ serializefromobject_result = new UnsafeRow(1);
/* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 038 */
/* 039 */ }
/* 040 */
/* 041 */ protected void processNext() throws java.io.IOException {
/* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 044 */ int inputadapter_value = inputadapter_row.getInt(0);
/* 045 */
/* 046 */ boolean mapelements_isNull = true;
/* 047 */ int mapelements_value = -1;
/* 048 */ if (!false) {
/* 049 */ mapelements_argValue = inputadapter_value;
/* 050 */
/* 051 */ mapelements_isNull = false;
/* 052 */ if (!mapelements_isNull) {
/* 053 */ mapelements_value = ((scala.Function1) references[0]).apply$mcII$sp(mapelements_argValue);
/* 054 */ }
/* 055 */
/* 056 */ }
/* 057 */
/* 058 */ serializefromobject_rowWriter.zeroOutNullBytes();
/* 059 */
/* 060 */ if (mapelements_isNull) {
/* 061 */ serializefromobject_rowWriter.setNullAt(0);
/* 062 */ } else {
/* 063 */ serializefromobject_rowWriter.write(0, mapelements_value);
/* 064 */ }
/* 065 */ append(serializefromobject_result);
/* 066 */ if (shouldStop()) return;
/* 067 */ }
/* 068 */ }
/* 069 */ }
```
Java bytecode for methods for `i => i * 7`
```java
$ javap -c Test\$\$anonfun\$5\$\$anonfun\$apply\$mcV\$sp\$1.class
Compiled from "Test.scala"
public final class org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1 extends scala.runtime.AbstractFunction1$mcII$sp implements scala.Serializable {
public static final long serialVersionUID;
public final int apply(int);
Code:
0: aload_0
1: iload_1
2: invokevirtual #18 // Method apply$mcII$sp:(I)I
5: ireturn
public int apply$mcII$sp(int);
Code:
0: iload_1
1: bipush 7
3: imul
4: ireturn
public final java.lang.Object apply(java.lang.Object);
Code:
0: aload_0
1: aload_1
2: invokestatic #29 // Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I
5: invokevirtual #31 // Method apply:(I)I
8: invokestatic #35 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
11: areturn
public org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1(org.apache.spark.sql.Test$$anonfun$5);
Code:
0: aload_0
1: invokespecial #42 // Method scala/runtime/AbstractFunction1$mcII$sp."<init>":()V
4: return
}
```
## How was this patch tested?
Added new test suites to `DatasetPrimitiveSuite`.
Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>
Closes #17172 from kiszk/SPARK-19008.
4 files changed, 208 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 617239f56c..7f4462e583 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object CatalystSerde { def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { @@ -211,13 +212,48 @@ case class TypedFilter( def typedCondition(input: Expression): Expression = { val (funcClass, methodName) = func match { case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call" - case _ => classOf[Any => Boolean] -> "apply" + case _ => FunctionUtils.getFunctionOneName(BooleanType, input.dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) Invoke(funcObj, methodName, BooleanType, input :: Nil) } } +object FunctionUtils { + private def getMethodType(dt: DataType, isOutput: Boolean): Option[String] = { + dt match { + case BooleanType if isOutput => Some("Z") + case IntegerType => Some("I") + case LongType => Some("J") + case FloatType => Some("F") + case DoubleType => Some("D") + case _ => None + } + } + + def getFunctionOneName(outputDT: DataType, inputDT: DataType): (Class[_], String) = { + // load "scala.Function1" using Java API to avoid requirements of type parameters + Utils.classForName("scala.Function1") -> { + // if a pair of an argument and return types is one of specific types + // whose specialized method (apply$mc..$sp) is generated by scalac, + // Catalyst generated a direct method call to the specialized method. + // The followings are references for this specialization: + // http://www.scala-lang.org/api/2.12.0/scala/Function1.html + // https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/ + // SpecializeTypes.scala + // http://www.cakesolutions.net/teamblogs/scala-dissection-functions + // http://axel22.github.io/2013/11/03/specialization-quirks.html + val inputType = getMethodType(inputDT, false) + val outputType = getMethodType(outputDT, true) + if (inputType.isDefined && outputType.isDefined) { + s"apply$$mc${outputType.get}${inputType.get}$$sp" + } else { + "apply" + } + } + } +} + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 199ba5ce69..fdd1bcc94b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState import org.apache.spark.sql.execution.streaming.KeyedStateImpl -import org.apache.spark.sql.types.{DataType, ObjectType, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** @@ -219,7 +221,7 @@ case class MapElementsExec( override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val (funcClass, methodName) = func match { case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" - case _ => classOf[Any => Any] -> "apply" + case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType) } val funcObj = Literal.create(func, ObjectType(funcClass)) val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 66d94d6016..1a0672b887 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -31,6 +31,49 @@ object DatasetBenchmark { case class Data(l: Long, s: String) + def backToBackMapLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ + + val rdd = spark.sparkContext.range(0, numRows) + val ds = spark.range(0, numRows) + val df = ds.toDF("l") + val func = (l: Long) => l + 1 + + val benchmark = new Benchmark("back-to-back map long", numRows) + + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.select($"l" + 1 as "l") + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = ds.as[Long] + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ @@ -72,6 +115,49 @@ object DatasetBenchmark { benchmark } + def backToBackFilterLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ + + val rdd = spark.sparkContext.range(1, numRows) + val ds = spark.range(1, numRows) + val df = ds.toDF("l") + val func = (l: Long) => l % 2L == 0L + + val benchmark = new Benchmark("back-to-back filter Long", numRows) + + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = res.filter(func) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.filter($"l" % 2L === 0L) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset") { iter => + var res = ds.as[Long] + var i = 0 + while (i < numChains) { + res = res.filter(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ @@ -165,9 +251,22 @@ object DatasetBenchmark { val numRows = 100000000 val numChains = 10 - val benchmark = backToBackMap(spark, numRows, numChains) - val benchmark2 = backToBackFilter(spark, numRows, numChains) - val benchmark3 = aggregate(spark, numRows) + val benchmark0 = backToBackMapLong(spark, numRows, numChains) + val benchmark1 = backToBackMap(spark, numRows, numChains) + val benchmark2 = backToBackFilterLong(spark, numRows, numChains) + val benchmark3 = backToBackFilter(spark, numRows, numChains) + val benchmark4 = aggregate(spark, numRows) + + /* + OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic + Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz + back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + RDD 1883 / 1892 53.1 18.8 1.0X + DataFrame 502 / 642 199.1 5.0 3.7X + Dataset 657 / 784 152.2 6.6 2.9X + */ + benchmark0.run() /* OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 @@ -178,7 +277,18 @@ object DatasetBenchmark { DataFrame 2647 / 3116 37.8 26.5 1.3X Dataset 4781 / 5155 20.9 47.8 0.7X */ - benchmark.run() + benchmark1.run() + + /* + OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic + Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz + back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + RDD 846 / 1120 118.1 8.5 1.0X + DataFrame 270 / 329 370.9 2.7 3.1X + Dataset 545 / 789 183.5 5.4 1.6X + */ + benchmark2.run() /* OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 @@ -189,7 +299,7 @@ object DatasetBenchmark { DataFrame 59 / 72 1695.4 0.6 22.8X Dataset 2777 / 2805 36.0 27.8 0.5X */ - benchmark2.run() + benchmark3.run() /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1 @@ -201,6 +311,6 @@ object DatasetBenchmark { Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X */ - benchmark3.run() + benchmark4.run() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 6b50cb3e48..82b707537e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -62,6 +62,40 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 2, 3, 4) } + test("mapPrimitive") { + val dsInt = Seq(1, 2, 3).toDS() + checkDataset(dsInt.map(_ > 1), false, true, true) + checkDataset(dsInt.map(_ + 1), 2, 3, 4) + checkDataset(dsInt.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsInt.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsInt.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsLong = Seq(1L, 2L, 3L).toDS() + checkDataset(dsLong.map(_ > 1), false, true, true) + checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsLong.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsLong.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsLong.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsFloat = Seq(1F, 2F, 3F).toDS() + checkDataset(dsFloat.map(_ > 1), false, true, true) + checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L) + checkDataset(dsFloat.map(_ + 1.1F), 2.1F, 3.1F, 4.1F) + checkDataset(dsFloat.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsDouble = Seq(1D, 2D, 3D).toDS() + checkDataset(dsDouble.map(_ > 1), false, true, true) + checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4) + checkDataset(dsDouble.map(e => (e + 8589934592L).toLong), + 8589934593L, 8589934594L, 8589934595L) + checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F) + checkDataset(dsDouble.map(_ + 1.23D), 2.23D, 3.23D, 4.23D) + + val dsBoolean = Seq(true, false).toDS() + checkDataset(dsBoolean.map(e => !e), false, true) + } + test("filter") { val ds = Seq(1, 2, 3, 4).toDS() checkDataset( @@ -69,6 +103,23 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 2, 4) } + test("filterPrimitive") { + val dsInt = Seq(1, 2, 3).toDS() + checkDataset(dsInt.filter(_ > 1), 2, 3) + + val dsLong = Seq(1L, 2L, 3L).toDS() + checkDataset(dsLong.filter(_ > 1), 2L, 3L) + + val dsFloat = Seq(1F, 2F, 3F).toDS() + checkDataset(dsFloat.filter(_ > 1), 2F, 3F) + + val dsDouble = Seq(1D, 2D, 3D).toDS() + checkDataset(dsDouble.filter(_ > 1), 2D, 3D) + + val dsBoolean = Seq(true, false).toDS() + checkDataset(dsBoolean.filter(e => !e), false) + } + test("foreach") { val ds = Seq(1, 2, 3).toDS() val acc = sparkContext.longAccumulator |