diff options
Diffstat (limited to 'sql')
5 files changed, 201 insertions, 48 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 8328278544..e2f5c7332d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType /** * The data type representing [[DynamicRow]] values. */ -case object DynamicType extends DataType +case object DynamicType extends DataType { + + /** + * The default size of a value of the DynamicType is 4096 bytes. + */ + override def defaultSize: Int = 4096 +} /** * Wrap a [[Row]] as a [[DynamicRow]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 1483beacc9..9628e93274 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -238,16 +238,11 @@ case class Rollup( case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output = child.output - override lazy val statistics: Statistics = - if (output.forall(_.dataType.isInstanceOf[NativeType])) { - val limit = limitExpr.eval(null).asInstanceOf[Int] - val sizeInBytes = (limit: Long) * output.map { a => - NativeType.defaultSizeOf(a.dataType.asInstanceOf[NativeType]) - }.sum - Statistics(sizeInBytes = sizeInBytes) - } else { - Statistics(sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product) - } + override lazy val statistics: Statistics = { + val limit = limitExpr.eval(null).asInstanceOf[Int] + val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum + Statistics(sizeInBytes = sizeInBytes) + } } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index bcd74603d4..9f30f40a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -215,6 +215,9 @@ abstract class DataType { case _ => false } + /** The default size of a value of this data type. */ + def defaultSize: Int + def isPrimitive: Boolean = false def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase @@ -235,33 +238,25 @@ abstract class DataType { * @group dataType */ @DeveloperApi -case object NullType extends DataType +case object NullType extends DataType { + override def defaultSize: Int = 1 +} -object NativeType { +protected[sql] object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) def unapply(dt: DataType): Boolean = all.contains(dt) - - val defaultSizeOf: Map[NativeType, Int] = Map( - IntegerType -> 4, - BooleanType -> 1, - LongType -> 8, - DoubleType -> 8, - FloatType -> 4, - ShortType -> 2, - ByteType -> 1, - StringType -> 4096) } -trait PrimitiveType extends DataType { +protected[sql] trait PrimitiveType extends DataType { override def isPrimitive = true } -object PrimitiveType { +protected[sql] object PrimitiveType { private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap @@ -276,7 +271,7 @@ object PrimitiveType { } } -abstract class NativeType extends DataType { +protected[sql] abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] private[sql] val ordering: Ordering[JvmType] @@ -300,6 +295,11 @@ case object StringType extends NativeType with PrimitiveType { private[sql] type JvmType = String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the StringType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } @@ -324,6 +324,11 @@ case object BinaryType extends NativeType with PrimitiveType { x.length - y.length } } + + /** + * The default size of a value of the BinaryType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } @@ -339,6 +344,11 @@ case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the BooleanType is 1 byte. + */ + override def defaultSize: Int = 1 } @@ -359,6 +369,11 @@ case object TimestampType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } + + /** + * The default size of a value of the TimestampType is 8 bytes. + */ + override def defaultSize: Int = 8 } @@ -379,10 +394,15 @@ case object DateType extends NativeType { private[sql] val ordering = new Ordering[JvmType] { def compare(x: Date, y: Date) = x.compareTo(y) } + + /** + * The default size of a value of the DateType is 8 bytes. + */ + override def defaultSize: Int = 8 } -abstract class NumericType extends NativeType with PrimitiveType { +protected[sql] abstract class NumericType extends NativeType with PrimitiveType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets @@ -392,13 +412,13 @@ abstract class NumericType extends NativeType with PrimitiveType { } -object NumericType { +protected[sql] object NumericType { def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] } /** Matcher for any expressions that evaluate to [[IntegralType]]s */ -object IntegralType { +protected[sql] object IntegralType { def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[IntegralType] => true case _ => false @@ -406,7 +426,7 @@ object IntegralType { } -sealed abstract class IntegralType extends NumericType { +protected[sql] sealed abstract class IntegralType extends NumericType { private[sql] val integral: Integral[JvmType] } @@ -425,6 +445,11 @@ case object LongType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the LongType is 8 bytes. + */ + override def defaultSize: Int = 8 } @@ -442,6 +467,11 @@ case object IntegerType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the IntegerType is 4 bytes. + */ + override def defaultSize: Int = 4 } @@ -459,6 +489,11 @@ case object ShortType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the ShortType is 2 bytes. + */ + override def defaultSize: Int = 2 } @@ -476,11 +511,16 @@ case object ByteType extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[JvmType]] + + /** + * The default size of a value of the ByteType is 1 byte. + */ + override def defaultSize: Int = 1 } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ -object FractionalType { +protected[sql] object FractionalType { def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[FractionalType] => true case _ => false @@ -488,7 +528,7 @@ object FractionalType { } -sealed abstract class FractionalType extends NumericType { +protected[sql] sealed abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] private[sql] val asIntegral: Integral[JvmType] } @@ -530,6 +570,11 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" case None => "DecimalType()" } + + /** + * The default size of a value of the DecimalType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } @@ -580,6 +625,11 @@ case object DoubleType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = DoubleAsIfIntegral + + /** + * The default size of a value of the DoubleType is 8 bytes. + */ + override def defaultSize: Int = 8 } @@ -598,6 +648,11 @@ case object FloatType extends FractionalType { private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] private[sql] val asIntegral = FloatAsIfIntegral + + /** + * The default size of a value of the FloatType is 4 bytes. + */ + override def defaultSize: Int = 4 } @@ -636,6 +691,12 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT ("type" -> typeName) ~ ("elementType" -> elementType.jsonValue) ~ ("containsNull" -> containsNull) + + /** + * The default size of a value of the ArrayType is 100 * the default size of the element type. + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * elementType.defaultSize } @@ -805,6 +866,11 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def length: Int = fields.length override def iterator: Iterator[StructField] = fields.iterator + + /** + * The default size of a value of the StructType is the total default sizes of all field types. + */ + override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum } @@ -848,6 +914,13 @@ case class MapType( ("keyType" -> keyType.jsonValue) ~ ("valueType" -> valueType.jsonValue) ~ ("valueContainsNull" -> valueContainsNull) + + /** + * The default size of a value of the MapType is + * 100 * (the default size of the key type + the default size of the value type). + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) } @@ -896,4 +969,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { * Class object for the UserType */ def userClass: java.lang.Class[UserType] + + /** + * The default size of a value of the UserDefinedType is 4096 bytes. + */ + override def defaultSize: Int = 4096 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 892195f46e..c147be9f6b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -62,6 +62,7 @@ class DataTypeSuite extends FunSuite { } } + checkDataTypeJsonRepr(NullType) checkDataTypeJsonRepr(BooleanType) checkDataTypeJsonRepr(ByteType) checkDataTypeJsonRepr(ShortType) @@ -69,7 +70,9 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(LongType) checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) + checkDataTypeJsonRepr(DecimalType(10, 5)) checkDataTypeJsonRepr(DecimalType.Unlimited) + checkDataTypeJsonRepr(DateType) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) checkDataTypeJsonRepr(BinaryType) @@ -77,12 +80,39 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(ArrayType(StringType, false)) checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + val metadata = new MetadataBuilder() .putString("name", "age") .build() - checkDataTypeJsonRepr( - StructType(Seq( - StructField("a", IntegerType, nullable = true), - StructField("b", ArrayType(DoubleType), nullable = false), - StructField("c", DoubleType, nullable = false, metadata)))) + val structType = StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false), + StructField("c", DoubleType, nullable = false, metadata))) + checkDataTypeJsonRepr(structType) + + def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { + test(s"Check the default size of ${dataType}") { + assert(dataType.defaultSize === expectedDefaultSize) + } + } + + checkDefaultSize(NullType, 1) + checkDefaultSize(BooleanType, 1) + checkDefaultSize(ByteType, 1) + checkDefaultSize(ShortType, 2) + checkDefaultSize(IntegerType, 4) + checkDefaultSize(LongType, 8) + checkDefaultSize(FloatType, 4) + checkDefaultSize(DoubleType, 8) + checkDefaultSize(DecimalType(10, 5), 4096) + checkDefaultSize(DecimalType.Unlimited, 4096) + checkDefaultSize(DateType, 8) + checkDefaultSize(TimestampType, 8) + checkDefaultSize(StringType, 4096) + checkDefaultSize(BinaryType, 4096) + checkDefaultSize(ArrayType(DoubleType, true), 800) + checkDefaultSize(ArrayType(StringType, false), 409600) + checkDefaultSize(MapType(IntegerType, StringType, true), 410000) + checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400) + checkDefaultSize(structType, 812) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c5b6fce5fd..67007b8c09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ +import org.apache.spark.sql.types._ class PlannerSuite extends FunSuite { test("unions are collapsed") { @@ -60,19 +61,62 @@ class PlannerSuite extends FunSuite { } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { - val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) - - // Using a threshold that is definitely larger than the small testing table (b) below - val a = testData.as('a) - val b = testData.limit(3).as('b) - val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString) + val fields = fieldTypes.zipWithIndex.map { + case (dataType, index) => StructField(s"c${index}", dataType, true) + } :+ StructField("key", IntegerType, true) + val schema = StructType(fields) + val row = Row.fromSeq(Seq.fill(fields.size)(null)) + val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil) + applySchema(rowRDD, schema).registerTempTable("testLimit") + + val planned = sql( + """ + |SELECT l.a, l.b + |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) + """.stripMargin).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + + dropTempTable("testLimit") + } - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val origThreshold = conf.autoBroadcastJoinThreshold - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + val simpleTypes = + NullType :: + BooleanType :: + ByteType :: + ShortType :: + IntegerType :: + LongType :: + FloatType :: + DoubleType :: + DecimalType(10, 5) :: + DecimalType.Unlimited :: + DateType :: + TimestampType :: + StringType :: + BinaryType :: Nil + + checkPlan(simpleTypes, newThreshold = 16434) + + val complexTypes = + ArrayType(DoubleType, true) :: + ArrayType(StringType, false) :: + MapType(IntegerType, StringType, true) :: + MapType(IntegerType, ArrayType(DoubleType), false) :: + StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", ArrayType(DoubleType), nullable = false), + StructField("c", DoubleType, nullable = false))) :: Nil + + checkPlan(complexTypes, newThreshold = 901617) setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) } |