aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-04-11 19:26:15 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-11 19:26:15 -0700
commit6d4e854ffbd7dee9a3cd7b44a00fd9c0e551f5b8 (patch)
tree8b0c79447539078108b8ff10e8195f5286d90b96 /sql/core
parentd2383fb5ffafd6b3a56b1ee6e0e035594473e2c8 (diff)
downloadspark-6d4e854ffbd7dee9a3cd7b44a00fd9c0e551f5b8.tar.gz
spark-6d4e854ffbd7dee9a3cd7b44a00fd9c0e551f5b8.tar.bz2
spark-6d4e854ffbd7dee9a3cd7b44a00fd9c0e551f5b8.zip
[SPARK-6367][SQL] Use the proper data type for those expressions that are hijacking existing data types.
This PR adds internal UDTs for expressions that are hijacking existing data types. The following UDTs are added: * `HyperLogLogUDT` (`BinaryType` as the SQL type) for `ApproxCountDistinctPartition` * `OpenHashSetUDT` (`ArrayType` as the SQL type) for `CollectHashSet`, `NewSet`, `AddItemToSet`, and `CombineSets`. I am also adding more unit tests for aggregation with code gen enabled. JIRA: https://issues.apache.org/jira/browse/SPARK-6367 Author: Yin Huai <yhuai@databricks.com> Closes #5094 from yhuai/expressionType and squashes the following commits: 8bcd11a [Yin Huai] Return types. 61a1d66 [Yin Huai] Merge remote-tracking branch 'upstream/master' into expressionType e8b4599 [Yin Huai] Merge remote-tracking branch 'upstream/master' into expressionType 2753156 [Yin Huai] Ignore aggregations having sum functions for now. b5eb259 [Yin Huai] Case object for HyperLogLog type. 00ebdbd [Yin Huai] deserialize/serialize. 54b87ae [Yin Huai] Add UDTs for expressions that return HyperLogLog and OpenHashSet.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala24
3 files changed, 38 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 861a2c21ad..3c58e93b45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -68,6 +68,8 @@ case class GeneratedAggregate(
a.collect { case agg: AggregateExpression => agg}
}
+ // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
+ // (in test "aggregation with codegen").
val computeFunctions = aggregatesToCompute.map {
case c @ Count(expr) =>
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
@@ -208,7 +210,8 @@ case class GeneratedAggregate(
currentMax)
case CollectHashSet(Seq(expr)) =>
- val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)()
+ val set =
+ AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
val initialValue = NewSet(expr.dataType)
val addToSet = AddItemToSet(expr, set)
@@ -219,9 +222,10 @@ case class GeneratedAggregate(
set)
case CombineSetsAndCount(inputSet) =>
- val ArrayType(inputType, _) = inputSet.dataType
- val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)()
- val initialValue = NewSet(inputType)
+ val elementType = inputSet.dataType.asInstanceOf[OpenHashSetUDT].elementType
+ val set =
+ AttributeReference("hashSet", new OpenHashSetUDT(elementType), nullable = false)()
+ val initialValue = NewSet(elementType)
val collectSets = CombineSets(set, inputSet)
AggregateEvaluation(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index fb8fc6dbd1..5e453e05e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
@@ -151,10 +152,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
"SELECT count(distinct key) FROM testData3x",
Row(100) :: Nil)
// SUM
- testCodeGen(
- "SELECT value, sum(key) FROM testData3x GROUP BY value",
- (1 to 100).map(i => Row(i.toString, 3 * i)))
- testCodeGen(
+ testCodeGen(
+ "SELECT value, sum(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 3 * i)))
+ testCodeGen(
"SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
Row(5050 * 3, 5050 * 3.0) :: Nil)
// AVERAGE
@@ -192,10 +193,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
testCodeGen(
"SELECT sum('a'), avg('a'), count(null) FROM testData",
Row(0, null, 0) :: Nil)
-
+
dropTempTable("testData3x")
setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}
+
test("Add Parser of SQL COALESCE()") {
checkAnswer(
sql("""SELECT COALESCE(1, 2)"""),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 902da5c3ba..2672e20dea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -23,13 +23,16 @@ import org.apache.spark.util.Utils
import scala.beans.{BeanInfo, BeanProperty}
+import com.clearspring.analytics.stream.cardinality.HyperLogLog
+
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
-
+import org.apache.spark.util.collection.OpenHashSet
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
@@ -119,4 +122,23 @@ class UserDefinedTypeSuite extends QueryTest {
df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
}
+
+ test("HyperLogLogUDT") {
+ val hyperLogLogUDT = HyperLogLogUDT
+ val hyperLogLog = new HyperLogLog(0.4)
+ (1 to 10).foreach(i => hyperLogLog.offer(Row(i)))
+
+ val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog))
+ assert(actual.cardinality() === hyperLogLog.cardinality())
+ assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes))
+ }
+
+ test("OpenHashSetUDT") {
+ val openHashSetUDT = new OpenHashSetUDT(IntegerType)
+ val set = new OpenHashSet[Int]
+ (1 to 10).foreach(i => set.add(i))
+
+ val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set))
+ assert(actual.iterator.toSet === set.iterator.toSet)
+ }
}