aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-09-11 15:00:13 -0700
committerAndrew Or <andrew@databricks.com>2015-09-11 15:00:13 -0700
commite626ac5f5c27dcc74113070f2fec03682bcd12bd (patch)
treef7bfa842b61b12334ecb676e3242d67c6a9c7ec8 /sql
parent1eede3b254ee3793841c92971707094ac8afee35 (diff)
downloadspark-e626ac5f5c27dcc74113070f2fec03682bcd12bd.tar.gz
spark-e626ac5f5c27dcc74113070f2fec03682bcd12bd.tar.bz2
spark-e626ac5f5c27dcc74113070f2fec03682bcd12bd.zip
[SPARK-9992] [SPARK-9994] [SPARK-9998] [SQL] Implement the local TopK, sample and intersect operators
This PR is in conflict with #8535. I will update this one when #8535 gets merged. Author: zsxwing <zsxwing@gmail.com> Closes #8573 from zsxwing/more-local-operators.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala63
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala82
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala73
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala35
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala54
8 files changed, 353 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 3f68b05a24..bf6d44c098 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
- * @param child the QueryPlan
+ * @param child the SparkPlan
*/
@DeveloperApi
case class Sample(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala
new file mode 100644
index 0000000000..740d485f8d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala
@@ -0,0 +1,63 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode)
+ extends BinaryLocalNode(conf) {
+
+ override def output: Seq[Attribute] = left.output
+
+ private[this] var leftRows: mutable.HashSet[InternalRow] = _
+
+ private[this] var currentRow: InternalRow = _
+
+ override def open(): Unit = {
+ left.open()
+ leftRows = mutable.HashSet[InternalRow]()
+ while (left.next()) {
+ leftRows += left.fetch().copy()
+ }
+ left.close()
+ right.open()
+ }
+
+ override def next(): Boolean = {
+ currentRow = null
+ while (currentRow == null && right.next()) {
+ currentRow = right.fetch()
+ if (!leftRows.contains(currentRow)) {
+ currentRow = null
+ }
+ }
+ currentRow != null
+ }
+
+ override def fetch(): InternalRow = currentRow
+
+ override def close(): Unit = {
+ left.close()
+ right.close()
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index c4f8ae304d..a2c275db9b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -70,6 +70,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
def close(): Unit
/**
+ * Returns the content through the [[Iterator]] interface.
+ */
+ final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this)
+
+ /**
* Returns the content of the iterator from the beginning to the end in the form of a Scala Seq.
*/
def collect(): Seq[Row] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
new file mode 100644
index 0000000000..abf3df1c0c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.local
+
+import java.util.Random
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
+
+/**
+ * Sample the dataset.
+ *
+ * @param conf the SQLConf
+ * @param lowerBound Lower-bound of the sampling probability (usually 0.0)
+ * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
+ * will be ub - lb.
+ * @param withReplacement Whether to sample with replacement.
+ * @param seed the random seed
+ * @param child the LocalNode
+ */
+case class SampleNode(
+ conf: SQLConf,
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long,
+ child: LocalNode) extends UnaryLocalNode(conf) {
+
+ override def output: Seq[Attribute] = child.output
+
+ private[this] var iterator: Iterator[InternalRow] = _
+
+ private[this] var currentRow: InternalRow = _
+
+ override def open(): Unit = {
+ child.open()
+ val (sampler, _seed) = if (withReplacement) {
+ val random = new Random(seed)
+ // Disable gap sampling since the gap sampling method buffers two rows internally,
+ // requiring us to copy the row, which is more expensive than the random number generator.
+ (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
+ // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
+ // of DataFrame
+ random.nextLong())
+ } else {
+ (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
+ }
+ sampler.setSeed(_seed)
+ iterator = sampler.sample(child.asIterator)
+ }
+
+ override def next(): Boolean = {
+ if (iterator.hasNext) {
+ currentRow = iterator.next()
+ true
+ } else {
+ false
+ }
+ }
+
+ override def fetch(): InternalRow = currentRow
+
+ override def close(): Unit = child.close()
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
new file mode 100644
index 0000000000..53f1dcc65d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.util.BoundedPriorityQueue
+
+case class TakeOrderedAndProjectNode(
+ conf: SQLConf,
+ limit: Int,
+ sortOrder: Seq[SortOrder],
+ projectList: Option[Seq[NamedExpression]],
+ child: LocalNode) extends UnaryLocalNode(conf) {
+
+ private[this] var projection: Option[Projection] = _
+ private[this] var ord: InterpretedOrdering = _
+ private[this] var iterator: Iterator[InternalRow] = _
+ private[this] var currentRow: InternalRow = _
+
+ override def output: Seq[Attribute] = {
+ val projectOutput = projectList.map(_.map(_.toAttribute))
+ projectOutput.getOrElse(child.output)
+ }
+
+ override def open(): Unit = {
+ child.open()
+ projection = projectList.map(new InterpretedProjection(_, child.output))
+ ord = new InterpretedOrdering(sortOrder, child.output)
+ // Priority keeps the largest elements, so let's reverse the ordering.
+ val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse)
+ while (child.next()) {
+ queue += child.fetch()
+ }
+ // Close it eagerly since we don't need it.
+ child.close()
+ iterator = queue.iterator
+ }
+
+ override def next(): Boolean = {
+ if (iterator.hasNext) {
+ val _currentRow = iterator.next()
+ currentRow = projection match {
+ case Some(p) => p(_currentRow)
+ case None => _currentRow
+ }
+ true
+ } else {
+ false
+ }
+ }
+
+ override def fetch(): InternalRow = currentRow
+
+ override def close(): Unit = child.close()
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
new file mode 100644
index 0000000000..7deaa375fc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
@@ -0,0 +1,35 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+class IntersectNodeSuite extends LocalNodeTest {
+
+ import testImplicits._
+
+ test("basic") {
+ val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
+ val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value")
+
+ checkAnswer2(
+ input1,
+ input2,
+ (node1, node2) => IntersectNode(conf, node1, node2),
+ input1.intersect(input2).collect()
+ )
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
new file mode 100644
index 0000000000..87a7da4539
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.local
+
+class SampleNodeSuite extends LocalNodeTest {
+
+ import testImplicits._
+
+ private def testSample(withReplacement: Boolean): Unit = {
+ test(s"withReplacement: $withReplacement") {
+ val seed = 0L
+ val input = sqlContext.sparkContext.
+ parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
+ toDF("key", "value")
+ checkAnswer(
+ input,
+ node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
+ input.sample(withReplacement, 0.3, seed).collect()
+ )
+ }
+ }
+
+ testSample(withReplacement = true)
+ testSample(withReplacement = false)
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
new file mode 100644
index 0000000000..ff28b24eef
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder}
+
+class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
+
+ import testImplicits._
+
+ private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
+ val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
+ col.expr match {
+ case expr: SortOrder =>
+ expr
+ case expr: Expression =>
+ SortOrder(expr, Ascending)
+ }
+ }
+ sortOrder
+ }
+
+ private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
+ val testCaseName = if (desc) "desc" else "asc"
+ test(testCaseName) {
+ val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
+ val sortColumn = if (desc) input.col("key").desc else input.col("key")
+ checkAnswer(
+ input,
+ node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node),
+ input.sort(sortColumn).limit(5).collect()
+ )
+ }
+ }
+
+ testTakeOrderedAndProjectNode(desc = false)
+ testTakeOrderedAndProjectNode(desc = true)
+}