aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala (renamed from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala)1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala20
3 files changed, 21 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 013027b199..4d6c1c2651 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -186,7 +186,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
val baseValue = value.asInstanceOf[Seq[_]]
- val index = ordinal.asInstanceOf[Int]
+ val index = ordinal.asInstanceOf[Number].intValue()
if (index >= baseValue.size || index < 0) {
null
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 72fdcebb4c..e0bf07ed18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.types._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 2b0f4618b2..b80911e725 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -26,6 +26,26 @@ import org.apache.spark.unsafe.types.UTF8String
class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
+ /**
+ * Runs through the testFunc for all integral data types.
+ *
+ * @param testFunc a test function that accepts a conversion function to convert an integer
+ * into another data type.
+ */
+ private def testIntegralDataTypes(testFunc: (Int => Any) => Unit): Unit = {
+ testFunc(_.toByte)
+ testFunc(_.toShort)
+ testFunc(identity)
+ testFunc(_.toLong)
+ }
+
+ test("GetArrayItem") {
+ testIntegralDataTypes { convert =>
+ val array = Literal.create(Seq("a", "b"), ArrayType(StringType))
+ checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b")
+ }
+ }
+
test("CreateStruct") {
val row = InternalRow(1, 2, 3)
val c1 = 'a.int.at(0).as("a")