aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala8
1 files changed, 5 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 95bae1c8a3..a72692960f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -135,7 +135,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
/** @group setParam */
def setMinCount(value: Int): this.type = set(minCount, value)
- override def fit(dataset: DataFrame): Word2VecModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): Word2VecModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
val wordVectors = new feature.Word2Vec()
@@ -219,7 +220,8 @@ class Word2VecModel private[ml] (
* Transform a sentence column to a vector column to represent the whole sentence. The transform
* is performed by averaging all word vectors it contains.
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))