aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPillis <pillis.work@gmail.com>2014-01-10 00:07:36 -0800
committerPillis <pillis.work@gmail.com>2014-01-10 00:07:36 -0800
commit8d021b42bc53a81172d98b556a340f7c2c4de0f3 (patch)
tree0f3e2d121a5a8e604f34ecec9b754f5ea38de2b7
parent181471906ed590347cbbe3422bd92e9b82f9e1bf (diff)
downloadspark-8d021b42bc53a81172d98b556a340f7c2c4de0f3.tar.gz
spark-8d021b42bc53a81172d98b556a340f7c2c4de0f3.tar.bz2
spark-8d021b42bc53a81172d98b556a340f7c2c4de0f3.zip
SPARK-961. Add a Vector.random() method - update 1
-rw-r--r--core/src/main/scala/org/apache/spark/util/Vector.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/util/VectorSuite.scala16
2 files changed, 12 insertions, 10 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala
index f9c6cdf2be..62fd6d8da5 100644
--- a/core/src/main/scala/org/apache/spark/util/Vector.scala
+++ b/core/src/main/scala/org/apache/spark/util/Vector.scala
@@ -126,7 +126,11 @@ object Vector {
def ones(length: Int) = Vector(length, _ => 1)
- def random(length: Int, random: Random = new Random()) = Vector(length, _ => random.nextDouble());
+ /**
+ * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers
+ * between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided.
+ */
+ def random(length: Int, random: Random = new XORShiftRandom()) = Vector(length, _ => random.nextDouble())
class Multiplier(num: Double) {
def * (vec: Vector) = vec * num
diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala
index 23d1bdb193..7006571ef0 100644
--- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala
@@ -27,20 +27,18 @@ import org.scalatest.FunSuite
class VectorSuite extends FunSuite {
def verifyVector(vector: Vector, expectedLength: Int) = {
- assert(vector.length == expectedLength); // Array must be of expected length
- assert(vector.length == vector.elements.distinct.length); // Values should not repeat
- assert(vector.sum > 0); // All values must not be 0
- assert(vector.sum < vector.length); // All values must not be 1
- assert(vector.elements.product > 0); // No value is 0
+ assert(vector.length == expectedLength)
+ assert(vector.elements.min > 0.0)
+ assert(vector.elements.max < 1.0)
}
test("random with default random number generator") {
- val vector100 = Vector.random(100);
- verifyVector(vector100, 100);
+ val vector100 = Vector.random(100)
+ verifyVector(vector100, 100)
}
test("random with given random number generator") {
- val vector100 = Vector.random(100, new Random(100));
- verifyVector(vector100, 100);
+ val vector100 = Vector.random(100, new Random(100))
+ verifyVector(vector100, 100)
}
}