aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala24
6 files changed, 91 insertions, 20 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 406de38d1c..14a855054b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -189,9 +189,10 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
override def children: Seq[Expression] = expressions
override def nullable: Boolean = false
- override def dataType: ArrayType = ArrayType(expressions.head.dataType)
+ override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType)
override def toString: String = s"AddToHashSet(${expressions.mkString(",")})"
- override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this)
+ override def newInstance(): CollectHashSetFunction =
+ new CollectHashSetFunction(expressions, this)
}
case class CollectHashSetFunction(
@@ -250,11 +251,28 @@ case class CombineSetsAndCountFunction(
override def eval(input: Row): Any = seen.size.toLong
}
+/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */
+private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
+
+ override def sqlType: DataType = BinaryType
+
+ /** Since we are using HyperLogLog internally, usually it will not be called. */
+ override def serialize(obj: Any): Array[Byte] =
+ obj.asInstanceOf[HyperLogLog].getBytes
+
+
+ /** Since we are using HyperLogLog internally, usually it will not be called. */
+ override def deserialize(datum: Any): HyperLogLog =
+ HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]])
+
+ override def userClass: Class[HyperLogLog] = classOf[HyperLogLog]
+}
+
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
override def nullable: Boolean = false
- override def dataType: DataType = child.dataType
+ override def dataType: DataType = HyperLogLogUDT
override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
override def newInstance(): ApproxCountDistinctPartitionFunction = {
new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d1abf3c0b6..aac56e1568 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -464,7 +464,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val itemEval = expressionEvaluator(item)
val setEval = expressionEvaluator(set)
- val ArrayType(elementType, _) = set.dataType
+ val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
itemEval.code ++ setEval.code ++
q"""
@@ -482,7 +482,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val leftEval = expressionEvaluator(left)
val rightEval = expressionEvaluator(right)
- val ArrayType(elementType, _) = left.dataType
+ val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
leftEval.code ++ rightEval.code ++
q"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 35faa00782..4c44182278 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -20,6 +20,33 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
+/** The data type for expressions returning an OpenHashSet as the result. */
+private[sql] class OpenHashSetUDT(
+ val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] {
+
+ override def sqlType: DataType = ArrayType(elementType)
+
+ /** Since we are using OpenHashSet internally, usually it will not be called. */
+ override def serialize(obj: Any): Seq[Any] = {
+ obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq
+ }
+
+ /** Since we are using OpenHashSet internally, usually it will not be called. */
+ override def deserialize(datum: Any): OpenHashSet[Any] = {
+ val iterator = datum.asInstanceOf[Seq[Any]].iterator
+ val set = new OpenHashSet[Any]
+ while(iterator.hasNext) {
+ set.add(iterator.next())
+ }
+
+ set
+ }
+
+ override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]]
+
+ private[spark] override def asNullable: OpenHashSetUDT = this
+}
+
/**
* Creates a new set of the specified type
*/
@@ -28,9 +55,7 @@ case class NewSet(elementType: DataType) extends LeafExpression {
override def nullable: Boolean = false
- // We are currently only using these Expressions internally for aggregation. However, if we ever
- // expose these to users we'll want to create a proper type instead of hijacking ArrayType.
- override def dataType: DataType = ArrayType(elementType)
+ override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType)
override def eval(input: Row): Any = {
new OpenHashSet[Any]()
@@ -50,7 +75,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
override def nullable: Boolean = set.nullable
- override def dataType: DataType = set.dataType
+ override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT]
override def eval(input: Row): Any = {
val itemEval = item.eval(input)
@@ -80,7 +105,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
override def nullable: Boolean = left.nullable || right.nullable
- override def dataType: DataType = left.dataType
+ override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT]
override def symbol: String = "++="
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 861a2c21ad..3c58e93b45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -68,6 +68,8 @@ case class GeneratedAggregate(
a.collect { case agg: AggregateExpression => agg}
}
+ // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
+ // (in test "aggregation with codegen").
val computeFunctions = aggregatesToCompute.map {
case c @ Count(expr) =>
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
@@ -208,7 +210,8 @@ case class GeneratedAggregate(
currentMax)
case CollectHashSet(Seq(expr)) =>
- val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)()
+ val set =
+ AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
val initialValue = NewSet(expr.dataType)
val addToSet = AddItemToSet(expr, set)
@@ -219,9 +222,10 @@ case class GeneratedAggregate(
set)
case CombineSetsAndCount(inputSet) =>
- val ArrayType(inputType, _) = inputSet.dataType
- val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)()
- val initialValue = NewSet(inputType)
+ val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType
+ val set =
+ AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)()
+ val initialValue = NewSet(elementType)
val collectSets = CombineSets(set, inputSet)
AggregateEvaluation(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index fb8fc6dbd1..5e453e05e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
@@ -151,10 +152,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
"SELECT count(distinct key) FROM testData3x",
Row(100) :: Nil)
// SUM
- testCodeGen(
- "SELECT value, sum(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, 3 * i)))
- testCodeGen(
+ testCodeGen(
+ "SELECT value, sum(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 3 * i)))
+ testCodeGen(
"SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
Row(5050 * 3, 5050 * 3.0) :: Nil)
// AVERAGE
@@ -192,10 +193,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
testCodeGen(
"SELECT sum('a'), avg('a'), count(null) FROM testData",
Row(0, null, 0) :: Nil)
-
+
dropTempTable("testData3x")
setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}
+
test("Add Parser of SQL COALESCE()") {
checkAnswer(
sql("""SELECT COALESCE(1, 2)"""),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 902da5c3ba..2672e20dea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -23,13 +23,16 @@ import org.apache.spark.util.Utils
import scala.beans.{BeanInfo, BeanProperty}
+import com.clearspring.analytics.stream.cardinality.HyperLogLog
+
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
-
+import org.apache.spark.util.collection.OpenHashSet
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
@@ -119,4 +122,23 @@ class UserDefinedTypeSuite extends QueryTest {
df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
}
+
+ test("HyperLogLogUDT") {
+ val hyperLogLogUDT = HyperLogLogUDT
+ val hyperLogLog = new HyperLogLog(0.4)
+ (1 to 10).foreach(i => hyperLogLog.offer(Row(i)))
+
+ val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog))
+ assert(actual.cardinality() === hyperLogLog.cardinality())
+ assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes))
+ }
+
+ test("OpenHashSetUDT") {
+ val openHashSetUDT = new OpenHashSetUDT(IntegerType)
+ val set = new OpenHashSet[Int]
+ (1 to 10).foreach(i => set.add(i))
+
+ val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set))
+ assert(actual.iterator.toSet === set.iterator.toSet)
+ }
}