diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala | 73 |
1 files changed, 43 insertions, 30 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index c3b1d5cdd7..774170ff40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -67,42 +67,14 @@ object MLUtils { path: String, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = { - val parsed = sc.textFile(path, minPartitions) - .map(_.trim) - .filter(line => !(line.isEmpty || line.startsWith("#"))) - .map { line => - val items = line.split(' ') - val label = items.head.toDouble - val (indices, values) = items.tail.filter(_.nonEmpty).map { item => - val indexAndValue = item.split(':') - val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. - val value = indexAndValue(1).toDouble - (index, value) - }.unzip - - // check if indices are one-based and in ascending order - var previous = -1 - var i = 0 - val indicesLength = indices.length - while (i < indicesLength) { - val current = indices(i) - require(current > previous, s"indices should be one-based and in ascending order;" - + " found current=$current, previous=$previous; line=\"$line\"") - previous = current - i += 1 - } - - (label, indices.toArray, values.toArray) - } + val parsed = parseLibSVMFile(sc, path, minPartitions) // Determine number of features. val d = if (numFeatures > 0) { numFeatures } else { parsed.persist(StorageLevel.MEMORY_ONLY) - parsed.map { case (label, indices, values) => - indices.lastOption.getOrElse(0) - }.reduce(math.max) + 1 + computeNumFeatures(parsed) } parsed.map { case (label, indices, values) => @@ -110,6 +82,47 @@ object MLUtils { } } + private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], Array[Double])]): Int = { + rdd.map { case (label, indices, values) => + indices.lastOption.getOrElse(0) + }.reduce(math.max) + 1 + } + + private[spark] def parseLibSVMFile( + sc: SparkContext, + path: String, + minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = { + sc.textFile(path, minPartitions) + .map(_.trim) + .filter(line => !(line.isEmpty || line.startsWith("#"))) + .map(parseLibSVMRecord) + } + + private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = { + val items = line.split(' ') + val label = items.head.toDouble + val (indices, values) = items.tail.filter(_.nonEmpty).map { item => + val indexAndValue = item.split(':') + val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. + val value = indexAndValue(1).toDouble + (index, value) + }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, s"indices should be one-based and in ascending order;" + + " found current=$current, previous=$previous; line=\"$line\"") + previous = current + i += 1 + } + + (label, indices.toArray, values.toArray) + } + /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. |