aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2016-02-23 11:20:27 -0800
committerMichael Armbrust <michael@databricks.com>2016-02-23 11:20:27 -0800
commitc5bfe5d2a22e0e66b27aa28a19785c3aac5d9f2e (patch)
tree0a624a9ebbacd2b2010ac872b8da60c44e9fc79f /sql
parent9f4263392e492b5bc0acecec2712438ff9a257b7 (diff)
downloadspark-c5bfe5d2a22e0e66b27aa28a19785c3aac5d9f2e.tar.gz
spark-c5bfe5d2a22e0e66b27aa28a19785c3aac5d9f2e.tar.bz2
spark-c5bfe5d2a22e0e66b27aa28a19785c3aac5d9f2e.zip
[SPARK-13440][SQL] ObjectType should accept any ObjectType, If should not care about nullability
The type checking functions of `If` and `UnwrapOption` are fixed to eliminate spurious failures. `UnwrapOption` was checking for an input of `ObjectType` but `ObjectType`'s accept function was hard coded to return `false`. `If`'s type check was returning a false negative in the case that the two options differed only by nullability. Tests added: - an end-to-end regression test is added to `DatasetSuite` for the reported failure. - all the unit tests in `ExpressionEncoderSuite` are augmented to also confirm successful analysis. These tests are actually what pointed out the additional issues with `If` resolution. Author: Michael Armbrust <michael@databricks.com> Closes #11316 from marmbrus/datasetOptions.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala8
6 files changed, 43 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 200c6a05df..c3e9fa33e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -34,7 +34,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
if (predicate.dataType != BooleanType) {
TypeCheckResult.TypeCheckFailure(
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
- } else if (trueValue.dataType != falseValue.dataType) {
+ } else if (trueValue.dataType.asNullable != falseValue.dataType.asNullable) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index d3b5879777..f9f1f88cec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -45,6 +45,9 @@ object LocalRelation {
case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
extends LeafNode with analysis.MultiInstanceRelation {
+ // A local relation must have resolved output.
+ require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.")
+
/**
* Returns an identical copy of this relation with new exprIds for all attributes. Different
* attributes are required when a relation is going to be included multiple times in the same
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
index fca0b799eb..06ee0fbfe9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
@@ -23,8 +23,10 @@ private[sql] object ObjectType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType =
throw new UnsupportedOperationException("null literals can't be casted to ObjectType")
- // No casting or comparison is supported.
- override private[sql] def acceptsType(other: DataType): Boolean = false
+ override private[sql] def acceptsType(other: DataType): Boolean = other match {
+ case ObjectType(_) => true
+ case _ => false
+ }
override private[sql] def simpleString: String = "Object"
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index e0a95ba8bb..ef825e6062 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -60,7 +60,18 @@ trait AnalysisTest extends PlanTest {
inputPlan: LogicalPlan,
caseSensitive: Boolean = true): Unit = {
val analyzer = getAnalyzer(caseSensitive)
- analyzer.checkAnalysis(analyzer.execute(inputPlan))
+ val analysisAttempt = analyzer.execute(inputPlan)
+ try analyzer.checkAnalysis(analysisAttempt) catch {
+ case a: AnalysisException =>
+ fail(
+ s"""
+ |Failed to Analyze Plan
+ |$inputPlan
+ |
+ |Partial Analysis
+ |$analysisAttempt
+ """.stripMargin, a)
+ }
}
protected def assertAnalysisError(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index e00060f9b6..cca320fae9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -23,12 +23,14 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.analysis.AnalysisTest
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{ArrayType, StructType}
+import org.apache.spark.sql.types.{ArrayType, ObjectType, StructType}
case class RepeatedStruct(s: Seq[PrimitiveData])
@@ -74,7 +76,7 @@ class JavaSerializable(val value: Int) extends Serializable {
}
}
-class ExpressionEncoderSuite extends SparkFunSuite {
+class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)
implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
@@ -305,6 +307,15 @@ class ExpressionEncoderSuite extends SparkFunSuite {
""".stripMargin, e)
}
+ // Test the correct resolution of serialization / deserialization.
+ val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))()
+ val inputPlan = LocalRelation(attr)
+ val plan =
+ Project(Alias(encoder.fromRowExpression, "obj")() :: Nil,
+ Project(encoder.namedExpressions,
+ inputPlan))
+ assertAnalysisSuccess(plan)
+
val isCorrect = (input, convertedBack) match {
case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2)
case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 498f007081..14fc37b64a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -613,6 +613,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
" - Input schema: struct<a:string,b:int>\n" +
" - Target schema: struct<_1:string>")
}
+
+ test("SPARK-13440: Resolving option fields") {
+ val df = Seq(1, 2, 3).toDS()
+ val ds = df.as[Option[Int]]
+ checkAnswer(
+ ds.filter(_ => true),
+ Some(1), Some(2), Some(3))
+ }
}
class OuterClass extends Serializable {