aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-04 11:55:04 -0700
committerReynold Xin <rxin@databricks.com>2015-07-04 11:55:04 -0700
commit347cab85cd924ffd326f3d1367b3b156ee08052d (patch)
treedd63bf43d36f9ec648cc165d501ddd7c6346746c /sql
parent48f7aed686afde70a6f0802c6cb37b0cad0509f1 (diff)
downloadspark-347cab85cd924ffd326f3d1367b3b156ee08052d.tar.gz
spark-347cab85cd924ffd326f3d1367b3b156ee08052d.tar.bz2
spark-347cab85cd924ffd326f3d1367b3b156ee08052d.zip
[SQL] More unit tests for implicit type cast & add simpleString to AbstractDataType.
Author: Reynold Xin <rxin@databricks.com> Closes #7221 from rxin/implicit-cast-tests and squashes the following commits: 64b13bd [Reynold Xin] Fixed a bug .. 489b732 [Reynold Xin] [SQL] More unit tests for implicit type cast & add simpleString to AbstractDataType.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala25
7 files changed, 42 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 583338da57..476ac2b7cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -40,7 +40,7 @@ trait CheckAnalysis {
def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
case e: Generator => true
- }).length >= 1
+ }).nonEmpty
}
def checkAnalysis(plan: LogicalPlan): Unit = {
@@ -85,12 +85,12 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
- case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty =>
+ case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.")
- case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK
+ case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index e5dc99fb62..ffefb0e783 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -37,6 +37,9 @@ private[sql] abstract class AbstractDataType {
* Returns true if this data type is a parent of the `childCandidate`.
*/
private[sql] def isParentOf(childCandidate: DataType): Boolean
+
+ /** Readable string representation for the type. */
+ private[sql] def simpleString: String
}
@@ -56,6 +59,10 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst
private[sql] override def defaultConcreteType: DataType = types.head
private[sql] override def isParentOf(childCandidate: DataType): Boolean = false
+
+ private[sql] override def simpleString: String = {
+ types.map(_.simpleString).mkString("(", " or ", ")")
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 8ea6cb14c3..43413ec761 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -31,6 +31,8 @@ object ArrayType extends AbstractDataType {
private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
childCandidate.isInstanceOf[ArrayType]
}
+
+ private[sql] override def simpleString: String = "array"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 434fc037aa..127b16ff85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -90,6 +90,8 @@ object DecimalType extends AbstractDataType {
childCandidate.isInstanceOf[DecimalType]
}
+ private[sql] override def simpleString: String = "decimal"
+
val Unlimited: DecimalType = DecimalType(None)
private[sql] object Fixed {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index 2b25617ec6..868dea13d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -75,6 +75,8 @@ object MapType extends AbstractDataType {
childCandidate.isInstanceOf[MapType]
}
+ private[sql] override def simpleString: String = "map"
+
/**
* Construct a [[MapType]] object with the given key type and value type.
* The `valueContainsNull` is true.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 7e77b77e73..3b17566d54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -309,6 +309,8 @@ object StructType extends AbstractDataType {
childCandidate.isInstanceOf[StructType]
}
+ private[sql] override def simpleString: String = "struct"
+
def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
def apply(fields: java.util.List[StructField]): StructType = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 60e727c6c7..67d05ab536 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.types._
class HiveTypeCoercionSuite extends PlanTest {
- test("implicit type cast") {
+ test("eligible implicit type cast") {
def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
assert(got.map(_.dataType) == Option(expected),
@@ -68,6 +68,29 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
+
+ shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
+ shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
+ }
+
+ test("ineligible implicit type cast") {
+ def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
+ val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
+ assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got")
+ }
+
+ shouldNotCast(IntegerType, DateType)
+ shouldNotCast(IntegerType, TimestampType)
+ shouldNotCast(LongType, DateType)
+ shouldNotCast(LongType, TimestampType)
+ shouldNotCast(DecimalType.Unlimited, DateType)
+ shouldNotCast(DecimalType.Unlimited, TimestampType)
+
+ shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
+
+ shouldNotCast(IntegerType, ArrayType)
+ shouldNotCast(IntegerType, MapType)
+ shouldNotCast(IntegerType, StructType)
}
test("tightest common bound for types") {