aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-04-11 19:26:15 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-11 19:26:15 -0700
commit6d4e854ffbd7dee9a3cd7b44a00fd9c0e551f5b8 (patch)
tree8b0c79447539078108b8ff10e8195f5286d90b96 /sql
parentd2383fb5ffafd6b3a56b1ee6e0e035594473e2c8 (diff)
downloadspark-6d4e854ffbd7dee9a3cd7b44a00fd9c0e551f5b8.tar.gz
spark-6d4e854ffbd7dee9a3cd7b44a00fd9c0e551f5b8.tar.bz2
spark-6d4e854ffbd7dee9a3cd7b44a00fd9c0e551f5b8.zip
[SPARK-6367][SQL] Use the proper data type for those expressions that are hijacking existing data types.
This PR adds internal UDTs for expressions that are hijacking existing data types. The following UDTs are added: * `HyperLogLogUDT` (`BinaryType` as the SQL type) for `ApproxCountDistinctPartition` * `OpenHashSetUDT` (`ArrayType` as the SQL type) for `CollectHashSet`, `NewSet`, `AddItemToSet`, and `CombineSets`. I am also adding more unit tests for aggregation with code gen enabled. JIRA: https://issues.apache.org/jira/browse/SPARK-6367 Author: Yin Huai <yhuai@databricks.com> Closes #5094 from yhuai/expressionType and squashes the following commits: 8bcd11a [Yin Huai] Return types. 61a1d66 [Yin Huai] Merge remote-tracking branch 'upstream/master' into expressionType e8b4599 [Yin Huai] Merge remote-tracking branch 'upstream/master' into expressionType 2753156 [Yin Huai] Ignore aggregations having sum functions for now. b5eb259 [Yin Huai] Case object for HyperLogLog type. 00ebdbd [Yin Huai] deserialize/serialize. 54b87ae [Yin Huai] Add UDTs for expressions that return HyperLogLog and OpenHashSet.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala24
6 files changed, 91 insertions, 20 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 406de38d1c..14a855054b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -189,9 +189,10 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
override def children: Seq[Expression] = expressions
override def nullable: Boolean = false
- override def dataType: ArrayType = ArrayType(expressions.head.dataType)
+ override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType)
override def toString: String = s"AddToHashSet(${expressions.mkString(",")})"
- override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this)
+ override def newInstance(): CollectHashSetFunction =
+ new CollectHashSetFunction(expressions, this)
}
case class CollectHashSetFunction(
@@ -250,11 +251,28 @@ case class CombineSetsAndCountFunction(
override def eval(input: Row): Any = seen.size.toLong
}
+/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */
+private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
+
+ override def sqlType: DataType = BinaryType
+
+ /** Since we are using HyperLogLog internally, usually it will not be called. */
+ override def serialize(obj: Any): Array[Byte] =
+ obj.asInstanceOf[HyperLogLog].getBytes
+
+
+ /** Since we are using HyperLogLog internally, usually it will not be called. */
+ override def deserialize(datum: Any): HyperLogLog =
+ HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]])
+
+ override def userClass: Class[HyperLogLog] = classOf[HyperLogLog]
+}
+
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
override def nullable: Boolean = false
- override def dataType: DataType = child.dataType
+ override def dataType: DataType = HyperLogLogUDT
override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
override def newInstance(): ApproxCountDistinctPartitionFunction = {
new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index d1abf3c0b6..aac56e1568 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -464,7 +464,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val itemEval = expressionEvaluator(item)
val setEval = expressionEvaluator(set)
- val ArrayType(elementType, _) = set.dataType
+ val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
itemEval.code ++ setEval.code ++
q"""
@@ -482,7 +482,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
val leftEval = expressionEvaluator(left)
val rightEval = expressionEvaluator(right)
- val ArrayType(elementType, _) = left.dataType
+ val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
leftEval.code ++ rightEval.code ++
q"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
index 35faa00782..4c44182278 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
@@ -20,6 +20,33 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
+/** The data type for expressions returning an OpenHashSet as the result. */
+private[sql] class OpenHashSetUDT(
+ val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] {
+
+ override def sqlType: DataType = ArrayType(elementType)
+
+ /** Since we are using OpenHashSet internally, usually it will not be called. */
+ override def serialize(obj: Any): Seq[Any] = {
+ obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq
+ }
+
+ /** Since we are using OpenHashSet internally, usually it will not be called. */
+ override def deserialize(datum: Any): OpenHashSet[Any] = {
+ val iterator = datum.asInstanceOf[Seq[Any]].iterator
+ val set = new OpenHashSet[Any]
+ while(iterator.hasNext) {
+ set.add(iterator.next())
+ }
+
+ set
+ }
+
+ override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]]
+
+ private[spark] override def asNullable: OpenHashSetUDT = this
+}
+
/**
* Creates a new set of the specified type
*/
@@ -28,9 +55,7 @@ case class NewSet(elementType: DataType) extends LeafExpression {
override def nullable: Boolean = false
- // We are currently only using these Expressions internally for aggregation. However, if we ever
- // expose these to users we'll want to create a proper type instead of hijacking ArrayType.
- override def dataType: DataType = ArrayType(elementType)
+ override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType)
override def eval(input: Row): Any = {
new OpenHashSet[Any]()
@@ -50,7 +75,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
override def nullable: Boolean = set.nullable
- override def dataType: DataType = set.dataType
+ override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT]
override def eval(input: Row): Any = {
val itemEval = item.eval(input)
@@ -80,7 +105,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
override def nullable: Boolean = left.nullable || right.nullable
- override def dataType: DataType = left.dataType
+ override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT]
override def symbol: String = "++="
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 861a2c21ad..3c58e93b45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -68,6 +68,8 @@ case class GeneratedAggregate(
a.collect { case agg: AggregateExpression => agg}
}
+ // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
+ // (in test "aggregation with codegen").
val computeFunctions = aggregatesToCompute.map {
case c @ Count(expr) =>
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
@@ -208,7 +210,8 @@ case class GeneratedAggregate(
currentMax)
case CollectHashSet(Seq(expr)) =>
- val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)()
+ val set =
+ AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
val initialValue = NewSet(expr.dataType)
val addToSet = AddItemToSet(expr, set)
@@ -219,9 +222,10 @@ case class GeneratedAggregate(
set)
case CombineSetsAndCount(inputSet) =>
- val ArrayType(inputType, _) = inputSet.dataType
- val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)()
- val initialValue = NewSet(inputType)
+ val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType
+ val set =
+ AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)()
+ val initialValue = NewSet(elementType)
val collectSets = CombineSets(set, inputSet)
AggregateEvaluation(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index fb8fc6dbd1..5e453e05e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
@@ -151,10 +152,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
"SELECT count(distinct key) FROM testData3x",
Row(100) :: Nil)
// SUM
- testCodeGen(
- "SELECT value, sum(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, 3 * i)))
- testCodeGen(
+ testCodeGen(
+ "SELECT value, sum(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 3 * i)))
+ testCodeGen(
"SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
Row(5050 * 3, 5050 * 3.0) :: Nil)
// AVERAGE
@@ -192,10 +193,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
testCodeGen(
"SELECT sum('a'), avg('a'), count(null) FROM testData",
Row(0, null, 0) :: Nil)
-
+
dropTempTable("testData3x")
setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}
+
test("Add Parser of SQL COALESCE()") {
checkAnswer(
sql("""SELECT COALESCE(1, 2)"""),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 902da5c3ba..2672e20dea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -23,13 +23,16 @@ import org.apache.spark.util.Utils
import scala.beans.{BeanInfo, BeanProperty}
+import com.clearspring.analytics.stream.cardinality.HyperLogLog
+
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
-
+import org.apache.spark.util.collection.OpenHashSet
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
@@ -119,4 +122,23 @@ class UserDefinedTypeSuite extends QueryTest {
df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
}
+
+ test("HyperLogLogUDT") {
+ val hyperLogLogUDT = HyperLogLogUDT
+ val hyperLogLog = new HyperLogLog(0.4)
+ (1 to 10).foreach(i => hyperLogLog.offer(Row(i)))
+
+ val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog))
+ assert(actual.cardinality() === hyperLogLog.cardinality())
+ assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes))
+ }
+
+ test("OpenHashSetUDT") {
+ val openHashSetUDT = new OpenHashSetUDT(IntegerType)
+ val set = new OpenHashSet[Int]
+ (1 to 10).foreach(i => set.add(i))
+
+ val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set))
+ assert(actual.iterator.toSet === set.iterator.toSet)
+ }
}