diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2016-05-27 21:57:41 -0500 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-05-27 21:57:41 -0500 |
commit | 9893dc975784551a62f65bbd709f8972e0204b2a (patch) | |
tree | 1706b1efab6def3bbcd4d90acc2c0acdfe75eccc /mllib | |
parent | 88c9c467a31630c558719679ca0894873a268b27 (diff) | |
download | spark-9893dc975784551a62f65bbd709f8972e0204b2a.tar.gz spark-9893dc975784551a62f65bbd709f8972e0204b2a.tar.bz2 spark-9893dc975784551a62f65bbd709f8972e0204b2a.zip |
[SPARK-15610][ML] update error message for k in pca
## What changes were proposed in this pull request?
Fix the wrong bound of `k` in `PCA`
`require(k <= sources.first().size, ...` -> `require(k < sources.first().size`
BTW, remove unused import in `ml.ElementwiseProduct`
## How was this patch tested?
manual tests
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Closes #13356 from zhengruifeng/fix_pca.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala | 1 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala | 6 |
2 files changed, 3 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 91989c3d2f..9d2e60fa3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -23,7 +23,6 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.sql.types.DataType diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 30c403e547..15b72205ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -40,8 +40,9 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { */ @Since("1.4.0") def fit(sources: RDD[Vector]): PCAModel = { - require(k <= sources.first().size, - s"source vector size is ${sources.first().size} must be greater than k=$k") + val numFeatures = sources.first().size + require(k <= numFeatures, + s"source vector size $numFeatures must be no less than k=$k") val mat = new RowMatrix(sources) val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) @@ -58,7 +59,6 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { case m => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}") - } val denseExplainedVariance = explainedVariance match { case dv: DenseVector => |