aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-07-08 22:05:58 -0700
committerReynold Xin <rxin@databricks.com>2015-07-08 22:05:58 -0700
commit768907eb7b0d3c11a420ef281454e36167011c89 (patch)
tree7ef8a6fa83aea6d71c12546abbe431c3ec400404
parentaba5784dab24c03ddad89f7a1b5d3d0dc8d109be (diff)
downloadspark-768907eb7b0d3c11a420ef281454e36167011c89.tar.gz
spark-768907eb7b0d3c11a420ef281454e36167011c89.tar.bz2
spark-768907eb7b0d3c11a420ef281454e36167011c89.zip
[SPARK-8926][SQL] Good errors for ExpectsInputType expressions
For example: `cannot resolve 'testfunction(null)' due to data type mismatch: argument 1 is expected to be of type int, however, null is of type datetype.` Author: Michael Armbrust <michael@databricks.com> Closes #7303 from marmbrus/expectsTypeErrors and squashes the following commits: c654a0e [Michael Armbrust] fix udts and make errors pretty 137160d [Michael Armbrust] style 5428fda [Michael Armbrust] style 10fac82 [Michael Armbrust] [SPARK-8926][SQL] Good errors for ExpectsInputType expressions
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala167
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala126
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala2
13 files changed, 256 insertions, 143 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 5367b7f330..8cb71995eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -702,11 +702,19 @@ object HiveTypeCoercion {
@Nullable val ret: Expression = (inType, expectedType) match {
// If the expected type is already a parent of the input type, no need to cast.
- case _ if expectedType.isParentOf(inType) => e
+ case _ if expectedType.isSameType(inType) => e
// Cast null type (usually from null literals) into target types
case (NullType, target) => Cast(e, target.defaultConcreteType)
+ // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is
+ // already a number, leave it as is.
+ case (_: NumericType, NumericType) => e
+
+ // If the function accepts any numeric type and the input is a string, we follow the hive
+ // convention and cast that input into a double
+ case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)
+
// Implicit cast among numeric types
// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to unlimited precision decimal.
@@ -732,7 +740,7 @@ object HiveTypeCoercion {
// First see if we can find our input type in the type collection. If we can, then just
// use the current expression; otherwise, find the first one we can implicitly cast.
case (_, TypeCollection(types)) =>
- if (types.exists(_.isParentOf(inType))) {
+ if (types.exists(_.isSameType(inType))) {
e
} else {
types.flatMap(implicitCast(e, _)).headOption.orNull
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
index 916e30154d..986cc09499 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
@@ -37,7 +37,16 @@ trait ExpectsInputTypes { self: Expression =>
def inputTypes: Seq[AbstractDataType]
override def checkInputDataTypes(): TypeCheckResult = {
- // TODO: implement proper type checking.
- TypeCheckResult.TypeCheckSuccess
+ val mismatches = children.zip(inputTypes).zipWithIndex.collect {
+ case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
+ s"Argument ${idx + 1} is expected to be of type ${expected.simpleString}, " +
+ s"however, ${child.prettyString} is of type ${child.dataType.simpleString}."
+ }
+
+ if (mismatches.isEmpty) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(mismatches.mkString(" "))
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index fb1b47e946..ad75fa2e31 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -34,9 +34,16 @@ private[sql] abstract class AbstractDataType {
private[sql] def defaultConcreteType: DataType
/**
- * Returns true if this data type is a parent of the `childCandidate`.
+ * Returns true if this data type is the same type as `other`. This is different that equality
+ * as equality will also consider data type parametrization, such as decimal precision.
*/
- private[sql] def isParentOf(childCandidate: DataType): Boolean
+ private[sql] def isSameType(other: DataType): Boolean
+
+ /**
+ * Returns true if `other` is an acceptable input type for a function that expectes this,
+ * possibly abstract, DataType.
+ */
+ private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
/** Readable string representation for the type. */
private[sql] def simpleString: String
@@ -58,11 +65,14 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
require(types.nonEmpty, s"TypeCollection ($types) cannot be empty")
- private[sql] override def defaultConcreteType: DataType = types.head.defaultConcreteType
+ override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType
+
+ override private[sql] def isSameType(other: DataType): Boolean = false
- private[sql] override def isParentOf(childCandidate: DataType): Boolean = false
+ override private[sql] def acceptsType(other: DataType): Boolean =
+ types.exists(_.isSameType(other))
- private[sql] override def simpleString: String = {
+ override private[sql] def simpleString: String = {
types.map(_.simpleString).mkString("(", " or ", ")")
}
}
@@ -108,7 +118,7 @@ abstract class NumericType extends AtomicType {
}
-private[sql] object NumericType {
+private[sql] object NumericType extends AbstractDataType {
/**
* Enables matching against NumericType for expressions:
* {{{
@@ -117,6 +127,14 @@ private[sql] object NumericType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
+
+ override private[sql] def defaultConcreteType: DataType = DoubleType
+
+ override private[sql] def simpleString: String = "numeric"
+
+ override private[sql] def isSameType(other: DataType): Boolean = false
+
+ override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 43413ec761..76ca7a84c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -26,13 +26,13 @@ object ArrayType extends AbstractDataType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true)
- private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
+ override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
- private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
- childCandidate.isInstanceOf[ArrayType]
+ override private[sql] def isSameType(other: DataType): Boolean = {
+ other.isInstanceOf[ArrayType]
}
- private[sql] override def simpleString: String = "array"
+ override private[sql] def simpleString: String = "array"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index a4c2da8e05..57718228e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -76,9 +76,9 @@ abstract class DataType extends AbstractDataType {
*/
private[spark] def asNullable: DataType
- private[sql] override def defaultConcreteType: DataType = this
+ override private[sql] def defaultConcreteType: DataType = this
- private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate
+ override private[sql] def isSameType(other: DataType): Boolean = this == other
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 127b16ff85..a1cafeab17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -84,13 +84,13 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
/** Extra factory methods and pattern matchers for Decimals */
object DecimalType extends AbstractDataType {
- private[sql] override def defaultConcreteType: DataType = Unlimited
+ override private[sql] def defaultConcreteType: DataType = Unlimited
- private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
- childCandidate.isInstanceOf[DecimalType]
+ override private[sql] def isSameType(other: DataType): Boolean = {
+ other.isInstanceOf[DecimalType]
}
- private[sql] override def simpleString: String = "decimal"
+ override private[sql] def simpleString: String = "decimal"
val Unlimited: DecimalType = DecimalType(None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index 868dea13d9..ddead10bc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -69,13 +69,13 @@ case class MapType(
object MapType extends AbstractDataType {
- private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType)
+ override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)
- private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
- childCandidate.isInstanceOf[MapType]
+ override private[sql] def isSameType(other: DataType): Boolean = {
+ other.isInstanceOf[MapType]
}
- private[sql] override def simpleString: String = "map"
+ override private[sql] def simpleString: String = "map"
/**
* Construct a [[MapType]] object with the given key type and value type.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index e2d3f53f7d..e0b8ff9178 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -303,13 +303,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
object StructType extends AbstractDataType {
- private[sql] override def defaultConcreteType: DataType = new StructType
+ override private[sql] def defaultConcreteType: DataType = new StructType
- private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
- childCandidate.isInstanceOf[StructType]
+ override private[sql] def isSameType(other: DataType): Boolean = {
+ other.isInstanceOf[StructType]
}
- private[sql] override def simpleString: String = "struct"
+ override private[sql] def simpleString: String = "struct"
private[sql] def fromString(raw: String): StructType = DataType.fromString(raw) match {
case t: StructType => t
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 6b20505c60..e47cfb4833 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -77,5 +77,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
* For UDT, asNullable will not change the nullability of its internal sqlType and just returns
* itself.
*/
- private[spark] override def asNullable: UserDefinedType[UserType] = this
+ override private[spark] def asNullable: UserDefinedType[UserType] = this
+
+ override private[sql] def acceptsType(dataType: DataType) =
+ this.getClass == dataType.getClass
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
new file mode 100644
index 0000000000..73236c3acb
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.{InternalRow, SimpleCatalystConf}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+
+case class TestFunction(
+ children: Seq[Expression],
+ inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes {
+ override def nullable: Boolean = true
+ override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+ override def dataType: DataType = StringType
+}
+
+case class UnresolvedTestPlan() extends LeafNode {
+ override lazy val resolved = false
+ override def output: Seq[Attribute] = Nil
+}
+
+class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
+ import AnalysisSuite._
+
+ def errorTest(
+ name: String,
+ plan: LogicalPlan,
+ errorMessages: Seq[String],
+ caseSensitive: Boolean = true): Unit = {
+ test(name) {
+ val error = intercept[AnalysisException] {
+ if (caseSensitive) {
+ caseSensitiveAnalyze(plan)
+ } else {
+ caseInsensitiveAnalyze(plan)
+ }
+ }
+
+ errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase))
+ }
+ }
+
+ val dateLit = Literal.create(null, DateType)
+
+ errorTest(
+ "single invalid type, single arg",
+ testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)),
+ "cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" ::
+ "null is of type date" ::Nil)
+
+ errorTest(
+ "single invalid type, second arg",
+ testRelation.select(
+ TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)),
+ "cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" ::
+ "null is of type date" ::Nil)
+
+ errorTest(
+ "multiple invalid type",
+ testRelation.select(
+ TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)),
+ "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" ::
+ "expected to be of type int" :: "null is of type date" ::Nil)
+
+ errorTest(
+ "unresolved window function",
+ testRelation2.select(
+ WindowExpression(
+ UnresolvedWindowFunction(
+ "lead",
+ UnresolvedAttribute("c") :: Nil),
+ WindowSpecDefinition(
+ UnresolvedAttribute("a") :: Nil,
+ SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
+ UnspecifiedFrame)).as('window)),
+ "lead" :: "window functions currently requires a HiveContext" :: Nil)
+
+ errorTest(
+ "too many generators",
+ listRelation.select(Explode('list).as('a), Explode('list).as('b)),
+ "only one generator" :: "explode" :: Nil)
+
+ errorTest(
+ "unresolved attributes",
+ testRelation.select('abcd),
+ "cannot resolve" :: "abcd" :: Nil)
+
+ errorTest(
+ "bad casts",
+ testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
+ "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
+
+ errorTest(
+ "non-boolean filters",
+ testRelation.where(Literal(1)),
+ "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil)
+
+ errorTest(
+ "missing group by",
+ testRelation2.groupBy('a)('b),
+ "'b'" :: "group by" :: Nil
+ )
+
+ errorTest(
+ "ambiguous field",
+ nestedRelation.select($"top.duplicateField"),
+ "Ambiguous reference to fields" :: "duplicateField" :: Nil,
+ caseSensitive = false)
+
+ errorTest(
+ "ambiguous field due to case insensitivity",
+ nestedRelation.select($"top.differentCase"),
+ "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil,
+ caseSensitive = false)
+
+ errorTest(
+ "missing field",
+ nestedRelation2.select($"top.c"),
+ "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil,
+ caseSensitive = false)
+
+ errorTest(
+ "catch all unresolved plan",
+ UnresolvedTestPlan(),
+ "unresolved" :: Nil)
+
+
+ test("SPARK-6452 regression test") {
+ // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
+ val plan =
+ Aggregate(
+ Nil,
+ Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
+ LocalRelation(
+ AttributeReference("a", IntegerType)(exprId = ExprId(2))))
+
+ assert(plan.resolved)
+
+ val message = intercept[AnalysisException] {
+ caseSensitiveAnalyze(plan)
+ }.getMessage
+
+ assert(message.contains("resolved attribute(s) a#1 missing from a#2"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 77ca080f36..58df1de983 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
+object AnalysisSuite {
val caseSensitiveConf = new SimpleCatalystConf(true)
val caseInsensitiveConf = new SimpleCatalystConf(false)
@@ -61,25 +61,28 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
- StructField("duplicateField", StringType) ::
- StructField("differentCase", StringType) ::
- StructField("differentcase", StringType) :: Nil
+ StructField("duplicateField", StringType) ::
+ StructField("differentCase", StringType) ::
+ StructField("differentcase", StringType) :: Nil
))())
val nestedRelation2 = LocalRelation(
AttributeReference("top", StructType(
StructField("aField", StringType) ::
- StructField("bField", StringType) ::
- StructField("cField", StringType) :: Nil
+ StructField("bField", StringType) ::
+ StructField("cField", StringType) :: Nil
))())
val listRelation = LocalRelation(
AttributeReference("list", ArrayType(IntegerType))())
- before {
- caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
- caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
- }
+ caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
+ caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
+}
+
+
+class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
+ import AnalysisSuite._
test("union project *") {
val plan = (1 to 100)
@@ -149,91 +152,6 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
}
- def errorTest(
- name: String,
- plan: LogicalPlan,
- errorMessages: Seq[String],
- caseSensitive: Boolean = true): Unit = {
- test(name) {
- val error = intercept[AnalysisException] {
- if (caseSensitive) {
- caseSensitiveAnalyze(plan)
- } else {
- caseInsensitiveAnalyze(plan)
- }
- }
-
- errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase))
- }
- }
-
- errorTest(
- "unresolved window function",
- testRelation2.select(
- WindowExpression(
- UnresolvedWindowFunction(
- "lead",
- UnresolvedAttribute("c") :: Nil),
- WindowSpecDefinition(
- UnresolvedAttribute("a") :: Nil,
- SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
- UnspecifiedFrame)).as('window)),
- "lead" :: "window functions currently requires a HiveContext" :: Nil)
-
- errorTest(
- "too many generators",
- listRelation.select(Explode('list).as('a), Explode('list).as('b)),
- "only one generator" :: "explode" :: Nil)
-
- errorTest(
- "unresolved attributes",
- testRelation.select('abcd),
- "cannot resolve" :: "abcd" :: Nil)
-
- errorTest(
- "bad casts",
- testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
- "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
-
- errorTest(
- "non-boolean filters",
- testRelation.where(Literal(1)),
- "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil)
-
- errorTest(
- "missing group by",
- testRelation2.groupBy('a)('b),
- "'b'" :: "group by" :: Nil
- )
-
- errorTest(
- "ambiguous field",
- nestedRelation.select($"top.duplicateField"),
- "Ambiguous reference to fields" :: "duplicateField" :: Nil,
- caseSensitive = false)
-
- errorTest(
- "ambiguous field due to case insensitivity",
- nestedRelation.select($"top.differentCase"),
- "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil,
- caseSensitive = false)
-
- errorTest(
- "missing field",
- nestedRelation2.select($"top.c"),
- "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil,
- caseSensitive = false)
-
- case class UnresolvedTestPlan() extends LeafNode {
- override lazy val resolved = false
- override def output: Seq[Attribute] = Nil
- }
-
- errorTest(
- "catch all unresolved plan",
- UnresolvedTestPlan(),
- "unresolved" :: Nil)
-
test("divide should be casted into fractional types") {
val testRelation2 = LocalRelation(
@@ -258,22 +176,4 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
assert(pl(3).dataType == DecimalType.Unlimited)
assert(pl(4).dataType == DoubleType)
}
-
- test("SPARK-6452 regression test") {
- // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
- val plan =
- Aggregate(
- Nil,
- Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
- LocalRelation(
- AttributeReference("a", IntegerType)(exprId = ExprId(2))))
-
- assert(plan.resolved)
-
- val message = intercept[AnalysisException] {
- caseSensitiveAnalyze(plan)
- }.getMessage
-
- assert(message.contains("resolved attribute(s) a#1 missing from a#2"))
- }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 93db33d44e..6e3aa0eebe 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -77,6 +77,14 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2))
shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2))
shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2))
+
+ shouldCast(StringType, NumericType, DoubleType)
+
+ // NumericType should not be changed when function accepts any of them.
+ Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
+ DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe =>
+ shouldCast(tpe, NumericType, tpe)
+ }
}
test("ineligible implicit type cast") {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 439d8cab5f..bbc39b892b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -359,7 +359,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
hiveconf.set(key, value)
}
- private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = {
+ override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = {
setConf(entry.key, entry.stringConverter(value))
}