aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-25 22:44:26 -0700
committerDavies Liu <davies@databricks.com>2015-06-25 22:44:26 -0700
commit1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb (patch)
tree85061123f7de88e73e317304ee8fec392f54e7db /sql/core
parent40360112c417b5432564f4bcb8a9100f4066b55e (diff)
downloadspark-1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb.tar.gz
spark-1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb.tar.bz2
spark-1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb.zip
[SPARK-8635] [SQL] improve performance of CatalystTypeConverters
In `CatalystTypeConverters.createToCatalystConverter`, we add special handling for primitive types. We can apply this strategy to more places to improve performance. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7018 from cloud-fan/converter and squashes the following commits: 8b16630 [Wenchen Fan] another fix 326c82c [Wenchen Fan] optimize type converter
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala4
6 files changed, 10 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index f3f0f53053..0db4df34f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1418,12 +1418,14 @@ class DataFrame private[sql](
lazy val rdd: RDD[Row] = {
// use a local variable to make sure the map closure doesn't capture the whole DataFrame
val schema = this.schema
- queryExecution.executedPlan.execute().mapPartitions { rows =>
+ internalRowRdd.mapPartitions { rows =>
val converter = CatalystTypeConverters.createToScalaConverter(schema)
rows.map(converter(_).asInstanceOf[Row])
}
}
+ private[sql] def internalRowRdd = queryExecution.executedPlan.execute()
+
/**
* Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
* @group rdd
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index 8df1da037c..3ebbf96090 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging {
(name, originalSchema.fields(index).dataType)
}
- val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
+ val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 93383e5a62..252c611d02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -81,7 +81,7 @@ private[sql] object StatFunctions extends Logging {
s"with dataType ${data.get.dataType} not supported.")
}
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
- df.select(columns: _*).rdd.aggregate(new CovarianceCounter)(
+ df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)(
seqOp = (counter, row) => {
counter.add(row.getDouble(0), row.getDouble(1))
},
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index a8f56f4767..ce16e050c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -313,7 +313,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
output: Seq[Attribute],
rdd: RDD[Row]): RDD[InternalRow] = {
if (relation.relation.needConversion) {
- execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType))
+ execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
} else {
rdd.map(_.asInstanceOf[InternalRow])
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index fb6173f58e..dbb369cf45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
writerContainer.driverSideSetup()
try {
- df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
+ df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _)
writerContainer.commitJob()
relation.refresh()
} catch { case cause: Throwable =>
@@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
writerContainer.driverSideSetup()
try {
- df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
+ df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _)
writerContainer.commitJob()
relation.refresh()
} catch { case cause: Throwable =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 79eac930e5..de0ed0c042 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -88,9 +88,9 @@ case class AllDataTypesScan(
UTF8String.fromString(s"varchar_$i"),
Seq(i, i + 1),
Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))),
- Map(i -> i.toString),
+ Map(i -> UTF8String.fromString(i.toString)),
Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)),
- Row(i, i.toString),
+ Row(i, UTF8String.fromString(i.toString)),
Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")),
InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1))))))
}