aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-09-11 08:53:40 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-11 08:53:40 -0700
commit960d2d0ac6b5a22242a922f87f745f7d1f736181 (patch)
treeff3578f7c6fbfcafe89e41235a8c61ed2d7c6a29 /mllib/src/main
parentb01b26260625f0ba14e5f3010207666d62d93864 (diff)
downloadspark-960d2d0ac6b5a22242a922f87f745f7d1f736181.tar.gz
spark-960d2d0ac6b5a22242a922f87f745f7d1f736181.tar.bz2
spark-960d2d0ac6b5a22242a922f87f745f7d1f736181.zip
[SPARK-10537] [ML] document LIBSVM source options in public API doc and some minor improvements
We should document options in public API doc. Otherwise, it is hard to find out the options without looking at the code. I tried to make `DefaultSource` private and put the documentation to package doc. However, since then there exists no public class under `source.libsvm`, the Java package doc doesn't show up in the generated html file (http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4492654). So I put the doc to `DefaultSource` instead. There are several minor updates in this PR: 1. Do `vectorType == "sparse"` only once. 2. Update `hashCode` and `equals`. 3. Remove inherited doc. 4. Delete temp dir in `afterAll`. Lewuathe Author: Xiangrui Meng <meng@databricks.com> Closes #8699 from mengxr/SPARK-10537.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala71
1 files changed, 44 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index b12cb62a4e..1f627777fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -21,12 +21,12 @@ import com.google.common.base.Objects
import org.apache.spark.Logging
import org.apache.spark.annotation.Since
-import org.apache.spark.mllib.linalg.VectorUDT
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.{StructType, StructField, DoubleType}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext}
import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
@@ -35,7 +35,7 @@ import org.apache.spark.sql.sources._
* @param vectorType The type of vector. It can be 'sparse' or 'dense'
* @param sqlContext The Spark SQLContext
*/
-private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
+private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with Logging with Serializable {
@@ -47,27 +47,56 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec
override def buildScan(): RDD[Row] = {
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
-
+ val sparse = vectorType == "sparse"
baseRdd.map { pt =>
- val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse
+ val features = if (sparse) pt.features.toSparse else pt.features.toDense
Row(pt.label, features)
}
}
override def hashCode(): Int = {
- Objects.hashCode(path, schema)
+ Objects.hashCode(path, Double.box(numFeatures), vectorType)
}
override def equals(other: Any): Boolean = other match {
- case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema)
- case _ => false
+ case that: LibSVMRelation =>
+ path == that.path &&
+ numFeatures == that.numFeatures &&
+ vectorType == that.vectorType
+ case _ =>
+ false
}
-
}
/**
- * This is used for creating DataFrame from LibSVM format file.
- * The LibSVM file path must be specified to DefaultSource.
+ * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
+ * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
+ * `features` containing feature vectors stored as [[Vector]]s.
+ *
+ * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and
+ * optionally specify options, for example:
+ * {{{
+ * // Scala
+ * val df = sqlContext.read.format("libsvm")
+ * .option("numFeatures", "780")
+ * .load("data/mllib/sample_libsvm_data.txt")
+ *
+ * // Java
+ * DataFrame df = sqlContext.read.format("libsvm")
+ * .option("numFeatures, "780")
+ * .load("data/mllib/sample_libsvm_data.txt");
+ * }}}
+ *
+ * LIBSVM data source supports the following options:
+ * - "numFeatures": number of features.
+ * If unspecified or nonpositive, the number of features will be determined automatically at the
+ * cost of one additional pass.
+ * This is also useful when the dataset is already split into multiple files and you want to load
+ * them separately, because some features may not present in certain files, which leads to
+ * inconsistent feature dimensions.
+ * - "vectorType": feature vector type, "sparse" (default) or "dense".
+ *
+ * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
*/
@Since("1.6.0")
class DefaultSource extends RelationProvider with DataSourceRegister {
@@ -75,24 +104,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
- private def checkPath(parameters: Map[String, String]): String = {
- require(parameters.contains("path"), "'path' must be specified")
- parameters.get("path").get
- }
-
- /**
- * Returns a new base relation with the given parameters.
- * Note: the parameters' keywords are case insensitive and this insensitivity is enforced
- * by the Map that is passed to the function.
- */
+ @Since("1.6.0")
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
: BaseRelation = {
- val path = checkPath(parameters)
+ val path = parameters.getOrElse("path",
+ throw new IllegalArgumentException("'path' must be specified"))
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
- /**
- * featuresType can be selected "dense" or "sparse".
- * This parameter decides the type of returned feature vector.
- */
val vectorType = parameters.getOrElse("vectorType", "sparse")
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
}