aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-05-19 10:57:47 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-19 10:57:47 -0700
commit7b16e9f2118fbfbb1c0ba957161fe500c9aff82a (patch)
tree671d93cd8addfd5188b150747647fffc15ac2651 /mllib
parentfb90273212dc7241c9a0c3446e25e0e0b9377750 (diff)
downloadspark-7b16e9f2118fbfbb1c0ba957161fe500c9aff82a.tar.gz
spark-7b16e9f2118fbfbb1c0ba957161fe500c9aff82a.tar.bz2
spark-7b16e9f2118fbfbb1c0ba957161fe500c9aff82a.zip
[SPARK-7678] [ML] Fix default random seed in HasSeed
Changed shared param HasSeed to have default based on hashCode of class name, instead of random number. Also, removed fixed random seeds from Word2Vec and ALS. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #6251 from jkbradley/scala-fixed-seed and squashes the following commits: 0e37184 [Joseph K. Bradley] Fixed Word2VecSuite, ALSSuite in spark.ml to use original fixed random seeds 678ec3a [Joseph K. Bradley] Removed fixed random seeds from Word2Vec and ALS. Changed shared param HasSeed to have default based on hashCode of class name, instead of random number.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala16
6 files changed, 14 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 8ace8c53bb..90f0be76df 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -68,7 +68,6 @@ private[feature] trait Word2VecBase extends Params
setDefault(stepSize -> 0.025)
setDefault(maxIter -> 1)
- setDefault(seed -> 42L)
/**
* Validate and transform the input schema.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 5085b798da..8b8cb81373 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -53,7 +53,7 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
isValid = "ParamValidators.gtEq(1)"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
- ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
+ ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
" For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
isValid = "ParamValidators.inRange(0, 1)"),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 7525d37007..3a4976d3dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -232,7 +232,7 @@ private[ml] trait HasFitIntercept extends Params {
}
/**
- * (private[ml]) Trait for shared param seed (default: Utils.random.nextLong()).
+ * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong).
*/
private[ml] trait HasSeed extends Params {
@@ -242,7 +242,7 @@ private[ml] trait HasSeed extends Params {
*/
final val seed: LongParam = new LongParam(this, "seed", "random seed")
- setDefault(seed, Utils.random.nextLong())
+ setDefault(seed, this.getClass.getName.hashCode.toLong)
/** @group getParam */
final def getSeed: Long = $(seed)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 45c57b50da..2a5ddbfae5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -148,7 +148,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
- ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, seed -> 0L)
+ ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
/**
* Validates and transforms the input schema.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 03ba86670d..43a09cc418 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -52,6 +52,7 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
.setVectorSize(3)
.setInputCol("text")
.setOutputCol("result")
+ .setSeed(42L)
.fit(docDF)
model.transform(docDF).select("result", "expected").collect().foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index fc7349330c..6cc6ec94eb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -345,6 +345,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
.setImplicitPrefs(implicitPrefs)
.setNumUserBlocks(numUserBlocks)
.setNumItemBlocks(numItemBlocks)
+ .setSeed(0)
val alpha = als.getAlpha
val model = als.fit(training.toDF())
val predictions = model.transform(test.toDF())
@@ -425,17 +426,18 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating))
- val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4)
+ val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4, seed = 0)
assert(longUserFactors.first()._1.getClass === classOf[Long])
val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating))
- val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4)
+ val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4, seed = 0)
assert(strUserFactors.first()._1.getClass === classOf[String])
}
test("nonnegative constraint") {
val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
- val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true)
+ val (userFactors, itemFactors) =
+ ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true, seed = 0)
def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = {
factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _)
}
@@ -459,7 +461,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
test("partitioner in returned factors") {
val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
val (userFactors, itemFactors) = ALS.train(
- ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4)
+ ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4, seed = 0)
for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) {
assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.")
val part = userFactors.partitioner.get
@@ -476,8 +478,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
test("als with large number of iterations") {
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
- ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2)
- ALS.train(
- ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true)
+ ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, seed = 0)
+ ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2,
+ implicitPrefs = true, seed = 0)
}
}