aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-06 15:37:07 -0800
committerReynold Xin <rxin@databricks.com>2015-11-06 15:37:07 -0800
commit7e9a9e603abce8689938bdd62d04b29299644aa4 (patch)
treecff62e9cde2e44aae8a2b5a8a2ae536bb67b9f38 /sql
parentf6680cdc5d2912dea9768ef5c3e2cc101b06daf8 (diff)
downloadspark-7e9a9e603abce8689938bdd62d04b29299644aa4.tar.gz
spark-7e9a9e603abce8689938bdd62d04b29299644aa4.tar.bz2
spark-7e9a9e603abce8689938bdd62d04b29299644aa4.zip
[SPARK-11269][SQL] Java API support & test cases for Dataset
This simply brings https://github.com/apache/spark/pull/9358 up-to-date. Author: Wenchen Fan <wenchen@databricks.com> Author: Reynold Xin <rxin@databricks.com> Closes #9528 from rxin/dataset-java.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala123
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala126
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala4
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java357
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala2
8 files changed, 644 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index 329a132d3d..f05e18288d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.catalyst.encoders
-
-
import scala.reflect.ClassTag
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType}
+import org.apache.spark.sql.catalyst.expressions._
/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable {
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
def clsTag: ClassTag[T]
}
+
+object Encoder {
+ import scala.reflect.runtime.universe._
+
+ def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
+ def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
+ def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
+ def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true)
+ def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true)
+ def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true)
+ def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
+ def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
+
+ def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = {
+ tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
+ .asInstanceOf[ExpressionEncoder[(T1, T2)]]
+ }
+
+ def tuple[T1, T2, T3](
+ enc1: Encoder[T1],
+ enc2: Encoder[T2],
+ enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
+ tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
+ .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
+ }
+
+ def tuple[T1, T2, T3, T4](
+ enc1: Encoder[T1],
+ enc2: Encoder[T2],
+ enc3: Encoder[T3],
+ enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
+ tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
+ .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
+ }
+
+ def tuple[T1, T2, T3, T4, T5](
+ enc1: Encoder[T1],
+ enc2: Encoder[T2],
+ enc3: Encoder[T3],
+ enc4: Encoder[T4],
+ enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
+ tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
+ .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
+ }
+
+ private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+ assert(encoders.length > 1)
+ // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
+ assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
+
+ val schema = StructType(encoders.zipWithIndex.map {
+ case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
+ })
+
+ val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
+
+ val extractExpressions = encoders.map {
+ case e if e.flat => e.extractExpressions.head
+ case other => CreateStruct(other.extractExpressions)
+ }.zipWithIndex.map { case (expr, index) =>
+ expr.transformUp {
+ case BoundReference(0, t: ObjectType, _) =>
+ Invoke(
+ BoundReference(0, ObjectType(cls), true),
+ s"_${index + 1}",
+ t)
+ }
+ }
+
+ val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
+ if (enc.flat) {
+ enc.constructExpression.transform {
+ case b: BoundReference => b.copy(ordinal = index)
+ }
+ } else {
+ enc.constructExpression.transformUp {
+ case BoundReference(ordinal, dt, _) =>
+ GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt)
+ }
+ }
+ }
+
+ val constructExpression =
+ NewInstance(cls, constructExpressions, false, ObjectType(cls))
+
+ new ExpressionEncoder[Any](
+ schema,
+ false,
+ extractExpressions,
+ constructExpression,
+ ClassTag.apply(cls))
+ }
+
+
+ def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
+
+ private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
+ import scala.reflect.api
+
+ // val mirror = runtimeMirror(c.getClassLoader)
+ val mirror = rootMirror
+ val sym = mirror.staticClass(c.getName)
+ val tpe = sym.selfType
+ TypeTag(mirror, new api.TypeCreator {
+ def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
+ if (m eq mirror) tpe.asInstanceOf[U # Type]
+ else throw new IllegalArgumentException(
+ s"Type tag defined in $mirror cannot be migrated to other mirrors.")
+ })
+ }
+
+ def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
+ implicit val typeTag1 = getTypeTag(c1)
+ implicit val typeTag2 = getTypeTag(c2)
+ ExpressionEncoder[(T1, T2)]()
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 8185528976..4f58464221 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -491,3 +491,24 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
}
}
+
+case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
+ extends UnaryExpression {
+
+ override def nullable: Boolean = true
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val row = child.gen(ctx)
+ s"""
+ ${row.code}
+ final boolean ${ev.isNull} = ${row.isNull};
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)};
+ }
+ """
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 4bca9c3b3f..fecbdac9a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -17,9 +17,13 @@
package org.apache.spark.sql
+import scala.collection.JavaConverters._
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
+import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
+
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
@@ -151,18 +155,37 @@ class Dataset[T] private[sql](
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
/**
+ * (Scala-specific)
* Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
* @since 1.6.0
*/
def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
/**
+ * (Java-specific)
+ * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
+ * @since 1.6.0
+ */
+ def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] =
+ filter(t => func.call(t).booleanValue())
+
+ /**
+ * (Scala-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
/**
+ * (Java-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+ map(t => func.call(t))(encoder)
+
+ /**
+ * (Scala-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* @since 1.6.0
*/
@@ -177,30 +200,77 @@ class Dataset[T] private[sql](
logicalPlan))
}
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+ * @since 1.6.0
+ */
+ def mapPartitions[U](
+ f: FlatMapFunction[java.util.Iterator[T], U],
+ encoder: Encoder[U]): Dataset[U] = {
+ val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala
+ mapPartitions(func)(encoder)
+ }
+
+ /**
+ * (Scala-specific)
+ * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
+ * and then flattening the results.
+ * @since 1.6.0
+ */
def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
mapPartitions(_.flatMap(func))
+ /**
+ * (Java-specific)
+ * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
+ * and then flattening the results.
+ * @since 1.6.0
+ */
+ def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ val func: (T) => Iterable[U] = x => f.call(x).asScala
+ flatMap(func)(encoder)
+ }
+
/* ************** *
* Side effects *
* ************** */
/**
+ * (Scala-specific)
* Runs `func` on each element of this Dataset.
* @since 1.6.0
*/
def foreach(func: T => Unit): Unit = rdd.foreach(func)
/**
+ * (Java-specific)
+ * Runs `func` on each element of this Dataset.
+ * @since 1.6.0
+ */
+ def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_))
+
+ /**
+ * (Scala-specific)
* Runs `func` on each partition of this Dataset.
* @since 1.6.0
*/
def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
+ /**
+ * (Java-specific)
+ * Runs `func` on each partition of this Dataset.
+ * @since 1.6.0
+ */
+ def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit =
+ foreachPartition(it => func.call(it.asJava))
+
/* ************* *
* Aggregation *
* ************* */
/**
+ * (Scala-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
@@ -208,6 +278,15 @@ class Dataset[T] private[sql](
def reduce(func: (T, T) => T): T = rdd.reduce(func)
/**
+ * (Java-specific)
+ * Reduces the elements of this Dataset using the specified binary function. The given function
+ * must be commutative and associative or the result may be non-deterministic.
+ * @since 1.6.0
+ */
+ def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _))
+
+ /**
+ * (Scala-specific)
* Aggregates the elements of each partition, and then the results for all the partitions, using a
* given associative and commutative function and a neutral "zero value".
*
@@ -221,6 +300,15 @@ class Dataset[T] private[sql](
def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op)
/**
+ * (Java-specific)
+ * Aggregates the elements of each partition, and then the results for all the partitions, using a
+ * given associative and commutative function and a neutral "zero value".
+ * @since 1.6.0
+ */
+ def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _))
+
+ /**
+ * (Scala-specific)
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* @since 1.6.0
*/
@@ -258,6 +346,14 @@ class Dataset[T] private[sql](
keyAttributes)
}
+ /**
+ * (Java-specific)
+ * Returns a [[GroupedDataset]] where the data is grouped by the given key function.
+ * @since 1.6.0
+ */
+ def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
+ groupBy(f.call(_))(encoder)
+
/* ****************** *
* Typed Relational *
* ****************** */
@@ -267,8 +363,7 @@ class Dataset[T] private[sql](
* {{{
* df.select($"colA", $"colB" + 1)
* }}}
- * @group dfops
- * @since 1.3.0
+ * @since 1.6.0
*/
// Copied from Dataframe to make sure we don't have invalid overloads.
@scala.annotation.varargs
@@ -279,7 +374,7 @@ class Dataset[T] private[sql](
*
* {{{
* val ds = Seq(1, 2, 3).toDS()
- * val newDS = ds.select(e[Int]("value + 1"))
+ * val newDS = ds.select(expr("value + 1").as[Int])
* }}}
* @since 1.6.0
*/
@@ -405,6 +500,8 @@ class Dataset[T] private[sql](
* This type of join can be useful both for preserving type-safety with the original object
* types as well as working with relational data where either side of the join has column
* names in common.
+ *
+ * @since 1.6.0
*/
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
val left = this.logicalPlan
@@ -438,12 +535,31 @@ class Dataset[T] private[sql](
* Gather to Driver Actions *
* ************************** */
- /** Returns the first element in this [[Dataset]]. */
+ /**
+ * Returns the first element in this [[Dataset]].
+ * @since 1.6.0
+ */
def first(): T = rdd.first()
- /** Collects the elements to an Array. */
+ /**
+ * Collects the elements to an Array.
+ * @since 1.6.0
+ */
def collect(): Array[T] = rdd.collect()
+ /**
+ * (Java-specific)
+ * Collects the elements to a Java list.
+ *
+ * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at
+ * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method
+ * instead and keep the generic type for result.
+ *
+ * @since 1.6.0
+ */
+ def collectAsList(): java.util.List[T] =
+ rdd.collect().toSeq.asJava
+
/** Returns the first `num` elements of this [[Dataset]] as an Array. */
def take(num: Int): Array[T] = rdd.take(num)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
index 45f0098b92..08097e9f02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala
@@ -27,9 +27,9 @@ package org.apache.spark.sql
*
* @since 1.6.0
*/
-case class DatasetHolder[T] private[sql](private val df: Dataset[T]) {
+case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) {
// This is declared with parentheses to prevent the Scala compiler from treating
- // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
- def toDS(): Dataset[T] = df
+ // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset.
+ def toDS(): Dataset[T] = ds
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index b8fc373dff..b2803d5a9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,7 +17,11 @@
package org.apache.spark.sql
+import java.util.{Iterator => JIterator}
+import scala.collection.JavaConverters._
+
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
@@ -104,6 +108,12 @@ class GroupedDataset[K, T] private[sql](
MapGroups(f, groupingAttributes, logicalPlan))
}
+ def mapGroups[U](
+ f: JFunction2[K, JIterator[T], JIterator[U]],
+ encoder: Encoder[U]): Dataset[U] = {
+ mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
+ }
+
// To ensure valid overloading.
protected def agg(expr: Column, exprs: Column*): DataFrame =
groupedData.agg(expr, exprs: _*)
@@ -196,4 +206,11 @@ class GroupedDataset[K, T] private[sql](
this.logicalPlan,
other.logicalPlan))
}
+
+ def cogroup[U, R](
+ other: GroupedDataset[K, U],
+ f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]],
+ encoder: Encoder[R]): Dataset[R] = {
+ cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 5ad3871093..5598731af5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -508,6 +508,10 @@ class SQLContext private[sql](
new Dataset[T](this, plan)
}
+ def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
+ createDataset(data.asScala)
+ }
+
/**
* Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be
* converted to Catalyst rows.
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
new file mode 100644
index 0000000000..a9493d576d
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -0,0 +1,357 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql;
+
+import java.io.Serializable;
+import java.util.*;
+
+import scala.Tuple2;
+import scala.Tuple3;
+import scala.Tuple4;
+import scala.Tuple5;
+import org.junit.*;
+
+import org.apache.spark.Accumulator;
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.function.*;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.catalyst.encoders.Encoder;
+import org.apache.spark.sql.catalyst.encoders.Encoder$;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.GroupedDataset;
+import org.apache.spark.sql.test.TestSQLContext;
+
+import static org.apache.spark.sql.functions.*;
+
+public class JavaDatasetSuite implements Serializable {
+ private transient JavaSparkContext jsc;
+ private transient TestSQLContext context;
+ private transient Encoder$ e = Encoder$.MODULE$;
+
+ @Before
+ public void setUp() {
+ // Trigger static initializer of TestData
+ SparkContext sc = new SparkContext("local[*]", "testing");
+ jsc = new JavaSparkContext(sc);
+ context = new TestSQLContext(sc);
+ context.loadTestData();
+ }
+
+ @After
+ public void tearDown() {
+ context.sparkContext().stop();
+ context = null;
+ jsc = null;
+ }
+
+ private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
+ return new Tuple2<T1, T2>(t1, t2);
+ }
+
+ @Test
+ public void testCollect() {
+ List<String> data = Arrays.asList("hello", "world");
+ Dataset<String> ds = context.createDataset(data, e.STRING());
+ String[] collected = (String[]) ds.collect();
+ Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected));
+ }
+
+ @Test
+ public void testCommonOperation() {
+ List<String> data = Arrays.asList("hello", "world");
+ Dataset<String> ds = context.createDataset(data, e.STRING());
+ Assert.assertEquals("hello", ds.first());
+
+ Dataset<String> filtered = ds.filter(new Function<String, Boolean>() {
+ @Override
+ public Boolean call(String v) throws Exception {
+ return v.startsWith("h");
+ }
+ });
+ Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList());
+
+
+ Dataset<Integer> mapped = ds.map(new Function<String, Integer>() {
+ @Override
+ public Integer call(String v) throws Exception {
+ return v.length();
+ }
+ }, e.INT());
+ Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
+
+ Dataset<String> parMapped = ds.mapPartitions(new FlatMapFunction<Iterator<String>, String>() {
+ @Override
+ public Iterable<String> call(Iterator<String> it) throws Exception {
+ List<String> ls = new LinkedList<String>();
+ while (it.hasNext()) {
+ ls.add(it.next().toUpperCase());
+ }
+ return ls;
+ }
+ }, e.STRING());
+ Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList());
+
+ Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> call(String s) throws Exception {
+ List<String> ls = new LinkedList<String>();
+ for (char c : s.toCharArray()) {
+ ls.add(String.valueOf(c));
+ }
+ return ls;
+ }
+ }, e.STRING());
+ Assert.assertEquals(
+ Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"),
+ flatMapped.collectAsList());
+ }
+
+ @Test
+ public void testForeach() {
+ final Accumulator<Integer> accum = jsc.accumulator(0);
+ List<String> data = Arrays.asList("a", "b", "c");
+ Dataset<String> ds = context.createDataset(data, e.STRING());
+
+ ds.foreach(new VoidFunction<String>() {
+ @Override
+ public void call(String s) throws Exception {
+ accum.add(1);
+ }
+ });
+ Assert.assertEquals(3, accum.value().intValue());
+ }
+
+ @Test
+ public void testReduce() {
+ List<Integer> data = Arrays.asList(1, 2, 3);
+ Dataset<Integer> ds = context.createDataset(data, e.INT());
+
+ int reduced = ds.reduce(new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer v1, Integer v2) throws Exception {
+ return v1 + v2;
+ }
+ });
+ Assert.assertEquals(6, reduced);
+
+ int folded = ds.fold(1, new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer v1, Integer v2) throws Exception {
+ return v1 * v2;
+ }
+ });
+ Assert.assertEquals(6, folded);
+ }
+
+ @Test
+ public void testGroupBy() {
+ List<String> data = Arrays.asList("a", "foo", "bar");
+ Dataset<String> ds = context.createDataset(data, e.STRING());
+ GroupedDataset<Integer, String> grouped = ds.groupBy(new Function<String, Integer>() {
+ @Override
+ public Integer call(String v) throws Exception {
+ return v.length();
+ }
+ }, e.INT());
+
+ Dataset<String> mapped = grouped.mapGroups(
+ new Function2<Integer, Iterator<String>, Iterator<String>>() {
+ @Override
+ public Iterator<String> call(Integer key, Iterator<String> data) throws Exception {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (data.hasNext()) {
+ sb.append(data.next());
+ }
+ return Collections.singletonList(sb.toString()).iterator();
+ }
+ },
+ e.STRING());
+
+ Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
+
+ List<Integer> data2 = Arrays.asList(2, 6, 10);
+ Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
+ GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new Function<Integer, Integer>() {
+ @Override
+ public Integer call(Integer v) throws Exception {
+ return v / 2;
+ }
+ }, e.INT());
+
+ Dataset<String> cogrouped = grouped.cogroup(
+ grouped2,
+ new Function3<Integer, Iterator<String>, Iterator<Integer>, Iterator<String>>() {
+ @Override
+ public Iterator<String> call(
+ Integer key,
+ Iterator<String> left,
+ Iterator<Integer> right) throws Exception {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (left.hasNext()) {
+ sb.append(left.next());
+ }
+ sb.append("#");
+ while (right.hasNext()) {
+ sb.append(right.next());
+ }
+ return Collections.singletonList(sb.toString()).iterator();
+ }
+ },
+ e.STRING());
+
+ Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList());
+ }
+
+ @Test
+ public void testGroupByColumn() {
+ List<String> data = Arrays.asList("a", "foo", "bar");
+ Dataset<String> ds = context.createDataset(data, e.STRING());
+ GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT());
+
+ Dataset<String> mapped = grouped.mapGroups(
+ new Function2<Integer, Iterator<String>, Iterator<String>>() {
+ @Override
+ public Iterator<String> call(Integer key, Iterator<String> data) throws Exception {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (data.hasNext()) {
+ sb.append(data.next());
+ }
+ return Collections.singletonList(sb.toString()).iterator();
+ }
+ },
+ e.STRING());
+
+ Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
+ }
+
+ @Test
+ public void testSelect() {
+ List<Integer> data = Arrays.asList(2, 6);
+ Dataset<Integer> ds = context.createDataset(data, e.INT());
+
+ Dataset<Tuple2<Integer, String>> selected = ds.select(
+ expr("value + 1").as(e.INT()),
+ col("value").cast("string").as(e.STRING()));
+
+ Assert.assertEquals(
+ Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),
+ selected.collectAsList());
+ }
+
+ @Test
+ public void testSetOperation() {
+ List<String> data = Arrays.asList("abc", "abc", "xyz");
+ Dataset<String> ds = context.createDataset(data, e.STRING());
+
+ Assert.assertEquals(
+ Arrays.asList("abc", "xyz"),
+ sort(ds.distinct().collectAsList().toArray(new String[0])));
+
+ List<String> data2 = Arrays.asList("xyz", "foo", "foo");
+ Dataset<String> ds2 = context.createDataset(data2, e.STRING());
+
+ Dataset<String> intersected = ds.intersect(ds2);
+ Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
+
+ Dataset<String> unioned = ds.union(ds2);
+ Assert.assertEquals(
+ Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"),
+ sort(unioned.collectAsList().toArray(new String[0])));
+
+ Dataset<String> subtracted = ds.subtract(ds2);
+ Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList());
+ }
+
+ private <T extends Comparable<T>> List<T> sort(T[] data) {
+ Arrays.sort(data);
+ return Arrays.asList(data);
+ }
+
+ @Test
+ public void testJoin() {
+ List<Integer> data = Arrays.asList(1, 2, 3);
+ Dataset<Integer> ds = context.createDataset(data, e.INT()).as("a");
+ List<Integer> data2 = Arrays.asList(2, 3, 4);
+ Dataset<Integer> ds2 = context.createDataset(data2, e.INT()).as("b");
+
+ Dataset<Tuple2<Integer, Integer>> joined =
+ ds.joinWith(ds2, col("a.value").equalTo(col("b.value")));
+ Assert.assertEquals(
+ Arrays.asList(tuple2(2, 2), tuple2(3, 3)),
+ joined.collectAsList());
+ }
+
+ @Test
+ public void testTupleEncoder() {
+ Encoder<Tuple2<Integer, String>> encoder2 = e.tuple(e.INT(), e.STRING());
+ List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b"));
+ Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2);
+ Assert.assertEquals(data2, ds2.collectAsList());
+
+ Encoder<Tuple3<Integer, Long, String>> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING());
+ List<Tuple3<Integer, Long, String>> data3 =
+ Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a"));
+ Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3);
+ Assert.assertEquals(data3, ds3.collectAsList());
+
+ Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
+ e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING());
+ List<Tuple4<Integer, String, Long, String>> data4 =
+ Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a"));
+ Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4);
+ Assert.assertEquals(data4, ds4.collectAsList());
+
+ Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 =
+ e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN());
+ List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
+ Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true));
+ Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
+ context.createDataset(data5, encoder5);
+ Assert.assertEquals(data5, ds5.collectAsList());
+ }
+
+ @Test
+ public void testNestedTupleEncoder() {
+ // test ((int, string), string)
+ Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder =
+ e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING());
+ List<Tuple2<Tuple2<Integer, String>, String>> data =
+ Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b"));
+ Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder);
+ Assert.assertEquals(data, ds.collectAsList());
+
+ // test (int, (string, string, long))
+ Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 =
+ e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG()));
+ List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
+ Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L)));
+ Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
+ context.createDataset(data2, encoder2);
+ Assert.assertEquals(data2, ds2.collectAsList());
+
+ // test (int, ((string, long), string))
+ Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 =
+ e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING()));
+ List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 =
+ Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b")));
+ Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 =
+ context.createDataset(data3, encoder3);
+ Assert.assertEquals(data3, ds3.collectAsList());
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 32443557fb..e3b0346f85 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -59,7 +59,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("foreach") {
val ds = Seq(1, 2, 3).toDS()
val acc = sparkContext.accumulator(0)
- ds.foreach(acc +=)
+ ds.foreach(acc += _)
assert(acc.value == 6)
}