aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-01-20 13:26:36 -0800
committerReynold Xin <rxin@databricks.com>2015-01-20 13:26:36 -0800
commitbc20a52b34e826895d0dcc1d783c021ebd456ebd (patch)
tree894ec677bda57e58c2c5d088ddc8a05261ecfec5 /sql
parent23e25543beaa5966b5f07365f338ce338fd6d71f (diff)
downloadspark-bc20a52b34e826895d0dcc1d783c021ebd456ebd.tar.gz
spark-bc20a52b34e826895d0dcc1d783c021ebd456ebd.tar.bz2
spark-bc20a52b34e826895d0dcc1d783c021ebd456ebd.zip
[SPARK-5287][SQL] Add defaultSizeOf to every data type.
JIRA: https://issues.apache.org/jira/browse/SPARK-5287 This PR only add `defaultSizeOf` to data types and make those internal type classes `protected[sql]`. I will use another PR to cleanup the type hierarchy of data types. Author: Yin Huai <yhuai@databricks.com> Closes #4081 from yhuai/SPARK-5287 and squashes the following commits: 90cec75 [Yin Huai] Update unit test. e1c600c [Yin Huai] Make internal classes protected[sql]. 7eaba68 [Yin Huai] Add `defaultSize` method to data types. fd425e0 [Yin Huai] Add all native types to NativeType.defaultSizeOf.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala120
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala66
5 files changed, 201 insertions, 48 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index 8328278544..e2f5c7332d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType
/**
* The data type representing [[DynamicRow]] values.
*/
-case object DynamicType extends DataType
+case object DynamicType extends DataType {
+
+ /**
+ * The default size of a value of the DynamicType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
+}
/**
* Wrap a [[Row]] as a [[DynamicRow]].
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 1483beacc9..9628e93274 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -238,16 +238,11 @@ case class Rollup(
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output = child.output
- override lazy val statistics: Statistics =
- if (output.forall(_.dataType.isInstanceOf[NativeType])) {
- val limit = limitExpr.eval(null).asInstanceOf[Int]
- val sizeInBytes = (limit: Long) * output.map { a =>
- NativeType.defaultSizeOf(a.dataType.asInstanceOf[NativeType])
- }.sum
- Statistics(sizeInBytes = sizeInBytes)
- } else {
- Statistics(sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product)
- }
+ override lazy val statistics: Statistics = {
+ val limit = limitExpr.eval(null).asInstanceOf[Int]
+ val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
+ Statistics(sizeInBytes = sizeInBytes)
+ }
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index bcd74603d4..9f30f40a17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -215,6 +215,9 @@ abstract class DataType {
case _ => false
}
+ /** The default size of a value of this data type. */
+ def defaultSize: Int
+
def isPrimitive: Boolean = false
def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
@@ -235,33 +238,25 @@ abstract class DataType {
* @group dataType
*/
@DeveloperApi
-case object NullType extends DataType
+case object NullType extends DataType {
+ override def defaultSize: Int = 1
+}
-object NativeType {
+protected[sql] object NativeType {
val all = Seq(
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
def unapply(dt: DataType): Boolean = all.contains(dt)
-
- val defaultSizeOf: Map[NativeType, Int] = Map(
- IntegerType -> 4,
- BooleanType -> 1,
- LongType -> 8,
- DoubleType -> 8,
- FloatType -> 4,
- ShortType -> 2,
- ByteType -> 1,
- StringType -> 4096)
}
-trait PrimitiveType extends DataType {
+protected[sql] trait PrimitiveType extends DataType {
override def isPrimitive = true
}
-object PrimitiveType {
+protected[sql] object PrimitiveType {
private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all
private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap
@@ -276,7 +271,7 @@ object PrimitiveType {
}
}
-abstract class NativeType extends DataType {
+protected[sql] abstract class NativeType extends DataType {
private[sql] type JvmType
@transient private[sql] val tag: TypeTag[JvmType]
private[sql] val ordering: Ordering[JvmType]
@@ -300,6 +295,11 @@ case object StringType extends NativeType with PrimitiveType {
private[sql] type JvmType = String
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the StringType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
}
@@ -324,6 +324,11 @@ case object BinaryType extends NativeType with PrimitiveType {
x.length - y.length
}
}
+
+ /**
+ * The default size of a value of the BinaryType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
}
@@ -339,6 +344,11 @@ case object BooleanType extends NativeType with PrimitiveType {
private[sql] type JvmType = Boolean
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the BooleanType is 1 byte.
+ */
+ override def defaultSize: Int = 1
}
@@ -359,6 +369,11 @@ case object TimestampType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
}
+
+ /**
+ * The default size of a value of the TimestampType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
}
@@ -379,10 +394,15 @@ case object DateType extends NativeType {
private[sql] val ordering = new Ordering[JvmType] {
def compare(x: Date, y: Date) = x.compareTo(y)
}
+
+ /**
+ * The default size of a value of the DateType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
}
-abstract class NumericType extends NativeType with PrimitiveType {
+protected[sql] abstract class NumericType extends NativeType with PrimitiveType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
// type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
@@ -392,13 +412,13 @@ abstract class NumericType extends NativeType with PrimitiveType {
}
-object NumericType {
+protected[sql] object NumericType {
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
}
/** Matcher for any expressions that evaluate to [[IntegralType]]s */
-object IntegralType {
+protected[sql] object IntegralType {
def unapply(a: Expression): Boolean = a match {
case e: Expression if e.dataType.isInstanceOf[IntegralType] => true
case _ => false
@@ -406,7 +426,7 @@ object IntegralType {
}
-sealed abstract class IntegralType extends NumericType {
+protected[sql] sealed abstract class IntegralType extends NumericType {
private[sql] val integral: Integral[JvmType]
}
@@ -425,6 +445,11 @@ case object LongType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Long]]
private[sql] val integral = implicitly[Integral[Long]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the LongType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
}
@@ -442,6 +467,11 @@ case object IntegerType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Int]]
private[sql] val integral = implicitly[Integral[Int]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the IntegerType is 4 bytes.
+ */
+ override def defaultSize: Int = 4
}
@@ -459,6 +489,11 @@ case object ShortType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Short]]
private[sql] val integral = implicitly[Integral[Short]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the ShortType is 2 bytes.
+ */
+ override def defaultSize: Int = 2
}
@@ -476,11 +511,16 @@ case object ByteType extends IntegralType {
private[sql] val numeric = implicitly[Numeric[Byte]]
private[sql] val integral = implicitly[Integral[Byte]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
+
+ /**
+ * The default size of a value of the ByteType is 1 byte.
+ */
+ override def defaultSize: Int = 1
}
/** Matcher for any expressions that evaluate to [[FractionalType]]s */
-object FractionalType {
+protected[sql] object FractionalType {
def unapply(a: Expression): Boolean = a match {
case e: Expression if e.dataType.isInstanceOf[FractionalType] => true
case _ => false
@@ -488,7 +528,7 @@ object FractionalType {
}
-sealed abstract class FractionalType extends NumericType {
+protected[sql] sealed abstract class FractionalType extends NumericType {
private[sql] val fractional: Fractional[JvmType]
private[sql] val asIntegral: Integral[JvmType]
}
@@ -530,6 +570,11 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
case None => "DecimalType()"
}
+
+ /**
+ * The default size of a value of the DecimalType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
}
@@ -580,6 +625,11 @@ case object DoubleType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = DoubleAsIfIntegral
+
+ /**
+ * The default size of a value of the DoubleType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
}
@@ -598,6 +648,11 @@ case object FloatType extends FractionalType {
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering = implicitly[Ordering[JvmType]]
private[sql] val asIntegral = FloatAsIfIntegral
+
+ /**
+ * The default size of a value of the FloatType is 4 bytes.
+ */
+ override def defaultSize: Int = 4
}
@@ -636,6 +691,12 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
("type" -> typeName) ~
("elementType" -> elementType.jsonValue) ~
("containsNull" -> containsNull)
+
+ /**
+ * The default size of a value of the ArrayType is 100 * the default size of the element type.
+ * (We assume that there are 100 elements).
+ */
+ override def defaultSize: Int = 100 * elementType.defaultSize
}
@@ -805,6 +866,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
override def length: Int = fields.length
override def iterator: Iterator[StructField] = fields.iterator
+
+ /**
+ * The default size of a value of the StructType is the total default sizes of all field types.
+ */
+ override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
}
@@ -848,6 +914,13 @@ case class MapType(
("keyType" -> keyType.jsonValue) ~
("valueType" -> valueType.jsonValue) ~
("valueContainsNull" -> valueContainsNull)
+
+ /**
+ * The default size of a value of the MapType is
+ * 100 * (the default size of the key type + the default size of the value type).
+ * (We assume that there are 100 elements).
+ */
+ override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
}
@@ -896,4 +969,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
* Class object for the UserType
*/
def userClass: java.lang.Class[UserType]
+
+ /**
+ * The default size of a value of the UserDefinedType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 892195f46e..c147be9f6b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -62,6 +62,7 @@ class DataTypeSuite extends FunSuite {
}
}
+ checkDataTypeJsonRepr(NullType)
checkDataTypeJsonRepr(BooleanType)
checkDataTypeJsonRepr(ByteType)
checkDataTypeJsonRepr(ShortType)
@@ -69,7 +70,9 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(LongType)
checkDataTypeJsonRepr(FloatType)
checkDataTypeJsonRepr(DoubleType)
+ checkDataTypeJsonRepr(DecimalType(10, 5))
checkDataTypeJsonRepr(DecimalType.Unlimited)
+ checkDataTypeJsonRepr(DateType)
checkDataTypeJsonRepr(TimestampType)
checkDataTypeJsonRepr(StringType)
checkDataTypeJsonRepr(BinaryType)
@@ -77,12 +80,39 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(ArrayType(StringType, false))
checkDataTypeJsonRepr(MapType(IntegerType, StringType, true))
checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false))
+
val metadata = new MetadataBuilder()
.putString("name", "age")
.build()
- checkDataTypeJsonRepr(
- StructType(Seq(
- StructField("a", IntegerType, nullable = true),
- StructField("b", ArrayType(DoubleType), nullable = false),
- StructField("c", DoubleType, nullable = false, metadata))))
+ val structType = StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", ArrayType(DoubleType), nullable = false),
+ StructField("c", DoubleType, nullable = false, metadata)))
+ checkDataTypeJsonRepr(structType)
+
+ def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = {
+ test(s"Check the default size of ${dataType}") {
+ assert(dataType.defaultSize === expectedDefaultSize)
+ }
+ }
+
+ checkDefaultSize(NullType, 1)
+ checkDefaultSize(BooleanType, 1)
+ checkDefaultSize(ByteType, 1)
+ checkDefaultSize(ShortType, 2)
+ checkDefaultSize(IntegerType, 4)
+ checkDefaultSize(LongType, 8)
+ checkDefaultSize(FloatType, 4)
+ checkDefaultSize(DoubleType, 8)
+ checkDefaultSize(DecimalType(10, 5), 4096)
+ checkDefaultSize(DecimalType.Unlimited, 4096)
+ checkDefaultSize(DateType, 8)
+ checkDefaultSize(TimestampType, 8)
+ checkDefaultSize(StringType, 4096)
+ checkDefaultSize(BinaryType, 4096)
+ checkDefaultSize(ArrayType(DoubleType, true), 800)
+ checkDefaultSize(ArrayType(StringType, false), 409600)
+ checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
+ checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
+ checkDefaultSize(structType, 812)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index c5b6fce5fd..67007b8c09 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.planner._
+import org.apache.spark.sql.types._
class PlannerSuite extends FunSuite {
test("unions are collapsed") {
@@ -60,19 +61,62 @@ class PlannerSuite extends FunSuite {
}
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
- val origThreshold = conf.autoBroadcastJoinThreshold
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString)
-
- // Using a threshold that is definitely larger than the small testing table (b) below
- val a = testData.as('a)
- val b = testData.limit(3).as('b)
- val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan
+ def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
+ setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString)
+ val fields = fieldTypes.zipWithIndex.map {
+ case (dataType, index) => StructField(s"c${index}", dataType, true)
+ } :+ StructField("key", IntegerType, true)
+ val schema = StructType(fields)
+ val row = Row.fromSeq(Seq.fill(fields.size)(null))
+ val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
+ applySchema(rowRDD, schema).registerTempTable("testLimit")
+
+ val planned = sql(
+ """
+ |SELECT l.a, l.b
+ |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key)
+ """.stripMargin).queryExecution.executedPlan
+
+ val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
+ val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
+
+ assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
+ assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
+
+ dropTempTable("testLimit")
+ }
- val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
- val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join }
+ val origThreshold = conf.autoBroadcastJoinThreshold
- assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
- assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
+ val simpleTypes =
+ NullType ::
+ BooleanType ::
+ ByteType ::
+ ShortType ::
+ IntegerType ::
+ LongType ::
+ FloatType ::
+ DoubleType ::
+ DecimalType(10, 5) ::
+ DecimalType.Unlimited ::
+ DateType ::
+ TimestampType ::
+ StringType ::
+ BinaryType :: Nil
+
+ checkPlan(simpleTypes, newThreshold = 16434)
+
+ val complexTypes =
+ ArrayType(DoubleType, true) ::
+ ArrayType(StringType, false) ::
+ MapType(IntegerType, StringType, true) ::
+ MapType(IntegerType, ArrayType(DoubleType), false) ::
+ StructType(Seq(
+ StructField("a", IntegerType, nullable = true),
+ StructField("b", ArrayType(DoubleType), nullable = false),
+ StructField("c", DoubleType, nullable = false))) :: Nil
+
+ checkPlan(complexTypes, newThreshold = 901617)
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString)
}