aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-11 10:52:23 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-11 10:52:23 -0800
commitec2b807212e568c9e98cd80746bcb61e02c7a98e (patch)
treed1909c9c970eb6bd55ae38b0c34bcccfe6a1593a /sql
parent9c57bc0efce0ac37d8319666f5a8d3e8dce7651c (diff)
downloadspark-ec2b807212e568c9e98cd80746bcb61e02c7a98e.tar.gz
spark-ec2b807212e568c9e98cd80746bcb61e02c7a98e.tar.bz2
spark-ec2b807212e568c9e98cd80746bcb61e02c7a98e.zip
[SPARK-11564][SQL][FOLLOW-UP] clean up java tuple encoder
We need to support custom classes like java beans and combine them into tuple, and it's very hard to do it with the TypeTag-based approach. We should keep only the compose-based way to create tuple encoder. This PR also move `Encoder` to `org.apache.spark.sql` Author: Wenchen Fan <wenchen@databricks.com> Closes #9567 from cloud-fan/java.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala (renamed from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala)65
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java78
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala1
14 files changed, 65 insertions, 113 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 6569b900fe..1ff7340557 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.encoders
+package org.apache.spark.sql
-import scala.reflect.ClassTag
-
-import org.apache.spark.util.Utils
-import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+import scala.reflect.ClassTag
/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
@@ -38,9 +39,7 @@ trait Encoder[T] extends Serializable {
def clsTag: ClassTag[T]
}
-object Encoder {
- import scala.reflect.runtime.universe._
-
+object Encoders {
def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
@@ -129,54 +128,4 @@ object Encoder {
constructExpression,
ClassTag.apply(cls))
}
-
- def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)]
-
- private def getTypeTag[T](c: Class[T]): TypeTag[T] = {
- import scala.reflect.api
-
- // val mirror = runtimeMirror(c.getClassLoader)
- val mirror = rootMirror
- val sym = mirror.staticClass(c.getName)
- val tpe = sym.selfType
- TypeTag(mirror, new api.TypeCreator {
- def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) =
- if (m eq mirror) tpe.asInstanceOf[U # Type]
- else throw new IllegalArgumentException(
- s"Type tag defined in $mirror cannot be migrated to other mirrors.")
- })
- }
-
- def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = {
- implicit val typeTag1 = getTypeTag(c1)
- implicit val typeTag2 = getTypeTag(c2)
- ExpressionEncoder[(T1, T2)]()
- }
-
- def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = {
- implicit val typeTag1 = getTypeTag(c1)
- implicit val typeTag2 = getTypeTag(c2)
- implicit val typeTag3 = getTypeTag(c3)
- ExpressionEncoder[(T1, T2, T3)]()
- }
-
- def forTuple[T1, T2, T3, T4](
- c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = {
- implicit val typeTag1 = getTypeTag(c1)
- implicit val typeTag2 = getTypeTag(c2)
- implicit val typeTag3 = getTypeTag(c3)
- implicit val typeTag4 = getTypeTag(c4)
- ExpressionEncoder[(T1, T2, T3, T4)]()
- }
-
- def forTuple[T1, T2, T3, T4, T5](
- c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5])
- : Encoder[(T1, T2, T3, T4, T5)] = {
- implicit val typeTag1 = getTypeTag(c1)
- implicit val typeTag2 = getTypeTag(c2)
- implicit val typeTag3 = getTypeTag(c3)
- implicit val typeTag4 = getTypeTag(c4)
- implicit val typeTag5 = getTypeTag(c5)
- ExpressionEncoder[(T1, T2, T3, T4, T5)]()
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 005c0627f5..294afde534 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -17,18 +17,18 @@
package org.apache.spark.sql.catalyst.encoders
-import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
-import org.apache.spark.util.Utils
-
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}
+import org.apache.spark.util.Utils
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType}
+import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
/**
* A factory for constructing encoders that convert objects and primitves to and from the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index d4642a5006..2c35adca9c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.catalyst
+import org.apache.spark.sql.Encoder
+
package object encoders {
private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
case e: ExpressionEncoder[A] => e
case _ => sys.error(s"Only expression encoders are supported today")
}
}
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 764f8aaebd..597f03e752 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index d26b6c3579..f0f275e91f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 691b476fff..a492099b93 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -23,7 +23,6 @@ import java.util.Properties
import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
-import scala.util.control.NonFatal
import com.fasterxml.jackson.core.JsonFactory
import org.apache.commons.lang3.StringUtils
@@ -35,7 +34,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.encoders.Encoder
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index db61499229..61e2a95450 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 1cf1e30f96..cd1fdc4edb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index b5a87c56e6..dfcbac8687 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.aggregate
import scala.language.existentials
import org.apache.spark.Logging
+import org.apache.spark.sql.Encoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 2aa5a7d540..360c9a5bc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder}
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index a59d738010..ab49ed4b5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -26,7 +26,7 @@ import scala.util.Try
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
-import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 2da63d1b96..33d8388f61 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -30,8 +30,8 @@ import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.catalyst.encoders.Encoder;
-import org.apache.spark.sql.catalyst.encoders.Encoder$;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.GroupedDataset;
import org.apache.spark.sql.test.TestSQLContext;
@@ -41,7 +41,6 @@ import static org.apache.spark.sql.functions.*;
public class JavaDatasetSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient TestSQLContext context;
- private transient Encoder$ e = Encoder$.MODULE$;
@Before
public void setUp() {
@@ -66,7 +65,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testCollect() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, e.STRING());
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
List<String> collected = ds.collectAsList();
Assert.assertEquals(Arrays.asList("hello", "world"), collected);
}
@@ -74,7 +73,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testTake() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, e.STRING());
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
List<String> collected = ds.takeAsList(1);
Assert.assertEquals(Arrays.asList("hello"), collected);
}
@@ -82,7 +81,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
- Dataset<String> ds = context.createDataset(data, e.STRING());
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
Assert.assertEquals("hello", ds.first());
Dataset<String> filtered = ds.filter(new FilterFunction<String>() {
@@ -99,7 +98,7 @@ public class JavaDatasetSuite implements Serializable {
public Integer call(String v) throws Exception {
return v.length();
}
- }, e.INT());
+ }, Encoders.INT());
Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList());
Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() {
@@ -111,7 +110,7 @@ public class JavaDatasetSuite implements Serializable {
}
return ls;
}
- }, e.STRING());
+ }, Encoders.STRING());
Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList());
Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() {
@@ -123,7 +122,7 @@ public class JavaDatasetSuite implements Serializable {
}
return ls;
}
- }, e.STRING());
+ }, Encoders.STRING());
Assert.assertEquals(
Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"),
flatMapped.collectAsList());
@@ -133,7 +132,7 @@ public class JavaDatasetSuite implements Serializable {
public void testForeach() {
final Accumulator<Integer> accum = jsc.accumulator(0);
List<String> data = Arrays.asList("a", "b", "c");
- Dataset<String> ds = context.createDataset(data, e.STRING());
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
ds.foreach(new ForeachFunction<String>() {
@Override
@@ -147,7 +146,7 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testReduce() {
List<Integer> data = Arrays.asList(1, 2, 3);
- Dataset<Integer> ds = context.createDataset(data, e.INT());
+ Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
int reduced = ds.reduce(new ReduceFunction<Integer>() {
@Override
@@ -161,13 +160,13 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
- Dataset<String> ds = context.createDataset(data, e.STRING());
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() {
@Override
public Integer call(String v) throws Exception {
return v.length();
}
- }, e.INT());
+ }, Encoders.INT());
Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() {
@Override
@@ -178,7 +177,7 @@ public class JavaDatasetSuite implements Serializable {
}
return sb.toString();
}
- }, e.STRING());
+ }, Encoders.STRING());
Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
@@ -193,27 +192,27 @@ public class JavaDatasetSuite implements Serializable {
return Collections.singletonList(sb.toString());
}
},
- e.STRING());
+ Encoders.STRING());
Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList());
List<Integer> data2 = Arrays.asList(2, 6, 10);
- Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
+ Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
@Override
public Integer call(Integer v) throws Exception {
return v / 2;
}
- }, e.INT());
+ }, Encoders.INT());
Dataset<String> cogrouped = grouped.cogroup(
grouped2,
new CoGroupFunction<Integer, String, Integer, String>() {
@Override
public Iterable<String> call(
- Integer key,
- Iterator<String> left,
- Iterator<Integer> right) throws Exception {
+ Integer key,
+ Iterator<String> left,
+ Iterator<Integer> right) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
while (left.hasNext()) {
sb.append(left.next());
@@ -225,7 +224,7 @@ public class JavaDatasetSuite implements Serializable {
return Collections.singletonList(sb.toString());
}
},
- e.STRING());
+ Encoders.STRING());
Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList());
}
@@ -233,8 +232,9 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testGroupByColumn() {
List<String> data = Arrays.asList("a", "foo", "bar");
- Dataset<String> ds = context.createDataset(data, e.STRING());
- GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT());
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ GroupedDataset<Integer, String> grouped =
+ ds.groupBy(length(col("value"))).asKey(Encoders.INT());
Dataset<String> mapped = grouped.map(
new MapGroupFunction<Integer, String, String>() {
@@ -247,7 +247,7 @@ public class JavaDatasetSuite implements Serializable {
return sb.toString();
}
},
- e.STRING());
+ Encoders.STRING());
Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
}
@@ -255,11 +255,11 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testSelect() {
List<Integer> data = Arrays.asList(2, 6);
- Dataset<Integer> ds = context.createDataset(data, e.INT());
+ Dataset<Integer> ds = context.createDataset(data, Encoders.INT());
Dataset<Tuple2<Integer, String>> selected = ds.select(
expr("value + 1"),
- col("value").cast("string")).as(e.tuple(e.INT(), e.STRING()));
+ col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING()));
Assert.assertEquals(
Arrays.asList(tuple2(3, "2"), tuple2(7, "6")),
@@ -269,14 +269,14 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testSetOperation() {
List<String> data = Arrays.asList("abc", "abc", "xyz");
- Dataset<String> ds = context.createDataset(data, e.STRING());
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
Assert.assertEquals(
Arrays.asList("abc", "xyz"),
sort(ds.distinct().collectAsList().toArray(new String[0])));
List<String> data2 = Arrays.asList("xyz", "foo", "foo");
- Dataset<String> ds2 = context.createDataset(data2, e.STRING());
+ Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING());
Dataset<String> intersected = ds.intersect(ds2);
Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList());
@@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testJoin() {
List<Integer> data = Arrays.asList(1, 2, 3);
- Dataset<Integer> ds = context.createDataset(data, e.INT()).as("a");
+ Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a");
List<Integer> data2 = Arrays.asList(2, 3, 4);
- Dataset<Integer> ds2 = context.createDataset(data2, e.INT()).as("b");
+ Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b");
Dataset<Tuple2<Integer, Integer>> joined =
ds.joinWith(ds2, col("a.value").equalTo(col("b.value")));
@@ -311,26 +311,28 @@ public class JavaDatasetSuite implements Serializable {
@Test
public void testTupleEncoder() {
- Encoder<Tuple2<Integer, String>> encoder2 = e.tuple(e.INT(), e.STRING());
+ Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING());
List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b"));
Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2);
Assert.assertEquals(data2, ds2.collectAsList());
- Encoder<Tuple3<Integer, Long, String>> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING());
+ Encoder<Tuple3<Integer, Long, String>> encoder3 =
+ Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING());
List<Tuple3<Integer, Long, String>> data3 =
Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a"));
Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3);
Assert.assertEquals(data3, ds3.collectAsList());
Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
- e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING());
+ Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING());
List<Tuple4<Integer, String, Long, String>> data4 =
Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a"));
Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4);
Assert.assertEquals(data4, ds4.collectAsList());
Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 =
- e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN());
+ Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(),
+ Encoders.BOOLEAN());
List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true));
Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
@@ -342,7 +344,7 @@ public class JavaDatasetSuite implements Serializable {
public void testNestedTupleEncoder() {
// test ((int, string), string)
Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder =
- e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING());
+ Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING());
List<Tuple2<Tuple2<Integer, String>, String>> data =
Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b"));
Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder);
@@ -350,7 +352,8 @@ public class JavaDatasetSuite implements Serializable {
// test (int, (string, string, long))
Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 =
- e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG()));
+ Encoders.tuple(Encoders.INT(),
+ Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG()));
List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L)));
Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
@@ -359,7 +362,8 @@ public class JavaDatasetSuite implements Serializable {
// test (int, ((string, long), string))
Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 =
- e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING()));
+ Encoders.tuple(Encoders.INT(),
+ Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING()));
List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 =
Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b")));
Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index d4f0ab76cf..378cd36527 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -17,13 +17,11 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.encoders.Encoder
-import org.apache.spark.sql.functions._
import scala.language.postfixOps
import org.apache.spark.sql.test.SharedSQLContext
-
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Aggregator
/** An `Aggregator` that adds up any numeric type returned by the given function. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 3c174efe73..7a8b7ae5bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -24,7 +24,6 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.catalyst.encoders.Encoder
abstract class QueryTest extends PlanTest {