aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
authorKazuaki Ishizaki <ishizaki@jp.ibm.com>2017-03-09 22:58:52 -0800
committerWenchen Fan <wenchen@databricks.com>2017-03-09 22:58:52 -0800
commit5949e6c4477fd3cb07a6962dbee48b4416ea65dd (patch)
tree63c9d88ce70e7c145193228693b24e31102eb26f /sql/core/src/main
parent82138e09b9ad8d9609d5c64d6c11244b8f230be7 (diff)
downloadspark-5949e6c4477fd3cb07a6962dbee48b4416ea65dd.tar.gz
spark-5949e6c4477fd3cb07a6962dbee48b4416ea65dd.tar.bz2
spark-5949e6c4477fd3cb07a6962dbee48b4416ea65dd.zip
[SPARK-19008][SQL] Improve performance of Dataset.map by eliminating boxing/unboxing
## What changes were proposed in this pull request? This PR improve performance of Dataset.map() for primitive types by removing boxing/unbox operations. This is based on [the discussion](https://github.com/apache/spark/pull/16391#discussion_r93788919) with cloud-fan. Current Catalyst generates a method call to a `apply()` method of an anonymous function written in Scala. The types of an argument and return value are `java.lang.Object`. As a result, each method call for a primitive value involves a pair of unboxing and boxing for calling this `apply()` method and a pair of boxing and unboxing for returning from this `apply()` method. This PR directly calls a specialized version of a `apply()` method without boxing and unboxing. For example, if types of an arguments ant return value is `int`, this PR generates a method call to `apply$mcII$sp`. This PR supports any combination of `Int`, `Long`, `Float`, and `Double`. The following is a benchmark result using [this program](https://github.com/apache/spark/pull/16391/files) with 4.7x. Here is a Dataset part of this program. Without this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ RDD 1923 / 1952 52.0 19.2 1.0X DataFrame 526 / 548 190.2 5.3 3.7X Dataset 3094 / 3154 32.3 30.9 0.6X ``` With this PR ``` OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic Intel(R) Xeon(R) CPU E5-2667 v3 3.20GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ RDD 1883 / 1892 53.1 18.8 1.0X DataFrame 502 / 642 199.1 5.0 3.7X Dataset 657 / 784 152.2 6.6 2.9X ``` ```java def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { import spark.implicits._ val rdd = spark.sparkContext.range(0, numRows) val ds = spark.range(0, numRows) val func = (l: Long) => l + 1 val benchmark = new Benchmark("back-to-back map", numRows) ... benchmark.addCase("Dataset") { iter => var res = ds.as[Long] var i = 0 while (i < numChains) { res = res.map(func) i += 1 } res.queryExecution.toRdd.foreach(_ => Unit) } benchmark } ``` A motivating example ```java Seq(1, 2, 3).toDS.map(i => i * 7).show ``` Generated code without this PR ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow deserializetoobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 012 */ private int mapelements_argValue; /* 013 */ private UnsafeRow mapelements_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 016 */ private UnsafeRow serializefromobject_result; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 019 */ /* 020 */ public GeneratedIterator(Object[] references) { /* 021 */ this.references = references; /* 022 */ } /* 023 */ /* 024 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 025 */ partitionIndex = index; /* 026 */ this.inputs = inputs; /* 027 */ inputadapter_input = inputs[0]; /* 028 */ deserializetoobject_result = new UnsafeRow(1); /* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0); /* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 031 */ /* 032 */ mapelements_result = new UnsafeRow(1); /* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0); /* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 035 */ serializefromobject_result = new UnsafeRow(1); /* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 038 */ /* 039 */ } /* 040 */ /* 041 */ protected void processNext() throws java.io.IOException { /* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ int inputadapter_value = inputadapter_row.getInt(0); /* 045 */ /* 046 */ boolean mapelements_isNull = true; /* 047 */ int mapelements_value = -1; /* 048 */ if (!false) { /* 049 */ mapelements_argValue = inputadapter_value; /* 050 */ /* 051 */ mapelements_isNull = false; /* 052 */ if (!mapelements_isNull) { /* 053 */ Object mapelements_funcResult = null; /* 054 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 055 */ if (mapelements_funcResult == null) { /* 056 */ mapelements_isNull = true; /* 057 */ } else { /* 058 */ mapelements_value = (Integer) mapelements_funcResult; /* 059 */ } /* 060 */ /* 061 */ } /* 062 */ /* 063 */ } /* 064 */ /* 065 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 066 */ /* 067 */ if (mapelements_isNull) { /* 068 */ serializefromobject_rowWriter.setNullAt(0); /* 069 */ } else { /* 070 */ serializefromobject_rowWriter.write(0, mapelements_value); /* 071 */ } /* 072 */ append(serializefromobject_result); /* 073 */ if (shouldStop()) return; /* 074 */ } /* 075 */ } /* 076 */ } ``` Generated code with this PR (lines 48-56 are changed) ```java /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private UnsafeRow deserializetoobject_result; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 012 */ private int mapelements_argValue; /* 013 */ private UnsafeRow mapelements_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 016 */ private UnsafeRow serializefromobject_result; /* 017 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 019 */ /* 020 */ public GeneratedIterator(Object[] references) { /* 021 */ this.references = references; /* 022 */ } /* 023 */ /* 024 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 025 */ partitionIndex = index; /* 026 */ this.inputs = inputs; /* 027 */ inputadapter_input = inputs[0]; /* 028 */ deserializetoobject_result = new UnsafeRow(1); /* 029 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 0); /* 030 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 031 */ /* 032 */ mapelements_result = new UnsafeRow(1); /* 033 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 0); /* 034 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 035 */ serializefromobject_result = new UnsafeRow(1); /* 036 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0); /* 037 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 038 */ /* 039 */ } /* 040 */ /* 041 */ protected void processNext() throws java.io.IOException { /* 042 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ int inputadapter_value = inputadapter_row.getInt(0); /* 045 */ /* 046 */ boolean mapelements_isNull = true; /* 047 */ int mapelements_value = -1; /* 048 */ if (!false) { /* 049 */ mapelements_argValue = inputadapter_value; /* 050 */ /* 051 */ mapelements_isNull = false; /* 052 */ if (!mapelements_isNull) { /* 053 */ mapelements_value = ((scala.Function1) references[0]).apply$mcII$sp(mapelements_argValue); /* 054 */ } /* 055 */ /* 056 */ } /* 057 */ /* 058 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 059 */ /* 060 */ if (mapelements_isNull) { /* 061 */ serializefromobject_rowWriter.setNullAt(0); /* 062 */ } else { /* 063 */ serializefromobject_rowWriter.write(0, mapelements_value); /* 064 */ } /* 065 */ append(serializefromobject_result); /* 066 */ if (shouldStop()) return; /* 067 */ } /* 068 */ } /* 069 */ } ``` Java bytecode for methods for `i => i * 7` ```java $ javap -c Test\$\$anonfun\$5\$\$anonfun\$apply\$mcV\$sp\$1.class Compiled from "Test.scala" public final class org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1 extends scala.runtime.AbstractFunction1$mcII$sp implements scala.Serializable { public static final long serialVersionUID; public final int apply(int); Code: 0: aload_0 1: iload_1 2: invokevirtual #18 // Method apply$mcII$sp:(I)I 5: ireturn public int apply$mcII$sp(int); Code: 0: iload_1 1: bipush 7 3: imul 4: ireturn public final java.lang.Object apply(java.lang.Object); Code: 0: aload_0 1: aload_1 2: invokestatic #29 // Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I 5: invokevirtual #31 // Method apply:(I)I 8: invokestatic #35 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer; 11: areturn public org.apache.spark.sql.Test$$anonfun$5$$anonfun$apply$mcV$sp$1(org.apache.spark.sql.Test$$anonfun$5); Code: 0: aload_0 1: invokespecial #42 // Method scala/runtime/AbstractFunction1$mcII$sp."<init>":()V 4: return } ``` ## How was this patch tested? Added new test suites to `DatasetPrimitiveSuite`. Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #17172 from kiszk/SPARK-19008.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala6
1 files changed, 4 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 199ba5ce69..fdd1bcc94b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
import org.apache.spark.sql.execution.streaming.KeyedStateImpl
-import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
/**
@@ -219,7 +221,7 @@ case class MapElementsExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
- case _ => classOf[Any => Any] -> "apply"
+ case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType)
}
val funcObj = Literal.create(func, ObjectType(funcClass))
val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)