aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java29
-rw-r--r--core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java29
-rw-r--r--core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java28
-rw-r--r--core/src/main/java/org/apache/spark/api/java/function/Function0.java2
-rw-r--r--core/src/main/java/org/apache/spark/api/java/function/MapFunction.java27
-rw-r--r--core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java28
-rw-r--r--core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala47
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala100
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java7
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala10
14 files changed, 316 insertions, 97 deletions
diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java
new file mode 100644
index 0000000000..e8d999dd00
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * Base interface for a function used in Dataset's filter function.
+ *
+ * If the function returns true, the element is discarded in the returned Dataset.
+ */
+public interface FilterFunction<T> extends Serializable {
+ boolean call(T value) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java
new file mode 100644
index 0000000000..07e54b28fa
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * Base interface for a function used in Dataset's foreach function.
+ *
+ * Spark will invoke the call function on each element in the input Dataset.
+ */
+public interface ForeachFunction<T> extends Serializable {
+ void call(T t) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java
new file mode 100644
index 0000000000..4938a51bcd
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * Base interface for a function used in Dataset's foreachPartition function.
+ */
+public interface ForeachPartitionFunction<T> extends Serializable {
+ void call(Iterator<T> t) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java
index 38e410c5de..c86928dd05 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java
@@ -23,5 +23,5 @@ import java.io.Serializable;
* A zero-argument function that returns an R.
*/
public interface Function0<R> extends Serializable {
- public R call() throws Exception;
+ R call() throws Exception;
}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java
new file mode 100644
index 0000000000..3ae6ef4489
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * Base interface for a map function used in Dataset's map function.
+ */
+public interface MapFunction<T, U> extends Serializable {
+ U call(T value) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java
new file mode 100644
index 0000000000..6cb569ce0c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * Base interface for function used in Dataset's mapPartitions.
+ */
+public interface MapPartitionsFunction<T, U> extends Serializable {
+ Iterable<U> call(Iterator<T> input) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java
new file mode 100644
index 0000000000..ee092d0058
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+/**
+ * Base interface for function used in Dataset's reduce.
+ */
+public interface ReduceFunction<T> extends Serializable {
+ T call(T v1, T v2) throws Exception;
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index f05e18288d..6569b900fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.reflect.ClassTag
import org.apache.spark.util.Utils
-import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType}
+import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.sql.catalyst.expressions._
/**
@@ -100,7 +100,7 @@ object Encoder {
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
Invoke(
- BoundReference(0, ObjectType(cls), true),
+ BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
t)
}
@@ -114,13 +114,13 @@ object Encoder {
} else {
enc.constructExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
- GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt)
+ GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
}
}
val constructExpression =
- NewInstance(cls, constructExpressions, false, ObjectType(cls))
+ NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls))
new ExpressionEncoder[Any](
schema,
@@ -130,7 +130,6 @@ object Encoder {
ClassTag.apply(cls))
}
-
def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
@@ -148,9 +147,36 @@ object Encoder {
})
}
- def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
+ def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
implicit val typeTag1 = getTypeTag(c1)
implicit val typeTag2 = getTypeTag(c2)
ExpressionEncoder[(T1, T2)]()
}
+
+ def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = {
+ implicit val typeTag1 = getTypeTag(c1)
+ implicit val typeTag2 = getTypeTag(c2)
+ implicit val typeTag3 = getTypeTag(c3)
+ ExpressionEncoder[(T1, T2, T3)]()
+ }
+
+ def forTuple[T1, T2, T3, T4](
+ c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = {
+ implicit val typeTag1 = getTypeTag(c1)
+ implicit val typeTag2 = getTypeTag(c2)
+ implicit val typeTag3 = getTypeTag(c3)
+ implicit val typeTag4 = getTypeTag(c4)
+ ExpressionEncoder[(T1, T2, T3, T4)]()
+ }
+
+ def forTuple[T1, T2, T3, T4, T5](
+ c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5])
+ : Encoder[(T1, T2, T3, T4, T5)] = {
+ implicit val typeTag1 = getTypeTag(c1)
+ implicit val typeTag2 = getTypeTag(c2)
+ implicit val typeTag3 = getTypeTag(c3)
+ implicit val typeTag4 = getTypeTag(c4)
+ implicit val typeTag5 = getTypeTag(c5)
+ ExpressionEncoder[(T1, T2, T3, T4, T5)]()
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index f2d4db5550..8ab958adad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1478,18 +1478,54 @@ class DataFrame private[sql](
/**
* Returns the first `n` rows in the [[DataFrame]].
+ *
+ * Running take requires moving data into the application's driver process, and doing so on a
+ * very large dataset can crash the driver process with OutOfMemoryError.
+ *
* @group action
* @since 1.3.0
*/
def take(n: Int): Array[Row] = head(n)
/**
+ * Returns the first `n` rows in the [[DataFrame]] as a list.
+ *
+ * Running take requires moving data into the application's driver process, and doing so with
+ * a very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * @group action
+ * @since 1.6.0
+ */
+ def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*)
+
+ /**
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
+ *
+ * Running take requires moving data into the application's driver process, and doing so with
+ * a very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * For Java API, use [[collectAsList]].
+ *
* @group action
* @since 1.3.0
*/
def collect(): Array[Row] = collect(needCallback = true)
+ /**
+ * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
+ *
+ * Running collect requires moving all the data into the application's driver process, and
+ * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ *
+ * @group action
+ * @since 1.3.0
+ */
+ def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
+ withNewExecutionId {
+ java.util.Arrays.asList(rdd.collect() : _*)
+ }
+ }
+
private def collect(needCallback: Boolean): Array[Row] = {
def execute(): Array[Row] = withNewExecutionId {
queryExecution.executedPlan.executeCollectPublic()
@@ -1503,17 +1539,6 @@ class DataFrame private[sql](
}
/**
- * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
- * @group action
- * @since 1.3.0
- */
- def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
- withNewExecutionId {
- java.util.Arrays.asList(rdd.collect() : _*)
- }
- }
-
- /**
* Returns the number of rows in the [[DataFrame]].
* @group action
* @since 1.3.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index fecbdac9a6..959e0f5ba0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
+import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
@@ -75,7 +75,11 @@ class Dataset[T] private[sql](
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
- /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */
+ /**
+ * Returns the schema of the encoded form of the objects in this [[Dataset]].
+ *
+ * @since 1.6.0
+ */
def schema: StructType = encoder.schema
/* ************* *
@@ -103,6 +107,7 @@ class Dataset[T] private[sql](
/**
* Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
* the same name after two Datasets have been joined.
+ * @since 1.6.0
*/
def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _))
@@ -166,8 +171,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
* @since 1.6.0
*/
- def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] =
- filter(t => func.call(t).booleanValue())
+ def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
/**
* (Scala-specific)
@@ -181,7 +185,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
- def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+ def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
map(t => func.call(t))(encoder)
/**
@@ -205,10 +209,8 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
- def mapPartitions[U](
- f: FlatMapFunction[java.util.Iterator[T], U],
- encoder: Encoder[U]): Dataset[U] = {
- val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala
+ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala
mapPartitions(func)(encoder)
}
@@ -248,7 +250,7 @@ class Dataset[T] private[sql](
* Runs `func` on each element of this Dataset.
* @since 1.6.0
*/
- def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_))
+ def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
/**
* (Scala-specific)
@@ -262,7 +264,7 @@ class Dataset[T] private[sql](
* Runs `func` on each partition of this Dataset.
* @since 1.6.0
*/
- def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit =
+ def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
foreachPartition(it => func.call(it.asJava))
/* ************* *
@@ -271,7 +273,7 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given function
+ * Reduces the elements of this Dataset using the specified binary function. The given function
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
@@ -279,33 +281,11 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given function
+ * Reduces the elements of this Dataset using the specified binary function. The given function
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
- def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _))
-
- /**
- * (Scala-specific)
- * Aggregates the elements of each partition, and then the results for all the partitions, using a
- * given associative and commutative function and a neutral "zero value".
- *
- * This behaves somewhat differently than the fold operations implemented for non-distributed
- * collections in functional languages like Scala. This fold operation may be applied to
- * partitions individually, and then those results will be folded into the final result.
- * If op is not commutative, then the result may differ from that of a fold applied to a
- * non-distributed collection.
- * @since 1.6.0
- */
- def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op)
-
- /**
- * (Java-specific)
- * Aggregates the elements of each partition, and then the results for all the partitions, using a
- * given associative and commutative function and a neutral "zero value".
- * @since 1.6.0
- */
- def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _))
+ def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
/**
* (Scala-specific)
@@ -351,7 +331,7 @@ class Dataset[T] private[sql](
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* @since 1.6.0
*/
- def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
+ def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
groupBy(f.call(_))(encoder)
/* ****************** *
@@ -367,7 +347,7 @@ class Dataset[T] private[sql](
*/
// Copied from Dataframe to make sure we don't have invalid overloads.
@scala.annotation.varargs
- def select(cols: Column*): DataFrame = toDF().select(cols: _*)
+ protected def select(cols: Column*): DataFrame = toDF().select(cols: _*)
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
@@ -462,8 +442,7 @@ class Dataset[T] private[sql](
* and thus is not affected by a custom `equals` function defined on `T`.
* @since 1.6.0
*/
- def intersect(other: Dataset[T]): Dataset[T] =
- withPlan[T](other)(Intersect)
+ def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect)
/**
* Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]]
@@ -473,8 +452,7 @@ class Dataset[T] private[sql](
* duplicate items. As such, it is analagous to `UNION ALL` in SQL.
* @since 1.6.0
*/
- def union(other: Dataset[T]): Dataset[T] =
- withPlan[T](other)(Union)
+ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
/**
* Returns a new [[Dataset]] where any elements present in `other` have been removed.
@@ -542,27 +520,47 @@ class Dataset[T] private[sql](
def first(): T = rdd.first()
/**
- * Collects the elements to an Array.
+ * Returns an array that contains all the elements in this [[Dataset]].
+ *
+ * Running collect requires moving all the data into the application's driver process, and
+ * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ *
+ * For Java API, use [[collectAsList]].
* @since 1.6.0
*/
def collect(): Array[T] = rdd.collect()
/**
- * (Java-specific)
- * Collects the elements to a Java list.
+ * Returns an array that contains all the elements in this [[Dataset]].
*
- * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at
- * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method
- * instead and keep the generic type for result.
+ * Running collect requires moving all the data into the application's driver process, and
+ * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
*
+ * For Java API, use [[collectAsList]].
* @since 1.6.0
*/
- def collectAsList(): java.util.List[T] =
- rdd.collect().toSeq.asJava
+ def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava
- /** Returns the first `num` elements of this [[Dataset]] as an Array. */
+ /**
+ * Returns the first `num` elements of this [[Dataset]] as an array.
+ *
+ * Running take requires moving data into the application's driver process, and doing so with
+ * a very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * @since 1.6.0
+ */
def take(num: Int): Array[T] = rdd.take(num)
+ /**
+ * Returns the first `num` elements of this [[Dataset]] as an array.
+ *
+ * Running take requires moving data into the application's driver process, and doing so with
+ * a very large `n` can crash the driver process with OutOfMemoryError.
+ *
+ * @since 1.6.0
+ */
+ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
+
/* ******************** *
* Internal Functions *
* ******************** */
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 40bff57a17..d191b50fa8 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -65,6 +65,13 @@ public class JavaDataFrameSuite {
Assert.assertEquals(1, df.select("key").collect()[0].get(0));
}
+ @Test
+ public void testCollectAndTake() {
+ DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3");
+ Assert.assertEquals(3, df.select("key").collectAsList().size());
+ Assert.assertEquals(2, df.select("key").takeAsList(2).size());
+ }
+
/**
* See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java.
*/
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0d3b1a5af5..0f90de774d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -68,8 +68,16 @@ public class JavaDatasetSuite implements Serializable {
public void testCollect() {
List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, e.STRING());
- String[] collected = (String[]) ds.collect();
- Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected));
+ List<String> collected = ds.collectAsList();
+ Assert.assertEquals(Arrays.asList("hello", "world"), collected);
+ }
+
+ @Test
+ public void testTake() {
+ List<String> data = Arrays.asList("hello", "world");
+ Dataset<String> ds = context.createDataset(data, e.STRING());
+ List<String> collected = ds.takeAsList(1);
+ Assert.assertEquals(Arrays.asList("hello"), collected);
}
@Test
@@ -78,16 +86,16 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> ds = context.createDataset(data, e.STRING());
Assert.assertEquals("hello", ds.first());
- Dataset<String> filtered = ds.filter(new Function<String, Boolean>() {
+ Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@Override
- public Boolean call(String v) throws Exception {
+ public boolean call(String v) throws Exception {
return v.startsWith("h");
}
});
Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList());
- Dataset<Integer> mapped = ds.map(new Function<String, Integer>() {
+ Dataset<Integer> mapped = ds.map(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@@ -95,7 +103,7 @@ public class JavaDatasetSuite implements Serializable {
}, e.INT());
Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
- Dataset<String> parMapped = ds.mapPartitions(new FlatMapFunction<Iterator<String>, String>() {
+ Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
@Override
public Iterable<String> call(Iterator<String> it) throws Exception {
List<String> ls = new LinkedList<String>();
@@ -128,7 +136,7 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("a", "b", "c");
Dataset<String> ds = context.createDataset(data, e.STRING());
- ds.foreach(new VoidFunction<String>() {
+ ds.foreach(new ForeachFunction<String>() {
@Override
public void call(String s) throws Exception {
accum.add(1);
@@ -142,28 +150,20 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data = Arrays.asList(1, 2, 3);
Dataset<Integer> ds = context.createDataset(data, e.INT());
- int reduced = ds.reduce(new Function2<Integer, Integer, Integer>() {
+ int reduced = ds.reduce(new ReduceFunction<Integer>() {
@Override
public Integer call(Integer v1, Integer v2) throws Exception {
return v1 + v2;
}
});
Assert.assertEquals(6, reduced);
-
- int folded = ds.fold(1, new Function2<Integer, Integer, Integer>() {
- @Override
- public Integer call(Integer v1, Integer v2) throws Exception {
- return v1 * v2;
- }
- });
- Assert.assertEquals(6, folded);
}
@Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, e.STRING());
- GroupedDataset<Integer, String> grouped = ds.groupBy(new Function<String, Integer>() {
+ GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
@@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable {
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
- GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new Function<Integer, Integer>() {
+ GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
@Override
public Integer call(Integer v) throws Exception {
return v / 2;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index fcf03f7180..63b00975e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -75,11 +75,6 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
assert(ds.reduce(_ + _) == 6)
}
- test("fold") {
- val ds = Seq(1, 2, 3).toDS()
- assert(ds.fold(0)(_ + _) == 6)
- }
-
test("groupBy function, keys") {
val ds = Seq(1, 2, 3, 4, 5).toDS()
val grouped = ds.groupBy(_ % 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6f1174e657..aea5a700d0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -61,6 +61,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)))
}
+ test("as case class - take") {
+ val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData]
+ assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2)))
+ }
+
test("map") {
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
checkAnswer(
@@ -137,11 +142,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
}
- test("fold") {
- val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
- assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
- }
-
test("joinWith, flat schema") {
val ds1 = Seq(1, 2, 3).toDS().as("a")
val ds2 = Seq(1, 2).toDS().as("b")