aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
authorGio Borje <gborje@linkedin.com>2016-05-25 16:52:31 -0500
committerSean Owen <sowen@cloudera.com>2016-05-25 16:52:31 -0500
commit589cce93c821ac28e9090a478f6e7465398b7c30 (patch)
tree2bab935ee7a6d078a3d639a72d8b9c99f82a3a76 /mllib/src/main/scala
parent9c297df3d4d5fa4bbfdffdaad15f362586db384b (diff)
downloadspark-589cce93c821ac28e9090a478f6e7465398b7c30.tar.gz
spark-589cce93c821ac28e9090a478f6e7465398b7c30.tar.bz2
spark-589cce93c821ac28e9090a478f6e7465398b7c30.zip
Log warnings for numIterations * miniBatchFraction < 1.0
## What changes were proposed in this pull request? Add a warning log for the case that `numIterations * miniBatchFraction <1.0` during gradient descent. If the product of those two numbers is less than `1.0`, then not all training examples will be used during optimization. To put this concretely, suppose that `numExamples = 100`, `miniBatchFraction = 0.2` and `numIterations = 3`. Then, 3 iterations will occur each sampling approximately 6 examples each. In the best case, each of the 6 examples are unique; hence 18/100 examples are used. This may be counter-intuitive to most users and led to the issue during the development of another Spark ML model: https://github.com/zhengruifeng/spark-libFM/issues/11. If a user actually does not require the training data set, it would be easier and more intuitive to use `RDD.sample`. ## How was this patch tested? `build/mvn -DskipTests clean package` build succeeds Author: Gio Borje <gborje@linkedin.com> Closes #13265 from Hydrotoast/master.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala5
1 files changed, 5 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index a67ea836e5..735e780909 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -197,6 +197,11 @@ object GradientDescent extends Logging {
"< 1.0 can be unstable because of the stochasticity in sampling.")
}
+ if (numIterations * miniBatchFraction < 1.0) {
+ logWarning("Not all examples will be used if numIterations * miniBatchFraction < 1.0: " +
+ s"numIterations=$numIterations and miniBatchFraction=$miniBatchFraction")
+ }
+
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
// Record previous weight and current one to calculate solution vector difference