aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-09-11 08:53:40 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-11 08:53:40 -0700
commit960d2d0ac6b5a22242a922f87f745f7d1f736181 (patch)
treeff3578f7c6fbfcafe89e41235a8c61ed2d7c6a29 /mllib
parentb01b26260625f0ba14e5f3010207666d62d93864 (diff)
downloadspark-960d2d0ac6b5a22242a922f87f745f7d1f736181.tar.gz
spark-960d2d0ac6b5a22242a922f87f745f7d1f736181.tar.bz2
spark-960d2d0ac6b5a22242a922f87f745f7d1f736181.zip
[SPARK-10537] [ML] document LIBSVM source options in public API doc and some minor improvements
We should document options in public API doc. Otherwise, it is hard to find out the options without looking at the code. I tried to make `DefaultSource` private and put the documentation to package doc. However, since then there exists no public class under `source.libsvm`, the Java package doc doesn't show up in the generated html file (http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4492654). So I put the doc to `DefaultSource` instead. There are several minor updates in this PR: 1. Do `vectorType == "sparse"` only once. 2. Update `hashCode` and `equals`. 3. Remove inherited doc. 4. Delete temp dir in `afterAll`. Lewuathe Author: Xiangrui Meng <meng@databricks.com> Closes #8699 from mengxr/SPARK-10537.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala71
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java (renamed from mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java)24
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala)14
3 files changed, 66 insertions, 43 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index b12cb62a4e..1f627777fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -21,12 +21,12 @@ import com.google.common.base.Objects
import org.apache.spark.Logging
import org.apache.spark.annotation.Since
-import org.apache.spark.mllib.linalg.VectorUDT
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.{StructType, StructField, DoubleType}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext}
import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
@@ -35,7 +35,7 @@ import org.apache.spark.sql.sources._
* @param vectorType The type of vector. It can be 'sparse' or 'dense'
* @param sqlContext The Spark SQLContext
*/
-private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
+private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with Logging with Serializable {
@@ -47,27 +47,56 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec
override def buildScan(): RDD[Row] = {
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
-
+ val sparse = vectorType == "sparse"
baseRdd.map { pt =>
- val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse
+ val features = if (sparse) pt.features.toSparse else pt.features.toDense
Row(pt.label, features)
}
}
override def hashCode(): Int = {
- Objects.hashCode(path, schema)
+ Objects.hashCode(path, Double.box(numFeatures), vectorType)
}
override def equals(other: Any): Boolean = other match {
- case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema)
- case _ => false
+ case that: LibSVMRelation =>
+ path == that.path &&
+ numFeatures == that.numFeatures &&
+ vectorType == that.vectorType
+ case _ =>
+ false
}
-
}
/**
- * This is used for creating DataFrame from LibSVM format file.
- * The LibSVM file path must be specified to DefaultSource.
+ * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
+ * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
+ * `features` containing feature vectors stored as [[Vector]]s.
+ *
+ * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and
+ * optionally specify options, for example:
+ * {{{
+ * // Scala
+ * val df = sqlContext.read.format("libsvm")
+ * .option("numFeatures", "780")
+ * .load("data/mllib/sample_libsvm_data.txt")
+ *
+ * // Java
+ * DataFrame df = sqlContext.read.format("libsvm")
+ * .option("numFeatures, "780")
+ * .load("data/mllib/sample_libsvm_data.txt");
+ * }}}
+ *
+ * LIBSVM data source supports the following options:
+ * - "numFeatures": number of features.
+ * If unspecified or nonpositive, the number of features will be determined automatically at the
+ * cost of one additional pass.
+ * This is also useful when the dataset is already split into multiple files and you want to load
+ * them separately, because some features may not present in certain files, which leads to
+ * inconsistent feature dimensions.
+ * - "vectorType": feature vector type, "sparse" (default) or "dense".
+ *
+ * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
*/
@Since("1.6.0")
class DefaultSource extends RelationProvider with DataSourceRegister {
@@ -75,24 +104,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
- private def checkPath(parameters: Map[String, String]): String = {
- require(parameters.contains("path"), "'path' must be specified")
- parameters.get("path").get
- }
-
- /**
- * Returns a new base relation with the given parameters.
- * Note: the parameters' keywords are case insensitive and this insensitivity is enforced
- * by the Map that is passed to the function.
- */
+ @Since("1.6.0")
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
: BaseRelation = {
- val path = checkPath(parameters)
+ val path = parameters.getOrElse("path",
+ throw new IllegalArgumentException("'path' must be specified"))
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
- /**
- * featuresType can be selected "dense" or "sparse".
- * This parameter decides the type of returned feature vector.
- */
val vectorType = parameters.getOrElse("vectorType", "sparse")
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 11fa4eec0c..2976b38e45 100644
--- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.ml.source;
+package org.apache.spark.ml.source.libsvm;
import java.io.File;
import java.io.IOException;
@@ -42,34 +42,34 @@ import org.apache.spark.util.Utils;
*/
public class JavaLibSVMRelationSuite {
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient SQLContext sqlContext;
- private File tmpDir;
- private File path;
+ private File tempDir;
+ private String path;
@Before
public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
- jsql = new SQLContext(jsc);
-
- tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
- path = new File(tmpDir.getPath(), "part-00000");
+ sqlContext = new SQLContext(jsc);
+ tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
+ File file = new File(tempDir, "part-00000");
String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
- Files.write(s, path, Charsets.US_ASCII);
+ Files.write(s, file, Charsets.US_ASCII);
+ path = tempDir.toURI().toString();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
- Utils.deleteRecursively(tmpDir);
+ Utils.deleteRecursively(tempDir);
}
@Test
public void verifyLibSVMDF() {
- dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath());
+ DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
+ .load(path);
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
Row r = dataset.first();
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 8ed134128c..997f574e51 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.ml.source
+package org.apache.spark.ml.source.libsvm
import java.io.File
@@ -23,11 +23,12 @@ import com.google.common.base.Charsets
import com.google.common.io.Files
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
+ var tempDir: File = _
var path: String = _
override def beforeAll(): Unit = {
@@ -38,12 +39,17 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
- val tempDir = Utils.createTempDir()
- val file = new File(tempDir.getPath, "part-00000")
+ tempDir = Utils.createTempDir()
+ val file = new File(tempDir, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
path = tempDir.toURI.toString
}
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ super.afterAll()
+ }
+
test("select as sparse vector") {
val df = sqlContext.read.format("libsvm").load(path)
assert(df.columns(0) == "label")