aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-03-10 17:00:17 -0800
committerYin Huai <yhuai@databricks.com>2016-03-10 17:00:17 -0800
commit1d542785b9949e7f92025e6754973a779cc37c52 (patch)
treeceda7492e40c9d9a9231a5011c91e30bf0b1f390 /sql/catalyst
parent27fe6bacc532184ef6e8a2a24cd07f2c9188004e (diff)
downloadspark-1d542785b9949e7f92025e6754973a779cc37c52.tar.gz
spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.bz2
spark-1d542785b9949e7f92025e6754973a779cc37c52.zip
[SPARK-13244][SQL] Migrates DataFrame to Dataset
## What changes were proposed in this pull request? This PR unifies DataFrame and Dataset by migrating existing DataFrame operations to Dataset and make `DataFrame` a type alias of `Dataset[Row]`. Most Scala code changes are source compatible, but Java API is broken as Java knows nothing about Scala type alias (mostly replacing `DataFrame` with `Dataset<Row>`). There are several noticeable API changes related to those returning arrays: 1. `collect`/`take` - Old APIs in class `DataFrame`: ```scala def collect(): Array[Row] def take(n: Int): Array[Row] ``` - New APIs in class `Dataset[T]`: ```scala def collect(): Array[T] def take(n: Int): Array[T] def collectRows(): Array[Row] def takeRows(n: Int): Array[Row] ``` Two specialized methods `collectRows` and `takeRows` are added because Java doesn't support returning generic arrays. Thus, for example, `DataFrame.collect(): Array[T]` actually returns `Object` instead of `Array<T>` from Java side. Normally, Java users may fall back to `collectAsList` and `takeAsList`. The two new specialized versions are added to avoid performance regression in ML related code (but maybe I'm wrong and they are not necessary here). 1. `randomSplit` - Old APIs in class `DataFrame`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] def randomSplit(weights: Array[Double]): Array[DataFrame] ``` - New APIs in class `Dataset[T]`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] def randomSplit(weights: Array[Double]): Array[Dataset[T]] ``` Similar problem as above, but hasn't been addressed for Java API yet. We can probably add `randomSplitAsList` to fix this one. 1. `groupBy` Some original `DataFrame.groupBy` methods have conflicting signature with original `Dataset.groupBy` methods. To distinguish these two, typed `Dataset.groupBy` methods are renamed to `groupByKey`. Other noticeable changes: 1. Dataset always do eager analysis now We used to support disabling DataFrame eager analysis to help reporting partially analyzed malformed logical plan on analysis failure. However, Dataset encoders requires eager analysi during Dataset construction. To preserve the error reporting feature, `AnalysisException` now takes an extra `Option[LogicalPlan]` argument to hold the partially analyzed plan, so that we can check the plan tree when reporting test failures. This plan is passed by `QueryExecution.assertAnalyzed`. ## How was this patch tested? Existing tests do the work. ## TODO - [ ] Fix all tests - [ ] Re-enable MiMA check - [ ] Update ScalaDoc (`since`, `group`, and example code) Author: Cheng Lian <lian@databricks.com> Author: Yin Huai <yhuai@databricks.com> Author: Wenchen Fan <wenchen@databricks.com> Author: Cheng Lian <liancheng@users.noreply.github.com> Closes #11443 from liancheng/ds-to-df.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala11
3 files changed, 25 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 97f28fad62..d2003fd689 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
// TODO: don't swallow original stack trace if it exists
@@ -30,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi
class AnalysisException protected[sql] (
val message: String,
val line: Option[Int] = None,
- val startPosition: Option[Int] = None)
+ val startPosition: Option[Int] = None,
+ val plan: Option[LogicalPlan] = None)
extends Exception with Serializable {
def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index d8f755a39c..902644e735 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -50,7 +50,9 @@ object RowEncoder {
inputObject: Expression,
inputType: DataType): Expression = inputType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType => inputObject
+ FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject
+
+ case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType)
case udt: UserDefinedType[_] =>
val obj = NewInstance(
@@ -137,6 +139,7 @@ object RowEncoder {
private def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt
+ case CalendarIntervalType => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
@@ -150,19 +153,23 @@ object RowEncoder {
private def constructorFor(schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
- val field = BoundReference(i, f.dataType, f.nullable)
+ val dt = f.dataType match {
+ case p: PythonUserDefinedType => p.sqlType
+ case other => other
+ }
+ val field = BoundReference(i, dt, f.nullable)
If(
IsNull(field),
- Literal.create(null, externalDataTypeFor(f.dataType)),
+ Literal.create(null, externalDataTypeFor(dt)),
constructorFor(field)
)
}
- CreateExternalRow(fields)
+ CreateExternalRow(fields, schema)
}
private def constructorFor(input: Expression): Expression = input.dataType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType => input
+ FloatType | DoubleType | BinaryType | CalendarIntervalType => input
case udt: UserDefinedType[_] =>
val obj = NewInstance(
@@ -216,7 +223,7 @@ object RowEncoder {
"toScalaMap",
keyData :: valueData :: Nil)
- case StructType(fields) =>
+ case schema @ StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
@@ -225,6 +232,6 @@ object RowEncoder {
}
If(IsNull(input),
Literal.create(null, externalDataTypeFor(input.dataType)),
- CreateExternalRow(convertedFields))
+ CreateExternalRow(convertedFields, schema))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 75ecbaa453..b95c5dd892 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -388,6 +388,8 @@ case class MapObjects private(
case a: ArrayType => (i: String) => s".getArray($i)"
case _: MapType => (i: String) => s".getMap($i)"
case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
+ case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)"
+ case DateType => (i: String) => s".getInt($i)"
}
private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
@@ -485,7 +487,9 @@ case class MapObjects private(
*
* @param children A list of expression to use as content of the external row.
*/
-case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression {
+case class CreateExternalRow(children: Seq[Expression], schema: StructType)
+ extends Expression with NonSQLExpression {
+
override def dataType: DataType = ObjectType(classOf[Row])
override def nullable: Boolean = false
@@ -494,8 +498,9 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
- val rowClass = classOf[GenericRow].getName
+ val rowClass = classOf[GenericRowWithSchema].getName
val values = ctx.freshName("values")
+ val schemaField = ctx.addReferenceObj("schema", schema)
s"""
boolean ${ev.isNull} = false;
final Object[] $values = new Object[${children.size}];
@@ -510,7 +515,7 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with
}
"""
}.mkString("\n") +
- s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
+ s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);"
}
}