diff options
author | lewuathe <lewuathe@me.com> | 2015-09-09 09:29:10 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-09-09 09:29:10 -0700 |
commit | 2ddeb63126d26149eda197e85b7b26ef16a6e97c (patch) | |
tree | 6d208f63c719af95e7392b4dcf3cd21949301e0e /mllib/src | |
parent | c1bc4f439f54625c01a585691e5293cd9961eb0c (diff) | |
download | spark-2ddeb63126d26149eda197e85b7b26ef16a6e97c.tar.gz spark-2ddeb63126d26149eda197e85b7b26ef16a6e97c.tar.bz2 spark-2ddeb63126d26149eda197e85b7b26ef16a6e97c.zip |
[SPARK-10117] [MLLIB] Implement SQL data source API for reading LIBSVM data
It is convenient to implement data source API for LIBSVM format to have a better integration with DataFrames and ML pipeline API.
Two option is implemented.
* `numFeatures`: Specify the dimension of features vector
* `featuresType`: Specify the type of output vector. `sparse` is default.
Author: lewuathe <lewuathe@me.com>
Closes #8537 from Lewuathe/SPARK-10117 and squashes the following commits:
986999d [lewuathe] Change unit test phrase
11d513f [lewuathe] Fix some reviews
21600a4 [lewuathe] Merge branch 'master' into SPARK-10117
9ce63c7 [lewuathe] Rewrite service loader file
1fdd2df [lewuathe] Merge branch 'SPARK-10117' of github.com:Lewuathe/spark into SPARK-10117
ba3657c [lewuathe] Merge branch 'master' into SPARK-10117
0ea1c1c [lewuathe] LibSVMRelation is registered into META-INF
4f40891 [lewuathe] Improve test suites
5ab62ab [lewuathe] Merge branch 'master' into SPARK-10117
8660d0e [lewuathe] Fix Java unit test
b56a948 [lewuathe] Merge branch 'master' into SPARK-10117
2c12894 [lewuathe] Remove unnecessary tag
7d693c2 [lewuathe] Resolv conflict
62010af [lewuathe] Merge branch 'master' into SPARK-10117
a97ee97 [lewuathe] Fix some points
aef9564 [lewuathe] Fix
70ee4dd [lewuathe] Add Java test
3fd8dce [lewuathe] [SPARK-10117] Implement SQL data source API for reading LIBSVM data
40d3027 [lewuathe] Add Java test
7056d4a [lewuathe] Merge branch 'master' into SPARK-10117
99accaa [lewuathe] [SPARK-10117] Implement SQL data source API for reading LIBSVM data
Diffstat (limited to 'mllib/src')
4 files changed, 256 insertions, 0 deletions
diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000..f632dd603c --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.ml.source.libsvm.DefaultSource diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala new file mode 100644 index 0000000000..b12cb62a4e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import com.google.common.base.Objects + +import org.apache.spark.Logging +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{StructType, StructField, DoubleType} +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.sources._ + +/** + * LibSVMRelation provides the DataFrame constructed from LibSVM format data. + * @param path File path of LibSVM format + * @param numFeatures The number of features + * @param vectorType The type of vector. It can be 'sparse' or 'dense' + * @param sqlContext The Spark SQLContext + */ +private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) + (@transient val sqlContext: SQLContext) + extends BaseRelation with TableScan with Logging with Serializable { + + override def schema: StructType = StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false) :: Nil + ) + + override def buildScan(): RDD[Row] = { + val sc = sqlContext.sparkContext + val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) + + baseRdd.map { pt => + val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse + Row(pt.label, features) + } + } + + override def hashCode(): Int = { + Objects.hashCode(path, schema) + } + + override def equals(other: Any): Boolean = other match { + case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema) + case _ => false + } + +} + +/** + * This is used for creating DataFrame from LibSVM format file. + * The LibSVM file path must be specified to DefaultSource. + */ +@Since("1.6.0") +class DefaultSource extends RelationProvider with DataSourceRegister { + + @Since("1.6.0") + override def shortName(): String = "libsvm" + + private def checkPath(parameters: Map[String, String]): String = { + require(parameters.contains("path"), "'path' must be specified") + parameters.get("path").get + } + + /** + * Returns a new base relation with the given parameters. + * Note: the parameters' keywords are case insensitive and this insensitivity is enforced + * by the Map that is passed to the function. + */ + override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) + : BaseRelation = { + val path = checkPath(parameters) + val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt + /** + * featuresType can be selected "dense" or "sparse". + * This parameter decides the type of returned feature vector. + */ + val vectorType = parameters.getOrElse("vectorType", "sparse") + new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java new file mode 100644 index 0000000000..11fa4eec0c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source; + +import java.io.File; +import java.io.IOException; + +import com.google.common.base.Charsets; +import com.google.common.io.Files; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.DenseVector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; + + +/** + * Test LibSVMRelation in Java. + */ +public class JavaLibSVMRelationSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + + private File tmpDir; + private File path; + + @Before + public void setUp() throws IOException { + jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); + jsql = new SQLContext(jsc); + + tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); + path = new File(tmpDir.getPath(), "part-00000"); + + String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; + Files.write(s, path, Charsets.US_ASCII); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + Utils.deleteRecursively(tmpDir); + } + + @Test + public void verifyLibSVMDF() { + dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath()); + Assert.assertEquals("label", dataset.columns()[0]); + Assert.assertEquals("features", dataset.columns()[1]); + Row r = dataset.first(); + Assert.assertEquals(1.0, r.getDouble(0), 1e-15); + DenseVector v = r.getAs(1); + Assert.assertEquals(Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0), v); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala new file mode 100644 index 0000000000..8ed134128c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils + +class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var path: String = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val lines = + """ + |1 1:1.0 3:2.0 5:3.0 + |0 + |0 2:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + path = tempDir.toURI.toString + } + + test("select as sparse vector") { + val df = sqlContext.read.format("libsvm").load(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } + + test("select as dense vector") { + val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense")) + .load(path) + assert(df.columns(0) == "label") + assert(df.columns(1) == "features") + assert(df.count() == 3) + val row1 = df.first() + assert(row1.getDouble(0) == 1.0) + val v = row1.getAs[DenseVector](1) + assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0)) + } + + test("select a vector with specifying the longer dimension") { + val df = sqlContext.read.option("numFeatures", "100").format("libsvm") + .load(path) + val row1 = df.first() + val v = row1.getAs[SparseVector](1) + assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + } +} |