aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorKazuaki Ishizaki <ishizaki@jp.ibm.com>2017-04-10 10:47:17 +0800
committerWenchen Fan <wenchen@databricks.com>2017-04-10 10:47:17 +0800
commit7a63f5e82758345ff1f3322950f2bbea350c48b9 (patch)
treeccace14effdec6410f87e49f5eb94f55eb51ccc1 /sql/catalyst
parent261eaf5149a8fe479ab4f9c34db892bcedbf5739 (diff)
downloadspark-7a63f5e82758345ff1f3322950f2bbea350c48b9.tar.gz
spark-7a63f5e82758345ff1f3322950f2bbea350c48b9.tar.bz2
spark-7a63f5e82758345ff1f3322950f2bbea350c48b9.zip
[SPARK-20253][SQL] Remove unnecessary nullchecks of a return value from Spark runtime routines in generated Java code
## What changes were proposed in this pull request? This PR elminates unnecessary nullchecks of a return value from known Spark runtime routines. We know whether a given Spark runtime routine returns ``null`` or not (e.g. ``ArrayData.toDoubleArray()`` never returns ``null``). Thus, we can eliminate a null check for the return value from the Spark runtime routine. When we run the following example program, now we get the Java code "Without this PR". In this code, since we know ``ArrayData.toDoubleArray()`` never returns ``null```, we can eliminate null checks at lines 90-92, and 97. ```java val ds = sparkContext.parallelize(Seq(Array(1.1, 2.2)), 1).toDS.cache ds.count ds.map(e => e).show ``` Without this PR ```java /* 050 */ protected void processNext() throws java.io.IOException { /* 051 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 052 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 053 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 054 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 055 */ /* 056 */ ArrayData deserializetoobject_value1 = null; /* 057 */ /* 058 */ if (!inputadapter_isNull) { /* 059 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 060 */ /* 061 */ Double[] deserializetoobject_convertedArray = null; /* 062 */ deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength]; /* 063 */ /* 064 */ int deserializetoobject_loopIndex = 0; /* 065 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 066 */ MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex)); /* 067 */ MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 068 */ /* 069 */ if (MapObjects_loopIsNull2) { /* 070 */ throw new RuntimeException(((java.lang.String) references[0])); /* 071 */ } /* 072 */ if (false) { /* 073 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 074 */ } else { /* 075 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2; /* 076 */ } /* 077 */ /* 078 */ deserializetoobject_loopIndex += 1; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/ /* 082 */ } /* 083 */ boolean deserializetoobject_isNull = true; /* 084 */ double[] deserializetoobject_value = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull = false; /* 087 */ if (!deserializetoobject_isNull) { /* 088 */ Object deserializetoobject_funcResult = null; /* 089 */ deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray(); /* 090 */ if (deserializetoobject_funcResult == null) { /* 091 */ deserializetoobject_isNull = true; /* 092 */ } else { /* 093 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull = deserializetoobject_value == null; /* 098 */ } /* 099 */ /* 100 */ boolean mapelements_isNull = true; /* 101 */ double[] mapelements_value = null; /* 102 */ if (!false) { /* 103 */ mapelements_resultIsNull = false; /* 104 */ /* 105 */ if (!mapelements_resultIsNull) { /* 106 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 107 */ mapelements_argValue = deserializetoobject_value; /* 108 */ } /* 109 */ /* 110 */ mapelements_isNull = mapelements_resultIsNull; /* 111 */ if (!mapelements_isNull) { /* 112 */ Object mapelements_funcResult = null; /* 113 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 114 */ if (mapelements_funcResult == null) { /* 115 */ mapelements_isNull = true; /* 116 */ } else { /* 117 */ mapelements_value = (double[]) mapelements_funcResult; /* 118 */ } /* 119 */ /* 120 */ } /* 121 */ mapelements_isNull = mapelements_value == null; /* 122 */ } /* 123 */ /* 124 */ serializefromobject_resultIsNull = false; /* 125 */ /* 126 */ if (!serializefromobject_resultIsNull) { /* 127 */ serializefromobject_resultIsNull = mapelements_isNull; /* 128 */ serializefromobject_argValue = mapelements_value; /* 129 */ } /* 130 */ /* 131 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 132 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 133 */ serializefromobject_isNull = serializefromobject_value == null; /* 134 */ serializefromobject_holder.reset(); /* 135 */ /* 136 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 137 */ /* 138 */ if (serializefromobject_isNull) { /* 139 */ serializefromobject_rowWriter.setNullAt(0); /* 140 */ } else { /* 141 */ // Remember the current cursor so that we can calculate how many bytes are /* 142 */ // written later. /* 143 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 144 */ /* 145 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 146 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 147 */ // grow the global buffer before writing data. /* 148 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 149 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 150 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 151 */ /* 152 */ } else { /* 153 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 154 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 155 */ /* 156 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 157 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 158 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 159 */ } else { /* 160 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 161 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 162 */ } /* 163 */ } /* 164 */ } /* 165 */ /* 166 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 167 */ } /* 168 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 169 */ append(serializefromobject_result); /* 170 */ if (shouldStop()) return; /* 171 */ } /* 172 */ } ``` With this PR (removed most of lines 90-97 in the above code) ```java /* 050 */ protected void processNext() throws java.io.IOException { /* 051 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 052 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 053 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 054 */ ArrayData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getArray(0)); /* 055 */ /* 056 */ ArrayData deserializetoobject_value1 = null; /* 057 */ /* 058 */ if (!inputadapter_isNull) { /* 059 */ int deserializetoobject_dataLength = inputadapter_value.numElements(); /* 060 */ /* 061 */ Double[] deserializetoobject_convertedArray = null; /* 062 */ deserializetoobject_convertedArray = new Double[deserializetoobject_dataLength]; /* 063 */ /* 064 */ int deserializetoobject_loopIndex = 0; /* 065 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 066 */ MapObjects_loopValue2 = (double) (inputadapter_value.getDouble(deserializetoobject_loopIndex)); /* 067 */ MapObjects_loopIsNull2 = inputadapter_value.isNullAt(deserializetoobject_loopIndex); /* 068 */ /* 069 */ if (MapObjects_loopIsNull2) { /* 070 */ throw new RuntimeException(((java.lang.String) references[0])); /* 071 */ } /* 072 */ if (false) { /* 073 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = null; /* 074 */ } else { /* 075 */ deserializetoobject_convertedArray[deserializetoobject_loopIndex] = MapObjects_loopValue2; /* 076 */ } /* 077 */ /* 078 */ deserializetoobject_loopIndex += 1; /* 079 */ } /* 080 */ /* 081 */ deserializetoobject_value1 = new org.apache.spark.sql.catalyst.util.GenericArrayData(deserializetoobject_convertedArray); /*###*/ /* 082 */ } /* 083 */ boolean deserializetoobject_isNull = true; /* 084 */ double[] deserializetoobject_value = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull = false; /* 087 */ if (!deserializetoobject_isNull) { /* 088 */ Object deserializetoobject_funcResult = null; /* 089 */ deserializetoobject_funcResult = deserializetoobject_value1.toDoubleArray(); /* 090 */ deserializetoobject_value = (double[]) deserializetoobject_funcResult; /* 091 */ /* 092 */ } /* 093 */ /* 094 */ } /* 095 */ /* 096 */ boolean mapelements_isNull = true; /* 097 */ double[] mapelements_value = null; /* 098 */ if (!false) { /* 099 */ mapelements_resultIsNull = false; /* 100 */ /* 101 */ if (!mapelements_resultIsNull) { /* 102 */ mapelements_resultIsNull = deserializetoobject_isNull; /* 103 */ mapelements_argValue = deserializetoobject_value; /* 104 */ } /* 105 */ /* 106 */ mapelements_isNull = mapelements_resultIsNull; /* 107 */ if (!mapelements_isNull) { /* 108 */ Object mapelements_funcResult = null; /* 109 */ mapelements_funcResult = ((scala.Function1) references[1]).apply(mapelements_argValue); /* 110 */ if (mapelements_funcResult == null) { /* 111 */ mapelements_isNull = true; /* 112 */ } else { /* 113 */ mapelements_value = (double[]) mapelements_funcResult; /* 114 */ } /* 115 */ /* 116 */ } /* 117 */ mapelements_isNull = mapelements_value == null; /* 118 */ } /* 119 */ /* 120 */ serializefromobject_resultIsNull = false; /* 121 */ /* 122 */ if (!serializefromobject_resultIsNull) { /* 123 */ serializefromobject_resultIsNull = mapelements_isNull; /* 124 */ serializefromobject_argValue = mapelements_value; /* 125 */ } /* 126 */ /* 127 */ boolean serializefromobject_isNull = serializefromobject_resultIsNull; /* 128 */ final ArrayData serializefromobject_value = serializefromobject_resultIsNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(serializefromobject_argValue); /* 129 */ serializefromobject_isNull = serializefromobject_value == null; /* 130 */ serializefromobject_holder.reset(); /* 131 */ /* 132 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 133 */ /* 134 */ if (serializefromobject_isNull) { /* 135 */ serializefromobject_rowWriter.setNullAt(0); /* 136 */ } else { /* 137 */ // Remember the current cursor so that we can calculate how many bytes are /* 138 */ // written later. /* 139 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 140 */ /* 141 */ if (serializefromobject_value instanceof UnsafeArrayData) { /* 142 */ final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes(); /* 143 */ // grow the global buffer before writing data. /* 144 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 145 */ ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 146 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 147 */ /* 148 */ } else { /* 149 */ final int serializefromobject_numElements = serializefromobject_value.numElements(); /* 150 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 8); /* 151 */ /* 152 */ for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) { /* 153 */ if (serializefromobject_value.isNullAt(serializefromobject_index)) { /* 154 */ serializefromobject_arrayWriter.setNullDouble(serializefromobject_index); /* 155 */ } else { /* 156 */ final double serializefromobject_element = serializefromobject_value.getDouble(serializefromobject_index); /* 157 */ serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element); /* 158 */ } /* 159 */ } /* 160 */ } /* 161 */ /* 162 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 163 */ } /* 164 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 165 */ append(serializefromobject_result); /* 166 */ if (shouldStop()) return; /* 167 */ } /* 168 */ } ``` ## How was this patch tested? Add test suites to ``DatasetPrimitiveSuite`` Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #17569 from kiszk/SPARK-20253.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala28
3 files changed, 41 insertions, 33 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 206ae2f0e5..198122759e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -251,19 +251,22 @@ object ScalaReflection extends ScalaReflection {
getPath :: Nil)
case t if t <:< localTypeOf[java.lang.String] =>
- Invoke(getPath, "toString", ObjectType(classOf[String]))
+ Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
- Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+ Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
+ returnNullable = false)
case t if t <:< localTypeOf[BigDecimal] =>
- Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
+ Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), returnNullable = false)
case t if t <:< localTypeOf[java.math.BigInteger] =>
- Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]))
+ Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]),
+ returnNullable = false)
case t if t <:< localTypeOf[scala.math.BigInt] =>
- Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]))
+ Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
+ returnNullable = false)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
@@ -284,7 +287,7 @@ object ScalaReflection extends ScalaReflection {
val arrayCls = arrayClassFor(elementType)
if (elementNullable) {
- Invoke(arrayData, "array", arrayCls)
+ Invoke(arrayData, "array", arrayCls, returnNullable = false)
} else {
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => "toIntArray"
@@ -297,7 +300,7 @@ object ScalaReflection extends ScalaReflection {
case other => throw new IllegalStateException("expect primitive array element type " +
"but got " + other)
}
- Invoke(arrayData, primitiveMethod, arrayCls)
+ Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
}
case t if t <:< localTypeOf[Seq[_]] =>
@@ -330,19 +333,21 @@ object ScalaReflection extends ScalaReflection {
Invoke(
MapObjects(
p => deserializerFor(keyType, Some(p), walkedTypePath),
- Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
+ Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
+ returnNullable = false),
schemaFor(keyType).dataType),
"array",
- ObjectType(classOf[Array[Any]]))
+ ObjectType(classOf[Array[Any]]), returnNullable = false)
val valueData =
Invoke(
MapObjects(
p => deserializerFor(valueType, Some(p), walkedTypePath),
- Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
+ Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
+ returnNullable = false),
schemaFor(valueType).dataType),
"array",
- ObjectType(classOf[Array[Any]]))
+ ObjectType(classOf[Array[Any]]), returnNullable = false)
StaticInvoke(
ArrayBasedMapData.getClass,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e95e97b9dc..0f8282d3b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -89,7 +89,7 @@ object RowEncoder {
udtClass,
Nil,
dataType = ObjectType(udtClass), false)
- Invoke(obj, "serialize", udt, inputObject :: Nil)
+ Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
case TimestampType =>
StaticInvoke(
@@ -136,16 +136,18 @@ object RowEncoder {
case t @ MapType(kt, vt, valueNullable) =>
val keys =
Invoke(
- Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+ Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
+ returnNullable = false),
"toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
+ ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
val convertedKeys = serializerFor(keys, ArrayType(kt, false))
val values =
Invoke(
- Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
+ Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]]),
+ returnNullable = false),
"toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
+ ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
NewInstance(
@@ -262,17 +264,18 @@ object RowEncoder {
input :: Nil)
case _: DecimalType =>
- Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+ Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]),
+ returnNullable = false)
case StringType =>
- Invoke(input, "toString", ObjectType(classOf[String]))
+ Invoke(input, "toString", ObjectType(classOf[String]), returnNullable = false)
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
MapObjects(deserializerFor(_), input, et),
"array",
- ObjectType(classOf[Array[_]]))
+ ObjectType(classOf[Array[_]]), returnNullable = false)
StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 53842ef348..6d94764f1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -225,25 +225,26 @@ case class Invoke(
getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
} else {
val funcResult = ctx.freshName("funcResult")
+ // If the function can return null, we do an extra check to make sure our null bit is still
+ // set correctly.
+ val assignResult = if (!returnNullable) {
+ s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;"
+ } else {
+ s"""
+ if ($funcResult != null) {
+ ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
+ } else {
+ ${ev.isNull} = true;
+ }
+ """
+ }
s"""
Object $funcResult = null;
${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
- if ($funcResult == null) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
- }
+ $assignResult
"""
}
- // If the function can return null, we do an extra check to make sure our null bit is still set
- // correctly.
- val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
- s"${ev.isNull} = ${ev.value} == null;"
- } else {
- ""
- }
-
val code = s"""
${obj.code}
boolean ${ev.isNull} = true;
@@ -254,7 +255,6 @@ case class Invoke(
if (!${ev.isNull}) {
$evaluate
}
- $postNullCheck
}
"""
ev.copy(code = code)