aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-01-01 19:23:06 -0800
committerReynold Xin <rxin@databricks.com>2016-01-01 19:23:06 -0800
commit44ee920fd49d35b421ae562ea99bcc8f2b98ced6 (patch)
treefe94b7a91dda2ef27d7be6507baa83a339050846
parent0da7bd50ddf0fb9e0e8aeadb9c7fb3edf6f0ee6e (diff)
downloadspark-44ee920fd49d35b421ae562ea99bcc8f2b98ced6.tar.gz
spark-44ee920fd49d35b421ae562ea99bcc8f2b98ced6.tar.bz2
spark-44ee920fd49d35b421ae562ea99bcc8f2b98ced6.zip
Revert "[SPARK-12286][SPARK-12290][SPARK-12294][SPARK-12284][SQL] always output UnsafeRow"
This reverts commit 0da7bd50ddf0fb9e0e8aeadb9c7fb3edf6f0ee6e.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala58
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala108
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala54
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala164
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala16
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala15
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala31
34 files changed, 574 insertions, 74 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 022303239f..eadf5cba6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -904,7 +904,8 @@ class SQLContext private[sql](
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches = Seq(
- Batch("Add exchange", Once, EnsureRequirements(self))
+ Batch("Add exchange", Once, EnsureRequirements(self)),
+ Batch("Add row converters", Once, EnsureRowFormats)
)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 7b4161930b..62cbc518e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair
@@ -49,14 +50,26 @@ case class Exchange(
case None => ""
}
- val simpleNodeName = "Exchange"
+ val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange"
s"$simpleNodeName$extraInfo"
}
+ /**
+ * Returns true iff we can support the data type, and we are not doing range partitioning.
+ */
+ private lazy val tungstenMode: Boolean = !newPartitioning.isInstanceOf[RangePartitioning]
+
override def outputPartitioning: Partitioning = newPartitioning
override def output: Seq[Attribute] = child.output
+ // This setting is somewhat counterintuitive:
+ // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row,
+ // so the planner inserts a converter to convert data into UnsafeRow if needed.
+ override def outputsUnsafeRows: Boolean = tungstenMode
+ override def canProcessSafeRows: Boolean = !tungstenMode
+ override def canProcessUnsafeRows: Boolean = tungstenMode
+
/**
* Determines whether records must be defensively copied before being sent to the shuffle.
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
@@ -117,7 +130,15 @@ case class Exchange(
}
}
- private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
+ @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf
+
+ private val serializer: Serializer = {
+ if (tungstenMode) {
+ new UnsafeRowSerializer(child.output.size)
+ } else {
+ new SparkSqlSerializer(sparkConf)
+ }
+ }
override protected def doPrepare(): Unit = {
// If an ExchangeCoordinator is needed, we register this Exchange operator
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index fc508bfafa..5c01af011d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, AttributeSet, GenericMutableRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation}
import org.apache.spark.sql.types.DataType
@@ -99,19 +99,10 @@ private[sql] case class PhysicalRDD(
rdd: RDD[InternalRow],
override val nodeName: String,
override val metadata: Map[String, String] = Map.empty,
- isUnsafeRow: Boolean = false)
+ override val outputsUnsafeRows: Boolean = false)
extends LeafNode {
- protected override def doExecute(): RDD[InternalRow] = {
- if (isUnsafeRow) {
- rdd
- } else {
- rdd.mapPartitionsInternal { iter =>
- val proj = UnsafeProjection.create(schema)
- iter.map(proj)
- }
- }
- }
+ protected override def doExecute(): RDD[InternalRow] = rdd
override def simpleString: String = {
val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index c3683cc4e7..91530bd637 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -41,11 +41,20 @@ case class Expand(
// as UNKNOWN partitioning
override def outputPartitioning: Partitioning = UnknownPartitioning(0)
+ override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
+
override def references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))
- private[this] val projection =
- (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output)
+ private[this] val projection = {
+ if (outputsUnsafeRows) {
+ (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output)
+ } else {
+ (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)()
+ }
+ }
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 4db88a09d8..0c613e91b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -64,7 +64,6 @@ case class Generate(
child.execute().mapPartitionsInternal { iter =>
val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null))
val joinedRow = new JoinedRow
- val proj = UnsafeProjection.create(output, output)
iter.flatMap { row =>
// we should always set the left (child output)
@@ -78,14 +77,13 @@ case class Generate(
} ++ LazyIterator(() => boundGenerator.terminate()).map { row =>
// we leave the left side as the last element of its child output
// keep it the same as Hive does
- proj(joinedRow.withRight(row))
+ joinedRow.withRight(row)
}
}
} else {
child.execute().mapPartitionsInternal { iter =>
- val proj = UnsafeProjection.create(output, output)
- (iter.flatMap(row => boundGenerator.eval(row)) ++
- LazyIterator(() => boundGenerator.terminate())).map(proj)
+ iter.flatMap(row => boundGenerator.eval(row)) ++
+ LazyIterator(() => boundGenerator.terminate())
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
index 59057bf966..ba7f6287ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.Attribute
/**
@@ -29,20 +29,15 @@ private[sql] case class LocalTableScan(
output: Seq[Attribute],
rows: Seq[InternalRow]) extends LeafNode {
- private val unsafeRows: Array[InternalRow] = {
- val proj = UnsafeProjection.create(output, output)
- rows.map(r => proj(r).copy()).toArray
- }
-
- private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows)
+ private lazy val rdd = sqlContext.sparkContext.parallelize(rows)
protected override def doExecute(): RDD[InternalRow] = rdd
override def executeCollect(): Array[InternalRow] = {
- unsafeRows
+ rows.toArray
}
override def executeTake(limit: Int): Array[InternalRow] = {
- unsafeRows.take(limit)
+ rows.take(limit).toArray
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 73dc8cb984..24207cb46f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -39,6 +39,10 @@ case class Sort(
testSpillFrequency: Int = 0)
extends UnaryNode {
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+
override def output: Seq[Attribute] = child.output
override def outputOrdering: Seq[SortOrder] = sortOrder
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index f20f32aace..fe9b2ad4a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -97,6 +97,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Specifies sort order for each partition requirements on the input data for this operator. */
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
+ /** Specifies whether this operator outputs UnsafeRows */
+ def outputsUnsafeRows: Boolean = false
+
+ /** Specifies whether this operator is capable of processing UnsafeRows */
+ def canProcessUnsafeRows: Boolean = false
+
+ /**
+ * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows
+ * that are not UnsafeRows).
+ */
+ def canProcessSafeRows: Boolean = true
/**
* Returns the result of this query as an RDD[InternalRow] by delegating to doExecute
@@ -104,6 +115,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* Concrete implementations of SparkPlan should override doExecute instead.
*/
final def execute(): RDD[InternalRow] = {
+ if (children.nonEmpty) {
+ val hasUnsafeInputs = children.exists(_.outputsUnsafeRows)
+ val hasSafeInputs = children.exists(!_.outputsUnsafeRows)
+ assert(!(hasSafeInputs && hasUnsafeInputs),
+ "Child operators should output rows in the same format")
+ assert(canProcessSafeRows || canProcessUnsafeRows,
+ "Operator must be able to process at least one row format")
+ assert(!hasSafeInputs || canProcessSafeRows,
+ "Operator will receive safe rows as input but cannot process safe rows")
+ assert(!hasUnsafeInputs || canProcessUnsafeRows,
+ "Operator will receive unsafe rows as input but cannot process unsafe rows")
+ }
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
prepare()
doExecute()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index b79d93d7ca..c941d673c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -100,6 +100,8 @@ case class Window(
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def canProcessUnsafeRows: Boolean = true
+
/**
* Create a bound ordering object for a given frame type and offset. A bound ordering object is
* used to determine which input row lies within the frame boundaries of an output row.
@@ -257,16 +259,16 @@ case class Window(
* @return the final resulting projection.
*/
private[this] def createResultProjection(
- expressions: Seq[Expression]): UnsafeProjection = {
+ expressions: Seq[Expression]): MutableProjection = {
val references = expressions.zipWithIndex.map{ case (e, i) =>
// Results of window expressions will be on the right side of child's output
BoundReference(child.output.size + i, e.dataType, e.nullable)
}
val unboundToRefMap = expressions.zip(references).toMap
val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
- UnsafeProjection.create(
+ newMutableProjection(
projectList ++ patchedWindowExpression,
- child.output)
+ child.output)()
}
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index 01d076678f..c4587ba677 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -49,6 +49,10 @@ case class SortBasedAggregate(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = false
+ override def canProcessSafeRows: Boolean = true
+
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
override def requiredChildDistribution: List[Distribution] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index 6501634ff9..ac920aa8bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -87,10 +87,6 @@ class SortBasedAggregationIterator(
// The aggregation buffer used by the sort-based aggregation.
private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer
- // An SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be
- // compared to MutableRow (aggregation buffer) directly.
- private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType))
-
protected def initialize(): Unit = {
if (inputIterator.hasNext) {
initializeBuffer(sortBasedAggregationBuffer)
@@ -114,7 +110,7 @@ class SortBasedAggregationIterator(
// We create a variable to track if we see the next group.
var findNextPartition = false
// firstRowInNextGroup is the first row of this group. We first process it.
- processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup))
+ processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
// The search will stop when we see the next group or there is no
// input row left in the iter.
@@ -126,7 +122,7 @@ class SortBasedAggregationIterator(
// Check if the current row belongs the current input row.
if (currentGroupingKey == groupingKey) {
- processRow(sortBasedAggregationBuffer, safeProj(currentRow))
+ processRow(sortBasedAggregationBuffer, currentRow)
} else {
// We find a new group.
findNextPartition = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 999ebb768a..9d758eb3b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -49,6 +49,10 @@ case class TungstenAggregate(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
+
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
override def producedAttributes: AttributeSet =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index af7237ef25..f19d72f067 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -36,6 +36,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
override private[sql] lazy val metrics = Map(
"numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
+
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
protected override def doExecute(): RDD[InternalRow] = {
@@ -76,6 +80,12 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
}
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+
+ override def canProcessUnsafeRows: Boolean = true
+
+ override def canProcessSafeRows: Boolean = true
}
/**
@@ -98,6 +108,10 @@ case class Sample(
{
override def output: Seq[Attribute] = child.output
+ override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
+
protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
// Disable gap sampling since the gap sampling method buffers two rows internally,
@@ -121,6 +135,8 @@ case class Range(
output: Seq[Attribute])
extends LeafNode {
+ override def outputsUnsafeRows: Boolean = true
+
protected override def doExecute(): RDD[InternalRow] = {
sqlContext
.sparkContext
@@ -183,6 +199,9 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan {
}
}
}
+ override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows)
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
protected override def doExecute(): RDD[InternalRow] =
sparkContext.union(children.map(_.execute()))
}
@@ -249,14 +268,12 @@ case class TakeOrderedAndProject(
// and this ordering needs to be created on the driver in order to be passed into Spark core code.
private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
+ // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
+ @transient private val projection = projectList.map(new InterpretedProjection(_, child.output))
+
private def collectData(): Array[InternalRow] = {
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
- if (projectList.isDefined) {
- val proj = UnsafeProjection.create(projectList.get, child.output)
- data.map(r => proj(r).copy())
- } else {
- data
- }
+ projection.map(data.map(_)).getOrElse(data)
}
override def executeCollect(): Array[InternalRow] = {
@@ -294,6 +311,10 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode {
protected override def doExecute(): RDD[InternalRow] = {
child.execute().coalesce(numPartitions, shuffle = false)
}
+
+ override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
}
/**
@@ -306,6 +327,10 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode {
protected override def doExecute(): RDD[InternalRow] = {
left.execute().map(_.copy()).subtract(right.execute().map(_.copy()))
}
+
+ override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows)
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
}
/**
@@ -318,6 +343,10 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode {
protected override def doExecute(): RDD[InternalRow] = {
left.execute().map(_.copy()).intersection(right.execute().map(_.copy()))
}
+
+ override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows)
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
}
/**
@@ -342,6 +371,10 @@ case class MapPartitions[T, U](
child: SparkPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = outputSet
+ override def canProcessSafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def outputsUnsafeRows: Boolean = true
+
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
val tBoundEncoder = tEncoder.bind(child.output)
@@ -361,6 +394,11 @@ case class AppendColumns[T, U](
child: SparkPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(newColumns)
+ // We are using an unsafe combiner.
+ override def canProcessSafeRows: Boolean = false
+ override def canProcessUnsafeRows: Boolean = true
+ override def outputsUnsafeRows: Boolean = true
+
override def output: Seq[Attribute] = child.output ++ newColumns
override protected def doExecute(): RDD[InternalRow] = {
@@ -390,6 +428,10 @@ case class MapGroups[K, T, U](
child: SparkPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = outputSet
+ override def canProcessSafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def outputsUnsafeRows: Boolean = true
+
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(groupingAttributes) :: Nil
@@ -430,6 +472,10 @@ case class CoGroup[Key, Left, Right, Result](
right: SparkPlan) extends BinaryNode {
override def producedAttributes: AttributeSet = outputSet
+ override def canProcessSafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def outputsUnsafeRows: Boolean = true
+
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index d80912309b..aa7a668e0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
+import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan}
import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.storage.StorageLevel
import org.apache.spark.{Accumulable, Accumulator, Accumulators}
@@ -39,7 +39,9 @@ private[sql] object InMemoryRelation {
storageLevel: StorageLevel,
child: SparkPlan,
tableName: Option[String]): InMemoryRelation =
- new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)()
+ new InMemoryRelation(child.output, useCompression, batchSize, storageLevel,
+ if (child.outputsUnsafeRows) child else ConvertToUnsafe(child),
+ tableName)()
}
/**
@@ -224,6 +226,8 @@ private[sql] case class InMemoryColumnarTableScan(
// The cached version does not change the outputOrdering of the original SparkPlan.
override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering
+ override def outputsUnsafeRows: Boolean = true
+
private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
// Returned filter predicate should return false iff it is impossible for the input expression
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 54275c2cc1..aab177b2e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -46,8 +46,15 @@ case class BroadcastNestedLoopJoin(
case BuildLeft => (right, left)
}
+ override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+
private[this] def genResultProjection: InternalRow => InternalRow = {
+ if (outputsUnsafeRows) {
UnsafeProjection.create(schema)
+ } else {
+ identity[InternalRow]
+ }
}
override def outputPartitioning: Partitioning = streamed.outputPartitioning
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index d9fa4c6b83..81bfe4e67c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -81,6 +81,10 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output ++ right.output
+ override def canProcessSafeRows: Boolean = false
+ override def canProcessUnsafeRows: Boolean = true
+ override def outputsUnsafeRows: Boolean = true
+
override private[sql] lazy val metrics = Map(
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
"numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 7f9d9daa5a..fb961d97c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -44,6 +44,10 @@ trait HashJoin {
override def output: Seq[Attribute] = left.output ++ right.output
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+
protected def buildSideKeyGenerator: Projection =
UnsafeProjection.create(buildKeys, buildPlan.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 6d464d6946..c6e5868187 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -64,6 +64,10 @@ trait HashOuterJoin {
s"HashOuterJoin should not take $x as the JoinType")
}
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+
protected def buildKeyGenerator: Projection =
UnsafeProjection.create(buildKeys, buildPlan.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 3e0f74cd98..f23a1830e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -33,6 +33,10 @@ trait HashSemiJoin {
override def output: Seq[Attribute] = left.output
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+
protected def leftKeyGenerator: Projection =
UnsafeProjection.create(leftKeys, left.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index 82498ee395..efa7b49410 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -42,6 +42,9 @@ case class LeftSemiJoinBNL(
override def output: Seq[Attribute] = left.output
+ override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+
/** The Streamed Relation */
override def left: SparkPlan = streamed
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 812f881d06..4bf7b521c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -53,6 +53,10 @@ case class SortMergeJoin(
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+
private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
// This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
keys.map(SortOrder(_, Ascending))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
index c3a2bfc59c..7ce38ebdb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
@@ -89,6 +89,10 @@ case class SortMergeOuterJoin(
keys.map(SortOrder(_, Ascending))
}
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+
private def createLeftKeyGenerator(): Projection =
UnsafeProjection.create(leftKeys, left.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index e46217050b..6a882c9234 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -69,6 +69,18 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin
*/
def close(): Unit
+ /** Specifies whether this operator outputs UnsafeRows */
+ def outputsUnsafeRows: Boolean = false
+
+ /** Specifies whether this operator is capable of processing UnsafeRows */
+ def canProcessUnsafeRows: Boolean = false
+
+ /**
+ * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows
+ * that are not UnsafeRows).
+ */
+ def canProcessSafeRows: Boolean = true
+
/**
* Returns the content through the [[Iterator]] interface.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala
index b7fa0c0202..7321fc66b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala
@@ -47,7 +47,11 @@ case class NestedLoopJoinNode(
}
private[this] def genResultProjection: InternalRow => InternalRow = {
- UnsafeProjection.create(schema)
+ if (outputsUnsafeRows) {
+ UnsafeProjection.create(schema)
+ } else {
+ identity[InternalRow]
+ }
}
private[this] var currentRow: InternalRow = _
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
index efb4b09c16..defcec95fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
@@ -351,6 +351,10 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
def children: Seq[SparkPlan] = child :: Nil
+ override def outputsUnsafeRows: Boolean = false
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
+
protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
@@ -396,14 +400,13 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val unpickle = new Unpickler
val row = new GenericMutableRow(1)
val joined = new JoinedRow
- val resultProj = UnsafeProjection.create(output, output)
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
row(0) = EvaluatePython.fromJava(result, udf.dataType)
- resultProj(joined(queue.poll(), row))
+ joined(queue.poll(), row)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
new file mode 100644
index 0000000000..5f8fc2de8b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.rules.Rule
+
+/**
+ * Converts Java-object-based rows into [[UnsafeRow]]s.
+ */
+case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def outputsUnsafeRows: Boolean = true
+ override def canProcessUnsafeRows: Boolean = false
+ override def canProcessSafeRows: Boolean = true
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ val convertToUnsafe = UnsafeProjection.create(child.schema)
+ iter.map(convertToUnsafe)
+ }
+ }
+}
+
+/**
+ * Converts [[UnsafeRow]]s back into Java-object-based rows.
+ */
+case class ConvertToSafe(child: SparkPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def outputsUnsafeRows: Boolean = false
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = false
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType))
+ iter.map(convertToSafe)
+ }
+ }
+}
+
+private[sql] object EnsureRowFormats extends Rule[SparkPlan] {
+
+ private def onlyHandlesSafeRows(operator: SparkPlan): Boolean =
+ operator.canProcessSafeRows && !operator.canProcessUnsafeRows
+
+ private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean =
+ operator.canProcessUnsafeRows && !operator.canProcessSafeRows
+
+ private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean =
+ operator.canProcessSafeRows && operator.canProcessUnsafeRows
+
+ override def apply(operator: SparkPlan): SparkPlan = operator.transformUp {
+ case operator: SparkPlan if onlyHandlesSafeRows(operator) =>
+ if (operator.children.exists(_.outputsUnsafeRows)) {
+ operator.withNewChildren {
+ operator.children.map {
+ c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c
+ }
+ }
+ } else {
+ operator
+ }
+ case operator: SparkPlan if onlyHandlesUnsafeRows(operator) =>
+ if (operator.children.exists(!_.outputsUnsafeRows)) {
+ operator.withNewChildren {
+ operator.children.map {
+ c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+ }
+ }
+ } else {
+ operator
+ }
+ case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) =>
+ if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) {
+ // If this operator's children produce both unsafe and safe rows,
+ // convert everything unsafe rows.
+ operator.withNewChildren {
+ operator.children.map {
+ c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c
+ }
+ }
+ } else {
+ operator
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 87bff3295f..911d12e93e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -28,7 +28,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
val input = (1 to 1000).map(Tuple1.apply)
checkAnswer(
input.toDF(),
- plan => Exchange(SinglePartition, plan),
+ plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))),
input.map(Row.fromTuple)
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala
new file mode 100644
index 0000000000..faef76d52a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.IntegerType
+
+class ExpandSuite extends SparkPlanTest with SharedSQLContext {
+ import testImplicits.localSeqToDataFrameHolder
+
+ private def testExpand(f: SparkPlan => SparkPlan): Unit = {
+ val input = (1 to 1000).map(Tuple1.apply)
+ val projections = Seq.tabulate(2) { i =>
+ Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil
+ }
+ val attributes = projections.head.map(_.toAttribute)
+ checkAnswer(
+ input.toDF(),
+ plan => Expand(projections, attributes, f(plan)),
+ input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j)))
+ )
+ }
+
+ test("inheriting child row type") {
+ val exprs = AttributeReference("a", IntegerType, false)() :: Nil
+ val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty)))
+ assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.")
+ }
+
+ test("expanding UnsafeRows") {
+ testExpand(ConvertToUnsafe)
+ }
+
+ test("expanding SafeRows") {
+ testExpand(identity)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
new file mode 100644
index 0000000000..2328899bb2
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull}
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{ArrayType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
+
+ private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect {
+ case c: ConvertToUnsafe => c
+ case c: ConvertToSafe => c
+ }
+
+ private val outputsSafe = ReferenceSort(Nil, false, PhysicalRDD(Seq.empty, null, "name"))
+ assert(!outputsSafe.outputsUnsafeRows)
+ private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name"))
+ assert(outputsUnsafe.outputsUnsafeRows)
+
+ test("planner should insert unsafe->safe conversions when required") {
+ val plan = Limit(10, outputsUnsafe)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
+ }
+
+ test("filter can process unsafe rows") {
+ val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(getConverters(preparedPlan).size === 1)
+ assert(preparedPlan.outputsUnsafeRows)
+ }
+
+ test("filter can process safe rows") {
+ val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(getConverters(preparedPlan).isEmpty)
+ assert(!preparedPlan.outputsUnsafeRows)
+ }
+
+ test("coalesce can process unsafe rows") {
+ val plan = Coalesce(1, outputsUnsafe)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(getConverters(preparedPlan).size === 1)
+ assert(preparedPlan.outputsUnsafeRows)
+ }
+
+ test("except can process unsafe rows") {
+ val plan = Except(outputsUnsafe, outputsUnsafe)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(getConverters(preparedPlan).size === 2)
+ assert(preparedPlan.outputsUnsafeRows)
+ }
+
+ test("except requires all of its input rows' formats to agree") {
+ val plan = Except(outputsSafe, outputsUnsafe)
+ assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(preparedPlan.outputsUnsafeRows)
+ }
+
+ test("intersect can process unsafe rows") {
+ val plan = Intersect(outputsUnsafe, outputsUnsafe)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(getConverters(preparedPlan).size === 2)
+ assert(preparedPlan.outputsUnsafeRows)
+ }
+
+ test("intersect requires all of its input rows' formats to agree") {
+ val plan = Intersect(outputsSafe, outputsUnsafe)
+ assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(preparedPlan.outputsUnsafeRows)
+ }
+
+ test("execute() fails an assertion if inputs rows are of different formats") {
+ val e = intercept[AssertionError] {
+ Union(Seq(outputsSafe, outputsUnsafe)).execute()
+ }
+ assert(e.getMessage.contains("format"))
+ }
+
+ test("union requires all of its input rows' formats to agree") {
+ val plan = Union(Seq(outputsSafe, outputsUnsafe))
+ assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(preparedPlan.outputsUnsafeRows)
+ }
+
+ test("union can process safe rows") {
+ val plan = Union(Seq(outputsSafe, outputsSafe))
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(!preparedPlan.outputsUnsafeRows)
+ }
+
+ test("union can process unsafe rows") {
+ val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
+ val preparedPlan = sqlContext.prepareForExecution.execute(plan)
+ assert(preparedPlan.outputsUnsafeRows)
+ }
+
+ test("round trip with ConvertToUnsafe and ConvertToSafe") {
+ val input = Seq(("hello", 1), ("world", 2))
+ checkAnswer(
+ sqlContext.createDataFrame(input),
+ plan => ConvertToSafe(ConvertToUnsafe(plan)),
+ input.map(Row.fromTuple)
+ )
+ }
+
+ test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
+ SQLContext.setActive(sqlContext)
+ val schema = ArrayType(StringType)
+ val rows = (1 to 100).map { i =>
+ InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))
+ }
+ val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows)
+
+ val plan =
+ DummyPlan(
+ ConvertToSafe(
+ ConvertToUnsafe(relation)))
+ assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString))
+ }
+}
+
+case class DummyPlan(child: SparkPlan) extends UnaryNode {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions { iter =>
+ // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some
+ // values gotten from the incoming rows.
+ // we cache all strings here to make sure we have deep copied UTF8String inside incoming
+ // safe InternalRow.
+ val strings = new scala.collection.mutable.ArrayBuffer[UTF8String]
+ iter.foreach { row =>
+ strings += row.getArray(0).getUTF8String(0)
+ }
+ strings.map(InternalRow(_)).iterator
+ }
+ }
+
+ override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)())
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index af971dfc6f..e5d34be4c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -99,7 +99,7 @@ class SortSuite extends SparkPlanTest with SharedSQLContext {
)
checkThatPlansAgree(
inputDf,
- p => Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23),
+ p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)),
ReferenceSort(sortOrder, global = true, _: SparkPlan),
sortAnswers = false
)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
index 1588728bdb..8141136de5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
@@ -132,17 +132,11 @@ case class HiveTableScan(
}
}
- protected override def doExecute(): RDD[InternalRow] = {
- val rdd = if (!relation.hiveQlTable.isPartitioned) {
- hadoopReader.makeRDDForTable(relation.hiveQlTable)
- } else {
- hadoopReader.makeRDDForPartitionedTable(
- prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))
- }
- rdd.mapPartitionsInternal { iter =>
- val proj = UnsafeProjection.create(schema)
- iter.map(proj)
- }
+ protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
+ hadoopReader.makeRDDForTable(relation.hiveQlTable)
+ } else {
+ hadoopReader.makeRDDForPartitionedTable(
+ prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))
}
override def output: Seq[Attribute] = attributes
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 44dc68e6ba..f936cf565b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -28,17 +28,18 @@ import org.apache.hadoop.hive.ql.{Context, ErrorMsg}
import org.apache.hadoop.hive.serde2.Serializer
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption
import org.apache.hadoop.hive.serde2.objectinspector._
-import org.apache.hadoop.mapred.{FileOutputFormat, JobConf}
+import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, Attribute}
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.sql.hive._
import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.SerializableJobConf
import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.util.SerializableJobConf
private[hive]
case class InsertIntoHiveTable(
@@ -100,17 +101,15 @@ case class InsertIntoHiveTable(
writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
- val proj = FromUnsafeProjection(child.schema)
iterator.foreach { row =>
var i = 0
- val safeRow = proj(row)
while (i < fieldOIs.length) {
- outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(safeRow.get(i, dataTypes(i)))
+ outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
i += 1
}
writerContainer
- .getLocalFileWriter(safeRow, table.schema)
+ .getLocalFileWriter(row, table.schema)
.write(serializer.serialize(outputData, standardOI))
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 6ccd417819..a61e162f48 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -213,8 +213,7 @@ case class ScriptTransformation(
child.execute().mapPartitions { iter =>
if (iter.hasNext) {
- val proj = UnsafeProjection.create(schema)
- processIterator(iter).map(proj)
+ processIterator(iter)
} else {
// If the input iterator has no rows then do not launch the external script.
Iterator.empty
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index efbf9988dd..665e87e3e3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -27,6 +27,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
+import org.apache.spark.sql.execution.ConvertToUnsafe
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
@@ -688,6 +689,36 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString)
}
}
+
+ test("HadoopFsRelation produces UnsafeRow") {
+ withTempTable("test_unsafe") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ sqlContext.range(3).write.format(dataSourceName).save(path)
+ sqlContext.read
+ .format(dataSourceName)
+ .option("dataSchema", new StructType().add("id", LongType, nullable = false).json)
+ .load(path)
+ .registerTempTable("test_unsafe")
+
+ val df = sqlContext.sql(
+ """SELECT COUNT(*)
+ |FROM test_unsafe a JOIN test_unsafe b
+ |WHERE a.id = b.id
+ """.stripMargin)
+
+ val plan = df.queryExecution.executedPlan
+
+ assert(
+ plan.collect { case plan: ConvertToUnsafe => plan }.isEmpty,
+ s"""Query plan shouldn't have ${classOf[ConvertToUnsafe].getSimpleName} node(s):
+ |$plan
+ """.stripMargin)
+
+ checkAnswer(df, Row(3))
+ }
+ }
+ }
}
// This class is used to test SPARK-8578. We should not use any custom output committer when