aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-08-23 16:19:10 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-23 16:19:10 -0700
commit7e191fe29bb09a8560cd75d453c4f7f662dff406 (patch)
tree38c9db34f2ff9dfd3042cacc3ca6fe351c576ece /sql/core
parent2fb1c72ea21e137c8b60a72e5aecd554c71b16e1 (diff)
downloadspark-7e191fe29bb09a8560cd75d453c4f7f662dff406.tar.gz
spark-7e191fe29bb09a8560cd75d453c4f7f662dff406.tar.bz2
spark-7e191fe29bb09a8560cd75d453c4f7f662dff406.zip
[SPARK-2554][SQL] CountDistinct partial aggregation and object allocation improvements
Author: Michael Armbrust <michael@databricks.com> Author: Gregory Owen <greowen@gmail.com> Closes #1935 from marmbrus/countDistinctPartial and squashes the following commits: 5c7848d [Michael Armbrust] turn off caching in the constructor 8074a80 [Michael Armbrust] fix tests 32d216f [Michael Armbrust] reynolds comments c122cca [Michael Armbrust] Address comments, add tests b2e8ef3 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into countDistinctPartial fae38f4 [Michael Armbrust] Fix style fdca896 [Michael Armbrust] cleanup 93d0f64 [Michael Armbrust] metastore concurrency fix. db44a30 [Michael Armbrust] JIT hax. 3868f6c [Michael Armbrust] Merge pull request #9 from GregOwen/countDistinctPartial c9e67de [Gregory Owen] Made SpecificRow and types serializable by Kryo 2b46c4b [Michael Armbrust] Merge remote-tracking branch 'origin/master' into countDistinctPartial 8ff6402 [Michael Armbrust] Add specific row. 58d15f1 [Michael Armbrust] disable codegen logging 87d101d [Michael Armbrust] Fix isNullAt bug abee26d [Michael Armbrust] WIP 27984d0 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into countDistinctPartial 57ae3b1 [Michael Armbrust] Fix order dependent test b3d0f64 [Michael Armbrust] Add golden files. c1f7114 [Michael Armbrust] Improve tests / fix serialization. f31b8ad [Michael Armbrust] more fixes 38c7449 [Michael Armbrust] comments and style 9153652 [Michael Armbrust] better toString d494598 [Michael Armbrust] Fix tests now that the planner is better 41fbd1d [Michael Armbrust] Never try and create an empty hash set. 050bb97 [Michael Armbrust] Skip no-arg constructors for kryo, bd08239 [Michael Armbrust] WIP 213ada8 [Michael Armbrust] First draft of partially aggregated and code generated count distinct / max
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala86
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala8
8 files changed, 137 insertions, 13 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 463a1d32d7..be9f155253 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -175,7 +175,7 @@ case class Aggregate(
private[this] val resultProjection =
new InterpretedMutableProjection(
resultExpressions, computedSchema ++ namedGroups.map(_._2))
- private[this] val joinedRow = new JoinedRow
+ private[this] val joinedRow = new JoinedRow4
override final def hasNext: Boolean = hashTableIter.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 4a26934c49..31ad5e8aab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -103,6 +103,40 @@ case class GeneratedAggregate(
updateCount :: updateSum :: Nil,
result
)
+
+ case m @ Max(expr) =>
+ val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
+ val initialValue = Literal(null, expr.dataType)
+ val updateMax = MaxOf(currentMax, expr)
+
+ AggregateEvaluation(
+ currentMax :: Nil,
+ initialValue :: Nil,
+ updateMax :: Nil,
+ currentMax)
+
+ case CollectHashSet(Seq(expr)) =>
+ val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)()
+ val initialValue = NewSet(expr.dataType)
+ val addToSet = AddItemToSet(expr, set)
+
+ AggregateEvaluation(
+ set :: Nil,
+ initialValue :: Nil,
+ addToSet :: Nil,
+ set)
+
+ case CombineSetsAndCount(inputSet) =>
+ val ArrayType(inputType, _) = inputSet.dataType
+ val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)()
+ val initialValue = NewSet(inputType)
+ val collectSets = CombineSets(set, inputSet)
+
+ AggregateEvaluation(
+ set :: Nil,
+ initialValue :: Nil,
+ collectSets :: Nil,
+ CountSet(set))
}
val computationSchema = computeFunctions.flatMap(_.schema)
@@ -151,7 +185,7 @@ case class GeneratedAggregate(
(namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
- val joinedRow = new JoinedRow
+ val joinedRow = new JoinedRow3
if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 34654447a5..077e6ebc5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -28,9 +28,13 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool}
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.{SerializerInstance, KryoSerializer}
+import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.util.MutablePair
import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet}
+
private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
val kryo = new Kryo()
@@ -41,6 +45,13 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
+
+ // Specific hashsets must come first TODO: Move to core.
+ kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer)
+ kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
+ kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
+ new OpenHashSetSerializer)
+
kryo.setReferences(false)
kryo.setClassLoader(Utils.getSparkClassLoader)
new AllScalaRegistrar().apply(kryo)
@@ -109,3 +120,78 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
HyperLogLog.Builder.build(bytes)
}
}
+
+private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
+ def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) {
+ val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
+ output.writeInt(hs.size)
+ val iterator = hs.iterator
+ while(iterator.hasNext) {
+ val row = iterator.next()
+ rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values)
+ }
+ }
+
+ def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = {
+ val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
+ val numItems = input.readInt()
+ val set = new OpenHashSet[Any](numItems + 1)
+ var i = 0
+ while (i < numItems) {
+ val row =
+ new GenericRow(rowSerializer.read(
+ kryo,
+ input,
+ classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])
+ set.add(row)
+ i += 1
+ }
+ set
+ }
+}
+
+private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] {
+ def write(kryo: Kryo, output: Output, hs: IntegerHashSet) {
+ output.writeInt(hs.size)
+ val iterator = hs.iterator
+ while(iterator.hasNext) {
+ val value: Int = iterator.next()
+ output.writeInt(value)
+ }
+ }
+
+ def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = {
+ val numItems = input.readInt()
+ val set = new IntegerHashSet
+ var i = 0
+ while (i < numItems) {
+ val value = input.readInt()
+ set.add(value)
+ i += 1
+ }
+ set
+ }
+}
+
+private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] {
+ def write(kryo: Kryo, output: Output, hs: LongHashSet) {
+ output.writeInt(hs.size)
+ val iterator = hs.iterator
+ while(iterator.hasNext) {
+ val value = iterator.next()
+ output.writeLong(value)
+ }
+ }
+
+ def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = {
+ val numItems = input.readInt()
+ val set = new LongHashSet
+ var i = 0
+ while (i < numItems) {
+ val value = input.readLong()
+ set.add(value)
+ i += 1
+ }
+ set
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f0c958fdb5..517b77804a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
import org.apache.spark.sql.parquet._
@@ -148,7 +149,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists {
- case _: Sum | _: Count => false
+ case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
+ // The generated set implementation is pretty limited ATM.
+ case CollectHashSet(exprs) if exprs.size == 1 &&
+ Seq(IntegerType, LongType).contains(exprs.head.dataType) => false
case _ => true
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index b08f9aacc1..2890a563be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -92,7 +92,7 @@ trait HashJoin {
private[this] var currentMatchPosition: Int = -1
// Mutable per row objects.
- private[this] val joinRow = new JoinedRow
+ private[this] val joinRow = new JoinedRow2
private[this] val joinKeys = streamSideKeyGenerator()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 0a3b59cbc2..ef4526ec03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -23,7 +23,7 @@ import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
import parquet.schema.MessageType
import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.parquet.CatalystConverter.FieldType
/**
@@ -278,14 +278,14 @@ private[parquet] class CatalystGroupConverter(
*/
private[parquet] class CatalystPrimitiveRowConverter(
protected[parquet] val schema: Array[FieldType],
- protected[parquet] var current: ParquetRelation.RowType)
+ protected[parquet] var current: MutableRow)
extends CatalystConverter {
// This constructor is used for the root converter only
def this(attributes: Array[Attribute]) =
this(
attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)),
- new ParquetRelation.RowType(attributes.length))
+ new SpecificMutableRow(attributes.map(_.dataType)))
protected [parquet] val converters: Array[Converter] =
schema.zipWithIndex.map {
@@ -299,7 +299,7 @@ private[parquet] class CatalystPrimitiveRowConverter(
override val parent = null
// Should be only called in root group converter!
- override def getCurrentRecord: ParquetRelation.RowType = current
+ override def getCurrentRecord: Row = current
override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index f6cfab736d..a5a5d139a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -139,7 +139,7 @@ case class ParquetTableScan(
partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
new Iterator[Row] {
- private[this] val joinedRow = new JoinedRow(Row(partitionRowValues:_*), null)
+ private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null)
def hasNext = iter.hasNext
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 76b1724471..37d64f0de7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -45,16 +45,16 @@ class PlannerSuite extends FunSuite {
assert(aggregations.size === 2)
}
- test("count distinct is not partially aggregated") {
+ test("count distinct is partially aggregated") {
val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed
val planned = HashAggregation(query)
- assert(planned.isEmpty)
+ assert(planned.nonEmpty)
}
- test("mixed aggregates are not partially aggregated") {
+ test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed
val planned = HashAggregation(query)
- assert(planned.isEmpty)
+ assert(planned.nonEmpty)
}
}