diff options
author | Cheng Lian <lian@databricks.com> | 2016-03-10 17:00:17 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2016-03-10 17:00:17 -0800 |
commit | 1d542785b9949e7f92025e6754973a779cc37c52 (patch) | |
tree | ceda7492e40c9d9a9231a5011c91e30bf0b1f390 /sql/catalyst | |
parent | 27fe6bacc532184ef6e8a2a24cd07f2c9188004e (diff) | |
download | spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.gz spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.bz2 spark-1d542785b9949e7f92025e6754973a779cc37c52.zip |
[SPARK-13244][SQL] Migrates DataFrame to Dataset
## What changes were proposed in this pull request?
This PR unifies DataFrame and Dataset by migrating existing DataFrame operations to Dataset and make `DataFrame` a type alias of `Dataset[Row]`.
Most Scala code changes are source compatible, but Java API is broken as Java knows nothing about Scala type alias (mostly replacing `DataFrame` with `Dataset<Row>`).
There are several noticeable API changes related to those returning arrays:
1. `collect`/`take`
- Old APIs in class `DataFrame`:
```scala
def collect(): Array[Row]
def take(n: Int): Array[Row]
```
- New APIs in class `Dataset[T]`:
```scala
def collect(): Array[T]
def take(n: Int): Array[T]
def collectRows(): Array[Row]
def takeRows(n: Int): Array[Row]
```
Two specialized methods `collectRows` and `takeRows` are added because Java doesn't support returning generic arrays. Thus, for example, `DataFrame.collect(): Array[T]` actually returns `Object` instead of `Array<T>` from Java side.
Normally, Java users may fall back to `collectAsList` and `takeAsList`. The two new specialized versions are added to avoid performance regression in ML related code (but maybe I'm wrong and they are not necessary here).
1. `randomSplit`
- Old APIs in class `DataFrame`:
```scala
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame]
def randomSplit(weights: Array[Double]): Array[DataFrame]
```
- New APIs in class `Dataset[T]`:
```scala
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]]
def randomSplit(weights: Array[Double]): Array[Dataset[T]]
```
Similar problem as above, but hasn't been addressed for Java API yet. We can probably add `randomSplitAsList` to fix this one.
1. `groupBy`
Some original `DataFrame.groupBy` methods have conflicting signature with original `Dataset.groupBy` methods. To distinguish these two, typed `Dataset.groupBy` methods are renamed to `groupByKey`.
Other noticeable changes:
1. Dataset always do eager analysis now
We used to support disabling DataFrame eager analysis to help reporting partially analyzed malformed logical plan on analysis failure. However, Dataset encoders requires eager analysi during Dataset construction. To preserve the error reporting feature, `AnalysisException` now takes an extra `Option[LogicalPlan]` argument to hold the partially analyzed plan, so that we can check the plan tree when reporting test failures. This plan is passed by `QueryExecution.assertAnalyzed`.
## How was this patch tested?
Existing tests do the work.
## TODO
- [ ] Fix all tests
- [ ] Re-enable MiMA check
- [ ] Update ScalaDoc (`since`, `group`, and example code)
Author: Cheng Lian <lian@databricks.com>
Author: Yin Huai <yhuai@databricks.com>
Author: Wenchen Fan <wenchen@databricks.com>
Author: Cheng Lian <liancheng@users.noreply.github.com>
Closes #11443 from liancheng/ds-to-df.
Diffstat (limited to 'sql/catalyst')
3 files changed, 25 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 97f28fad62..d2003fd689 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan // TODO: don't swallow original stack trace if it exists @@ -30,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, - val startPosition: Option[Int] = None) + val startPosition: Option[Int] = None, + val plan: Option[LogicalPlan] = None) extends Exception with Serializable { def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d8f755a39c..902644e735 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -50,7 +50,9 @@ object RowEncoder { inputObject: Expression, inputType: DataType): Expression = inputType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => inputObject + FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject + + case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType) case udt: UserDefinedType[_] => val obj = NewInstance( @@ -137,6 +139,7 @@ object RowEncoder { private def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt + case CalendarIntervalType => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -150,19 +153,23 @@ object RowEncoder { private def constructorFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => - val field = BoundReference(i, f.dataType, f.nullable) + val dt = f.dataType match { + case p: PythonUserDefinedType => p.sqlType + case other => other + } + val field = BoundReference(i, dt, f.nullable) If( IsNull(field), - Literal.create(null, externalDataTypeFor(f.dataType)), + Literal.create(null, externalDataTypeFor(dt)), constructorFor(field) ) } - CreateExternalRow(fields) + CreateExternalRow(fields, schema) } private def constructorFor(input: Expression): Expression = input.dataType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => input + FloatType | DoubleType | BinaryType | CalendarIntervalType => input case udt: UserDefinedType[_] => val obj = NewInstance( @@ -216,7 +223,7 @@ object RowEncoder { "toScalaMap", keyData :: valueData :: Nil) - case StructType(fields) => + case schema @ StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), @@ -225,6 +232,6 @@ object RowEncoder { } If(IsNull(input), Literal.create(null, externalDataTypeFor(input.dataType)), - CreateExternalRow(convertedFields)) + CreateExternalRow(convertedFields, schema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 75ecbaa453..b95c5dd892 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -388,6 +388,8 @@ case class MapObjects private( case a: ArrayType => (i: String) => s".getArray($i)" case _: MapType => (i: String) => s".getMap($i)" case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType) + case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)" + case DateType => (i: String) => s".getInt($i)" } private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { @@ -485,7 +487,9 @@ case class MapObjects private( * * @param children A list of expression to use as content of the external row. */ -case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression { +case class CreateExternalRow(children: Seq[Expression], schema: StructType) + extends Expression with NonSQLExpression { + override def dataType: DataType = ObjectType(classOf[Row]) override def nullable: Boolean = false @@ -494,8 +498,9 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val rowClass = classOf[GenericRow].getName + val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") + val schemaField = ctx.addReferenceObj("schema", schema) s""" boolean ${ev.isNull} = false; final Object[] $values = new Object[${children.size}]; @@ -510,7 +515,7 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with } """ }.mkString("\n") + - s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);" + s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);" } } |