aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-04-08 20:37:01 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-08 20:37:01 -0700
commitb9e0c937dfa1ca93b63d0b39d5f156b16c2fdc0a (patch)
tree1a170ca9021d09676150061380d607eb4939fdbe /mllib/src/test
parentce8ec5456169682f27f846e7b8d51e6c4bcf75e3 (diff)
downloadspark-b9e0c937dfa1ca93b63d0b39d5f156b16c2fdc0a.tar.gz
spark-b9e0c937dfa1ca93b63d0b39d5f156b16c2fdc0a.tar.bz2
spark-b9e0c937dfa1ca93b63d0b39d5f156b16c2fdc0a.zip
[SPARK-1434] [MLLIB] change labelParser from anonymous function to trait
This is a patch to address @mateiz 's comment in https://github.com/apache/spark/pull/245 MLUtils#loadLibSVMData uses an anonymous function for the label parser. Java users won't like it. So I make a trait for LabelParser and provide two implementations: binary and multiclass. Author: Xiangrui Meng <meng@databricks.com> Closes #345 from mengxr/label-parser and squashes the following commits: ac44409 [Xiangrui Meng] use singleton objects for label parsers 3b1a7c6 [Xiangrui Meng] add tests for label parsers c2e571c [Xiangrui Meng] rename LabelParser.apply to LabelParser.parse use extends for singleton 11c94e0 [Xiangrui Meng] add return types 7f8eb36 [Xiangrui Meng] change labelParser from annoymous function to trait
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala41
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala4
2 files changed, 43 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala
new file mode 100644
index 0000000000..ac85677f2f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.scalatest.FunSuite
+
+class LabelParsersSuite extends FunSuite {
+ test("binary label parser") {
+ for (parser <- Seq(BinaryLabelParser, BinaryLabelParser.getInstance())) {
+ assert(parser.parse("+1") === 1.0)
+ assert(parser.parse("1") === 1.0)
+ assert(parser.parse("0") === 0.0)
+ assert(parser.parse("-1") === 0.0)
+ }
+ }
+
+ test("multiclass label parser") {
+ for (parser <- Seq(MulticlassLabelParser, MulticlassLabelParser.getInstance())) {
+ assert(parser.parse("0") == 0.0)
+ assert(parser.parse("+1") === 1.0)
+ assert(parser.parse("1") === 1.0)
+ assert(parser.parse("2") === 2.0)
+ assert(parser.parse("3") === 3.0)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 27d41c7869..e451c350b8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -80,7 +80,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString
- val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect()
+ val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser, 6).collect()
val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()
for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
@@ -93,7 +93,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
}
- val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect()
+ val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser).collect()
assert(multiclassPoints.length === 3)
assert(multiclassPoints(0).label === 1.0)
assert(multiclassPoints(1).label === -1.0)