aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2015-11-25 01:02:36 -0800
committerReynold Xin <rxin@databricks.com>2015-11-25 01:02:36 -0800
commit2610e06124c7fc0b2b1cfb2e3050a35ab492fb71 (patch)
treec4bf32e700515557a1443964af48b0d2d9c0c25b /sql/core
parent2169886883d33b33acf378ac42a626576b342df1 (diff)
downloadspark-2610e06124c7fc0b2b1cfb2e3050a35ab492fb71.tar.gz
spark-2610e06124c7fc0b2b1cfb2e3050a35ab492fb71.tar.bz2
spark-2610e06124c7fc0b2b1cfb2e3050a35ab492fb71.zip
[SPARK-11970][SQL] Adding JoinType into JoinWith and support Sample in Dataset API
Except inner join, maybe the other join types are also useful when users are using the joinWith function. Thus, added the joinType into the existing joinWith call in Dataset APIs. Also providing another joinWith interface for the cartesian-join-like functionality. Please provide your opinions. marmbrus rxin cloud-fan Thank you! Author: gatorsmile <gatorsmile@gmail.com> Closes #9921 from gatorsmile/joinWith.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala45
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala36
2 files changed, 65 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index dd84b8bc11..97eb5b9692 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -20,16 +20,16 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
-import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.function._
-
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
/**
* :: Experimental ::
@@ -83,7 +83,6 @@ class Dataset[T] private[sql](
/**
* Returns the schema of the encoded form of the objects in this [[Dataset]].
- *
* @since 1.6.0
*/
def schema: StructType = resolvedTEncoder.schema
@@ -185,7 +184,6 @@ class Dataset[T] private[sql](
* .transform(featurize)
* .transform(...)
* }}}
- *
* @since 1.6.0
*/
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
@@ -453,6 +451,21 @@ class Dataset[T] private[sql](
c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
+ /**
+ * Returns a new [[Dataset]] by sampling a fraction of records.
+ * @since 1.6.0
+ */
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] =
+ withPlan(Sample(0.0, fraction, withReplacement, seed, _))
+
+ /**
+ * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed.
+ * @since 1.6.0
+ */
+ def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = {
+ sample(withReplacement, fraction, Utils.random.nextLong)
+ }
+
/* **************** *
* Set operations *
* **************** */
@@ -511,13 +524,17 @@ class Dataset[T] private[sql](
* types as well as working with relational data where either side of the join has column
* names in common.
*
+ * @param other Right side of the join.
+ * @param condition Join expression.
+ * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`.
* @since 1.6.0
*/
- def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+ def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
val left = this.logicalPlan
val right = other.logicalPlan
- val joined = sqlContext.executePlan(Join(left, right, Inner, Some(condition.expr)))
+ val joined = sqlContext.executePlan(Join(left, right, joinType =
+ JoinType(joinType), Some(condition.expr)))
val leftOutput = joined.analyzed.output.take(left.output.length)
val rightOutput = joined.analyzed.output.takeRight(right.output.length)
@@ -540,6 +557,18 @@ class Dataset[T] private[sql](
}
}
+ /**
+ * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair
+ * where `condition` evaluates to true.
+ *
+ * @param other Right side of the join.
+ * @param condition Join expression.
+ * @since 1.6.0
+ */
+ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+ joinWith(other, condition, "inner")
+ }
+
/* ************************** *
* Gather to Driver Actions *
* ************************** */
@@ -584,7 +613,6 @@ class Dataset[T] private[sql](
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
- *
* @since 1.6.0
*/
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
@@ -594,7 +622,6 @@ class Dataset[T] private[sql](
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
- *
* @since 1.6.0
*/
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c253fdbb8c..7d539180de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -185,17 +185,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds2 = Seq(1, 2).toDS().as("b")
checkAnswer(
- ds1.joinWith(ds2, $"a.value" === $"b.value"),
+ ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"),
(1, 1), (2, 2))
}
- test("joinWith, expression condition") {
- val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
- val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+ test("joinWith, expression condition, outer join") {
+ val nullInteger = null.asInstanceOf[Integer]
+ val nullString = null.asInstanceOf[String]
+ val ds1 = Seq(ClassNullableData("a", 1),
+ ClassNullableData("c", 3)).toDS()
+ val ds2 = Seq(("a", new Integer(1)),
+ ("b", new Integer(2))).toDS()
checkAnswer(
- ds1.joinWith(ds2, $"_1" === $"a"),
- (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
+ ds1.joinWith(ds2, $"_1" === $"a", "outer"),
+ (ClassNullableData("a", 1), ("a", new Integer(1))),
+ (ClassNullableData("c", 3), (nullString, nullInteger)),
+ (ClassNullableData(nullString, nullInteger), ("b", new Integer(2))))
}
test("joinWith tuple with primitive, expression") {
@@ -225,7 +231,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"),
((("a", 1), ("a", 1)), ("a", 1)),
((("b", 2), ("b", 2)), ("b", 2)))
-
}
test("groupBy function, keys") {
@@ -367,6 +372,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1 -> "a", 2 -> "bc", 3 -> "d")
}
+ test("sample with replacement") {
+ val n = 100
+ val data = sparkContext.parallelize(1 to n, 2).toDS()
+ checkAnswer(
+ data.sample(withReplacement = true, 0.05, seed = 13),
+ 5, 10, 52, 73)
+ }
+
+ test("sample without replacement") {
+ val n = 100
+ val data = sparkContext.parallelize(1 to n, 2).toDS()
+ checkAnswer(
+ data.sample(withReplacement = false, 0.05, seed = 13),
+ 3, 17, 27, 58, 62)
+ }
+
test("SPARK-11436: we should rebind right encoder when join 2 datasets") {
val ds1 = Seq("1", "2").toDS().as("a")
val ds2 = Seq(2, 3).toDS().as("b")
@@ -440,6 +461,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
case class ClassData(a: String, b: Int)
+case class ClassNullableData(a: String, b: Integer)
/**
* A class used to test serialization using encoders. This class throws exceptions when using