From 254e0509762937acc9c72b432d5d953bf72c3c52 Mon Sep 17 00:00:00 2001 From: Vincenzo Selvaggio Date: Wed, 29 Apr 2015 23:21:21 -0700 Subject: [SPARK-1406] Mllib pmml model export See PDF attached to the JIRA issue 1406. The contribution is my original work and I license the work to the project under the project's open source license. Author: Vincenzo Selvaggio Author: Xiangrui Meng Author: selvinsource Closes #3062 from selvinsource/mllib_pmml_model_export_SPARK-1406 and squashes the following commits: 852aac6 [Vincenzo Selvaggio] [SPARK-1406] Update JPMML version to 1.1.15 in LICENSE file 085cf42 [Vincenzo Selvaggio] [SPARK-1406] Added Double Min and Max Fixed scala style 30165c4 [Vincenzo Selvaggio] [SPARK-1406] Fixed extreme cases for logit 7a5e0ec [Vincenzo Selvaggio] [SPARK-1406] Binary classification for SVM and Logistic Regression cfcb596 [Vincenzo Selvaggio] [SPARK-1406] Throw IllegalArgumentException when exporting a multinomial logistic regression 25dce33 [Vincenzo Selvaggio] [SPARK-1406] Update code to latest pmml model dea98ca [Vincenzo Selvaggio] [SPARK-1406] Exclude transitive dependency for pmml model 66b7c12 [Vincenzo Selvaggio] [SPARK-1406] Updated pmml model lib to 1.1.15, latest Java 6 compatible a0a55f7 [Vincenzo Selvaggio] Merge pull request #2 from mengxr/SPARK-1406 3c22f79 [Xiangrui Meng] more code style e2313df [Vincenzo Selvaggio] Merge pull request #1 from mengxr/SPARK-1406 472d757 [Xiangrui Meng] fix code style 1676e15 [Vincenzo Selvaggio] fixed scala issue e2ffae8 [Vincenzo Selvaggio] fixed scala style b8823b0 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 b25bbf7 [Vincenzo Selvaggio] [SPARK-1406] Added export of pmml to distributed file system using the spark context 7a949d0 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style f46c75c [Vincenzo Selvaggio] [SPARK-1406] Added PMMLExportable to supported models 7b33b4e [Vincenzo Selvaggio] [SPARK-1406] Added a PMMLExportable interface Restructured code in a new package mllib.pmml Supported models implements the new PMMLExportable interface: LogisticRegression, SVM, KMeansModel, LinearRegression, RidgeRegression, Lasso d559ec5 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 8fe12bb [Vincenzo Selvaggio] [SPARK-1406] Adjusted logistic regression export description and target categories 03bc3a5 [Vincenzo Selvaggio] added logistic regression da2ec11 [Vincenzo Selvaggio] [SPARK-1406] added linear SVM PMML export 82f2131 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 19adf29 [Vincenzo Selvaggio] [SPARK-1406] Fixed scala style 1faf985 [Vincenzo Selvaggio] [SPARK-1406] Added target field to the regression model for completeness Adjusted unit test to deal with this change 3ae8ae5 [Vincenzo Selvaggio] [SPARK-1406] Adjusted imported order according to the guidelines c67ce81 [Vincenzo Selvaggio] Merge remote-tracking branch 'upstream/master' into mllib_pmml_model_export_SPARK-1406 78515ec [Vincenzo Selvaggio] [SPARK-1406] added pmml export for LinearRegressionModel, RidgeRegressionModel and LassoModel e29dfb9 [Vincenzo Selvaggio] removed version, by default is set to 4.2 (latest from jpmml) removed copyright ae8b993 [Vincenzo Selvaggio] updated some commented tests to use the new ModelExporter object reordered the imports df8a89e [Vincenzo Selvaggio] added pmml version to pmml model changed the copyright to spark a1b4dc3 [Vincenzo Selvaggio] updated imports 834ca44 [Vincenzo Selvaggio] reordered the import accordingly to the guidelines 349a76b [Vincenzo Selvaggio] new helper object to serialize the models to pmml format c3ef9b8 [Vincenzo Selvaggio] set it to private 6357b98 [Vincenzo Selvaggio] set it to private e1eb251 [Vincenzo Selvaggio] removed serialization part, this will be part of the ModelExporter helper object aba5ee1 [Vincenzo Selvaggio] fixed cluster export cd6c07c [Vincenzo Selvaggio] fixed scala style to run tests f75b988 [Vincenzo Selvaggio] Merge remote-tracking branch 'origin/master' into mllib_pmml_model_export_SPARK-1406 07a29bf [selvinsource] Update LICENSE 8841439 [Vincenzo Selvaggio] adjust scala style in order to compile 1433b11 [Vincenzo Selvaggio] complete suite tests 8e71b8d [Vincenzo Selvaggio] kmeans pmml export implementation 9bc494f [Vincenzo Selvaggio] added scala suite tests added saveLocalFile to ModelExport trait 226e184 [Vincenzo Selvaggio] added javadoc and export model type in case there is a need to support other types of export (not just PMML) a0e3679 [Vincenzo Selvaggio] export and pmml export traits kmeans test implementation --- LICENSE | 1 + mllib/pom.xml | 15 ++++ .../mllib/classification/LogisticRegression.scala | 3 +- .../apache/spark/mllib/classification/SVM.scala | 3 +- .../spark/mllib/clustering/KMeansModel.scala | 4 +- .../apache/spark/mllib/pmml/PMMLExportable.scala | 74 +++++++++++++++++ .../BinaryClassificationPMMLModelExport.scala | 90 ++++++++++++++++++++ .../export/GeneralizedLinearPMMLModelExport.scala | 75 +++++++++++++++++ .../mllib/pmml/export/KMeansPMMLModelExport.scala | 83 +++++++++++++++++++ .../spark/mllib/pmml/export/PMMLModelExport.scala | 47 +++++++++++ .../mllib/pmml/export/PMMLModelExportFactory.scala | 64 +++++++++++++++ .../org/apache/spark/mllib/regression/Lasso.scala | 3 +- .../spark/mllib/regression/LinearRegression.scala | 3 +- .../spark/mllib/regression/RidgeRegression.scala | 3 +- .../BinaryClassificationPMMLModelExportSuite.scala | 84 +++++++++++++++++++ .../GeneralizedLinearPMMLModelExportSuite.scala | 84 +++++++++++++++++++ .../pmml/export/KMeansPMMLModelExportSuite.scala | 49 +++++++++++ .../pmml/export/PMMLModelExportFactorySuite.scala | 95 ++++++++++++++++++++++ 18 files changed, 774 insertions(+), 6 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala diff --git a/LICENSE b/LICENSE index 9b364a4d00..21c42e9a20 100644 --- a/LICENSE +++ b/LICENSE @@ -814,6 +814,7 @@ BSD-style licenses The following components are provided under a BSD-style license. See project link for details. (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) + (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) diff --git a/mllib/pom.xml b/mllib/pom.xml index 5dfab36c76..a3c57ae260 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -109,6 +109,21 @@ test-jar test + + org.jpmml + pmml-model + 1.1.15 + + + com.sun.xml.fastinfoset + FastInfoset + + + com.sun.istack + istack-commons-runtime + + + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 057b628c6a..bd2e9079ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -23,6 +23,7 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} import org.apache.spark.rdd.RDD @@ -46,7 +47,7 @@ class LogisticRegressionModel ( val numFeatures: Int, val numClasses: Int) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable - with Saveable { + with Saveable with PMMLExportable { if (numClasses == 2) { require(weights.size == numFeatures, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 52fb62dcff..33104cf06c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -22,6 +22,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD @@ -36,7 +37,7 @@ class SVMModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable - with Saveable { + with Saveable with PMMLExportable { private var threshold: Option[Double] = Some(0.0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index e4e411a3c8..ba228b11fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -25,6 +25,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext @@ -34,7 +35,8 @@ import org.apache.spark.sql.Row /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ -class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable { +class KMeansModel ( + val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable { /** A Java-friendly constructor that takes an Iterable of Vectors. */ def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala new file mode 100644 index 0000000000..354e90f3ee --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml + +import java.io.{File, OutputStream, StringWriter} +import javax.xml.transform.stream.StreamResult + +import org.jpmml.model.JAXBUtil + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory + +/** + * Export model to the PMML format + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +trait PMMLExportable { + + /** + * Export the model to the stream result in PMML format + */ + private def toPMML(streamResult: StreamResult): Unit = { + val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this) + JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult) + } + + /** + * Export the model to a local file in PMML format + */ + def toPMML(localPath: String): Unit = { + toPMML(new StreamResult(new File(localPath))) + } + + /** + * Export the model to a directory on a distributed file system in PMML format + */ + def toPMML(sc: SparkContext, path: String): Unit = { + val pmml = toPMML() + sc.parallelize(Array(pmml), 1).saveAsTextFile(path) + } + + /** + * Export the model to the OutputStream in PMML format + */ + def toPMML(outputStream: OutputStream): Unit = { + toPMML(new StreamResult(outputStream)) + } + + /** + * Export the model to a String in PMML format + */ + def toPMML(): String = { + val writer = new StringWriter + toPMML(new StreamResult(writer)) + writer.toString + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala new file mode 100644 index 0000000000..34b447584e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.regression.GeneralizedLinearModel + +/** + * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel + */ +private[mllib] class BinaryClassificationPMMLModelExport( + model : GeneralizedLinearModel, + description : String, + normalizationMethod : RegressionNormalizationMethodType, + threshold: Double) + extends PMMLModelExport { + + populateBinaryClassificationPMML() + + /** + * Export the input LogisticRegressionModel or SVMModel to PMML format. + */ + private def populateBinaryClassificationPMML(): Unit = { + pmml.getHeader.setDescription(description) + + if (model.weights.size > 0) { + val fields = new SArray[FieldName](model.weights.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val regressionTableYES = new RegressionTable(model.intercept).withTargetCategory("1") + var interceptNO = threshold + if (RegressionNormalizationMethodType.LOGIT == normalizationMethod) { + if (threshold <= 0) { + interceptNO = Double.MinValue + } else if (threshold >= 1) { + interceptNO = Double.MaxValue + } else { + interceptNO = -math.log(1 / threshold - 1) + } + } + val regressionTableNO = new RegressionTable(interceptNO).withTargetCategory("0") + val regressionModel = new RegressionModel() + .withFunctionName(MiningFunctionType.CLASSIFICATION) + .withMiningSchema(miningSchema) + .withModelName(description) + .withNormalizationMethod(normalizationMethod) + .withRegressionTables(regressionTableYES, regressionTableNO) + + for (i <- 0 until model.weights.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } + + // add target field + val targetField = FieldName.create("target") + dataDictionary + .withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.STRING)) + miningSchema + .withMiningFields(new MiningField(targetField) + .withUsageType(FieldUsageType.TARGET)) + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala new file mode 100644 index 0000000000..1874786af0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExport.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.regression.GeneralizedLinearModel + +/** + * PMML Model Export for GeneralizedLinearModel abstract class + */ +private[mllib] class GeneralizedLinearPMMLModelExport( + model: GeneralizedLinearModel, + description: String) + extends PMMLModelExport { + + populateGeneralizedLinearPMML(model) + + /** + * Export the input GeneralizedLinearModel model to PMML format. + */ + private def populateGeneralizedLinearPMML(model: GeneralizedLinearModel): Unit = { + pmml.getHeader.setDescription(description) + + if (model.weights.size > 0) { + val fields = new SArray[FieldName](model.weights.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val regressionTable = new RegressionTable(model.intercept) + val regressionModel = new RegressionModel() + .withFunctionName(MiningFunctionType.REGRESSION) + .withMiningSchema(miningSchema) + .withModelName(description) + .withRegressionTables(regressionTable) + + for (i <- 0 until model.weights.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + regressionTable.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) + } + + // for completeness add target field + val targetField = FieldName.create("target") + dataDictionary.withDataFields(new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(targetField) + .withUsageType(FieldUsageType.TARGET)) + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(regressionModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala new file mode 100644 index 0000000000..069e7afc9f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.clustering.KMeansModel + +/** + * PMML Model Export for KMeansModel class + */ +private[mllib] class KMeansPMMLModelExport(model : KMeansModel) extends PMMLModelExport{ + + populateKMeansPMML(model) + + /** + * Export the input KMeansModel model to PMML format. + */ + private def populateKMeansPMML(model : KMeansModel): Unit = { + pmml.getHeader.setDescription("k-means clustering") + + if (model.clusterCenters.length > 0) { + val clusterCenter = model.clusterCenters(0) + val fields = new SArray[FieldName](clusterCenter.size) + val dataDictionary = new DataDictionary + val miningSchema = new MiningSchema + val comparisonMeasure = new ComparisonMeasure() + .withKind(ComparisonMeasure.Kind.DISTANCE) + .withMeasure(new SquaredEuclidean()) + val clusteringModel = new ClusteringModel() + .withModelName("k-means") + .withMiningSchema(miningSchema) + .withComparisonMeasure(comparisonMeasure) + .withFunctionName(MiningFunctionType.CLUSTERING) + .withModelClass(ClusteringModel.ModelClass.CENTER_BASED) + .withNumberOfClusters(model.clusterCenters.length) + + for (i <- 0 until clusterCenter.size) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema + .withMiningFields(new MiningField(fields(i)) + .withUsageType(FieldUsageType.ACTIVE)) + clusteringModel.withClusteringFields( + new ClusteringField(fields(i)).withCompareFunction(CompareFunctionType.ABS_DIFF)) + } + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + for (i <- 0 until model.clusterCenters.length) { + val cluster = new Cluster() + .withName("cluster_" + i) + .withArray(new org.dmg.pmml.Array() + .withType(Array.Type.REAL) + .withN(clusterCenter.size) + .withValue(model.clusterCenters(i).toArray.mkString(" "))) + // we don't have the size of the single cluster but only the centroids (withValue) + // .withSize(value) + clusteringModel.withClusters(cluster) + } + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(clusteringModel) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala new file mode 100644 index 0000000000..ebdeae50bb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import java.text.SimpleDateFormat +import java.util.Date + +import scala.beans.BeanProperty + +import org.dmg.pmml.{Application, Header, PMML, Timestamp} + +private[mllib] trait PMMLModelExport { + + /** + * Holder of the exported model in PMML format + */ + @BeanProperty + val pmml: PMML = new PMML + + setHeader(pmml) + + private def setHeader(pmml: PMML): Unit = { + val version = getClass.getPackage.getImplementationVersion + val app = new Application().withName("Apache Spark MLlib").withVersion(version) + val timestamp = new Timestamp() + .withContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + val header = new Header() + .withApplication(app) + .withTimestamp(timestamp) + pmml.setHeader(header) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala new file mode 100644 index 0000000000..c16e83d6a0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionNormalizationMethodType + +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.classification.SVMModel +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.regression.LassoModel +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.RidgeRegressionModel + +private[mllib] object PMMLModelExportFactory { + + /** + * Factory object to help creating the necessary PMMLModelExport implementation + * taking as input the machine learning model (for example KMeansModel). + */ + def createPMMLModelExport(model: Any): PMMLModelExport = { + model match { + case kmeans: KMeansModel => + new KMeansPMMLModelExport(kmeans) + case linear: LinearRegressionModel => + new GeneralizedLinearPMMLModelExport(linear, "linear regression") + case ridge: RidgeRegressionModel => + new GeneralizedLinearPMMLModelExport(ridge, "ridge regression") + case lasso: LassoModel => + new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") + case svm: SVMModel => + new BinaryClassificationPMMLModelExport( + svm, "linear SVM", RegressionNormalizationMethodType.NONE, + svm.getThreshold.getOrElse(0.0)) + case logistic: LogisticRegressionModel => + if (logistic.numClasses == 2) { + new BinaryClassificationPMMLModelExport( + logistic, "logistic regression", RegressionNormalizationMethodType.LOGIT, + logistic.getThreshold.getOrElse(0.5)) + } else { + throw new IllegalArgumentException( + "PMML Export not supported for Multinomial Logistic Regression") + } + case _ => + throw new IllegalArgumentException( + "PMML Export not supported for model: " + model.getClass.getName) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index e8b0381657..4f482384f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD @@ -34,7 +35,7 @@ class LassoModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable with Saveable { + with RegressionModel with Serializable with Saveable with PMMLExportable { override protected def predictPoint( dataMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 6fa7ad52a5..9453c4f66c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD @@ -34,7 +35,7 @@ class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable - with Saveable { + with Saveable with PMMLExportable { override protected def predictPoint( dataMatrix: Vector, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 309f9af466..e0c03d8180 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -35,7 +36,7 @@ class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable with Saveable { + with RegressionModel with Serializable with Saveable with PMMLExportable { override protected def predictPoint( dataMatrix: Vector, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala new file mode 100644 index 0000000000..0b646cf1ce --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionModel +import org.dmg.pmml.RegressionNormalizationMethodType +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.classification.SVMModel +import org.apache.spark.mllib.util.LinearDataGenerator + +class BinaryClassificationPMMLModelExportSuite extends FunSuite { + + test("logistic regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val logisticRegressionModel = + new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) + + val logisticModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + + // assert that the PMML format is as expected + assert(logisticModelExport.isInstanceOf[PMMLModelExport]) + val pmml = logisticModelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "logistic regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === logisticRegressionModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === logisticRegressionModel.weights.size) + // verify if there is a second table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) + // ensure logistic regression has normalization method set to LOGIT + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) + } + + test("linear SVM PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + + // assert that the PMML format is as expected + assert(svmModelExport.isInstanceOf[PMMLModelExport]) + val pmml = svmModelExport.getPmml + assert(pmml.getHeader.getDescription + === "linear SVM") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === svmModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table (for target category 1) + // with the same number of predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getTargetCategory === "1") + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === svmModel.weights.size) + // verify if there is a second table with target category 0 and no predictors + assert(pmmlRegressionModel.getRegressionTables.get(1).getTargetCategory === "0") + assert(pmmlRegressionModel.getRegressionTables.get(1).getNumericPredictors.size === 0) + // ensure linear SVM has normalization method set to NONE + assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala new file mode 100644 index 0000000000..f9afbd888d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.RegressionModel +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} +import org.apache.spark.mllib.util.LinearDataGenerator + +class GeneralizedLinearPMMLModelExportSuite extends FunSuite { + + test("linear regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val linearRegressionModel = + new LinearRegressionModel(linearInput(0).features, linearInput(0).label) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) + // assert that the PMML format is as expected + assert(linearModelExport.isInstanceOf[PMMLModelExport]) + val pmml = linearModelExport.getPmml + assert(pmml.getHeader.getDescription === "linear regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === linearRegressionModel.weights.size + 1) + // This verifies that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === linearRegressionModel.weights.size) + } + + test("ridge regression PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val ridgeRegressionModel = + new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) + // assert that the PMML format is as expected + assert(ridgeModelExport.isInstanceOf[PMMLModelExport]) + val pmml = ridgeModelExport.getPmml + assert(pmml.getHeader.getDescription === "ridge regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === ridgeRegressionModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === ridgeRegressionModel.weights.size) + } + + test("lasso PMML export") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) + // assert that the PMML format is as expected + assert(lassoModelExport.isInstanceOf[PMMLModelExport]) + val pmml = lassoModelExport.getPmml + assert(pmml.getHeader.getDescription === "lasso regression") + // check that the number of fields match the weights size + assert(pmml.getDataDictionary.getNumberOfFields === lassoModel.weights.size + 1) + // This verify that there is a model attached to the pmml object and the model is a regression + // one. It also verifies that the pmml model has a regression table with the same number of + // predictors of the model weights. + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[RegressionModel] + assert(pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors.size + === lassoModel.weights.size) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala new file mode 100644 index 0000000000..b985d0446d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.dmg.pmml.ClusteringModel +import org.scalatest.FunSuite + +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.linalg.Vectors + +class KMeansPMMLModelExportSuite extends FunSuite { + + test("KMeansPMMLModelExport generate PMML format") { + val clusterCenters = Array( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0)) + val kmeansModel = new KMeansModel(clusterCenters) + + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) + + // assert that the PMML format is as expected + assert(modelExport.isInstanceOf[PMMLModelExport]) + val pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml + assert(pmml.getHeader.getDescription === "k-means clustering") + // check that the number of fields match the single vector size + assert(pmml.getDataDictionary.getNumberOfFields === clusterCenters(0).size) + // This verify that there is a model attached to the pmml object and the model is a clustering + // one. It also verifies that the pmml model has the same number of clusters of the spark model. + val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] + assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala new file mode 100644 index 0000000000..f28a4ac8ad --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} +import org.apache.spark.mllib.clustering.KMeansModel +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} +import org.apache.spark.mllib.util.LinearDataGenerator + +class PMMLModelExportFactorySuite extends FunSuite { + + test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { + val clusterCenters = Array( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0)) + val kmeansModel = new KMeansModel(clusterCenters) + + val modelExport = PMMLModelExportFactory.createPMMLModelExport(kmeansModel) + + assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) + } + + test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " + + "LinearRegressionModel, RidgeRegressionModel or LassoModel") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + + val linearRegressionModel = + new LinearRegressionModel(linearInput(0).features, linearInput(0).label) + val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel) + assert(linearModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + val ridgeRegressionModel = + new RidgeRegressionModel(linearInput(0).features, linearInput(0).label) + val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel) + assert(ridgeModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + + val lassoModel = new LassoModel(linearInput(0).features, linearInput(0).label) + val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel) + assert(lassoModelExport.isInstanceOf[GeneralizedLinearPMMLModelExport]) + } + + test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + + "when passing a LogisticRegressionModel or SVMModel") { + val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) + + val logisticRegressionModel = + new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) + val logisticRegressionModelExport = + PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) + assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) + + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) + assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) + } + + test("PMMLModelExportFactory throw IllegalArgumentException " + + "when passing a Multinomial Logistic Regression") { + /** 3 classes, 2 features */ + val multiclassLogisticRegressionModel = new LogisticRegressionModel( + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + numFeatures = 2, numClasses = 3) + + intercept[IllegalArgumentException] { + PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) + } + } + + test("PMMLModelExportFactory throw IllegalArgumentException when passing an unsupported model") { + val invalidModel = new Object + + intercept[IllegalArgumentException] { + PMMLModelExportFactory.createPMMLModelExport(invalidModel) + } + } +} -- cgit v1.2.3