aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-05-22 01:00:16 -0700
committerReynold Xin <rxin@databricks.com>2015-05-22 01:00:16 -0700
commitf6f2eeb17910b5d446dfd61839e37dd698d0860f (patch)
treedabde9be745a52c4d4af0f2177f795dcb8cc005b /sql
parent2728c3df6690c2fcd4af3bd1c604c98ef6d509a5 (diff)
downloadspark-f6f2eeb17910b5d446dfd61839e37dd698d0860f.tar.gz
spark-f6f2eeb17910b5d446dfd61839e37dd698d0860f.tar.bz2
spark-f6f2eeb17910b5d446dfd61839e37dd698d0860f.zip
[SPARK-7322][SQL] Window functions in DataFrame
This closes #6104. Author: Cheng Hao <hao.cheng@intel.com> Author: Reynold Xin <rxin@databricks.com> Closes #6343 from rxin/window-df and squashes the following commits: 026d587 [Reynold Xin] Address code review feedback. dc448fe [Reynold Xin] Fixed Hive tests. 9794d9d [Reynold Xin] Moved Java test package. 9331605 [Reynold Xin] Refactored API. 3313e2a [Reynold Xin] Merge pull request #6104 from chenghao-intel/df_window d625a64 [Cheng Hao] Update the dataframe window API as suggsted c141fb1 [Cheng Hao] hide all of properties of the WindowFunctionDefinition 3b1865f [Cheng Hao] scaladoc typos f3fd2d0 [Cheng Hao] polish the unit test 6847825 [Cheng Hao] Add additional analystcs functions 57e3bc0 [Cheng Hao] typos 24a08ec [Cheng Hao] scaladoc 28222ed [Cheng Hao] fix bug of range/row Frame 1d91865 [Cheng Hao] style issue 53f89f2 [Cheng Hao] remove the over from the functions.scala 964c013 [Cheng Hao] add more unit tests and window functions 64e18a7 [Cheng Hao] Add Window Function support for DataFrame
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala175
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala228
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java78
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java)4
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java)0
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java)0
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java)0
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java)0
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java)0
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala219
13 files changed, 807 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index dc0aeea7c4..6895aa1010 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -18,13 +18,13 @@
package org.apache.spark.sql
import scala.language.implicitConversions
-import scala.collection.JavaConversions._
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
+import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._
@@ -889,6 +889,22 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*/
def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr)
+ /**
+ * Define a windowing column.
+ *
+ * {{{
+ * val w = Window.partitionBy("name").orderBy("id")
+ * df.select(
+ * sum("price").over(w.rangeBetween(Long.MinValue, 2)),
+ * avg("price").over(w.rowsBetween(0, 4))
+ * )
+ * }}}
+ *
+ * @group expr_ops
+ * @since 1.4.0
+ */
+ def over(window: expressions.WindowSpec): Column = window.withAggregate(this)
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index d78b4c2f89..3ec1c4a2f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, Unresol
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
-import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.json.JacksonGenerator
import org.apache.spark.sql.sources.CreateTableUsingAsSelect
@@ -411,7 +411,7 @@ class DataFrame private[sql](
joined.left,
joined.right,
joinType = Inner,
- Some(expressions.EqualTo(
+ Some(catalyst.expressions.EqualTo(
joined.left.resolve(usingColumn),
joined.right.resolve(usingColumn))))
)
@@ -480,8 +480,9 @@ class DataFrame private[sql](
// By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference.
val cond = plan.condition.map { _.transform {
- case expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) =>
- expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name))
+ case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference)
+ if a.sameRef(b) =>
+ catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name))
}}
plan.copy(condition = cond)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
new file mode 100644
index 0000000000..d4003b2d9c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * :: Experimental ::
+ * Utility functions for defining window in DataFrames.
+ *
+ * {{{
+ * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
+ * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0)
+ *
+ * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING
+ * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3)
+ * }}}
+ *
+ * @since 1.4.0
+ */
+@Experimental
+object Window {
+
+ /**
+ * Creates a [[WindowSpec]] with the partitioning defined.
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def partitionBy(colName: String, colNames: String*): WindowSpec = {
+ spec.partitionBy(colName, colNames : _*)
+ }
+
+ /**
+ * Creates a [[WindowSpec]] with the partitioning defined.
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def partitionBy(cols: Column*): WindowSpec = {
+ spec.partitionBy(cols : _*)
+ }
+
+ /**
+ * Creates a [[WindowSpec]] with the ordering defined.
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def orderBy(colName: String, colNames: String*): WindowSpec = {
+ spec.orderBy(colName, colNames : _*)
+ }
+
+ /**
+ * Creates a [[WindowSpec]] with the ordering defined.
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def orderBy(cols: Column*): WindowSpec = {
+ spec.orderBy(cols : _*)
+ }
+
+ private def spec: WindowSpec = {
+ new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame)
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
new file mode 100644
index 0000000000..c3d2246297
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.{Column, catalyst}
+import org.apache.spark.sql.catalyst.expressions._
+
+
+/**
+ * :: Experimental ::
+ * A window specification that defines the partitioning, ordering, and frame boundaries.
+ *
+ * Use the static methods in [[Window]] to create a [[WindowSpec]].
+ *
+ * @since 1.4.0
+ */
+@Experimental
+class WindowSpec private[sql](
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ frame: catalyst.expressions.WindowFrame) {
+
+ /**
+ * Defines the partitioning columns in a [[WindowSpec]].
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def partitionBy(colName: String, colNames: String*): WindowSpec = {
+ partitionBy((colName +: colNames).map(Column(_)): _*)
+ }
+
+ /**
+ * Defines the partitioning columns in a [[WindowSpec]].
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def partitionBy(cols: Column*): WindowSpec = {
+ new WindowSpec(cols.map(_.expr), orderSpec, frame)
+ }
+
+ /**
+ * Defines the ordering columns in a [[WindowSpec]].
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def orderBy(colName: String, colNames: String*): WindowSpec = {
+ orderBy((colName +: colNames).map(Column(_)): _*)
+ }
+
+ /**
+ * Defines the ordering columns in a [[WindowSpec]].
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def orderBy(cols: Column*): WindowSpec = {
+ val sortOrder: Seq[SortOrder] = cols.map { col =>
+ col.expr match {
+ case expr: SortOrder =>
+ expr
+ case expr: Expression =>
+ SortOrder(expr, Ascending)
+ }
+ }
+ new WindowSpec(partitionSpec, sortOrder, frame)
+ }
+
+ /**
+ * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
+ *
+ * Both `start` and `end` are relative positions from the current row. For example, "0" means
+ * "current row", while "-1" means the row before the current row, and "5" means the fifth row
+ * after the current row.
+ *
+ * @param start boundary start, inclusive.
+ * The frame is unbounded if this is the minimum long value.
+ * @param end boundary end, inclusive.
+ * The frame is unbounded if this is the maximum long value.
+ * @since 1.4.0
+ */
+ def rowsBetween(start: Long, end: Long): WindowSpec = {
+ between(RowFrame, start, end)
+ }
+
+ /**
+ * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
+ *
+ * Both `start` and `end` are relative from the current row. For example, "0" means "current row",
+ * while "-1" means one off before the current row, and "5" means the five off after the
+ * current row.
+ *
+ * @param start boundary start, inclusive.
+ * The frame is unbounded if this is the minimum long value.
+ * @param end boundary end, inclusive.
+ * The frame is unbounded if this is the maximum long value.
+ * @since 1.4.0
+ */
+ def rangeBetween(start: Long, end: Long): WindowSpec = {
+ between(RangeFrame, start, end)
+ }
+
+ private def between(typ: FrameType, start: Long, end: Long): WindowSpec = {
+ val boundaryStart = start match {
+ case 0 => CurrentRow
+ case Long.MinValue => UnboundedPreceding
+ case x if x < 0 => ValuePreceding(-start.toInt)
+ case x if x > 0 => ValueFollowing(start.toInt)
+ }
+
+ val boundaryEnd = end match {
+ case 0 => CurrentRow
+ case Long.MaxValue => UnboundedFollowing
+ case x if x < 0 => ValuePreceding(-end.toInt)
+ case x if x > 0 => ValueFollowing(end.toInt)
+ }
+
+ new WindowSpec(
+ partitionSpec,
+ orderSpec,
+ SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd))
+ }
+
+ /**
+ * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression.
+ */
+ private[sql] def withAggregate(aggregate: Column): Column = {
+ val windowExpr = aggregate.expr match {
+ case Average(child) => WindowExpression(
+ UnresolvedWindowFunction("avg", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Sum(child) => WindowExpression(
+ UnresolvedWindowFunction("sum", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Count(child) => WindowExpression(
+ UnresolvedWindowFunction("count", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case First(child) => WindowExpression(
+ // TODO this is a hack for Hive UDAF first_value
+ UnresolvedWindowFunction("first_value", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Last(child) => WindowExpression(
+ // TODO this is a hack for Hive UDAF last_value
+ UnresolvedWindowFunction("last_value", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Min(child) => WindowExpression(
+ UnresolvedWindowFunction("min", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case Max(child) => WindowExpression(
+ UnresolvedWindowFunction("max", child :: Nil),
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case wf: WindowFunction => WindowExpression(
+ wf,
+ WindowSpecDefinition(partitionSpec, orderSpec, frame))
+ case x =>
+ throw new UnsupportedOperationException(s"$x is not supported in window operation.")
+ }
+ new Column(windowExpr)
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 6640631cf0..8775be724e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -37,6 +37,7 @@ import org.apache.spark.util.Utils
* @groupname sort_funcs Sorting functions
* @groupname normal_funcs Non-aggregate functions
* @groupname math_funcs Math functions
+ * @groupname window_funcs Window functions
* @groupname Ungrouped Support functions for DataFrames.
* @since 1.3.0
*/
@@ -321,6 +322,233 @@ object functions {
def max(columnName: String): Column = max(Column(columnName))
//////////////////////////////////////////////////////////////////////////////////////////////
+ // Window functions
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Window function: returns the lag value of current row of the expression,
+ * null when the current row extends before the beginning of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lag(columnName: String): Column = {
+ lag(columnName, 1)
+ }
+
+ /**
+ * Window function: returns the lag value of current row of the column,
+ * null when the current row extends before the beginning of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lag(e: Column): Column = {
+ lag(e, 1)
+ }
+
+ /**
+ * Window function: returns the lag values of current row of the expression,
+ * null when the current row extends before the beginning of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lag(e: Column, count: Int): Column = {
+ lag(e, count, null)
+ }
+
+ /**
+ * Window function: returns the lag values of current row of the column,
+ * null when the current row extends before the beginning of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lag(columnName: String, count: Int): Column = {
+ lag(columnName, count, null)
+ }
+
+ /**
+ * Window function: returns the lag values of current row of the column,
+ * given default value when the current row extends before the beginning
+ * of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lag(columnName: String, count: Int, defaultValue: Any): Column = {
+ lag(Column(columnName), count, defaultValue)
+ }
+
+ /**
+ * Window function: returns the lag values of current row of the expression,
+ * given default value when the current row extends before the beginning
+ * of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lag(e: Column, count: Int, defaultValue: Any): Column = {
+ UnresolvedWindowFunction("lag", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil)
+ }
+
+ /**
+ * Window function: returns the lead value of current row of the column,
+ * null when the current row extends before the end of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lead(columnName: String): Column = {
+ lead(columnName, 1)
+ }
+
+ /**
+ * Window function: returns the lead value of current row of the expression,
+ * null when the current row extends before the end of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lead(e: Column): Column = {
+ lead(e, 1)
+ }
+
+ /**
+ * Window function: returns the lead values of current row of the column,
+ * null when the current row extends before the end of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lead(columnName: String, count: Int): Column = {
+ lead(columnName, count, null)
+ }
+
+ /**
+ * Window function: returns the lead values of current row of the expression,
+ * null when the current row extends before the end of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lead(e: Column, count: Int): Column = {
+ lead(e, count, null)
+ }
+
+ /**
+ * Window function: returns the lead values of current row of the column,
+ * given default value when the current row extends before the end of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lead(columnName: String, count: Int, defaultValue: Any): Column = {
+ lead(Column(columnName), count, defaultValue)
+ }
+
+ /**
+ * Window function: returns the lead values of current row of the expression,
+ * given default value when the current row extends before the end of the window.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def lead(e: Column, count: Int, defaultValue: Any): Column = {
+ UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil)
+ }
+
+ /**
+ * NTILE for specified expression.
+ * NTILE allows easy calculation of tertiles, quartiles, deciles and other
+ * common summary statistics. This function divides an ordered partition into a specified
+ * number of groups called buckets and assigns a bucket number to each row in the partition.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def ntile(e: Column): Column = {
+ UnresolvedWindowFunction("ntile", e.expr :: Nil)
+ }
+
+ /**
+ * NTILE for specified column.
+ * NTILE allows easy calculation of tertiles, quartiles, deciles and other
+ * common summary statistics. This function divides an ordered partition into a specified
+ * number of groups called buckets and assigns a bucket number to each row in the partition.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def ntile(columnName: String): Column = {
+ ntile(Column(columnName))
+ }
+
+ /**
+ * Assigns a unique number (sequentially, starting from 1, as defined by ORDER BY) to each
+ * row within the partition.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def rowNumber(): Column = {
+ UnresolvedWindowFunction("row_number", Nil)
+ }
+
+ /**
+ * The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking
+ * sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK
+ * and had three people tie for second place, you would say that all three were in second
+ * place and that the next person came in third.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def denseRank(): Column = {
+ UnresolvedWindowFunction("dense_rank", Nil)
+ }
+
+ /**
+ * The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking
+ * sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK
+ * and had three people tie for second place, you would say that all three were in second
+ * place and that the next person came in third.
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def rank(): Column = {
+ UnresolvedWindowFunction("rank", Nil)
+ }
+
+ /**
+ * CUME_DIST (defined as the inverse of percentile in some statistical books) computes
+ * the position of a specified value relative to a set of values.
+ * To compute the CUME_DIST of a value x in a set S of size N, you use the formula:
+ * CUME_DIST(x) = number of values in S coming before and including x in the specified order / N
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def cumeDist(): Column = {
+ UnresolvedWindowFunction("cume_dist", Nil)
+ }
+
+ /**
+ * PERCENT_RANK is similar to CUME_DIST, but it uses rank values rather than row counts
+ * in its numerator.
+ * The formula:
+ * (rank of row in its partition - 1) / (number of rows in the partition - 1)
+ *
+ * @group window_funcs
+ * @since 1.4.0
+ */
+ def percentRank(): Column = {
+ UnresolvedWindowFunction("percent_rank", Nil)
+ }
+
+ //////////////////////////////////////////////////////////////////////////////////////////////
// Non-aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
new file mode 100644
index 0000000000..c4828c4717
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.expressions.Window;
+import org.apache.spark.sql.hive.HiveContext;
+import org.apache.spark.sql.hive.test.TestHive$;
+
+public class JavaDataFrameSuite {
+ private transient JavaSparkContext sc;
+ private transient HiveContext hc;
+
+ DataFrame df;
+
+ private void checkAnswer(DataFrame actual, List<Row> expected) {
+ String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
+ if (errorMessage != null) {
+ Assert.fail(errorMessage);
+ }
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ hc = TestHive$.MODULE$;
+ sc = new JavaSparkContext(hc.sparkContext());
+
+ List<String> jsonObjects = new ArrayList<String>(10);
+ for (int i = 0; i < 10; i++) {
+ jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}");
+ }
+ df = hc.jsonRDD(sc.parallelize(jsonObjects));
+ df.registerTempTable("window_table");
+ }
+
+ @After
+ public void tearDown() throws IOException {
+ // Clean up tables.
+ hc.sql("DROP TABLE IF EXISTS window_table");
+ }
+
+ @Test
+ public void saveTableAndQueryIt() {
+ checkAnswer(
+ df.select(functions.avg("key").over(
+ Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))),
+ hc.sql("SELECT avg(key) " +
+ "OVER (PARTITION BY value " +
+ " ORDER BY key " +
+ " ROWS BETWEEN 1 preceding and 1 following) " +
+ "FROM window_table").collectAsList());
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
index 58fe96adab..64d1ce9293 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
@@ -14,7 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.hive;
+
+package test.org.apache.spark.sql.hive;
import java.io.File;
import java.io.IOException;
@@ -36,6 +37,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.QueryTest$;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.hive.test.TestHive$;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
index 6c4f378bc5..6c4f378bc5 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java
index 808e2986d3..808e2986d3 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java
index f33210ebda..f33210ebda 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java
index a369188d47..a369188d47 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java
index 0165591a7c..0165591a7c 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala
new file mode 100644
index 0000000000..6cea6776c8
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.{Row, QueryTest}
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+
+class HiveDataFrameWindowSuite extends QueryTest {
+
+ test("reuse window partitionBy") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ val w = Window.partitionBy("key").orderBy("value")
+
+ checkAnswer(
+ df.select(
+ lead("key").over(w),
+ lead("value").over(w)),
+ Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
+ }
+
+ test("reuse window orderBy") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ val w = Window.orderBy("value").partitionBy("key")
+
+ checkAnswer(
+ df.select(
+ lead("key").over(w),
+ lead("value").over(w)),
+ Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
+ }
+
+ test("lead") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+
+ checkAnswer(
+ df.select(
+ lead("value").over(Window.partitionBy($"key").orderBy($"value"))),
+ sql(
+ """SELECT
+ | lead(value) OVER (PARTITION BY key ORDER BY value)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("lag") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+
+ checkAnswer(
+ df.select(
+ lag("value").over(
+ Window.partitionBy($"key")
+ .orderBy($"value"))),
+ sql(
+ """SELECT
+ | lag(value) OVER (PARTITION BY key ORDER BY value)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("lead with default value") {
+ val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
+ (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))),
+ sql(
+ """SELECT
+ | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("lag with default value") {
+ val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
+ (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))),
+ sql(
+ """SELECT
+ | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("rank functions in unspecific window") {
+ val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ $"key",
+ max("key").over(Window.partitionBy("value").orderBy("key")),
+ min("key").over(Window.partitionBy("value").orderBy("key")),
+ mean("key").over(Window.partitionBy("value").orderBy("key")),
+ count("key").over(Window.partitionBy("value").orderBy("key")),
+ sum("key").over(Window.partitionBy("value").orderBy("key")),
+ ntile("key").over(Window.partitionBy("value").orderBy("key")),
+ ntile($"key").over(Window.partitionBy("value").orderBy("key")),
+ rowNumber().over(Window.partitionBy("value").orderBy("key")),
+ denseRank().over(Window.partitionBy("value").orderBy("key")),
+ rank().over(Window.partitionBy("value").orderBy("key")),
+ cumeDist().over(Window.partitionBy("value").orderBy("key")),
+ percentRank().over(Window.partitionBy("value").orderBy("key"))),
+ sql(
+ s"""SELECT
+ |key,
+ |max(key) over (partition by value order by key),
+ |min(key) over (partition by value order by key),
+ |avg(key) over (partition by value order by key),
+ |count(key) over (partition by value order by key),
+ |sum(key) over (partition by value order by key),
+ |ntile(key) over (partition by value order by key),
+ |ntile(key) over (partition by value order by key),
+ |row_number() over (partition by value order by key),
+ |dense_rank() over (partition by value order by key),
+ |rank() over (partition by value order by key),
+ |cume_dist() over (partition by value order by key),
+ |percent_rank() over (partition by value order by key)
+ |FROM window_table""".stripMargin).collect())
+ }
+
+ test("aggregation and rows between") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))),
+ sql(
+ """SELECT
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("aggregation and range betweens") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))),
+ sql(
+ """SELECT
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("aggregation and rows betweens with unbounded") {
+ val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ $"key",
+ last("value").over(
+ Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)),
+ last("value").over(
+ Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)),
+ last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))),
+ sql(
+ """SELECT
+ | key,
+ | last_value(value) OVER
+ | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following),
+ | last_value(value) OVER
+ | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row),
+ | last_value(value) OVER
+ | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("aggregation and range betweens with unbounded") {
+ val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ $"key",
+ last("value").over(
+ Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue))
+ .equalTo("2")
+ .as("last_v"),
+ avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1))
+ .as("avg_key1"),
+ avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue))
+ .as("avg_key2"),
+ avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0))
+ .as("avg_key3")
+ ),
+ sql(
+ """SELECT
+ | key,
+ | last_value(value) OVER
+ | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2",
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following),
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following),
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row)
+ | FROM window_table""".stripMargin).collect())
+ }
+}