aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-07-31 23:11:22 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-31 23:11:22 -0700
commit67ad4e21fc68336b0ad6f9a363fb5ebb51f592bf (patch)
treede2371bb74fec1fa6b93391809e30bbad202ca3f /sql
parent3320b0ba262159c0c7209ce39b353c93c597077d (diff)
downloadspark-67ad4e21fc68336b0ad6f9a363fb5ebb51f592bf.tar.gz
spark-67ad4e21fc68336b0ad6f9a363fb5ebb51f592bf.tar.bz2
spark-67ad4e21fc68336b0ad6f9a363fb5ebb51f592bf.zip
[SPARK-8232] [SQL] Add sort_array support
Add expression `sort_array` support. Author: Cheng Hao <hao.cheng@intel.com> This patch had conflicts when merged, resolved by Committer: Davies Liu <davies.liu@gmail.com> Closes #7581 from chenghao-intel/sort_array and squashes the following commits: 664c960 [Cheng Hao] update the sort_array by using the ArrayData 276d2d5 [Cheng Hao] add empty line 0edab9c [Cheng Hao] Add asending/descending support for sort_array 80fc0f8 [Cheng Hao] Add type checking a42b678 [Cheng Hao] Add sort_array support
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala80
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala51
5 files changed, 166 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ee44cbcba6..6e144518bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -233,6 +233,7 @@ object FunctionRegistry {
// collection functions
expression[Size]("size"),
+ expression[SortArray]("sort_array"),
// misc functions
expression[Crc32]("crc32"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 1a00dbc254..0a530596a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -16,7 +16,10 @@
*/
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
+import java.util.Comparator
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types._
/**
@@ -39,3 +42,78 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;")
}
}
+
+/**
+ * Sorts the input array in ascending / descending order according to the natural ordering of
+ * the array elements and returns it.
+ */
+case class SortArray(base: Expression, ascendingOrder: Expression)
+ extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+
+ def this(e: Expression) = this(e, Literal(true))
+
+ override def left: Expression = base
+ override def right: Expression = ascendingOrder
+ override def dataType: DataType = base.dataType
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
+
+ override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
+ case _ @ ArrayType(n: AtomicType, _) => TypeCheckResult.TypeCheckSuccess
+ case _ @ ArrayType(n, _) => TypeCheckResult.TypeCheckFailure(
+ s"Type $n is not the AtomicType, we can not perform the ordering operations")
+ case other =>
+ TypeCheckResult.TypeCheckFailure(s"ArrayType(AtomicType) is expected, but we got $other")
+ }
+
+ @transient
+ private lazy val lt = {
+ val ordering = base.dataType match {
+ case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
+ }
+
+ new Comparator[Any]() {
+ override def compare(o1: Any, o2: Any): Int = {
+ if (o1 == null && o2 == null) {
+ 0
+ } else if (o1 == null) {
+ -1
+ } else if (o2 == null) {
+ 1
+ } else {
+ ordering.compare(o1, o2)
+ }
+ }
+ }
+ }
+
+ @transient
+ private lazy val gt = {
+ val ordering = base.dataType match {
+ case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
+ }
+
+ new Comparator[Any]() {
+ override def compare(o1: Any, o2: Any): Int = {
+ if (o1 == null && o2 == null) {
+ 0
+ } else if (o1 == null) {
+ 1
+ } else if (o2 == null) {
+ -1
+ } else {
+ -ordering.compare(o1, o2)
+ }
+ }
+ }
+ }
+
+ override def nullSafeEval(array: Any, ascending: Any): Any = {
+ val data = array.asInstanceOf[ArrayData].toArray().asInstanceOf[Array[AnyRef]]
+ java.util.Arrays.sort(
+ data,
+ if (ascending.asInstanceOf[Boolean]) lt else gt)
+ new GenericArrayData(data.asInstanceOf[Array[Any]])
+ }
+
+ override def prettyName: String = "sort_array"
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index 28c41b5716..2c7e85c446 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -43,4 +43,26 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}
+
+ test("Sort Array") {
+ val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
+ val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
+ val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
+
+ checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
+ checkEvaluation(new SortArray(a1), Seq[Integer]())
+ checkEvaluation(new SortArray(a2), Seq("a", "b"))
+ checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
+ checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
+ checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
+ checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
+ checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
+ checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
+ checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
+ checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
+ checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))
+
+ checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 57bb00a741..3c9421f5cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2223,19 +2223,30 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Returns length of array or map
+ * Returns length of array or map.
+ *
* @group collection_funcs
* @since 1.5.0
*/
- def size(columnName: String): Column = size(Column(columnName))
+ def size(e: Column): Column = Size(e.expr)
/**
- * Returns length of array or map
+ * Sorts the input array for the given column in ascending order,
+ * according to the natural ordering of the array elements.
+ *
* @group collection_funcs
* @since 1.5.0
*/
- def size(column: Column): Column = Size(column.expr)
+ def sort_array(e: Column): Column = sort_array(e, true)
+ /**
+ * Sorts the input array for the given column in ascending / descending order,
+ * according to the natural ordering of the array elements.
+ *
+ * @group collection_funcs
+ * @since 1.5.0
+ */
+ def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr)
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 1baec5d376..46921d1425 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -267,6 +267,53 @@ class DataFrameFunctionsSuite extends QueryTest {
)
}
+ test("sort_array function") {
+ val df = Seq(
+ (Array[Int](2, 1, 3), Array("b", "c", "a")),
+ (Array[Int](), Array[String]()),
+ (null, null)
+ ).toDF("a", "b")
+ checkAnswer(
+ df.select(sort_array($"a"), sort_array($"b")),
+ Seq(
+ Row(Seq(1, 2, 3), Seq("a", "b", "c")),
+ Row(Seq[Int](), Seq[String]()),
+ Row(null, null))
+ )
+ checkAnswer(
+ df.select(sort_array($"a", false), sort_array($"b", false)),
+ Seq(
+ Row(Seq(3, 2, 1), Seq("c", "b", "a")),
+ Row(Seq[Int](), Seq[String]()),
+ Row(null, null))
+ )
+ checkAnswer(
+ df.selectExpr("sort_array(a)", "sort_array(b)"),
+ Seq(
+ Row(Seq(1, 2, 3), Seq("a", "b", "c")),
+ Row(Seq[Int](), Seq[String]()),
+ Row(null, null))
+ )
+ checkAnswer(
+ df.selectExpr("sort_array(a, true)", "sort_array(b, false)"),
+ Seq(
+ Row(Seq(1, 2, 3), Seq("c", "b", "a")),
+ Row(Seq[Int](), Seq[String]()),
+ Row(null, null))
+ )
+
+ val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b")
+ assert(intercept[AnalysisException] {
+ df2.selectExpr("sort_array(a)").collect()
+ }.getMessage().contains("Type ArrayType(IntegerType,false) is not the AtomicType, " +
+ "we can not perform the ordering operations"))
+
+ val df3 = Seq(("xxx", "x")).toDF("a", "b")
+ assert(intercept[AnalysisException] {
+ df3.selectExpr("sort_array(a)").collect()
+ }.getMessage().contains("ArrayType(AtomicType) is expected, but we got StringType"))
+ }
+
test("array size function") {
val df = Seq(
(Array[Int](1, 2), "x"),
@@ -274,7 +321,7 @@ class DataFrameFunctionsSuite extends QueryTest {
(Array[Int](1, 2, 3), "z")
).toDF("a", "b")
checkAnswer(
- df.select(size("a")),
+ df.select(size($"a")),
Seq(Row(2), Row(0), Row(3))
)
checkAnswer(
@@ -290,7 +337,7 @@ class DataFrameFunctionsSuite extends QueryTest {
(Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z")
).toDF("a", "b")
checkAnswer(
- df.select(size("a")),
+ df.select(size($"a")),
Seq(Row(2), Row(0), Row(3))
)
checkAnswer(