aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-29 00:01:10 -0800
committerReynold Xin <rxin@databricks.com>2015-01-29 00:01:10 -0800
commit5ad78f62056f2560cd371ee964111a646806d0ff (patch)
treec5db8104a00b4a835db77bf7f7116622b47c8cc3 /sql
parenta63be1a18f7b7d77f7deef2abc9a5be6ad24ae28 (diff)
downloadspark-5ad78f62056f2560cd371ee964111a646806d0ff.tar.gz
spark-5ad78f62056f2560cd371ee964111a646806d0ff.tar.bz2
spark-5ad78f62056f2560cd371ee964111a646806d0ff.zip
[SQL] Various DataFrame DSL update.
1. Added foreach, foreachPartition, flatMap to DataFrame. 2. Added col() in dsl. 3. Support renaming columns in toDataFrame. 4. Support type inference on arrays (in addition to Seq). 5. Updated mllib to use the new DSL. Author: Reynold Xin <rxin@databricks.com> Closes #4260 from rxin/sql-dsl-update and squashes the following commits: 73466c1 [Reynold Xin] Fixed LogisticRegression. Also added better error message for resolve. fab3ccc [Reynold Xin] Bug fix. d31fcd2 [Reynold Xin] Style fix. 62608c4 [Reynold Xin] [SQL] Various DataFrame DSL update.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala47
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala21
7 files changed, 94 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 191d16fb10..4def65b01f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -57,6 +57,7 @@ trait ScalaReflection {
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
+ case (s: Array[_], arrayType: ArrayType) => s.toSeq
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
@@ -140,7 +141,9 @@ trait ScalaReflection {
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
- sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 5138942a55..4a66716e0a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -60,6 +60,7 @@ case class OptionalData(
case class ComplexData(
arrayField: Seq[Int],
+ arrayField1: Array[Int],
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: Map[Int, Long],
mapFieldValueContainsNull: Map[Int, java.lang.Long],
@@ -132,6 +133,10 @@ class ScalaReflectionSuite extends FunSuite {
ArrayType(IntegerType, containsNull = false),
nullable = true),
StructField(
+ "arrayField1",
+ ArrayType(IntegerType, containsNull = false),
+ nullable = true),
+ StructField(
"arrayFieldContainsNull",
ArrayType(IntegerType, containsNull = true),
nullable = true),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 7f9a91a032..9be2a03afa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -22,15 +22,19 @@ import scala.language.implicitConversions
import org.apache.spark.sql.api.scala.dsl.lit
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.types._
object Column {
- def unapply(col: Column): Option[Expression] = Some(col.expr)
-
+ /**
+ * Creates a [[Column]] based on the given column name.
+ * Same as [[api.scala.dsl.col]] and [[api.java.dsl.col]].
+ */
def apply(colName: String): Column = new Column(colName)
+
+ /** For internal pattern matching. */
+ private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr)
}
@@ -438,7 +442,7 @@ class Column(
* @param ordinal
* @return
*/
- override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal))
+ override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
/**
* An expression that gets a field by name in a [[StructField]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index ceb5f86bef..050366aea8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -118,8 +118,8 @@ class DataFrame protected[sql](
/** Resolves a column name into a Catalyst [[NamedExpression]]. */
protected[sql] def resolve(colName: String): NamedExpression = {
- logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(
- throw new RuntimeException(s"""Cannot resolve column name "$colName""""))
+ logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(throw new RuntimeException(
+ s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})"""))
}
/** Left here for compatibility reasons. */
@@ -131,6 +131,29 @@ class DataFrame protected[sql](
*/
def toDataFrame: DataFrame = this
+ /**
+ * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
+ * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
+ * {{{
+ * val rdd: RDD[(Int, String)] = ...
+ * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2
+ * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name"
+ * }}}
+ */
+ @scala.annotation.varargs
+ def toDataFrame(colName: String, colNames: String*): DataFrame = {
+ val newNames = colName +: colNames
+ require(schema.size == newNames.size,
+ "The number of columns doesn't match.\n" +
+ "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
+ "New column names: " + newNames.mkString(", "))
+
+ val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) =>
+ apply(oldName).as(newName)
+ }
+ select(newCols :_*)
+ }
+
/** Returns the schema of this [[DataFrame]]. */
override def schema: StructType = queryExecution.analyzed.schema
@@ -227,7 +250,7 @@ class DataFrame protected[sql](
}
/**
- * Selects a single column and return it as a [[Column]].
+ * Selects column based on the column name and return it as a [[Column]].
*/
override def apply(colName: String): Column = colName match {
case "*" =>
@@ -467,6 +490,12 @@ class DataFrame protected[sql](
}
/**
+ * Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
+ * and then flattening the results.
+ */
+ override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
+
+ /**
* Returns a new RDD by applying a function to each partition of this DataFrame.
*/
override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
@@ -474,6 +503,16 @@ class DataFrame protected[sql](
}
/**
+ * Applies a function `f` to all rows.
+ */
+ override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
+
+ /**
+ * Applies a function f to each partition of this [[DataFrame]].
+ */
+ override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
+
+ /**
* Returns the first `n` rows in the [[DataFrame]].
*/
override def take(n: Int): Array[Row] = head(n)
@@ -520,7 +559,7 @@ class DataFrame protected[sql](
/////////////////////////////////////////////////////////////////////////////
/**
- * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s.
+ * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s.
*/
override def rdd: RDD[Row] = {
val schema = this.schema
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
index 5eeaf17d71..59634082f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
@@ -44,8 +44,14 @@ private[sql] trait RDDApi[T] {
def map[R: ClassTag](f: T => R): RDD[R]
+ def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R]
+
def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R]
+ def foreach(f: T => Unit): Unit
+
+ def foreachPartition(f: Iterator[T] => Unit): Unit
+
def take(n: Int): Array[T]
def collect(): Array[T]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
index 74d7649e08..16702afdb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
@@ -33,6 +33,13 @@ public class dsl {
private static package$ scalaDsl = package$.MODULE$;
/**
+ * Returns a {@link Column} based on the given column name.
+ */
+ public static Column col(String colName) {
+ return new Column(colName);
+ }
+
+ /**
* Creates a column of literal value.
*/
public static Column lit(Object literalValue) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
index 9f2d1427d4..dc851fc504 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.api.scala
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions._
@@ -37,6 +38,21 @@ package object dsl {
/** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+// /**
+// * An implicit conversion that turns a RDD of product into a [[DataFrame]].
+// *
+// * This method requires an implicit SQLContext in scope. For example:
+// * {{{
+// * implicit val sqlContext: SQLContext = ...
+// * val rdd: RDD[(Int, String)] = ...
+// * rdd.toDataFrame // triggers the implicit here
+// * }}}
+// */
+// implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext)
+// : DataFrame = {
+// context.createDataFrame(rdd)
+// }
+
/** Converts $"col name" into an [[Column]]. */
implicit class StringToColumn(val sc: StringContext) extends AnyVal {
def $(args: Any*): ColumnName = {
@@ -47,6 +63,11 @@ package object dsl {
private[this] implicit def toColumn(expr: Expression): Column = new Column(expr)
/**
+ * Returns a [[Column]] based on the given column name.
+ */
+ def col(colName: String): Column = new Column(colName)
+
+ /**
* Creates a [[Column]] of literal value.
*/
def lit(literal: Any): Column = {