aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-05-27 21:57:41 -0500
committerSean Owen <sowen@cloudera.com>2016-05-27 21:57:41 -0500
commit9893dc975784551a62f65bbd709f8972e0204b2a (patch)
tree1706b1efab6def3bbcd4d90acc2c0acdfe75eccc /mllib/src
parent88c9c467a31630c558719679ca0894873a268b27 (diff)
downloadspark-9893dc975784551a62f65bbd709f8972e0204b2a.tar.gz
spark-9893dc975784551a62f65bbd709f8972e0204b2a.tar.bz2
spark-9893dc975784551a62f65bbd709f8972e0204b2a.zip
[SPARK-15610][ML] update error message for k in pca
## What changes were proposed in this pull request? Fix the wrong bound of `k` in `PCA` `require(k <= sources.first().size, ...` -> `require(k < sources.first().size` BTW, remove unused import in `ml.ElementwiseProduct` ## How was this patch tested? manual tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #13356 from zhengruifeng/fix_pca.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala6
2 files changed, 3 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 91989c3d2f..9d2e60fa3f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -23,7 +23,6 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.sql.types.DataType
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
index 30c403e547..15b72205ac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
@@ -40,8 +40,9 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
*/
@Since("1.4.0")
def fit(sources: RDD[Vector]): PCAModel = {
- require(k <= sources.first().size,
- s"source vector size is ${sources.first().size} must be greater than k=$k")
+ val numFeatures = sources.first().size
+ require(k <= numFeatures,
+ s"source vector size $numFeatures must be no less than k=$k")
val mat = new RowMatrix(sources)
val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
@@ -58,7 +59,6 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
case m =>
throw new IllegalArgumentException("Unsupported matrix format. Expected " +
s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}")
-
}
val denseExplainedVariance = explainedVariance match {
case dv: DenseVector =>