aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala67
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala5
10 files changed, 98 insertions, 54 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 35df2429db..8095083f33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -316,6 +316,21 @@ abstract class UnaryNode extends LogicalPlan {
override def children: Seq[LogicalPlan] = child :: Nil
override protected def validConstraints: Set[Expression] = child.constraints
+
+ override def statistics: Statistics = {
+ // There should be some overhead in Row object, the size should not be zero when there is
+ // no columns, this help to prevent divide-by-zero error.
+ val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8
+ val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
+ // Assume there will be the same number of rows as child has.
+ var sizeInBytes = (child.statistics.sizeInBytes * outputRowSize) / childRowSize
+ if (sizeInBytes == 0) {
+ // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
+ // (product of children).
+ sizeInBytes = 1
+ }
+ Statistics(sizeInBytes = sizeInBytes)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index c98d33d5a4..af43cb3786 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -176,6 +176,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
Some(children.flatMap(_.maxRows).min)
}
}
+
+ override def statistics: Statistics = {
+ val leftSize = left.statistics.sizeInBytes
+ val rightSize = right.statistics.sizeInBytes
+ val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize
+ Statistics(sizeInBytes = sizeInBytes)
+ }
}
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
@@ -188,6 +195,10 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
childrenResolved &&
left.output.length == right.output.length &&
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
+
+ override def statistics: Statistics = {
+ Statistics(sizeInBytes = left.statistics.sizeInBytes)
+ }
}
/** Factory for constructing new `Union` nodes. */
@@ -426,6 +437,14 @@ case class Aggregate(
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows
+
+ override def statistics: Statistics = {
+ if (groupingExpressions.isEmpty) {
+ Statistics(sizeInBytes = 1)
+ } else {
+ super.statistics
+ }
+ }
}
case class Window(
@@ -521,9 +540,7 @@ case class Expand(
AttributeSet(projections.flatten.flatMap(_.references))
override def statistics: Statistics = {
- // TODO shouldn't we factor in the size of the projection versus the size of the backing child
- // row?
- val sizeInBytes = child.statistics.sizeInBytes * projections.length
+ val sizeInBytes = super.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}
}
@@ -648,6 +665,17 @@ case class Sample(
val isTableSample: java.lang.Boolean = false) extends UnaryNode {
override def output: Seq[Attribute] = child.output
+
+ override def statistics: Statistics = {
+ val ratio = upperBound - lowerBound
+ // BigInt can't multiply with Double
+ var sizeInBytes = child.statistics.sizeInBytes * (ratio * 100).toInt / 100
+ if (sizeInBytes == 0) {
+ sizeInBytes = 1
+ }
+ Statistics(sizeInBytes = sizeInBytes)
+ }
+
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
index f2c6f34ea5..c40e140e8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
@@ -47,9 +47,9 @@ class BinaryType private() extends AtomicType {
}
/**
- * The default size of a value of the BinaryType is 4096 bytes.
+ * The default size of a value of the BinaryType is 100 bytes.
*/
- override def defaultSize: Int = 4096
+ override def defaultSize: Int = 100
private[spark] override def asNullable: BinaryType = this
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 71ea5b8841..2e03ddae76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -91,9 +91,9 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
}
/**
- * The default size of a value of the DecimalType is 4096 bytes.
+ * The default size of a value of the DecimalType is 8 bytes (precision <= 18) or 16 bytes.
*/
- override def defaultSize: Int = 4096
+ override def defaultSize: Int = if (precision <= Decimal.MAX_LONG_DIGITS) 8 else 16
override def simpleString: String = s"decimal($precision,$scale)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
index a7627a2de1..44a25361f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -38,9 +38,9 @@ class StringType private() extends AtomicType {
private[sql] val ordering = implicitly[Ordering[InternalType]]
/**
- * The default size of a value of the StringType is 4096 bytes.
+ * The default size of a value of the StringType is 20 bytes.
*/
- override def defaultSize: Int = 4096
+ override def defaultSize: Int = 20
private[spark] override def asNullable: StringType = this
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 7664c30ee7..9d2449f3b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -71,10 +71,7 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
*/
def userClass: java.lang.Class[UserType]
- /**
- * The default size of a value of the UserDefinedType is 4096 bytes.
- */
- override def defaultSize: Int = 4096
+ override def defaultSize: Int = sqlType.defaultSize
/**
* For UDT, asNullable will not change the nullability of its internal sqlType and just returns
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index c2bbca7c33..6b85f12521 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -248,15 +248,15 @@ class DataTypeSuite extends SparkFunSuite {
checkDefaultSize(LongType, 8)
checkDefaultSize(FloatType, 4)
checkDefaultSize(DoubleType, 8)
- checkDefaultSize(DecimalType(10, 5), 4096)
- checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096)
+ checkDefaultSize(DecimalType(10, 5), 8)
+ checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 16)
checkDefaultSize(DateType, 4)
checkDefaultSize(TimestampType, 8)
- checkDefaultSize(StringType, 4096)
- checkDefaultSize(BinaryType, 4096)
+ checkDefaultSize(StringType, 20)
+ checkDefaultSize(BinaryType, 100)
checkDefaultSize(ArrayType(DoubleType, true), 800)
- checkDefaultSize(ArrayType(StringType, false), 409600)
- checkDefaultSize(MapType(IntegerType, StringType, true), 410000)
+ checkDefaultSize(ArrayType(StringType, false), 2000)
+ checkDefaultSize(MapType(IntegerType, StringType, true), 2400)
checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400)
checkDefaultSize(structType, 812)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index d26a0b7467..a9e77abbda 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution
-import scala.collection.immutable.IndexedSeq
-
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 8f2a0c0351..a1211e4380 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -63,36 +63,40 @@ class JoinSuite extends QueryTest with SharedSQLContext {
test("join operator selection") {
sqlContext.cacheManager.clearCache()
- Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
- ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
- ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
- ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
- ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
- ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
- ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
- ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
- ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
- ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
- ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
- ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]),
- ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
- classOf[SortMergeJoin]), // converted from Right Outer to Inner
- ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
- classOf[SortMergeOuterJoin]),
- ("SELECT * FROM testData full outer join testData2 ON key = a",
- classOf[SortMergeOuterJoin]),
- ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
- classOf[BroadcastNestedLoopJoin]),
- ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
- classOf[BroadcastNestedLoopJoin]),
- ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
- classOf[BroadcastNestedLoopJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
+ Seq(
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
+ ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2",
+ classOf[CartesianProduct]),
+ ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a",
+ classOf[CartesianProduct]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]),
+ ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
+ classOf[SortMergeJoin]),
+ ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
+ classOf[SortMergeOuterJoin]),
+ ("SELECT * FROM testData full outer join testData2 ON key = a",
+ classOf[SortMergeOuterJoin]),
+ ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
+ classOf[BroadcastNestedLoopJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ }
}
// ignore("SortMergeJoin shouldn't work on unsortable columns") {
@@ -118,9 +122,10 @@ class JoinSuite extends QueryTest with SharedSQLContext {
test("broadcasted hash outer join operator selection") {
sqlContext.cacheManager.clearCache()
sql("CACHE TABLE testData")
+ sql("CACHE TABLE testData2")
Seq(
("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
- classOf[SortMergeOuterJoin]),
+ classOf[BroadcastHashJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
classOf[BroadcastHashJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index b4d6f4ecdd..1d9db27e09 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -22,7 +22,7 @@ import java.sql.Timestamp
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.execution.aggregate
-import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
import org.apache.spark.sql.test.SQLTestData._
@@ -780,8 +780,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
assert(cp.isEmpty, "should not use CartesianProduct for null-safe join")
val smj = df.queryExecution.sparkPlan.collect {
case smj: SortMergeJoin => smj
+ case j: BroadcastHashJoin => j
}
- assert(smj.size > 0, "should use SortMergeJoin")
+ assert(smj.size > 0, "should use SortMergeJoin or BroadcastHashJoin")
checkAnswer(df, Row(100) :: Nil)
}