aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala6
4 files changed, 20 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 3a70d25534..d794f034f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.util.hashing.MurmurHash3
import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.types.DateUtils
+import org.apache.spark.sql.types.{StructType, DateUtils}
object Row {
/**
@@ -123,6 +123,11 @@ trait Row extends Serializable {
def length: Int
/**
+ * Schema for the row.
+ */
+ def schema: StructType = null
+
+ /**
* Returns the value at position i. If the value is null, null is returned. The following
* is a mapping between Spark SQL types and return types:
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 11fd443733..d6126c24fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import java.sql.Timestamp
import org.apache.spark.util.Utils
-import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
@@ -91,9 +91,9 @@ trait ScalaReflection {
def convertRowToScala(r: Row, schema: StructType): Row = {
// TODO: This is very slow!!!
- new GenericRow(
+ new GenericRowWithSchema(
r.toSeq.zip(schema.fields.map(_.dataType))
- .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
+ .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray, schema)
}
/** Returns a Sequence of attributes for the given case class type. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 73ec7a6d11..faa3667718 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.NativeType
+import org.apache.spark.sql.types.{StructType, NativeType}
/**
@@ -149,6 +149,10 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
def copy() = this
}
+class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
+ extends GenericRow(values) {
+}
+
class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
/** No-arg constructor for serialization. */
def this() = this(null)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 524571d9cc..0da619def1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -89,6 +89,12 @@ class DataFrameSuite extends QueryTest {
testData.collect().toSeq)
}
+ test("head and take") {
+ assert(testData.take(2) === testData.collect().take(2))
+ assert(testData.head(2) === testData.collect().take(2))
+ assert(testData.head(2).head.schema === testData.schema)
+ }
+
test("self join") {
val df1 = testData.select(testData("key")).as('df1)
val df2 = testData.select(testData("key")).as('df2)