From b5347a4664625ede6ab9d8ef6558457a34ae423f Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 30 Apr 2015 21:56:03 -0700 Subject: [SPARK-7248] implemented random number generators for DataFrames Adds the functions `rand` (Uniform Dist) and `randn` (Normal Dist.) as expressions to DataFrames. cc mengxr rxin Author: Burak Yavuz Closes #5819 from brkyvz/df-rng and squashes the following commits: 50d69d4 [Burak Yavuz] add seed for test that failed 4234c3a [Burak Yavuz] fix Rand expression 13cad5c [Burak Yavuz] couple fixes 7d53953 [Burak Yavuz] waiting for hive tests b453716 [Burak Yavuz] move radn with seed down 03637f0 [Burak Yavuz] fix broken hive func c5909eb [Burak Yavuz] deleted old implementation of Rand 6d43895 [Burak Yavuz] implemented random generators --- python/pyspark/sql/functions.py | 25 +++++++++- python/pyspark/sql/tests.py | 10 ++++ .../spark/sql/catalyst/expressions/Rand.scala | 36 -------------- .../spark/sql/catalyst/expressions/random.scala | 56 ++++++++++++++++++++++ .../catalyst/optimizer/ConstantFoldingSuite.scala | 4 +- .../scala/org/apache/spark/sql/functions.scala | 30 +++++++++++- .../scala/org/apache/spark/sql/mathfunctions.scala | 2 - .../org/apache/spark/sql/JavaDataFrameSuite.java | 3 ++ .../apache/spark/sql/ColumnExpressionSuite.scala | 22 +++++++++ .../scala/org/apache/spark/sql/hive/HiveQl.scala | 7 ++- 10 files changed, 149 insertions(+), 46 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 555c2fa5e7..241f821757 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -67,7 +67,6 @@ _functions = { 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', } - for _name, _doc in _functions.items(): globals()[_name] = _create_function(_name, _doc) del _name, _doc @@ -75,6 +74,30 @@ __all__ += _functions.keys() __all__.sort() +def rand(seed=None): + """ + Generate a random column with i.i.d. samples from U[0.0, 1.0]. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.rand(seed) + else: + jc = sc._jvm.functions.rand() + return Column(jc) + + +def randn(seed=None): + """ + Generate a column with i.i.d. samples from the standard normal distribution. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.randn(seed) + else: + jc = sc._jvm.functions.randn() + return Column(jc) + + def approxCountDistinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2ffd18ebd7..5640bb5ea2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -416,6 +416,16 @@ class SQLTests(ReusedPySparkTestCase): assert_close([math.hypot(i, 2 * i) for i in range(10)], df.select(functions.hypot(df.a, df.b)).collect()) + def test_rand_functions(self): + df = self.df + from pyspark.sql import functions + rnd = df.select('key', functions.rand()).collect() + for row in rnd: + assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1] + rndn = df.select('key', functions.randn(5)).collect() + for row in rndn: + assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala deleted file mode 100644 index f5fea3f015..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import java.util.Random - -import org.apache.spark.sql.types.{DataType, DoubleType} - - -case object Rand extends LeafExpression { - override def dataType: DataType = DoubleType - override def nullable: Boolean = false - - private[this] lazy val rand = new Random - - override def eval(input: Row = null): EvaluatedType = { - rand.nextDouble().asInstanceOf[EvaluatedType] - } - - override def toString: String = "RAND()" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala new file mode 100644 index 0000000000..66d7c8b07c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.types.{DataType, DoubleType} +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom + +/** + * A Random distribution generating expression. + * TODO: This can be made generic to generate any type of random distribution, or any type of + * StructType. + * + * Since this expression is stateful, it cannot be a case object. + */ +abstract class RDG(seed: Long) extends LeafExpression with Serializable { + self: Product => + + /** + * Record ID within each partition. By being transient, the Random Number Generator is + * reset every time we serialize and deserialize it. + */ + @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) + + override type EvaluatedType = Double + + override def nullable: Boolean = false + + override def dataType: DataType = DoubleType +} + +/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ +case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) { + override def eval(input: Row): Double = rng.nextDouble() +} + +/** Generate a random column with i.i.d. gaussian random distribution. */ +case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) { + override def eval(input: Row): Double = rng.nextGaussian() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 14b28e8402..18f92150b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -160,7 +160,7 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select( - Rand + Literal(1) as Symbol("c1"), + Rand(5L) + Literal(1) as Symbol("c1"), Sum('a) as Symbol("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -168,7 +168,7 @@ class ConstantFoldingSuite extends PlanTest { val correctAnswer = testRelation .select( - Rand + Literal(1.0) as Symbol("c1"), + Rand(5L) + Literal(1.0) as Symbol("c1"), Sum('a) as Symbol("c2")) .analyze diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index aa31d04a0c..242e64d3ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ - +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -346,6 +346,34 @@ object functions { */ def not(e: Column): Column = !e + /** + * Generate a random column with i.i.d. samples from U[0.0, 1.0]. + * + * @group normal_funcs + */ + def rand(seed: Long): Column = Rand(seed) + + /** + * Generate a random column with i.i.d. samples from U[0.0, 1.0]. + * + * @group normal_funcs + */ + def rand(): Column = rand(Utils.random.nextLong) + + /** + * Generate a column with i.i.d. samples from the standard normal distribution. + * + * @group normal_funcs + */ + def randn(seed: Long): Column = Randn(seed) + + /** + * Generate a column with i.i.d. samples from the standard normal distribution. + * + * @group normal_funcs + */ + def randn(): Column = randn(Utils.random.nextLong) + /** * Partition ID of the Spark task. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala index d901542b7e..db47480c38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.functions.lit /** * :: Experimental :: * Mathematical Functions available for [[DataFrame]]. - * - * @groupname double_funcs Functions that require DoubleType as an input */ @Experimental // scalastyle:off diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 966d879e1f..ebe96e649d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -104,6 +104,9 @@ public class JavaDataFrameSuite { df2.select(pow("a", "a"), pow("b", 2.0)); df2.select(pow(col("a"), col("b")), exp("b")); df2.select(sin("a"), acos("b")); + + df2.select(rand(), acos("b")); + df2.select(col("*"), randn(5L)); } @Ignore diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 2ba5fc21ff..6322faf4d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.scalatest.Matchers._ + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -349,4 +351,24 @@ class ColumnExpressionSuite extends QueryTest { assert(schema("value").metadata === Metadata.empty) assert(schema("abc").metadata === metadata) } + + test("rand") { + val randCol = testData.select('key, rand(5L).as("rand")) + randCol.columns.length should be (2) + val rows = randCol.collect() + rows.foreach { row => + assert(row.getDouble(1) <= 1.0) + assert(row.getDouble(1) >= 0.0) + } + } + + test("randn") { + val randCol = testData.select('key, randn(5L).as("rand")) + randCol.columns.length should be (2) + val rows = randCol.collect() + rows.foreach { row => + assert(row.getDouble(1) <= 4.0) + assert(row.getDouble(1) >= -4.0) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0a86519e14..63a8c05f77 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.hive import java.sql.Date - -import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} - import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse._ @@ -1244,7 +1242,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Other functions */ case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) => CreateArray(children.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand + case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand() + case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType)) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => -- cgit v1.2.3