diff options
author | Kazuaki Ishizaki <ishizaki@jp.ibm.com> | 2017-03-24 12:57:56 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-03-24 12:57:56 +0800 |
commit | bb823ca4b479a00030c4919c2d857d254b2a44d8 (patch) | |
tree | 8d201a9819151b96ad9f5e991730a093b72092b6 /sql/core | |
parent | d27daa54bd341b29737a6352d9a1055151248ae7 (diff) | |
download | spark-bb823ca4b479a00030c4919c2d857d254b2a44d8.tar.gz spark-bb823ca4b479a00030c4919c2d857d254b2a44d8.tar.bz2 spark-bb823ca4b479a00030c4919c2d857d254b2a44d8.zip |
[SPARK-19959][SQL] Fix to throw NullPointerException in df[java.lang.Long].collect
## What changes were proposed in this pull request?
This PR fixes `NullPointerException` in the generated code by Catalyst. When we run the following code, we get the following `NullPointerException`. This is because there is no null checks for `inputadapter_value` while `java.lang.Long inputadapter_value` at Line 30 may have `null`.
This happen when a type of DataFrame is nullable primitive type such as `java.lang.Long` and the wholestage codegen is used. While the physical plan keeps `nullable=true` in `input[0, java.lang.Long, true].longValue`, `BoundReference.doGenCode` ignores `nullable=true`. Thus, nullcheck code will not be generated and `NullPointerException` will occur.
This PR checks the nullability and correctly generates nullcheck if needed.
```java
sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF.collect
```
```java
Caused by: java.lang.NullPointerException
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(generated.java:37)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:393)
...
```
Generated code without this PR
```java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private UnsafeRow serializefromobject_result;
/* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 012 */
/* 013 */ public GeneratedIterator(Object[] references) {
/* 014 */ this.references = references;
/* 015 */ }
/* 016 */
/* 017 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 018 */ partitionIndex = index;
/* 019 */ this.inputs = inputs;
/* 020 */ inputadapter_input = inputs[0];
/* 021 */ serializefromobject_result = new UnsafeRow(1);
/* 022 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 023 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 024 */
/* 025 */ }
/* 026 */
/* 027 */ protected void processNext() throws java.io.IOException {
/* 028 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 029 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 030 */ java.lang.Long inputadapter_value = (java.lang.Long)inputadapter_row.get(0, null);
/* 031 */
/* 032 */ boolean serializefromobject_isNull = true;
/* 033 */ long serializefromobject_value = -1L;
/* 034 */ if (!false) {
/* 035 */ serializefromobject_isNull = false;
/* 036 */ if (!serializefromobject_isNull) {
/* 037 */ serializefromobject_value = inputadapter_value.longValue();
/* 038 */ }
/* 039 */
/* 040 */ }
/* 041 */ serializefromobject_rowWriter.zeroOutNullBytes();
/* 042 */
/* 043 */ if (serializefromobject_isNull) {
/* 044 */ serializefromobject_rowWriter.setNullAt(0);
/* 045 */ } else {
/* 046 */ serializefromobject_rowWriter.write(0, serializefromobject_value);
/* 047 */ }
/* 048 */ append(serializefromobject_result);
/* 049 */ if (shouldStop()) return;
/* 050 */ }
/* 051 */ }
/* 052 */ }
```
Generated code with this PR
```java
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */ private Object[] references;
/* 007 */ private scala.collection.Iterator[] inputs;
/* 008 */ private scala.collection.Iterator inputadapter_input;
/* 009 */ private UnsafeRow serializefromobject_result;
/* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 011 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 012 */
/* 013 */ public GeneratedIterator(Object[] references) {
/* 014 */ this.references = references;
/* 015 */ }
/* 016 */
/* 017 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 018 */ partitionIndex = index;
/* 019 */ this.inputs = inputs;
/* 020 */ inputadapter_input = inputs[0];
/* 021 */ serializefromobject_result = new UnsafeRow(1);
/* 022 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 0);
/* 023 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 024 */
/* 025 */ }
/* 026 */
/* 027 */ protected void processNext() throws java.io.IOException {
/* 028 */ while (inputadapter_input.hasNext() && !stopEarly()) {
/* 029 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 030 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 031 */ java.lang.Long inputadapter_value = inputadapter_isNull ? null : ((java.lang.Long)inputadapter_row.get(0, null));
/* 032 */
/* 033 */ boolean serializefromobject_isNull = true;
/* 034 */ long serializefromobject_value = -1L;
/* 035 */ if (!inputadapter_isNull) {
/* 036 */ serializefromobject_isNull = false;
/* 037 */ if (!serializefromobject_isNull) {
/* 038 */ serializefromobject_value = inputadapter_value.longValue();
/* 039 */ }
/* 040 */
/* 041 */ }
/* 042 */ serializefromobject_rowWriter.zeroOutNullBytes();
/* 043 */
/* 044 */ if (serializefromobject_isNull) {
/* 045 */ serializefromobject_rowWriter.setNullAt(0);
/* 046 */ } else {
/* 047 */ serializefromobject_rowWriter.write(0, serializefromobject_value);
/* 048 */ }
/* 049 */ append(serializefromobject_result);
/* 050 */ if (shouldStop()) return;
/* 051 */ }
/* 052 */ }
/* 053 */ }
```
## How was this patch tested?
Added new test suites in `DataFrameSuites`
Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>
Closes #17302 from kiszk/SPARK-19959.
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 094efbaead..63094d1b61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -51,4 +51,15 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } + + test("SPARK-19959: df[java.lang.Long].collect includes null throws NullPointerException") { + checkAnswer(sparkContext.parallelize(Seq[java.lang.Integer](0, null, 2), 1).toDF, + Seq(Row(0), Row(null), Row(2))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF, + Seq(Row(0L), Row(null), Row(2L))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Float](0.0F, null, 2.0F), 1).toDF, + Seq(Row(0.0F), Row(null), Row(2.0F))) + checkAnswer(sparkContext.parallelize(Seq[java.lang.Double](0.0D, null, 2.0D), 1).toDF, + Seq(Row(0.0D), Row(null), Row(2.0D))) + } } |