In [1]:
from pyspark.ml import Pipeline
from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import *
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql import Row
from pyspark.sql.functions import *
from pyspark.sql.types import *
Basic example on Transformer and Estimator¶
In [2]:
# Prepare training data from a list of (label, features) tuples.
# Dense Vectors are just NumPy arrays
training = spark.createDataFrame([
(1, Vectors.dense([0.0, 1.1, 0.1])),
(0, Vectors.dense([2.0, 1.0, -1.0])),
(0, Vectors.dense([2.0, 1.3, 1.0])),
(1, Vectors.dense([0.0, 1.2, -0.5]))], [“label”, “features”])
# Create a LogisticRegression instance. This instance is an Estimator.
lr = LogisticRegression(maxIter=10, regParam=0.01)
# Print out the parameters, documentation, and any default values.
print lr.explainParams()
family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial (default: auto)
featuresCol: features column name. (default: features)
fitIntercept: whether to fit an intercept term. (default: True)
labelCol: label column name. (default: label)
maxIter: max number of iterations (>= 0). (default: 100, current: 10)
predictionCol: prediction column name. (default: prediction)
probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. (default: probability)
rawPredictionCol: raw prediction (a.k.a. confidence) column name. (default: rawPrediction)
regParam: regularization parameter (>= 0). (default: 0.0, current: 0.01)
standardization: whether to standardize the training features before fitting the model. (default: True)
threshold: Threshold in binary classification prediction, in range [0, 1]. If threshold and thresholds are both set, they must match.e.g. if threshold is p, then thresholds must be equal to [1-p, p]. (default: 0.5)
thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0, excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class’s threshold. (undefined)
tol: the convergence tolerance for iterative algorithms (>= 0). (default: 1e-06)
weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0. (undefined)
aggregationDepth: suggested depth for treeAggregate (>= 2). (default: 2)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. (default: 0.0)
family: The name of
In [3]:
# Learn a LogisticRegression model. This uses the parameters stored in lr.
model1 = lr.fit(training)
print model1
# model1 is a Model (i.e., a transformer produced by an Estimator)
print “Model 1’s trained coefficients: “, model1.coefficients
LogisticRegression_4c2e8a0650d8a2e4dd40
Model 1’s trained coefficients: [-3.1120572566,2.6484863784,-0.39555871831]
In [4]:
# We may alternatively specify parameters using a Python dictionary as a paramMap
paramMap = {lr.maxIter: 20}
paramMap[lr.maxIter] = 30 # Specify 1 Param, overwriting the original maxIter.
paramMap.update({lr.regParam: 0.1, lr.threshold: 0.55}) # Specify multiple Params.
# You can combine paramMaps, which are python dictionaries.
paramMap[lr.probabilityCol] = “myProbability” # Change output column name
# Now learn a new model using the paramMapCombined parameters.
# paramMapCombined overrides all parameters set earlier via lr.set* methods.
model2 = lr.fit(training, paramMap)
print “Model 2’s trained coefficients: “, model2.coefficients
Model 2’s trained coefficients: [-1.43136570278,0.432088704954,-0.149203769854]
In [5]:
# Prepare test data
test = spark.createDataFrame([
(1, Vectors.dense([-1.0, 1.5, 1.3])),
(2, Vectors.dense([3.0, 2.0, -0.1])),
(3, Vectors.dense([0.0, 2.2, -1.5]))], [“id”, “features”])
# Make predictions on test data using the Transformer.transform() method.
# LogisticRegression.transform will only use the ‘features’ column.
# Note that model2.transform() outputs a “myProbability” column instead of the usual
# ‘probability’ column since we renamed the lr.probabilityCol parameter previously.
model1.transform(test).show()
model2.transform(test).show()
+—+————–+——————–+——————–+———-+
| id| features| rawPrediction| probability|prediction|
+—+————–+——————–+——————–+———-+
| 1|[-1.0,1.5,1.3]|[-6.6201561534426…|[0.00133144762578…| 1.0|
| 2|[3.0,2.0,-0.1]|[3.95004747812153…|[0.98110991800680…| 0.0|
| 3|[0.0,2.2,-1.5]|[-6.4696037729905…|[0.00154744141349…| 1.0|
+—+————–+——————–+——————–+———-+
+—+————–+——————–+——————–+———-+
| id| features| rawPrediction| myProbability|prediction|
+—+————–+——————–+——————–+———-+
| 1|[-1.0,1.5,1.3]|[-2.8046572125682…|[0.05707302714277…| 1.0|
| 2|[3.0,2.0,-0.1]|[2.49587596827288…|[0.92385220384891…| 0.0|
| 3|[0.0,2.2,-1.5]|[-2.0935241588488…|[0.10972783382176…| 1.0|
+—+————–+——————–+——————–+———-+
Pipeline example¶
In [6]:
# Prepare training documents from a list of (id, text, label) tuples.
training = spark.createDataFrame([
(0, “a b c d spark spark”, 1),
(1, “b d”, 0),
(2, “spark f g h”, 1),
(3, “hadoop mapreduce”, 0)
], [“id”, “text”, “label”])
In [7]:
# A tokenizer converts the input string to lowercase and then splits it by white spaces.
tokenizer = Tokenizer(inputCol=”text”, outputCol=”words”)
tokenizer.transform(training).show()
+—+——————-+—–+——————–+
| id| text|label| words|
+—+——————-+—–+——————–+
| 0|a b c d spark spark| 1|[a, b, c, d, spar…|
| 1| b d| 0| [b, d]|
| 2| spark f g h| 1| [spark, f, g, h]|
| 3| hadoop mapreduce| 0| [hadoop, mapreduce]|
+—+——————-+—–+——————–+
In [8]:
# The same can be achieved by DataFrameAPI:
# But you will need to wrap it as a transformer to use it in a pipeline.
training.select(‘*’, split(training[‘text’],’ ‘).alias(‘words’)).show()
+—+——————-+—–+——————–+
| id| text|label| words|
+—+——————-+—–+——————–+
| 0|a b c d spark spark| 1|[a, b, c, d, spar…|
| 1| b d| 0| [b, d]|
| 2| spark f g h| 1| [spark, f, g, h]|
| 3| hadoop mapreduce| 0| [hadoop, mapreduce]|
+—+——————-+—–+——————–+
In [9]:
# Maps a sequence of terms to their term frequencies using the hashing trick.
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol=”features”)
a = hashingTF.transform(tokenizer.transform(training))
a.show(truncate=False)
print a.select(‘features’).first()
+—+——————-+—–+————————–+—————————————————————-+
|id |text |label|words |features |
+—+——————-+—–+————————–+—————————————————————-+
|0 |a b c d spark spark|1 |[a, b, c, d, spark, spark]|(262144,[27526,28698,30913,227410,234657],[1.0,1.0,1.0,1.0,2.0])|
|1 |b d |0 |[b, d] |(262144,[27526,30913],[1.0,1.0]) |
|2 |spark f g h |1 |[spark, f, g, h] |(262144,[15554,24152,51505,234657],[1.0,1.0,1.0,1.0]) |
|3 |hadoop mapreduce |0 |[hadoop, mapreduce] |(262144,[42633,155117],[1.0,1.0]) |
+—+——————-+—–+————————–+—————————————————————-+
Row(features=SparseVector(262144, {27526: 1.0, 28698: 1.0, 30913: 1.0, 227410: 1.0, 234657: 2.0}))
In [10]:
# lr is an estimator
lr = LogisticRegression(maxIter=10, regParam=0.001)
# Now we are ready to assumble the pipeline
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
# Fit the pipeline to training documents.
model = pipeline.fit(training)
# Prepare test documents, which are unlabeled (id, text) tuples.
test = spark.createDataFrame([
(4, “spark i j k”),
(5, “l m n”),
(6, “spark hadoop spark”),
(7, “apache hadoop”)
], [“id”, “text”])
# Make predictions on test documents and print columns of interest.
model.transform(test).show()
+—+——————+——————–+——————–+——————–+——————–+———-+
| id| text| words| features| rawPrediction| probability|prediction|
+—+——————+——————–+——————–+——————–+——————–+———-+
| 4| spark i j k| [spark, i, j, k]|(262144,[20197,24…|[-0.7908353682239…|[0.31198932781127…| 1.0|
| 5| l m n| [l, m, n]|(262144,[18910,10…|[1.52403423257606…|[0.82113177383528…| 0.0|
| 6|spark hadoop spark|[spark, hadoop, s…|(262144,[155117,2…|[-0.7068761675479…|[0.33028946002907…| 1.0|
| 7| apache hadoop| [apache, hadoop]|(262144,[66695,15…|[3.92286303405201…|[0.98059945726531…| 0.0|
+—+——————+——————–+——————–+——————–+——————–+———-+
Example: Analyzing food inspection data using logistic regression¶
In [11]:
# Data at https://www.cse.ust.hk/msbd5003/data/Food_Inspections1.csv
inspections = spark.read.csv(‘../data/Food_Inspections1.csv’, inferSchema=True)
Let’s take a look at its schema:
In [12]:
inspections.printSchema()
root
|– _c0: integer (nullable = true)
|– _c1: string (nullable = true)
|– _c2: string (nullable = true)
|– _c3: integer (nullable = true)
|– _c4: string (nullable = true)
|– _c5: string (nullable = true)
|– _c6: string (nullable = true)
|– _c7: string (nullable = true)
|– _c8: string (nullable = true)
|– _c9: integer (nullable = true)
|– _c10: string (nullable = true)
|– _c11: string (nullable = true)
|– _c12: string (nullable = true)
|– _c13: string (nullable = true)
|– _c14: double (nullable = true)
|– _c15: double (nullable = true)
|– _c16: string (nullable = true)
In [13]:
inspections.show()
+——+——————–+——————–+——-+——————–+—————+——————–+——-+—+—–+———-+——————–+——————+——————–+——————+——————+——————–+
| _c0| _c1| _c2| _c3| _c4| _c5| _c6| _c7|_c8| _c9| _c10| _c11| _c12| _c13| _c14| _c15| _c16|
+——+——————–+——————–+——-+——————–+—————+——————–+——-+—+—–+———-+——————–+——————+——————–+——————+——————+——————–+
|413707| LUNA PARK INC| LUNA PARK DAY CARE|2049789|Children’s Servic…| Risk 1 (High)| 3250 W FOSTER AVE |CHICAGO| IL|60625|09/21/2010| License-Task Force| Fail|24. DISH WASHING …| 41.97583445690982| -87.7107455232781|(41.9758344569098…|
|391234| CAFE SELMARIE| CAFE SELMARIE|1069067| Restaurant| Risk 1 (High)| 4729 N LINCOLN AVE |CHICAGO| IL|60625|09/21/2010| Canvass| Fail|2. FACILITIES TO …| 41.96740659751604|-87.68761642361608|(41.9674065975160…|
|413751| MANCHU WOK|MANCHU WOK (T3-H/…|1909522| Restaurant| Risk 1 (High)| 11601 W TOUHY AVE |CHICAGO| IL|60666|09/21/2010| Canvass| Pass|33. FOOD AND NON-…|42.008536400868735|-87.91442843927047|(42.0085364008687…|
|413708|BENCHMARK HOSPITA…|BENCHMARK HOSPITA…|2049411| Restaurant| Risk 1 (High)|325 N LA SALLE ST…|CHICAGO| IL|60654|09/21/2010|Task Force Liquor…| Pass| null| 41.88819879207664|-87.63236298373182|(41.8881987920766…|
|413722| JJ BURGER| JJ BURGER|2055016| Restaurant|Risk 2 (Medium)| 749 S CICERO AVE |CHICAGO| IL|60644|09/21/2010| License| Pass| null| 41.87082601444883|-87.74476763884662|(41.8708260144488…|
|413752|GOLDEN HOOKS FISH…|GOLDEN HOOKS FISH…|2042435| Restaurant|Risk 2 (Medium)| 3958 W MONROE ST |CHICAGO| IL|60624|09/21/2010|Short Form Complaint| Pass| null| 41.87987261425607|-87.72551692436804|(41.8798726142560…|
|413714|THE DOCK AT MONTR…|THE DOCK AT MONTR…|2043260| Restaurant| Risk 1 (High)| 4400 N SIMONDS DR |CHICAGO| IL|60640|09/21/2010| License| Fail| null| 41.96390893734172|-87.63863624840039|(41.9639089373417…|
|413753|CLARK FOOD & CIGA…| null|2042203| Grocery Store| Risk 3 (Low)|6761 N CLARK ST B…|CHICAGO| IL|60626|09/21/2010| Canvass| Pass| null| 42.0053117273606|-87.67294053846207|(42.0053117273606…|
|120580| SUSHI PINK| SUSHI PINK|1847340| Restaurant| Risk 1 (High)|909 W WASHINGTON …|CHICAGO| IL|60607|09/21/2010| Canvass| Pass|32. FOOD AND NON-…|41.882987317760424|-87.65014022876997|(41.8829873177604…|
|401216| M.H.R.,L.L.C.| M.H.R.,L.L.C.|1621323| Restaurant|Risk 2 (Medium)| 623 S WABASH AVE |CHICAGO| IL|60605|09/21/2010| Canvass| Out of Business| null| 41.87390845559158|-87.62583770570953|(41.8739084555915…|
|413715| NABO’S| NABO’S|1931861| Restaurant| Risk 1 (High)| 3351 N BROADWAY |CHICAGO| IL|60657|09/21/2010|Canvass Re-Inspec…| Pass|19. OUTSIDE GARBA…| 41.94334005547684|-87.64466387044959|(41.9433400554768…|
|413721|THE NICHOLSON SCHOOL|THE NICHOLSON SCHOOL|2002702|Daycare (2 – 6 Ye…| Risk 1 (High)| 1700 W CORTLAND ST |CHICAGO| IL|60622|09/21/2010| License| Pass| null| 41.91618227133264| -87.6703413842735|(41.9161822713326…|
|401215| M.H.R.,L.L.C.| M.H.R.,L.L.C.|1621322| Restaurant| Risk 1 (High)| 600 S MICHIGAN AVE |CHICAGO| IL|60605|09/21/2010| Canvass| Out of Business| null| 41.87437161535891|-87.62437952778167|(41.8743716153589…|
|420207| WHOLE FOODS MARKET| WHOLE FOODS MARKET|1933690| Grocery Store| Risk 1 (High)|1550 N KINGSBURY ST |CHICAGO| IL|60642|09/21/2010| Complaint| Pass|32. FOOD AND NON-…| 41.90939878780941|-87.65305069789407|(41.9093987878094…|
|154514| LAS FUENTES| LAS FUENTES| 12575| Restaurant| Risk 1 (High)| 2558 N HALSTED ST |CHICAGO| IL|60614|09/21/2010| Canvass| Fail|18. NO EVIDENCE O…| 41.9290354100918|-87.64903392789199|(41.9290354100918…|
|413711|CASA CENTRAL COMM…|CASA CENTRAL COMM…| 60766| Restaurant| Risk 1 (High)|1343 N CALIFORNIA…|CHICAGO| IL|60622|09/21/2010| License| Pass|41. PREMISES MAIN…| 41.90598597077873|-87.69680735572291|(41.9059859707787…|
|413764|LA BRUQUENA RESTA…|LA BRUQUENA RESTA…|1492868| Restaurant| Risk 1 (High)| 2726 W DIVISION ST |CHICAGO| IL|60622|09/21/2010|Suspected Food Po…|Pass w/ Conditions|4. SOURCE OF CROS…|41.903046386818346| -87.695535129416|(41.9030463868183…|
|413732|SODEXHO AT UNITED…|SODEXHO AT UNITED…| 20467| Restaurant| Risk 1 (High)| 11601 W TOUHY AVE |CHICAGO| IL|60666|09/21/2010| Canvass| Out of Business| null|42.008536400868735|-87.91442843927047|(42.0085364008687…|
|413757|WHIZ KIDS NURSERY…|WHIZ KIDS NURSERY…|1948277|Daycare Above and…| Risk 1 (High)| 514-522 W 103RD ST |CHICAGO| IL|60628|09/21/2010| Canvass| Pass|35. WALLS, CEILIN…|41.707112812685075|-87.63620425242559|(41.7071128126850…|
|363272|FRUTERIA GUAYAUIT…|FRUTERIA GUAYAUIT…|1446823| Grocery Store| Risk 3 (Low)| 3849 S KEDZIE AVE |CHICAGO| IL|60632|09/21/2010| Canvass| Out of Business| null| 41.82290845958193|-87.70426021024545|(41.8229084595819…|
+——+——————–+——————–+——-+——————–+—————+——————–+——-+—+—–+———-+——————–+——————+——————–+——————+——————+——————–+
only showing top 20 rows
We now have the CSV file as a DataFrame. It has some columns we will not use. Dropping them can save memory when caching the DataFrame. Also, we should give these columns meaningful names.
In [14]:
# Drop unused columns and rename interesting columns.
# Keep interesting columns and rename them to something meaningful
# Mapping column index to name.
columnNames = {0: “id”, 1: “name”, 12: “results”, 13: “violations”}
# Rename column from ‘_c{id}’ to something meaningful.
cols = [inspections[i].alias(columnNames[i]) for i in columnNames.keys()]
# Drop columns we are not using.
df = inspections.select(cols).where(col(‘violations’).isNotNull())
df.cache()
df.show()
df.count()
+——+——————–+——————+——————–+
| id| name| results| violations|
+——+——————–+——————+——————–+
|413707| LUNA PARK INC| Fail|24. DISH WASHING …|
|391234| CAFE SELMARIE| Fail|2. FACILITIES TO …|
|413751| MANCHU WOK| Pass|33. FOOD AND NON-…|
|120580| SUSHI PINK| Pass|32. FOOD AND NON-…|
|413715| NABO’S| Pass|19. OUTSIDE GARBA…|
|420207| WHOLE FOODS MARKET| Pass|32. FOOD AND NON-…|
|154514| LAS FUENTES| Fail|18. NO EVIDENCE O…|
|413711|CASA CENTRAL COMM…| Pass|41. PREMISES MAIN…|
|413764|LA BRUQUENA RESTA…|Pass w/ Conditions|4. SOURCE OF CROS…|
|413757|WHIZ KIDS NURSERY…| Pass|35. WALLS, CEILIN…|
|154516|TACO & BURRITO PA…| Pass|30. FOOD IN ORIGI…|
|413759| MARISCOS EL VENENO| Pass|18. NO EVIDENCE O…|
|114554|THE HANGGE- UPPE,…| Pass|18. NO EVIDENCE O…|
|413758| LINDY’S CHILI INC|Pass w/ Conditions|30. FOOD IN ORIGI…|
|343362| FUMARE MEATS| Pass|40. REFRIGERATION…|
|413754| Subway| Pass|38. VENTILATION: …|
|289222|LITTLE CAESARS PIZZA| Pass|34. FLOORS: CONST…|
|413755| BILLY’S GRILL| Pass|33. FOOD AND NON-…|
|343364| FRESHII| Fail|18. NO EVIDENCE O…|
|289221|NICKY’S GRILL & Y…| Pass|33. FOOD AND NON-…|
+——+——————–+——————+——————–+
only showing top 20 rows
Out[14]:
10469
In [15]:
df.take(1)
Out[15]:
[Row(id=413707, name=u’LUNA PARK INC’, results=u’Fail’, violations=u’24. DISH WASHING FACILITIES: PROPERLY DESIGNED, CONSTRUCTED, MAINTAINED, INSTALLED, LOCATED AND OPERATED – Comments: All dishwashing machines must be of a type that complies with all requirements of the plumbing section of the Municipal Code of Chicago and Rules and Regulation of the Board of Health. OBSEVERD THE 3 COMPARTMENT SINK BACKING UP INTO THE 1ST AND 2ND COMPARTMENT WITH CLEAR WATER AND SLOWLY DRAINING OUT. INST NEED HAVE IT REPAIR. CITATION ISSUED, SERIOUS VIOLATION 7-38-030 H000062369-10 COURT DATE 10-28-10 TIME 1 P.M. ROOM 107 400 W. SURPERIOR. | 36. LIGHTING: REQUIRED MINIMUM FOOT-CANDLES OF LIGHT PROVIDED, FIXTURES SHIELDED – Comments: Shielding to protect against broken glass falling into food shall be provided for all artificial lighting sources in preparation, service, and display facilities. LIGHT SHIELD ARE MISSING UNDER HOOD OF COOKING EQUIPMENT AND NEED TO REPLACE LIGHT UNDER UNIT. 4 LIGHTS ARE OUT IN THE REAR CHILDREN AREA,IN THE KINDERGARDEN CLASS ROOM. 2 LIGHT ARE OUT EAST REAR, LIGHT FRONT WEST ROOM. NEED TO REPLACE ALL LIGHT THAT ARE NOT WORKING. | 35. WALLS, CEILINGS, ATTACHED EQUIPMENT CONSTRUCTED PER CODE: GOOD REPAIR, SURFACES CLEAN AND DUST-LESS CLEANING METHODS – Comments: The walls and ceilings shall be in good repair and easily cleaned. MISSING CEILING TILES WITH STAINS IN WEST,EAST, IN FRONT AREA WEST, AND BY THE 15MOS AREA. NEED TO BE REPLACED. | 32. FOOD AND NON-FOOD CONTACT SURFACES PROPERLY DESIGNED, CONSTRUCTED AND MAINTAINED – Comments: All food and non-food contact equipment and utensils shall be smooth, easily cleanable, and durable, and shall be in good repair. SPLASH GUARDED ARE NEEDED BY THE EXPOSED HAND SINK IN THE KITCHEN AREA | 34. FLOORS: CONSTRUCTED PER CODE, CLEANED, GOOD REPAIR, COVING INSTALLED, DUST-LESS CLEANING METHODS USED – Comments: The floors shall be constructed per code, be smooth and easily cleaned, and be kept clean and in good repair. INST NEED TO ELEVATE ALL FOOD ITEMS 6INCH OFF THE FLOOR 6 INCH AWAY FORM WALL. ‘)]
The output of the above cell gives us an idea of the schema of the input file; the file includes the name of every establishment, the type of establishment, the address, the data of the inspections, and the location, among other things.
Let’s start to get a sense of what our dataset contains. For example, what are the different values in the results column?
In [16]:
df.select(‘results’).distinct().show()
+——————+
| results|
+——————+
| Fail|
|Pass w/ Conditions|
| Pass|
+——————+
In [17]:
df.groupBy(‘results’).count().show()
+——————+—–+
| results|count|
+——————+—–+
| Fail| 2607|
|Pass w/ Conditions| 1028|
| Pass| 6834|
+——————+—–+
Let us develop a model that can guess the outcome of a food inspection, given the violations. Since logistic regression is a binary classification method, it makes sense to group our data into two categories: Fail and Pass. A “Pass w/ Conditions” is still a Pass, so when we train the model, we will consider the two results equivalent.
Let us go ahead and convert our existing dataframe(df) into a new dataframe where each inspection is represented as a label-violations pair. In our case, a label of 0 represents a failure, a label of 1 represents a success.
In [18]:
# The function to clean the data
labeledData = df.select(when(df.results == ‘Fail’, 0)
.when(df.results == ‘Pass’, 1)
.otherwise(1)
.alias(‘label’),
‘violations’)
labeledData.show()
+—–+——————–+
|label| violations|
+—–+——————–+
| 0|24. DISH WASHING …|
| 0|2. FACILITIES TO …|
| 1|33. FOOD AND NON-…|
| 1|32. FOOD AND NON-…|
| 1|19. OUTSIDE GARBA…|
| 1|32. FOOD AND NON-…|
| 0|18. NO EVIDENCE O…|
| 1|41. PREMISES MAIN…|
| 1|4. SOURCE OF CROS…|
| 1|35. WALLS, CEILIN…|
| 1|30. FOOD IN ORIGI…|
| 1|18. NO EVIDENCE O…|
| 1|18. NO EVIDENCE O…|
| 1|30. FOOD IN ORIGI…|
| 1|40. REFRIGERATION…|
| 1|38. VENTILATION: …|
| 1|34. FLOORS: CONST…|
| 1|33. FOOD AND NON-…|
| 0|18. NO EVIDENCE O…|
| 1|33. FOOD AND NON-…|
+—–+——————–+
only showing top 20 rows
Train a logistic regression model from the input dataframe
In [19]:
trainingData, testData = labeledData.randomSplit([0.8, 0.2])
tokenizer = Tokenizer(inputCol=”violations”, outputCol=”words”)
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol=”features”)
lr = LogisticRegression(maxIter=10, regParam=0.01)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
model = pipeline.fit(trainingData)
predictionsDf = model.transform(testData)
predictionsDf.show()
+—–+——————–+——————–+——————–+——————–+——————–+———-+
|label| violations| words| features| rawPrediction| probability|prediction|
+—–+——————–+——————–+——————–+——————–+——————–+———-+
| 0|10. SEWAGE AND WA…|[10., sewage, and…|(262144,[1234,197…|[9.84942348297464…|[0.99994722517389…| 0.0|
| 0|11. ADEQUATE NUMB…|[11., adequate, n…|(262144,[5220,850…|[-0.0513705633838…|[0.48716018264989…| 1.0|
| 0|11. ADEQUATE NUMB…|[11., adequate, n…|(262144,[5220,610…|[-0.3068435754019…|[0.42388537268736…| 1.0|
| 0|12. HAND WASHING …|[12., hand, washi…|(262144,[422,1109…|[5.54757292737684…|[0.99611822306776…| 0.0|
| 0|12. HAND WASHING …|[12., hand, washi…|(262144,[215,1836…|[0.73882932540309…|[0.67673980891158…| 0.0|
| 0|12. HAND WASHING …|[12., hand, washi…|(262144,[1836,187…|[3.56729509687770…|[0.97254305209713…| 0.0|
| 0|14. PREVIOUS SERI…|[14., previous, s…|(262144,[1463,232…|[3.91236242291980…|[0.98039868054890…| 0.0|
| 0|14. PREVIOUS SERI…|[14., previous, s…|(262144,[3298,453…|[-4.2573365150736…|[0.01396226191792…| 1.0|
| 0|16. FOOD PROTECTE…|[16., food, prote…|(262144,[3067,329…|[1.97098950129219…|[0.87771735565373…| 0.0|
| 0|16. FOOD PROTECTE…|[16., food, prote…|(262144,[2786,316…|[6.34254542860340…|[0.99824327447327…| 0.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[4640,522…|[1.28286245451621…|[0.78293663555228…| 0.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[2325,278…|[4.23595524389172…|[0.98574029561866…| 0.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[1972,278…|[3.22877734489762…|[0.96190297324105…| 0.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[1234,197…|[2.35576269900135…|[0.91339118889834…| 0.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[2325,271…|[10.1098074944306…|[0.99995932301775…| 0.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[2786,316…|[-0.0951175279767…|[0.47623853015459…| 1.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[2786,316…|[3.72355228146601…|[0.97642134352393…| 0.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[2786,316…|[2.28258278037262…|[0.90742424303578…| 0.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[1972,278…|[7.37688453255977…|[0.99937484434029…| 0.0|
| 0|18. NO EVIDENCE O…|[18., no, evidenc…|(262144,[1972,278…|[3.54303378754711…|[0.97188771988297…| 0.0|
+—–+——————–+——————–+——————–+——————–+——————–+———-+
only showing top 20 rows
In [20]:
numSuccesses = predictionsDf.where(‘label == prediction’).count()
numInspections = predictionsDf.count()
print (“There were %d inspections and there were %d successful predictions” % (numInspections, numSuccesses))
print(“This is a %d%% success rate” % (float(numSuccesses) / float(numInspections) * 100))
There were 2104 inspections and there were 1826 successful predictions
This is a 86% success rate
Cross-Validation¶
CrossValidator begins by splitting the dataset into a set of folds which are used as separate training and test datasets. E.g., with k=5 folds, CrossValidator will generate 5 (training, test) dataset pairs, each of which uses 4/5 of the data for training and 1/5 for testing. To evaluate a particular ParamMap, CrossValidator computes the average evaluation metric for the 5 Models produced by fitting the Estimator on the 5 different (training, test) dataset pairs.
After identifying the best ParamMap, CrossValidator finally re-fits the Estimator using the best ParamMap and the entire dataset.
In [21]:
# We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
# This will allow us to jointly choose parameters for all Pipeline stages.
# A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
# We use a ParamGridBuilder to construct a grid of parameters to search over.
# With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
# this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
paramGrid = ParamGridBuilder() \
.addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
.addGrid(lr.regParam, [0.1, 0.01]) \
.build()
crossval = CrossValidator(estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=BinaryClassificationEvaluator(),
numFolds=5)
# Run cross-validation, and choose the best set of parameters.
cvModel = crossval.fit(trainingData)
predictionsDf = cvModel.transform(testData)
numSuccesses = predictionsDf.where(‘label == prediction’).count()
numInspections = predictionsDf.count()
print (“There were %d inspections and there were %d successful predictions” % (numInspections, numSuccesses))
print(“This is a %d%% success rate” % (float(numSuccesses) / float(numInspections) * 100))
There were 2104 inspections and there were 1854 successful predictions
This is a 88% success rate
In [22]:
cvModel.explainParams()
Out[22]:
estimator: estimator to be cross-validated (current: Pipeline_4d508b7dc95479bf7897)
estimatorParamMaps: estimator param maps (current: [{Param(parent=u’LogisticRegression_4de3adfa7c180a25b4d2′, name=’regParam’, doc=’regularization parameter (>= 0).’): 0.1, Param(parent=u’HashingTF_4bf9b60fc08483f17273′, name=’numFeatures’, doc=’number of features.’): 10}, {Param(parent=u’LogisticRegression_4de3adfa7c180a25b4d2′, name=’regParam’, doc=’regularization parameter (>= 0).’): 0.1, Param(parent=u’HashingTF_4bf9b60fc08483f17273′, name=’numFeatures’, doc=’number of features.’): 100}, {Param(parent=u’LogisticRegression_4de3adfa7c180a25b4d2′, name=’regParam’, doc=’regularization parameter (>= 0).’): 0.1, Param(parent=u’HashingTF_4bf9b60fc08483f17273′, name=’numFeatures’, doc=’number of features.’): 1000}, {Param(parent=u’LogisticRegression_4de3adfa7c180a25b4d2′, name=’regParam’, doc=’regularization parameter (>= 0).’): 0.01, Param(parent=u’HashingTF_4bf9b60fc08483f17273′, name=’numFeatures’, doc=’number of features.’): 10}, {Param(parent=u’LogisticRegression_4de3adfa7c180a25b4d2′, name=’regParam’, doc=’regularization parameter (>= 0).’): 0.01, Param(parent=u’HashingTF_4bf9b60fc08483f17273′, name=’numFeatures’, doc=’number of features.’): 100}, {Param(parent=u’LogisticRegression_4de3adfa7c180a25b4d2′, name=’regParam’, doc=’regularization parameter (>= 0).’): 0.01, Param(parent=u’HashingTF_4bf9b60fc08483f17273′, name=’numFeatures’, doc=’number of features.’): 1000}])
evaluator: evaluator used to select hyper-parameters that maximize the validator metric (current: BinaryClassificationEvaluator_416488c06eaaa8f8fdc2)
seed: random seed. (default: -4372709618522015412)
In [ ]: