aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala62
1 files changed, 28 insertions, 34 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 0163fa8bd8..34ff929701 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -18,16 +18,16 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.mllib.linalg.BLAS._
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{DataFrame, Row}
/**
* Params for [[Word2Vec]] and [[Word2VecModel]].
@@ -43,7 +43,7 @@ private[feature] trait Word2VecBase extends Params
setDefault(vectorSize -> 100)
/** @group getParam */
- def getVectorSize: Int = getOrDefault(vectorSize)
+ def getVectorSize: Int = $(vectorSize)
/**
* Number of partitions for sentences of words.
@@ -53,7 +53,7 @@ private[feature] trait Word2VecBase extends Params
setDefault(numPartitions -> 1)
/** @group getParam */
- def getNumPartitions: Int = getOrDefault(numPartitions)
+ def getNumPartitions: Int = $(numPartitions)
/**
* The minimum number of times a token must appear to be included in the word2vec model's
@@ -64,7 +64,7 @@ private[feature] trait Word2VecBase extends Params
setDefault(minCount -> 5)
/** @group getParam */
- def getMinCount: Int = getOrDefault(minCount)
+ def getMinCount: Int = $(minCount)
setDefault(stepSize -> 0.025)
setDefault(maxIter -> 1)
@@ -73,10 +73,9 @@ private[feature] trait Word2VecBase extends Params
/**
* Validate and transform the input schema.
*/
- protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = extractParamMap(paramMap)
- SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true))
- SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
}
@@ -112,25 +111,22 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
/** @group setParam */
def setMinCount(value: Int): this.type = set(minCount, value)
- override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
- val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v }
+ override def fit(dataset: DataFrame): Word2VecModel = {
+ transformSchema(dataset.schema, logging = true)
+ val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0))
val wordVectors = new feature.Word2Vec()
- .setLearningRate(map(stepSize))
- .setMinCount(map(minCount))
- .setNumIterations(map(maxIter))
- .setNumPartitions(map(numPartitions))
- .setSeed(map(seed))
- .setVectorSize(map(vectorSize))
+ .setLearningRate($(stepSize))
+ .setMinCount($(minCount))
+ .setNumIterations($(maxIter))
+ .setNumPartitions($(numPartitions))
+ .setSeed($(seed))
+ .setVectorSize($(vectorSize))
.fit(input)
- val model = new Word2VecModel(this, map, wordVectors)
- Params.inheritValues(map, this, model)
- model
+ copyValues(new Word2VecModel(this, wordVectors))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}
@@ -141,7 +137,6 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
@AlphaComponent
class Word2VecModel private[ml] (
override val parent: Word2Vec,
- override val fittingParamMap: ParamMap,
wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
@@ -155,15 +150,14 @@ class Word2VecModel private[ml] (
* Transform a sentence column to a vector column to represent the whole sentence. The transform
* is performed by averaging all word vectors it contains.
*/
- override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = extractParamMap(paramMap)
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.size == 0) {
- Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double])
+ Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double])
} else {
- val cum = Vectors.zeros(map(vectorSize))
+ val cum = Vectors.zeros($(vectorSize))
val model = bWordVectors.value.getVectors
for (word <- sentence) {
if (model.contains(word)) {
@@ -176,10 +170,10 @@ class Word2VecModel private[ml] (
cum
}
}
- dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol))))
+ dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))
}
- override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap)
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
}
}