diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 1b0a9a12e8..91989c3d2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -19,10 +19,12 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.sql.types.DataType /** @@ -52,7 +54,7 @@ class ElementwiseProduct(override val uid: String) override protected def createTransformFunc: Vector => Vector = { require(params.contains(scalingVec), s"transformation requires a weight vector") val elemScaler = new feature.ElementwiseProduct($(scalingVec)) - elemScaler.transform + v => elemScaler.transform(v) } override protected def outputDataType: DataType = new VectorUDT() |