diff options
11 files changed, 100 insertions, 56 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 83b034fe77..bf43452e08 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -551,8 +551,8 @@ class DataFrame(object): >>> df_as1 = df.alias("df_as1") >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') - >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() - [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)] + >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").collect() + [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] """ assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 35df2429db..8095083f33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -316,6 +316,21 @@ abstract class UnaryNode extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override protected def validConstraints: Set[Expression] = child.constraints + + override def statistics: Statistics = { + // There should be some overhead in Row object, the size should not be zero when there is + // no columns, this help to prevent divide-by-zero error. + val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 + val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 + // Assume there will be the same number of rows as child has. + var sizeInBytes = (child.statistics.sizeInBytes * outputRowSize) / childRowSize + if (sizeInBytes == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + sizeInBytes = 1 + } + Statistics(sizeInBytes = sizeInBytes) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index c98d33d5a4..af43cb3786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -176,6 +176,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation Some(children.flatMap(_.maxRows).min) } } + + override def statistics: Statistics = { + val leftSize = left.statistics.sizeInBytes + val rightSize = right.statistics.sizeInBytes + val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize + Statistics(sizeInBytes = sizeInBytes) + } } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { @@ -188,6 +195,10 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le childrenResolved && left.output.length == right.output.length && left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } + + override def statistics: Statistics = { + Statistics(sizeInBytes = left.statistics.sizeInBytes) + } } /** Factory for constructing new `Union` nodes. */ @@ -426,6 +437,14 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows + + override def statistics: Statistics = { + if (groupingExpressions.isEmpty) { + Statistics(sizeInBytes = 1) + } else { + super.statistics + } + } } case class Window( @@ -521,9 +540,7 @@ case class Expand( AttributeSet(projections.flatten.flatMap(_.references)) override def statistics: Statistics = { - // TODO shouldn't we factor in the size of the projection versus the size of the backing child - // row? - val sizeInBytes = child.statistics.sizeInBytes * projections.length + val sizeInBytes = super.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } } @@ -648,6 +665,17 @@ case class Sample( val isTableSample: java.lang.Boolean = false) extends UnaryNode { override def output: Seq[Attribute] = child.output + + override def statistics: Statistics = { + val ratio = upperBound - lowerBound + // BigInt can't multiply with Double + var sizeInBytes = child.statistics.sizeInBytes * (ratio * 100).toInt / 100 + if (sizeInBytes == 0) { + sizeInBytes = 1 + } + Statistics(sizeInBytes = sizeInBytes) + } + override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala index f2c6f34ea5..c40e140e8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -47,9 +47,9 @@ class BinaryType private() extends AtomicType { } /** - * The default size of a value of the BinaryType is 4096 bytes. + * The default size of a value of the BinaryType is 100 bytes. */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = 100 private[spark] override def asNullable: BinaryType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 71ea5b8841..2e03ddae76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -91,9 +91,9 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { } /** - * The default size of a value of the DecimalType is 4096 bytes. + * The default size of a value of the DecimalType is 8 bytes (precision <= 18) or 16 bytes. */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = if (precision <= Decimal.MAX_LONG_DIGITS) 8 else 16 override def simpleString: String = s"decimal($precision,$scale)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala index a7627a2de1..44a25361f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -38,9 +38,9 @@ class StringType private() extends AtomicType { private[sql] val ordering = implicitly[Ordering[InternalType]] /** - * The default size of a value of the StringType is 4096 bytes. + * The default size of a value of the StringType is 20 bytes. */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = 20 private[spark] override def asNullable: StringType = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 7664c30ee7..9d2449f3b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -71,10 +71,7 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { */ def userClass: java.lang.Class[UserType] - /** - * The default size of a value of the UserDefinedType is 4096 bytes. - */ - override def defaultSize: Int = 4096 + override def defaultSize: Int = sqlType.defaultSize /** * For UDT, asNullable will not change the nullability of its internal sqlType and just returns diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index c2bbca7c33..6b85f12521 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -248,15 +248,15 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(LongType, 8) checkDefaultSize(FloatType, 4) checkDefaultSize(DoubleType, 8) - checkDefaultSize(DecimalType(10, 5), 4096) - checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096) + checkDefaultSize(DecimalType(10, 5), 8) + checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 16) checkDefaultSize(DateType, 4) checkDefaultSize(TimestampType, 8) - checkDefaultSize(StringType, 4096) - checkDefaultSize(BinaryType, 4096) + checkDefaultSize(StringType, 20) + checkDefaultSize(BinaryType, 100) checkDefaultSize(ArrayType(DoubleType, true), 800) - checkDefaultSize(ArrayType(StringType, false), 409600) - checkDefaultSize(MapType(IntegerType, StringType, true), 410000) + checkDefaultSize(ArrayType(StringType, false), 2000) + checkDefaultSize(MapType(IntegerType, StringType, true), 2400) checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400) checkDefaultSize(structType, 812) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index d26a0b7467..a9e77abbda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.collection.immutable.IndexedSeq - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8f2a0c0351..a1211e4380 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -63,36 +63,40 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("join operator selection") { sqlContext.cacheManager.clearCache() - Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), - ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), - ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[SortMergeJoin]), // converted from Right Outer to Inner - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]), - ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]), - ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", + classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", + classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[SortMergeJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData full outer join testData2 ON key = a", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", + classOf[BroadcastNestedLoopJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } } // ignore("SortMergeJoin shouldn't work on unsortable columns") { @@ -118,9 +122,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("broadcasted hash outer join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") + sql("CACHE TABLE testData2") Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", - classOf[SortMergeOuterJoin]), + classOf[BroadcastHashJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b4d6f4ecdd..1d9db27e09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -22,7 +22,7 @@ import java.sql.Timestamp import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ @@ -780,8 +780,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(cp.isEmpty, "should not use CartesianProduct for null-safe join") val smj = df.queryExecution.sparkPlan.collect { case smj: SortMergeJoin => smj + case j: BroadcastHashJoin => j } - assert(smj.size > 0, "should use SortMergeJoin") + assert(smj.size > 0, "should use SortMergeJoin or BroadcastHashJoin") checkAnswer(df, Row(100) :: Nil) } |