aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-02-25 10:13:40 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-25 10:13:40 -0800
commitf84c799ea0b82abca6a4fad39532c2515743b632 (patch)
treea4e2cdfece2a03b2a3d13bde3985821a5a70e932 /sql
parentdd077abf2e2949fdfec31074b760b587f00efcf2 (diff)
downloadspark-f84c799ea0b82abca6a4fad39532c2515743b632.tar.gz
spark-f84c799ea0b82abca6a4fad39532c2515743b632.tar.bz2
spark-f84c799ea0b82abca6a4fad39532c2515743b632.zip
[SPARK-5996][SQL] Fix specialized outbound conversions
Author: Michael Armbrust <michael@databricks.com> Closes #4757 from marmbrus/udtConversions and squashes the following commits: 3714aad [Michael Armbrust] [SPARK-5996][SQL] Fix specialized outbound conversions
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala10
3 files changed, 20 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
index d6d8258f46..d3a18b37d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -30,7 +31,9 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo
override def execute() = rdd
- override def executeCollect() = rows.toArray
+ override def executeCollect() =
+ rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray
- override def executeTake(limit: Int) = rows.take(limit).toArray
+ override def executeTake(limit: Int) =
+ rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4dc506c21a..710268584c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -134,13 +134,15 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
val ord = new RowOrdering(sortOrder, child.output)
+ private def collectData() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+
// TODO: Is this copying for no reason?
- override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
- .map(ScalaReflection.convertRowToScala(_, this.schema))
+ override def executeCollect() =
+ collectData().map(ScalaReflection.convertRowToScala(_, this.schema))
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- override def execute() = sparkContext.makeRDD(executeCollect(), 1)
+ override def execute() = sparkContext.makeRDD(collectData(), 1)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 9c098df24c..47fdb55432 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -22,6 +22,7 @@ import java.io.File
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -105,4 +106,13 @@ class UserDefinedTypeSuite extends QueryTest {
tempDir.delete()
pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath)
}
+
+ // Tests to make sure that all operators correctly convert types on the way out.
+ test("Local UDTs") {
+ val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec")
+ df.collect()(0).getAs[MyDenseVector](1)
+ df.take(1)(0).getAs[MyDenseVector](1)
+ df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
+ df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
+ }
}