aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-03-30 20:47:10 -0700
committerReynold Xin <rxin@databricks.com>2015-03-30 20:47:10 -0700
commitb8ff2bc61c9835867f56afa1860ab5eb727c4a58 (patch)
treee29f737f32f9c21e22ff6fd7778549ec907c6015 /sql
parentfde6945417355ae57500b67d034c9cad4f20d240 (diff)
downloadspark-b8ff2bc61c9835867f56afa1860ab5eb727c4a58.tar.gz
spark-b8ff2bc61c9835867f56afa1860ab5eb727c4a58.tar.bz2
spark-b8ff2bc61c9835867f56afa1860ab5eb727c4a58.zip
[SPARK-6119][SQL] DataFrame support for missing data handling
This pull request adds variants of DataFrame.na.drop and DataFrame.na.fill to the Scala/Java API, and DataFrame.fillna and DataFrame.dropna to the Python API. Author: Reynold Xin <rxin@databricks.com> Closes #5274 from rxin/df-missing-value and squashes the following commits: 4ee1b98 [Reynold Xin] Improve error reporting in Python. 33a330c [Reynold Xin] Remove replace for now. bc4fdbb [Reynold Xin] Added documentation for replace. d56f5a5 [Reynold Xin] Added replace for Scala/Java. 2385d00 [Reynold Xin] Feedback from Xiangrui on "how". 914a374 [Reynold Xin] fill with map. 185c67e [Reynold Xin] Allow specifying column subsets in fill. 749eb47 [Reynold Xin] fillna 249b94e [Reynold Xin] Removing undefined functions. 6a73c68 [Reynold Xin] Missing file. 67d7003 [Reynold Xin] [SPARK-6119][SQL] DataFrame.na.drop (Scala/Java) and DataFrame.dropna (Python)
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala228
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala157
6 files changed, 424 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index d1f3d4f4ee..f9161cf34f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -35,7 +35,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
override def toString: String = s"Coalesce(${children.mkString(",")})"
- def dataType: DataType = if (resolved) {
+ override def dataType: DataType = if (resolved) {
children.head.dataType
} else {
val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ")
@@ -74,3 +74,26 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
child.eval(input) != null
}
}
+
+/**
+ * A predicate that is evaluated to be true if there are at least `n` non-null values.
+ */
+case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
+ override def nullable: Boolean = false
+ override def foldable: Boolean = false
+ override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"
+
+ private[this] val childrenArray = children.toArray
+
+ override def eval(input: Row): Boolean = {
+ var numNonNulls = 0
+ var i = 0
+ while (i < childrenArray.length && numNonNulls < n) {
+ if (childrenArray(i).eval(input) != null) {
+ numNonNulls += 1
+ }
+ i += 1
+ }
+ numNonNulls >= n
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 423ef3912b..5cd0a18ff6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -237,8 +237,8 @@ class DataFrame private[sql](
def toDF(colNames: String*): DataFrame = {
require(schema.size == colNames.size,
"The number of columns doesn't match.\n" +
- "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
- "New column names: " + colNames.mkString(", "))
+ s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" +
+ s"New column names (${colNames.size}): " + colNames.mkString(", "))
val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) =>
apply(oldName).as(newName)
@@ -320,6 +320,17 @@ class DataFrame private[sql](
def show(): Unit = show(20)
/**
+ * Returns a [[DataFrameNaFunctions]] for working with missing data.
+ * {{{
+ * // Dropping rows containing any null values.
+ * df.na.drop()
+ * }}}
+ *
+ * @group dfops
+ */
+ def na: DataFrameNaFunctions = new DataFrameNaFunctions(this)
+
+ /**
* Cartesian join with another [[DataFrame]].
*
* Note that cartesian joins are very expensive without an extra filter that can be pushed down.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
new file mode 100644
index 0000000000..3a3dc70f72
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -0,0 +1,228 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import java.{lang => jl}
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+
+/**
+ * Functionality for working with missing data in [[DataFrame]]s.
+ */
+final class DataFrameNaFunctions private[sql](df: DataFrame) {
+
+ /**
+ * Returns a new [[DataFrame]] that drops rows containing any null values.
+ */
+ def drop(): DataFrame = drop("any", df.columns)
+
+ /**
+ * Returns a new [[DataFrame]] that drops rows containing null values.
+ *
+ * If `how` is "any", then drop rows containing any null values.
+ * If `how` is "all", then drop rows only if every column is null for that row.
+ */
+ def drop(how: String): DataFrame = drop(how, df.columns)
+
+ /**
+ * Returns a new [[DataFrame]] that drops rows containing any null values
+ * in the specified columns.
+ */
+ def drop(cols: Array[String]): DataFrame = drop(cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values
+ * in the specified columns.
+ */
+ def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols)
+
+ /**
+ * Returns a new [[DataFrame]] that drops rows containing null values
+ * in the specified columns.
+ *
+ * If `how` is "any", then drop rows containing any null values in the specified columns.
+ * If `how` is "all", then drop rows only if every specified column is null for that row.
+ */
+ def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values
+ * in the specified columns.
+ *
+ * If `how` is "any", then drop rows containing any null values in the specified columns.
+ * If `how` is "all", then drop rows only if every specified column is null for that row.
+ */
+ def drop(how: String, cols: Seq[String]): DataFrame = {
+ how.toLowerCase match {
+ case "any" => drop(cols.size, cols)
+ case "all" => drop(1, cols)
+ case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'")
+ }
+ }
+
+ /**
+ * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values.
+ */
+ def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns)
+
+ /**
+ * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null
+ * values in the specified columns.
+ */
+ def drop(minNonNulls: Int, cols: Array[String]): DataFrame = drop(minNonNulls, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than
+ * `minNonNulls` non-null values in the specified columns.
+ */
+ def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
+ // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values.
+ val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name)))
+ df.filter(Column(predicate))
+ }
+
+ /**
+ * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`.
+ */
+ def fill(value: Double): DataFrame = fill(value, df.columns)
+
+ /**
+ * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`.
+ */
+ def fill(value: String): DataFrame = fill(value, df.columns)
+
+ /**
+ * Returns a new [[DataFrame]] that replaces null values in specified numeric columns.
+ * If a specified column is not a numeric column, it is ignored.
+ */
+ def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified
+ * numeric columns. If a specified column is not a numeric column, it is ignored.
+ */
+ def fill(value: Double, cols: Seq[String]): DataFrame = {
+ val columnEquals = df.sqlContext.analyzer.resolver
+ val projections = df.schema.fields.map { f =>
+ // Only fill if the column is part of the cols list.
+ if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) {
+ fillCol[Double](f, value)
+ } else {
+ df.col(f.name)
+ }
+ }
+ df.select(projections : _*)
+ }
+
+ /**
+ * Returns a new [[DataFrame]] that replaces null values in specified string columns.
+ * If a specified column is not a string column, it is ignored.
+ */
+ def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in
+ * specified string columns. If a specified column is not a string column, it is ignored.
+ */
+ def fill(value: String, cols: Seq[String]): DataFrame = {
+ val columnEquals = df.sqlContext.analyzer.resolver
+ val projections = df.schema.fields.map { f =>
+ // Only fill if the column is part of the cols list.
+ if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) {
+ fillCol[String](f, value)
+ } else {
+ df.col(f.name)
+ }
+ }
+ df.select(projections : _*)
+ }
+
+ /**
+ * Returns a new [[DataFrame]] that replaces null values.
+ *
+ * The key of the map is the column name, and the value of the map is the replacement value.
+ * The value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`.
+ *
+ * For example, the following replaces null values in column "A" with string "unknown", and
+ * null values in column "B" with numeric value 1.0.
+ * {{{
+ * import com.google.common.collect.ImmutableMap;
+ * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0));
+ * }}}
+ */
+ def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
+
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] that replaces null values.
+ *
+ * The key of the map is the column name, and the value of the map is the replacement value.
+ * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`.
+ *
+ * For example, the following replaces null values in column "A" with string "unknown", and
+ * null values in column "B" with numeric value 1.0.
+ * {{{
+ * df.na.fill(Map(
+ * "A" -> "unknown",
+ * "B" -> 1.0
+ * ))
+ * }}}
+ */
+ def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
+
+ private def fill0(values: Seq[(String, Any)]): DataFrame = {
+ // Error handling
+ values.foreach { case (colName, replaceValue) =>
+ // Check column name exists
+ df.resolve(colName)
+
+ // Check data type
+ replaceValue match {
+ case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _: String =>
+ // This is good
+ case _ => throw new IllegalArgumentException(
+ s"Unsupported value type ${replaceValue.getClass.getName} ($replaceValue).")
+ }
+ }
+
+ val columnEquals = df.sqlContext.analyzer.resolver
+ val projections = df.schema.fields.map { f =>
+ values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
+ v match {
+ case v: jl.Float => fillCol[Double](f, v.toDouble)
+ case v: jl.Double => fillCol[Double](f, v)
+ case v: jl.Long => fillCol[Double](f, v.toDouble)
+ case v: jl.Integer => fillCol[Double](f, v.toDouble)
+ case v: String => fillCol[String](f, v)
+ }
+ }.getOrElse(df.col(f.name))
+ }
+ df.select(projections : _*)
+ }
+
+ /**
+ * Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
+ */
+ private def fillCol[T](col: StructField, replacement: T): Column = {
+ coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 45a63ae26e..a5e6b638d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -127,10 +127,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
* import com.google.common.collect.ImmutableMap;
- * df.groupBy("department").agg(ImmutableMap.<String, String>builder()
- * .put("age", "max")
- * .put("expense", "sum")
- * .build());
+ * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum"));
* }}}
*/
def agg(exprs: java.util.Map[String, String]): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 2b0358c4e2..0b770f2251 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -49,7 +49,7 @@ private[sql] object JsonRDD extends Logging {
val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1)
val allKeys =
if (schemaData.isEmpty()) {
- Set.empty[(String,DataType)]
+ Set.empty[(String, DataType)]
} else {
parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
new file mode 100644
index 0000000000..0896f175c0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.sql.test.TestSQLContext.implicits._
+
+
+class DataFrameNaFunctionsSuite extends QueryTest {
+
+ def createDF(): DataFrame = {
+ Seq[(String, java.lang.Integer, java.lang.Double)](
+ ("Bob", 16, 176.5),
+ ("Alice", null, 164.3),
+ ("David", 60, null),
+ ("Amy", null, null),
+ (null, null, null)).toDF("name", "age", "height")
+ }
+
+ test("drop") {
+ val input = createDF()
+ val rows = input.collect()
+
+ checkAnswer(
+ input.na.drop("name" :: Nil),
+ rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil)
+
+ checkAnswer(
+ input.na.drop("age" :: Nil),
+ rows(0) :: rows(2) :: Nil)
+
+ checkAnswer(
+ input.na.drop("age" :: "height" :: Nil),
+ rows(0) :: Nil)
+
+ checkAnswer(
+ input.na.drop(),
+ rows(0))
+
+ // dropna on an a dataframe with no column should return an empty data frame.
+ val empty = input.sqlContext.emptyDataFrame.select()
+ assert(empty.na.drop().count() === 0L)
+
+ // Make sure the columns are properly named.
+ assert(input.na.drop().columns.toSeq === input.columns.toSeq)
+ }
+
+ test("drop with how") {
+ val input = createDF()
+ val rows = input.collect()
+
+ checkAnswer(
+ input.na.drop("all"),
+ rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil)
+
+ checkAnswer(
+ input.na.drop("any"),
+ rows(0) :: Nil)
+
+ checkAnswer(
+ input.na.drop("any", Seq("age", "height")),
+ rows(0) :: Nil)
+
+ checkAnswer(
+ input.na.drop("all", Seq("age", "height")),
+ rows(0) :: rows(1) :: rows(2) :: Nil)
+ }
+
+ test("drop with threshold") {
+ val input = createDF()
+ val rows = input.collect()
+
+ checkAnswer(
+ input.na.drop(2, Seq("age", "height")),
+ rows(0) :: Nil)
+
+ checkAnswer(
+ input.na.drop(3, Seq("name", "age", "height")),
+ rows(0))
+
+ // Make sure the columns are properly named.
+ assert(input.na.drop(2, Seq("age", "height")).columns.toSeq === input.columns.toSeq)
+ }
+
+ test("fill") {
+ val input = createDF()
+
+ val fillNumeric = input.na.fill(50.6)
+ checkAnswer(
+ fillNumeric,
+ Row("Bob", 16, 176.5) ::
+ Row("Alice", 50, 164.3) ::
+ Row("David", 60, 50.6) ::
+ Row("Amy", 50, 50.6) ::
+ Row(null, 50, 50.6) :: Nil)
+
+ // Make sure the columns are properly named.
+ assert(fillNumeric.columns.toSeq === input.columns.toSeq)
+
+ // string
+ checkAnswer(
+ input.na.fill("unknown").select("name"),
+ Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil)
+ assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq)
+
+ // fill double with subset columns
+ checkAnswer(
+ input.na.fill(50.6, "age" :: Nil),
+ Row("Bob", 16, 176.5) ::
+ Row("Alice", 50, 164.3) ::
+ Row("David", 60, null) ::
+ Row("Amy", 50, null) ::
+ Row(null, 50, null) :: Nil)
+
+ // fill string with subset columns
+ checkAnswer(
+ Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil),
+ Row("test", null))
+ }
+
+ test("fill with map") {
+ val df = Seq[(String, String, java.lang.Long, java.lang.Double)](
+ (null, null, null, null)).toDF("a", "b", "c", "d")
+ checkAnswer(
+ df.na.fill(Map(
+ "a" -> "test",
+ "c" -> 1,
+ "d" -> 2.2
+ )),
+ Row("test", null, 1, 2.2))
+
+ // Test Java version
+ checkAnswer(
+ df.na.fill(mapAsJavaMap(Map(
+ "a" -> "test",
+ "c" -> 1,
+ "d" -> 2.2
+ ))),
+ Row("test", null, 1, 2.2))
+ }
+}