aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-04-30 21:56:03 -0700
committerReynold Xin <rxin@databricks.com>2015-04-30 21:56:03 -0700
commitb5347a4664625ede6ab9d8ef6558457a34ae423f (patch)
tree6d068b5ede427fa4e4658d5d3ebcd9798c64950f
parent69a739c7f5fd002432ece203957e1458deb2f4c3 (diff)
downloadspark-b5347a4664625ede6ab9d8ef6558457a34ae423f.tar.gz
spark-b5347a4664625ede6ab9d8ef6558457a34ae423f.tar.bz2
spark-b5347a4664625ede6ab9d8ef6558457a34ae423f.zip
[SPARK-7248] implemented random number generators for DataFrames
Adds the functions `rand` (Uniform Dist) and `randn` (Normal Dist.) as expressions to DataFrames. cc mengxr rxin Author: Burak Yavuz <brkyvz@gmail.com> Closes #5819 from brkyvz/df-rng and squashes the following commits: 50d69d4 [Burak Yavuz] add seed for test that failed 4234c3a [Burak Yavuz] fix Rand expression 13cad5c [Burak Yavuz] couple fixes 7d53953 [Burak Yavuz] waiting for hive tests b453716 [Burak Yavuz] move radn with seed down 03637f0 [Burak Yavuz] fix broken hive func c5909eb [Burak Yavuz] deleted old implementation of Rand 6d43895 [Burak Yavuz] implemented random generators
-rw-r--r--python/pyspark/sql/functions.py25
-rw-r--r--python/pyspark/sql/tests.py10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala56
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala2
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala22
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala7
10 files changed, 149 insertions, 46 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 555c2fa5e7..241f821757 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -67,7 +67,6 @@ _functions = {
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}
-
for _name, _doc in _functions.items():
globals()[_name] = _create_function(_name, _doc)
del _name, _doc
@@ -75,6 +74,30 @@ __all__ += _functions.keys()
__all__.sort()
+def rand(seed=None):
+ """
+ Generate a random column with i.i.d. samples from U[0.0, 1.0].
+ """
+ sc = SparkContext._active_spark_context
+ if seed:
+ jc = sc._jvm.functions.rand(seed)
+ else:
+ jc = sc._jvm.functions.rand()
+ return Column(jc)
+
+
+def randn(seed=None):
+ """
+ Generate a column with i.i.d. samples from the standard normal distribution.
+ """
+ sc = SparkContext._active_spark_context
+ if seed:
+ jc = sc._jvm.functions.randn(seed)
+ else:
+ jc = sc._jvm.functions.randn()
+ return Column(jc)
+
+
def approxCountDistinct(col, rsd=None):
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2ffd18ebd7..5640bb5ea2 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -416,6 +416,16 @@ class SQLTests(ReusedPySparkTestCase):
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
+ def test_rand_functions(self):
+ df = self.df
+ from pyspark.sql import functions
+ rnd = df.select('key', functions.rand()).collect()
+ for row in rnd:
+ assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
+ rndn = df.select('key', functions.randn(5)).collect()
+ for row in rndn:
+ assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
deleted file mode 100644
index f5fea3f015..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import java.util.Random
-
-import org.apache.spark.sql.types.{DataType, DoubleType}
-
-
-case object Rand extends LeafExpression {
- override def dataType: DataType = DoubleType
- override def nullable: Boolean = false
-
- private[this] lazy val rand = new Random
-
- override def eval(input: Row = null): EvaluatedType = {
- rand.nextDouble().asInstanceOf[EvaluatedType]
- }
-
- override def toString: String = "RAND()"
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
new file mode 100644
index 0000000000..66d7c8b07c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.types.{DataType, DoubleType}
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * A Random distribution generating expression.
+ * TODO: This can be made generic to generate any type of random distribution, or any type of
+ * StructType.
+ *
+ * Since this expression is stateful, it cannot be a case object.
+ */
+abstract class RDG(seed: Long) extends LeafExpression with Serializable {
+ self: Product =>
+
+ /**
+ * Record ID within each partition. By being transient, the Random Number Generator is
+ * reset every time we serialize and deserialize it.
+ */
+ @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId())
+
+ override type EvaluatedType = Double
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = DoubleType
+}
+
+/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
+case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
+ override def eval(input: Row): Double = rng.nextDouble()
+}
+
+/** Generate a random column with i.i.d. gaussian random distribution. */
+case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
+ override def eval(input: Row): Double = rng.nextGaussian()
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 14b28e8402..18f92150b0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -160,7 +160,7 @@ class ConstantFoldingSuite extends PlanTest {
val originalQuery =
testRelation
.select(
- Rand + Literal(1) as Symbol("c1"),
+ Rand(5L) + Literal(1) as Symbol("c1"),
Sum('a) as Symbol("c2"))
val optimized = Optimize.execute(originalQuery.analyze)
@@ -168,7 +168,7 @@ class ConstantFoldingSuite extends PlanTest {
val correctAnswer =
testRelation
.select(
- Rand + Literal(1.0) as Symbol("c1"),
+ Rand(5L) + Literal(1.0) as Symbol("c1"),
Sum('a) as Symbol("c2"))
.analyze
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index aa31d04a0c..242e64d3ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
-
+import org.apache.spark.util.Utils
/**
* :: Experimental ::
@@ -347,6 +347,34 @@ object functions {
def not(e: Column): Column = !e
/**
+ * Generate a random column with i.i.d. samples from U[0.0, 1.0].
+ *
+ * @group normal_funcs
+ */
+ def rand(seed: Long): Column = Rand(seed)
+
+ /**
+ * Generate a random column with i.i.d. samples from U[0.0, 1.0].
+ *
+ * @group normal_funcs
+ */
+ def rand(): Column = rand(Utils.random.nextLong)
+
+ /**
+ * Generate a column with i.i.d. samples from the standard normal distribution.
+ *
+ * @group normal_funcs
+ */
+ def randn(seed: Long): Column = Randn(seed)
+
+ /**
+ * Generate a column with i.i.d. samples from the standard normal distribution.
+ *
+ * @group normal_funcs
+ */
+ def randn(): Column = randn(Utils.random.nextLong)
+
+ /**
* Partition ID of the Spark task.
*
* Note that this is indeterministic because it depends on data partitioning and task scheduling.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
index d901542b7e..db47480c38 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala
@@ -27,8 +27,6 @@ import org.apache.spark.sql.functions.lit
/**
* :: Experimental ::
* Mathematical Functions available for [[DataFrame]].
- *
- * @groupname double_funcs Functions that require DoubleType as an input
*/
@Experimental
// scalastyle:off
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 966d879e1f..ebe96e649d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -104,6 +104,9 @@ public class JavaDataFrameSuite {
df2.select(pow("a", "a"), pow("b", 2.0));
df2.select(pow(col("a"), col("b")), exp("b"));
df2.select(sin("a"), acos("b"));
+
+ df2.select(rand(), acos("b"));
+ df2.select(col("*"), randn(5L));
}
@Ignore
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 2ba5fc21ff..6322faf4d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import org.scalatest.Matchers._
+
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -349,4 +351,24 @@ class ColumnExpressionSuite extends QueryTest {
assert(schema("value").metadata === Metadata.empty)
assert(schema("abc").metadata === metadata)
}
+
+ test("rand") {
+ val randCol = testData.select('key, rand(5L).as("rand"))
+ randCol.columns.length should be (2)
+ val rows = randCol.collect()
+ rows.foreach { row =>
+ assert(row.getDouble(1) <= 1.0)
+ assert(row.getDouble(1) >= 0.0)
+ }
+ }
+
+ test("randn") {
+ val randCol = testData.select('key, randn(5L).as("rand"))
+ randCol.columns.length should be (2)
+ val rows = randCol.collect()
+ rows.foreach { row =>
+ assert(row.getDouble(1) <= 4.0)
+ assert(row.getDouble(1) >= -4.0)
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 0a86519e14..63a8c05f77 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -19,13 +19,11 @@ package org.apache.spark.sql.hive
import java.sql.Date
-
-import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
-
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Context
+import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.parse._
@@ -1244,7 +1242,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
/* Other functions */
case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) =>
CreateArray(children.map(nodeToExpr))
- case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand
+ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand()
+ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong)
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) =>
Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType))
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>