aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala73
1 files changed, 43 insertions, 30 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index c3b1d5cdd7..774170ff40 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -67,42 +67,14 @@ object MLUtils {
path: String,
numFeatures: Int,
minPartitions: Int): RDD[LabeledPoint] = {
- val parsed = sc.textFile(path, minPartitions)
- .map(_.trim)
- .filter(line => !(line.isEmpty || line.startsWith("#")))
- .map { line =>
- val items = line.split(' ')
- val label = items.head.toDouble
- val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
- val indexAndValue = item.split(':')
- val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
- val value = indexAndValue(1).toDouble
- (index, value)
- }.unzip
-
- // check if indices are one-based and in ascending order
- var previous = -1
- var i = 0
- val indicesLength = indices.length
- while (i < indicesLength) {
- val current = indices(i)
- require(current > previous, s"indices should be one-based and in ascending order;"
- + " found current=$current, previous=$previous; line=\"$line\"")
- previous = current
- i += 1
- }
-
- (label, indices.toArray, values.toArray)
- }
+ val parsed = parseLibSVMFile(sc, path, minPartitions)
// Determine number of features.
val d = if (numFeatures > 0) {
numFeatures
} else {
parsed.persist(StorageLevel.MEMORY_ONLY)
- parsed.map { case (label, indices, values) =>
- indices.lastOption.getOrElse(0)
- }.reduce(math.max) + 1
+ computeNumFeatures(parsed)
}
parsed.map { case (label, indices, values) =>
@@ -110,6 +82,47 @@ object MLUtils {
}
}
+ private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], Array[Double])]): Int = {
+ rdd.map { case (label, indices, values) =>
+ indices.lastOption.getOrElse(0)
+ }.reduce(math.max) + 1
+ }
+
+ private[spark] def parseLibSVMFile(
+ sc: SparkContext,
+ path: String,
+ minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = {
+ sc.textFile(path, minPartitions)
+ .map(_.trim)
+ .filter(line => !(line.isEmpty || line.startsWith("#")))
+ .map(parseLibSVMRecord)
+ }
+
+ private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = {
+ val items = line.split(' ')
+ val label = items.head.toDouble
+ val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
+ val indexAndValue = item.split(':')
+ val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
+ val value = indexAndValue(1).toDouble
+ (index, value)
+ }.unzip
+
+ // check if indices are one-based and in ascending order
+ var previous = -1
+ var i = 0
+ val indicesLength = indices.length
+ while (i < indicesLength) {
+ val current = indices(i)
+ require(current > previous, s"indices should be one-based and in ascending order;"
+ + " found current=$current, previous=$previous; line=\"$line\"")
+ previous = current
+ i += 1
+ }
+
+ (label, indices.toArray, values.toArray)
+ }
+
/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of
* partitions.