diff options
author | Sean Owen <sowen@cloudera.com> | 2016-09-04 12:40:51 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-09-04 12:40:51 +0100 |
commit | cdeb97a8cd26e3282cc2a4f126242ed2199f3898 (patch) | |
tree | 22bb93ee40ae08cb0f1928c7c2fdd535739ecd23 /mllib/src/main/scala | |
parent | e75c162e9e510d74b07f28ccf6c7948ac317a7c6 (diff) | |
download | spark-cdeb97a8cd26e3282cc2a4f126242ed2199f3898.tar.gz spark-cdeb97a8cd26e3282cc2a4f126242ed2199f3898.tar.bz2 spark-cdeb97a8cd26e3282cc2a4f126242ed2199f3898.zip |
[SPARK-17311][MLLIB] Standardize Python-Java MLlib API to accept optional long seeds in all cases
## What changes were proposed in this pull request?
Related to https://github.com/apache/spark/pull/14524 -- just the 'fix' rather than a behavior change.
- PythonMLlibAPI methods that take a seed now always take a `java.lang.Long` consistently, allowing the Python API to specify "no seed"
- .mllib's Word2VecModel seemed to be an odd man out in .mllib in that it picked its own random seed. Instead it defaults to None, meaning, letting the Scala implementation pick a seed
- BisectingKMeansModel arguably should not hard-code a seed for consistency with .mllib, I think. However I left it.
## How was this patch tested?
Existing tests
Author: Sean Owen <sowen@cloudera.com>
Closes #14826 from srowen/SPARK-16832.2.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index a80cca70f4..2ed6c6be1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -126,13 +126,13 @@ private[python] class PythonMLLibAPI extends Serializable { k: Int, maxIterations: Int, minDivisibleClusterSize: Double, - seed: Long): BisectingKMeansModel = { - new BisectingKMeans() + seed: java.lang.Long): BisectingKMeansModel = { + val kmeans = new BisectingKMeans() .setK(k) .setMaxIterations(maxIterations) .setMinDivisibleClusterSize(minDivisibleClusterSize) - .setSeed(seed) - .run(data) + if (seed != null) kmeans.setSeed(seed) + kmeans.run(data) } /** @@ -678,7 +678,7 @@ private[python] class PythonMLLibAPI extends Serializable { learningRate: Double, numPartitions: Int, numIterations: Int, - seed: Long, + seed: java.lang.Long, minCount: Int, windowSize: Int): Word2VecModelWrapper = { val word2vec = new Word2Vec() @@ -686,9 +686,9 @@ private[python] class PythonMLLibAPI extends Serializable { .setLearningRate(learningRate) .setNumPartitions(numPartitions) .setNumIterations(numIterations) - .setSeed(seed) .setMinCount(minCount) .setWindowSize(windowSize) + if (seed != null) word2vec.setSeed(seed) try { val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) new Word2VecModelWrapper(model) @@ -751,7 +751,7 @@ private[python] class PythonMLLibAPI extends Serializable { impurityStr: String, maxDepth: Int, maxBins: Int, - seed: Int): RandomForestModel = { + seed: java.lang.Long): RandomForestModel = { val algo = Algo.fromString(algoStr) val impurity = Impurities.fromString(impurityStr) @@ -763,11 +763,13 @@ private[python] class PythonMLLibAPI extends Serializable { maxBins = maxBins, categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap) val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + // Only done because methods below want an int, not an optional Long + val intSeed = getSeedOrDefault(seed).toInt try { if (algo == Algo.Classification) { - RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed) + RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, intSeed) } else { - RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed) + RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, intSeed) } } finally { cached.unpersist(blocking = false) |