aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala6
1 files changed, 2 insertions, 4 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
index a0bb5dabf4..0b5d31c0ff 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
@@ -118,17 +118,15 @@ object OneVsRestExample {
val inputData = sqlContext.read.format("libsvm").load(params.input)
// compute the train/test split: if testInput is not provided use part of input.
val data = params.testInput match {
- case Some(t) => {
+ case Some(t) =>
// compute the number of features in the training set.
val numFeatures = inputData.first().getAs[Vector](1).size
val testData = sqlContext.read.option("numFeatures", numFeatures.toString)
.format("libsvm").load(t)
Array[DataFrame](inputData, testData)
- }
- case None => {
+ case None =>
val f = params.fracTest
inputData.randomSplit(Array(1 - f, f), seed = 12345)
- }
}
val Array(train, test) = data.map(_.cache())