aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala71
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java (renamed from mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java)24
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala)14
3 files changed, 66 insertions, 43 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index b12cb62a4e..1f627777fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -21,12 +21,12 @@ import com.google.common.base.Objects
import org.apache.spark.Logging
import org.apache.spark.annotation.Since
-import org.apache.spark.mllib.linalg.VectorUDT
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.{StructType, StructField, DoubleType}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext}
import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
/**
* LibSVMRelation provides the DataFrame constructed from LibSVM format data.
@@ -35,7 +35,7 @@ import org.apache.spark.sql.sources._
* @param vectorType The type of vector. It can be 'sparse' or 'dense'
* @param sqlContext The Spark SQLContext
*/
-private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
+private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String)
(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with Logging with Serializable {
@@ -47,27 +47,56 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec
override def buildScan(): RDD[Row] = {
val sc = sqlContext.sparkContext
val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
-
+ val sparse = vectorType == "sparse"
baseRdd.map { pt =>
- val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse
+ val features = if (sparse) pt.features.toSparse else pt.features.toDense
Row(pt.label, features)
}
}
override def hashCode(): Int = {
- Objects.hashCode(path, schema)
+ Objects.hashCode(path, Double.box(numFeatures), vectorType)
}
override def equals(other: Any): Boolean = other match {
- case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema)
- case _ => false
+ case that: LibSVMRelation =>
+ path == that.path &&
+ numFeatures == that.numFeatures &&
+ vectorType == that.vectorType
+ case _ =>
+ false
}
-
}
/**
- * This is used for creating DataFrame from LibSVM format file.
- * The LibSVM file path must be specified to DefaultSource.
+ * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]].
+ * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and
+ * `features` containing feature vectors stored as [[Vector]]s.
+ *
+ * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and
+ * optionally specify options, for example:
+ * {{{
+ * // Scala
+ * val df = sqlContext.read.format("libsvm")
+ * .option("numFeatures", "780")
+ * .load("data/mllib/sample_libsvm_data.txt")
+ *
+ * // Java
+ * DataFrame df = sqlContext.read.format("libsvm")
+ * .option("numFeatures, "780")
+ * .load("data/mllib/sample_libsvm_data.txt");
+ * }}}
+ *
+ * LIBSVM data source supports the following options:
+ * - "numFeatures": number of features.
+ * If unspecified or nonpositive, the number of features will be determined automatically at the
+ * cost of one additional pass.
+ * This is also useful when the dataset is already split into multiple files and you want to load
+ * them separately, because some features may not present in certain files, which leads to
+ * inconsistent feature dimensions.
+ * - "vectorType": feature vector type, "sparse" (default) or "dense".
+ *
+ * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]]
*/
@Since("1.6.0")
class DefaultSource extends RelationProvider with DataSourceRegister {
@@ -75,24 +104,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
- private def checkPath(parameters: Map[String, String]): String = {
- require(parameters.contains("path"), "'path' must be specified")
- parameters.get("path").get
- }
-
- /**
- * Returns a new base relation with the given parameters.
- * Note: the parameters' keywords are case insensitive and this insensitivity is enforced
- * by the Map that is passed to the function.
- */
+ @Since("1.6.0")
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
: BaseRelation = {
- val path = checkPath(parameters)
+ val path = parameters.getOrElse("path",
+ throw new IllegalArgumentException("'path' must be specified"))
val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt
- /**
- * featuresType can be selected "dense" or "sparse".
- * This parameter decides the type of returned feature vector.
- */
val vectorType = parameters.getOrElse("vectorType", "sparse")
new LibSVMRelation(path, numFeatures, vectorType)(sqlContext)
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 11fa4eec0c..2976b38e45 100644
--- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.ml.source;
+package org.apache.spark.ml.source.libsvm;
import java.io.File;
import java.io.IOException;
@@ -42,34 +42,34 @@ import org.apache.spark.util.Utils;
*/
public class JavaLibSVMRelationSuite {
private transient JavaSparkContext jsc;
- private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient SQLContext sqlContext;
- private File tmpDir;
- private File path;
+ private File tempDir;
+ private String path;
@Before
public void setUp() throws IOException {
jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
- jsql = new SQLContext(jsc);
-
- tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
- path = new File(tmpDir.getPath(), "part-00000");
+ sqlContext = new SQLContext(jsc);
+ tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource");
+ File file = new File(tempDir, "part-00000");
String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
- Files.write(s, path, Charsets.US_ASCII);
+ Files.write(s, file, Charsets.US_ASCII);
+ path = tempDir.toURI().toString();
}
@After
public void tearDown() {
jsc.stop();
jsc = null;
- Utils.deleteRecursively(tmpDir);
+ Utils.deleteRecursively(tempDir);
}
@Test
public void verifyLibSVMDF() {
- dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath());
+ DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
+ .load(path);
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
Row r = dataset.first();
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 8ed134128c..997f574e51 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.ml.source
+package org.apache.spark.ml.source.libsvm
import java.io.File
@@ -23,11 +23,12 @@ import com.google.common.base.Charsets
import com.google.common.io.Files
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils
class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
+ var tempDir: File = _
var path: String = _
override def beforeAll(): Unit = {
@@ -38,12 +39,17 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
- val tempDir = Utils.createTempDir()
- val file = new File(tempDir.getPath, "part-00000")
+ tempDir = Utils.createTempDir()
+ val file = new File(tempDir, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
path = tempDir.toURI.toString
}
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ super.afterAll()
+ }
+
test("select as sparse vector") {
val df = sqlContext.read.format("libsvm").load(path)
assert(df.columns(0) == "label")