From 5ad78f62056f2560cd371ee964111a646806d0ff Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 29 Jan 2015 00:01:10 -0800 Subject: [SQL] Various DataFrame DSL update. 1. Added foreach, foreachPartition, flatMap to DataFrame. 2. Added col() in dsl. 3. Support renaming columns in toDataFrame. 4. Support type inference on arrays (in addition to Seq). 5. Updated mllib to use the new DSL. Author: Reynold Xin Closes #4260 from rxin/sql-dsl-update and squashes the following commits: 73466c1 [Reynold Xin] Fixed LogisticRegression. Also added better error message for resolve. fab3ccc [Reynold Xin] Bug fix. d31fcd2 [Reynold Xin] Style fix. 62608c4 [Reynold Xin] [SQL] Various DataFrame DSL update. --- .../spark/sql/catalyst/ScalaReflection.scala | 5 ++- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 5 +++ .../main/scala/org/apache/spark/sql/Column.scala | 12 ++++-- .../scala/org/apache/spark/sql/DataFrame.scala | 47 ++++++++++++++++++++-- .../src/main/scala/org/apache/spark/sql/api.scala | 6 +++ .../scala/org/apache/spark/sql/api/java/dsl.java | 7 ++++ .../apache/spark/sql/api/scala/dsl/package.scala | 21 ++++++++++ 7 files changed, 94 insertions(+), 9 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 191d16fb10..4def65b01f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -57,6 +57,7 @@ trait ScalaReflection { case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (s: Array[_], arrayType: ArrayType) => s.toSeq case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) } @@ -140,7 +141,9 @@ trait ScalaReflection { // Need to decide if we actually need a special type here. case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 5138942a55..4a66716e0a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -60,6 +60,7 @@ case class OptionalData( case class ComplexData( arrayField: Seq[Int], + arrayField1: Array[Int], arrayFieldContainsNull: Seq[java.lang.Integer], mapField: Map[Int, Long], mapFieldValueContainsNull: Map[Int, java.lang.Long], @@ -131,6 +132,10 @@ class ScalaReflectionSuite extends FunSuite { "arrayField", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField( + "arrayField1", + ArrayType(IntegerType, containsNull = false), + nullable = true), StructField( "arrayFieldContainsNull", ArrayType(IntegerType, containsNull = true), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7f9a91a032..9be2a03afa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -22,15 +22,19 @@ import scala.language.implicitConversions import org.apache.spark.sql.api.scala.dsl.lit import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.types._ object Column { - def unapply(col: Column): Option[Expression] = Some(col.expr) - + /** + * Creates a [[Column]] based on the given column name. + * Same as [[api.scala.dsl.col]] and [[api.java.dsl.col]]. + */ def apply(colName: String): Column = new Column(colName) + + /** For internal pattern matching. */ + private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr) } @@ -438,7 +442,7 @@ class Column( * @param ordinal * @return */ - override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal)) + override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) /** * An expression that gets a field by name in a [[StructField]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ceb5f86bef..050366aea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -118,8 +118,8 @@ class DataFrame protected[sql]( /** Resolves a column name into a Catalyst [[NamedExpression]]. */ protected[sql] def resolve(colName: String): NamedExpression = { - logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse( - throw new RuntimeException(s"""Cannot resolve column name "$colName"""")) + logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(throw new RuntimeException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")) } /** Left here for compatibility reasons. */ @@ -131,6 +131,29 @@ class DataFrame protected[sql]( */ def toDataFrame: DataFrame = this + /** + * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion + * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: + * {{{ + * val rdd: RDD[(Int, String)] = ... + * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2 + * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name" + * }}} + */ + @scala.annotation.varargs + def toDataFrame(colName: String, colNames: String*): DataFrame = { + val newNames = colName +: colNames + require(schema.size == newNames.size, + "The number of columns doesn't match.\n" + + "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + + "New column names: " + newNames.mkString(", ")) + + val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) => + apply(oldName).as(newName) + } + select(newCols :_*) + } + /** Returns the schema of this [[DataFrame]]. */ override def schema: StructType = queryExecution.analyzed.schema @@ -227,7 +250,7 @@ class DataFrame protected[sql]( } /** - * Selects a single column and return it as a [[Column]]. + * Selects column based on the column name and return it as a [[Column]]. */ override def apply(colName: String): Column = colName match { case "*" => @@ -466,6 +489,12 @@ class DataFrame protected[sql]( rdd.map(f) } + /** + * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], + * and then flattening the results. + */ + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + /** * Returns a new RDD by applying a function to each partition of this DataFrame. */ @@ -473,6 +502,16 @@ class DataFrame protected[sql]( rdd.mapPartitions(f) } + /** + * Applies a function `f` to all rows. + */ + override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + + /** + * Applies a function f to each partition of this [[DataFrame]]. + */ + override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + /** * Returns the first `n` rows in the [[DataFrame]]. */ @@ -520,7 +559,7 @@ class DataFrame protected[sql]( ///////////////////////////////////////////////////////////////////////////// /** - * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s. + * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. */ override def rdd: RDD[Row] = { val schema = this.schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala index 5eeaf17d71..59634082f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -44,8 +44,14 @@ private[sql] trait RDDApi[T] { def map[R: ClassTag](f: T => R): RDD[R] + def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] + def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] + def foreach(f: T => Unit): Unit + + def foreachPartition(f: Iterator[T] => Unit): Unit + def take(n: Int): Array[T] def collect(): Array[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java index 74d7649e08..16702afdb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java @@ -32,6 +32,13 @@ public class dsl { private static package$ scalaDsl = package$.MODULE$; + /** + * Returns a {@link Column} based on the given column name. + */ + public static Column col(String colName) { + return new Column(colName); + } + /** * Creates a column of literal value. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala index 9f2d1427d4..dc851fc504 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.api.scala import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ @@ -37,6 +38,21 @@ package object dsl { /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) +// /** +// * An implicit conversion that turns a RDD of product into a [[DataFrame]]. +// * +// * This method requires an implicit SQLContext in scope. For example: +// * {{{ +// * implicit val sqlContext: SQLContext = ... +// * val rdd: RDD[(Int, String)] = ... +// * rdd.toDataFrame // triggers the implicit here +// * }}} +// */ +// implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext) +// : DataFrame = { +// context.createDataFrame(rdd) +// } + /** Converts $"col name" into an [[Column]]. */ implicit class StringToColumn(val sc: StringContext) extends AnyVal { def $(args: Any*): ColumnName = { @@ -46,6 +62,11 @@ package object dsl { private[this] implicit def toColumn(expr: Expression): Column = new Column(expr) + /** + * Returns a [[Column]] based on the given column name. + */ + def col(colName: String): Column = new Column(colName) + /** * Creates a [[Column]] of literal value. */ -- cgit v1.2.3