aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-06 14:07:51 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-06 14:07:51 -0700
commit25cff1019da9d6cfc486a31d035b372ea5fbdfd2 (patch)
tree2336b0899ff33ad0187442c6098b089587f5047b /mllib/src/main
parent4e982364426c7d65032e8006c63ca4f9a0d40470 (diff)
downloadspark-25cff1019da9d6cfc486a31d035b372ea5fbdfd2.tar.gz
spark-25cff1019da9d6cfc486a31d035b372ea5fbdfd2.tar.bz2
spark-25cff1019da9d6cfc486a31d035b372ea5fbdfd2.zip
[SPARK-2852][MLLIB] API consistency for `mllib.feature`
This is part of SPARK-2828: 1. added a Java-friendly fit method to Word2Vec with tests 2. change DeveloperApi to Experimental for Normalizer & StandardScaler 3. change default feature dimension to 2^20 in HashingTF Author: Xiangrui Meng <meng@databricks.com> Closes #1807 from mengxr/feature-api-check and squashes the following commits: 773c1a9 [Xiangrui Meng] change default numFeatures to 2^20 in HashingTF change annotation from DeveloperApi to Experimental in Normalizer and StandardScaler 883e122 [Xiangrui Meng] add @Experimental to Word2VecModel add a Java-friendly method to Word2Vec.fit with tests
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala19
4 files changed, 25 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
index 0f6d5809e0..c534758183 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
@@ -32,12 +32,12 @@ import org.apache.spark.util.Utils
* :: Experimental ::
* Maps a sequence of terms to their term frequencies using the hashing trick.
*
- * @param numFeatures number of features (default: 1000000)
+ * @param numFeatures number of features (default: 2^20^)
*/
@Experimental
class HashingTF(val numFeatures: Int) extends Serializable {
- def this() = this(1000000)
+ def this() = this(1 << 20)
/**
* Returns the index of the input term.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
index ea9fd0a80d..3afb477672 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
@@ -19,11 +19,11 @@ package org.apache.spark.mllib.feature
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
- * :: DeveloperApi ::
+ * :: Experimental ::
* Normalizes samples individually to unit L^p^ norm
*
* For any 1 <= p < Double.PositiveInfinity, normalizes samples using
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
*
* @param p Normalization in L^p^ space, p = 2 by default.
*/
-@DeveloperApi
+@Experimental
class Normalizer(p: Double) extends VectorTransformer {
def this() = this(2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index cc2d7579c2..e6c9f8f67d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -19,14 +19,14 @@ package org.apache.spark.mllib.feature
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
/**
- * :: DeveloperApi ::
+ * :: Experimental ::
* Standardizes features by removing the mean and scaling to unit variance using column summary
* statistics on the samples in the training set.
*
@@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD
* dense output, so this does not work on sparse input and will raise an exception.
* @param withStd True by default. Scales the data to unit standard deviation.
*/
-@DeveloperApi
+@Experimental
class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer {
def this() = this(false, true)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 3bf44ad7c4..395037e1ec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -17,6 +17,9 @@
package org.apache.spark.mllib.feature
+import java.lang.{Iterable => JavaIterable}
+
+import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -25,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
@@ -239,7 +243,7 @@ class Word2Vec extends Serializable with Logging {
a += 1
}
}
-
+
/**
* Computes the vector representation of each word in vocabulary.
* @param dataset an RDD of words
@@ -369,11 +373,22 @@ class Word2Vec extends Serializable with Logging {
new Word2VecModel(word2VecMap.toMap)
}
+
+ /**
+ * Computes the vector representation of each word in vocabulary (Java version).
+ * @param dataset a JavaRDD of words
+ * @return a Word2VecModel
+ */
+ def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = {
+ fit(dataset.rdd.map(_.asScala))
+ }
}
/**
-* Word2Vec model
+ * :: Experimental ::
+ * Word2Vec model
*/
+@Experimental
class Word2VecModel private[mllib] (
private val model: Map[String, Array[Float]]) extends Serializable {