aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-11-11 12:48:51 -0800
committerReynold Xin <rxin@databricks.com>2015-11-11 12:48:51 -0800
commita9a6b80c718008aac7c411dfe46355efe58dee2e (patch)
tree12a068af636bd57d6114e3b28bdb161042b257ed /sql
parentdf97df2b39194f60051f78cce23f0ba6cfe4b1df (diff)
downloadspark-a9a6b80c718008aac7c411dfe46355efe58dee2e.tar.gz
spark-a9a6b80c718008aac7c411dfe46355efe58dee2e.tar.bz2
spark-a9a6b80c718008aac7c411dfe46355efe58dee2e.zip
[SPARK-11645][SQL] Remove OpenHashSet for the old aggregate.
Author: Reynold Xin <rxin@databricks.com> Closes #9621 from rxin/SPARK-11645.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala194
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala103
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala11
5 files changed, 5 insertions, 316 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 5a4bba232b..ccd91d3549 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -33,10 +33,6 @@ import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types._
-// These classes are here to avoid issues with serialization and integration with quasiquotes.
-class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
-class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
-
/**
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
*
@@ -205,8 +201,6 @@ class CodeGenContext {
case _: StructType => "InternalRow"
case _: ArrayType => "ArrayData"
case _: MapType => "MapData"
- case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
- case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case udt: UserDefinedType[_] => javaType(udt.sqlType)
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
case ObjectType(cls) => cls.getName
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 9ef2261414..4c17d02a23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -39,7 +39,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case t: ArrayType if canSupport(t.elementType) => true
case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
- case dt: OpenHashSetUDT => false // it's not a standard UDT
case udt: UserDefinedType[_] => canSupport(udt.sqlType)
case _ => false
}
@@ -309,13 +308,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
in.map(BindReferences.bindReference(_, inputSchema))
def generate(
- expressions: Seq[Expression],
- subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
+ expressions: Seq[Expression],
+ subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
create(canonicalize(expressions), subexpressionEliminationEnabled)
}
protected def create(expressions: Seq[Expression]): UnsafeProjection = {
- create(expressions, false)
+ create(expressions, subexpressionEliminationEnabled = false)
}
private def create(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
deleted file mode 100644
index d124d29d53..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala
+++ /dev/null
@@ -1,194 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
-
-/** The data type for expressions returning an OpenHashSet as the result. */
-private[sql] class OpenHashSetUDT(
- val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] {
-
- override def sqlType: DataType = ArrayType(elementType)
-
- /** Since we are using OpenHashSet internally, usually it will not be called. */
- override def serialize(obj: Any): Seq[Any] = {
- obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq
- }
-
- /** Since we are using OpenHashSet internally, usually it will not be called. */
- override def deserialize(datum: Any): OpenHashSet[Any] = {
- val iterator = datum.asInstanceOf[Seq[Any]].iterator
- val set = new OpenHashSet[Any]
- while(iterator.hasNext) {
- set.add(iterator.next())
- }
-
- set
- }
-
- override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]]
-
- private[spark] override def asNullable: OpenHashSetUDT = this
-}
-
-/**
- * Creates a new set of the specified type
- */
-case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback {
-
- override def nullable: Boolean = false
-
- override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType)
-
- override def eval(input: InternalRow): Any = {
- new OpenHashSet[Any]()
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- elementType match {
- case IntegerType | LongType =>
- ev.isNull = "false"
- s"""
- ${ctx.javaType(dataType)} ${ev.value} = new ${ctx.javaType(dataType)}();
- """
- case _ => super.genCode(ctx, ev)
- }
- }
-
- override def toString: String = s"new Set($dataType)"
-}
-
-/**
- * Adds an item to a set.
- * For performance, this expression mutates its input during evaluation.
- * Note: this expression is internal and created only by the GeneratedAggregate,
- * we don't need to do type check for it.
- */
-case class AddItemToSet(item: Expression, set: Expression)
- extends Expression with CodegenFallback {
-
- override def children: Seq[Expression] = item :: set :: Nil
-
- override def nullable: Boolean = set.nullable
-
- override def dataType: DataType = set.dataType
-
- override def eval(input: InternalRow): Any = {
- val itemEval = item.eval(input)
- val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
-
- if (itemEval != null) {
- if (setEval != null) {
- setEval.add(itemEval)
- setEval
- } else {
- null
- }
- } else {
- setEval
- }
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
- elementType match {
- case IntegerType | LongType =>
- val itemEval = item.gen(ctx)
- val setEval = set.gen(ctx)
- val htype = ctx.javaType(dataType)
-
- ev.isNull = "false"
- ev.value = setEval.value
- itemEval.code + setEval.code + s"""
- if (!${itemEval.isNull} && !${setEval.isNull}) {
- (($htype)${setEval.value}).add(${itemEval.value});
- }
- """
- case _ => super.genCode(ctx, ev)
- }
- }
-
- override def toString: String = s"$set += $item"
-}
-
-/**
- * Combines the elements of two sets.
- * For performance, this expression mutates its left input set during evaluation.
- * Note: this expression is internal and created only by the GeneratedAggregate,
- * we don't need to do type check for it.
- */
-case class CombineSets(left: Expression, right: Expression)
- extends BinaryExpression with CodegenFallback {
-
- override def nullable: Boolean = left.nullable
- override def dataType: DataType = left.dataType
-
- override def eval(input: InternalRow): Any = {
- val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]]
- if(leftEval != null) {
- val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]]
- if (rightEval != null) {
- val iterator = rightEval.iterator
- while(iterator.hasNext) {
- val rightValue = iterator.next()
- leftEval.add(rightValue)
- }
- }
- leftEval
- } else {
- null
- }
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
- elementType match {
- case IntegerType | LongType =>
- val leftEval = left.gen(ctx)
- val rightEval = right.gen(ctx)
- val htype = ctx.javaType(dataType)
-
- ev.isNull = leftEval.isNull
- ev.value = leftEval.value
- leftEval.code + rightEval.code + s"""
- if (!${leftEval.isNull} && !${rightEval.isNull}) {
- ${leftEval.value}.union((${htype})${rightEval.value});
- }
- """
- case _ => super.genCode(ctx, ev)
- }
- }
-}
-
-/**
- * Returns the number of elements in the input set.
- * Note: this expression is internal and created only by the GeneratedAggregate,
- * we don't need to do type check for it.
- */
-case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback {
-
- override def dataType: DataType = LongType
-
- protected override def nullSafeEval(input: Any): Any =
- input.asInstanceOf[OpenHashSet[Any]].size.toLong
-
- override def toString: String = s"$child.count()"
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index b19ad4f1c5..8317f648cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -22,19 +22,16 @@ import java.util.{HashMap => JavaHashMap}
import scala.reflect.ClassTag
-import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.twitter.chill.ResourcePool
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.MutablePair
-import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.{SparkConf, SparkEnv}
+
private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
val kryo = super.newKryo()
@@ -43,16 +40,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
- kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
- new HyperLogLogSerializer)
kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer)
kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer)
- // Specific hashsets must come first TODO: Move to core.
- kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer)
- kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
- kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
- new OpenHashSetSerializer)
kryo.register(classOf[Decimal])
kryo.register(classOf[JavaHashMap[_, _]])
@@ -62,7 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
}
private[execution] class KryoResourcePool(size: Int)
- extends ResourcePool[SerializerInstance](size) {
+ extends ResourcePool[SerializerInstance](size) {
val ser: SparkSqlSerializer = {
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
@@ -116,92 +106,3 @@ private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] {
new java.math.BigDecimal(input.readString())
}
}
-
-private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
- def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
- val bytes = hyperLogLog.getBytes()
- output.writeInt(bytes.length)
- output.writeBytes(bytes)
- }
-
- def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
- val length = input.readInt()
- val bytes = input.readBytes(length)
- HyperLogLog.Builder.build(bytes)
- }
-}
-
-private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
- def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) {
- val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
- output.writeInt(hs.size)
- val iterator = hs.iterator
- while(iterator.hasNext) {
- val row = iterator.next()
- rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values)
- }
- }
-
- def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = {
- val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
- val numItems = input.readInt()
- val set = new OpenHashSet[Any](numItems + 1)
- var i = 0
- while (i < numItems) {
- val row =
- new GenericInternalRow(rowSerializer.read(
- kryo,
- input,
- classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])
- set.add(row)
- i += 1
- }
- set
- }
-}
-
-private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] {
- def write(kryo: Kryo, output: Output, hs: IntegerHashSet) {
- output.writeInt(hs.size)
- val iterator = hs.iterator
- while(iterator.hasNext) {
- val value: Int = iterator.next()
- output.writeInt(value)
- }
- }
-
- def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = {
- val numItems = input.readInt()
- val set = new IntegerHashSet
- var i = 0
- while (i < numItems) {
- val value = input.readInt()
- set.add(value)
- i += 1
- }
- set
- }
-}
-
-private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] {
- def write(kryo: Kryo, output: Output, hs: LongHashSet) {
- output.writeInt(hs.size)
- val iterator = hs.iterator
- while(iterator.hasNext) {
- val value = iterator.next()
- output.writeLong(value)
- }
- }
-
- def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = {
- val numItems = input.readInt()
- val set = new LongHashSet
- var i = 0
- while (i < numItems) {
- val value = input.readLong()
- set.add(value)
- i += 1
- }
- set
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index e31c528f3a..f602f2fb89 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -23,7 +23,6 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.OpenHashSetUDT
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -131,15 +130,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
}
- test("OpenHashSetUDT") {
- val openHashSetUDT = new OpenHashSetUDT(IntegerType)
- val set = new OpenHashSet[Int]
- (1 to 10).foreach(i => set.add(i))
-
- val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set))
- assert(actual.iterator.toSet === set.iterator.toSet)
- }
-
test("UDTs with JSON") {
val data = Seq(
"{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}",
@@ -163,7 +153,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
test("SPARK-10472 UserDefinedType.typeName") {
assert(IntegerType.typeName === "integer")
assert(new MyDenseVectorUDT().typeName === "mydensevector")
- assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset")
}
test("Catalyst type converter null handling for UDTs") {