diff options
Diffstat (limited to 'mllib')
6 files changed, 34 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index e545df1e37..081a574bee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -162,11 +162,15 @@ class PipelineModel private[ml] ( } override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { - transformSchema(dataset.schema, paramMap, logging = true) - stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap)) + // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap + val map = (fittingParamMap ++ this.paramMap) ++ paramMap + transformSchema(dataset.schema, map, logging = true) + stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap)) + // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap + val map = (fittingParamMap ++ this.paramMap) ++ paramMap + stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 490e6609ad..23fbd228d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -18,16 +18,14 @@ package org.apache.spark.ml import scala.annotation.varargs -import scala.reflect.runtime.universe.TypeTag import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.sql.SchemaRDD import org.apache.spark.sql.api.java.JavaSchemaRDD -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.expressions.ScalaUdf import org.apache.spark.sql.catalyst.types._ /** @@ -86,7 +84,7 @@ abstract class Transformer extends PipelineStage with Params { * Abstract class for transformers that take one input column, apply transformation, and output the * result as a new column. */ -private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]] +private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] @@ -100,6 +98,11 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor protected def createTransformFunc(paramMap: ParamMap): IN => OUT /** + * Returns the data type of the output column. + */ + protected def outputDataType: DataType + + /** * Validates the input type. Throw an exception if it is invalid. */ protected def validateInputType(inputType: DataType): Unit = {} @@ -111,9 +114,8 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor if (schema.fieldNames.contains(map(outputCol))) { throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") } - val output = ScalaReflection.schemaFor[OUT] val outputFields = schema.fields :+ - StructField(map(outputCol), output.dataType, output.nullable) + StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive) StructType(outputFields) } @@ -121,7 +123,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor transformSchema(dataset.schema, paramMap, logging = true) import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val udf = this.createTransformFunc(map) - dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol)) + val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr)) + dataset.select(Star(None), udf as map(outputCol)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index b98b1755a3..e0bfb1e484 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -21,7 +21,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.sql.catalyst.types.DataType /** * :: AlphaComponent :: @@ -39,4 +40,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform } + + override protected def outputDataType: DataType = new VectorUDT() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0a6599b64c..9352f40f37 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.{DataType, StringType} +import org.apache.spark.sql.{DataType, StringType, ArrayType} /** * :: AlphaComponent :: @@ -36,4 +36,6 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { protected override def validateInputType(inputType: DataType): Unit = { require(inputType == StringType, s"Input type must be string type but got $inputType.") } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8fd46aef4b..4b4340af54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -17,13 +17,12 @@ package org.apache.spark.ml.param -import java.lang.reflect.Modifier - -import org.apache.spark.annotation.AlphaComponent - import scala.annotation.varargs import scala.collection.mutable +import java.lang.reflect.Modifier + +import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Identifiable /** @@ -221,7 +220,9 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten /** * Puts a list of param pairs (overwrites if the input params exists). + * Not usable from Java */ + @varargs def put(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => put(p.param.asInstanceOf[Param[Any]], p.value) @@ -282,6 +283,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * where the latter overwrites this if there exists conflicts. */ def ++(other: ParamMap): ParamMap = { + // TODO: Provide a better method name for Java users. new ParamMap(this.map ++ other.map) } @@ -290,6 +292,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Adds all parameters from the input param map into this param map. */ def ++=(other: ParamMap): this.type = { + // TODO: Provide a better method name for Java users. this.map ++= other.map this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 8c4c9c6cf6..9fed513bec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -96,7 +96,9 @@ private[spark] object BLAS extends Serializable with Logging { * dot(x, y) */ def dot(x: Vector, y: Vector): Double = { - require(x.size == y.size) + require(x.size == y.size, + "BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" + + " x.size = " + x.size + ", y.size = " + y.size) (x, y) match { case (dx: DenseVector, dy: DenseVector) => dot(dx, dy) |