Overhauled the prediction sample and updated to v1.4.
This commit is contained in:
parent
2b746cb379
commit
1c300f091f
File diff suppressed because it is too large
Load Diff
|
@ -2,7 +2,7 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# Copyright:: Copyright 2011 Google Inc.
|
# Copyright:: Copyright 2011 Google Inc.
|
||||||
# License:: All Rights Reserved.
|
# License:: Apache 2.0
|
||||||
# Original Author:: Bob Aman, Winton Davies, Robert Kaplow
|
# Original Author:: Bob Aman, Winton Davies, Robert Kaplow
|
||||||
# Maintainer:: Robert Kaplow (mailto:rkaplow@google.com)
|
# Maintainer:: Robert Kaplow (mailto:rkaplow@google.com)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ require 'datamapper'
|
||||||
require 'google/api_client'
|
require 'google/api_client'
|
||||||
require 'yaml'
|
require 'yaml'
|
||||||
|
|
||||||
use Rack::Session::Pool, :expire_after => 86400 # 1 day
|
enable :sessions
|
||||||
|
|
||||||
# Set up our token store
|
# Set up our token store
|
||||||
DataMapper.setup(:default, 'sqlite::memory:')
|
DataMapper.setup(:default, 'sqlite::memory:')
|
||||||
|
@ -20,8 +20,8 @@ class TokenPair
|
||||||
include DataMapper::Resource
|
include DataMapper::Resource
|
||||||
|
|
||||||
property :id, Serial
|
property :id, Serial
|
||||||
property :refresh_token, String
|
property :refresh_token, String, :length => 255
|
||||||
property :access_token, String
|
property :access_token, String, :length => 255
|
||||||
property :expires_in, Integer
|
property :expires_in, Integer
|
||||||
property :issued_at, Integer
|
property :issued_at, Integer
|
||||||
|
|
||||||
|
@ -43,10 +43,32 @@ class TokenPair
|
||||||
end
|
end
|
||||||
TokenPair.auto_migrate!
|
TokenPair.auto_migrate!
|
||||||
|
|
||||||
before do
|
def save_token_pair(session, client)
|
||||||
|
token_pair = if session[:token_id]
|
||||||
|
TokenPair.first_or_create(:id => session[:token_id])
|
||||||
|
else
|
||||||
|
TokenPair.new
|
||||||
|
end
|
||||||
|
token_pair.update_token!(client.authorization)
|
||||||
|
if token_pair.save
|
||||||
|
session[:token_id] = token_pair.id
|
||||||
|
else
|
||||||
|
token_pair.errors.each do |e|
|
||||||
|
raise e
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# FILL IN THIS SECTION
|
||||||
|
# This is the name of the {bucket}/{object} you are using for the language
|
||||||
|
# file.
|
||||||
|
# ------------------------
|
||||||
|
DATA_OBJECT = "bucket/object"
|
||||||
|
# ------------------------
|
||||||
|
|
||||||
|
before do
|
||||||
# FILL IN THIS SECTION
|
# FILL IN THIS SECTION
|
||||||
# This will work if your yaml file is stored as ./google-api.yaml
|
# This will work if your yaml file is stored as .google-api.yaml
|
||||||
# ------------------------
|
# ------------------------
|
||||||
oauth_yaml = YAML.load_file('.google-api.yaml')
|
oauth_yaml = YAML.load_file('.google-api.yaml')
|
||||||
@client = Google::APIClient.new
|
@client = Google::APIClient.new
|
||||||
|
@ -59,20 +81,17 @@ before do
|
||||||
|
|
||||||
@client.authorization.redirect_uri = to('/oauth2callback')
|
@client.authorization.redirect_uri = to('/oauth2callback')
|
||||||
|
|
||||||
# Workaround for now as expires_in may be nil, but when converted to int it becomes 0.
|
|
||||||
@client.authorization.expires_in = 1800 if @client.authorization.expires_in.to_i == 0
|
|
||||||
|
|
||||||
if session[:token_id]
|
if session[:token_id]
|
||||||
# Load the access token here if it's available
|
# Load the access token here if it's available
|
||||||
token_pair = TokenPair.get(session[:token_id])
|
token_pair = TokenPair.get(session[:token_id])
|
||||||
@client.authorization.update_token!(token_pair.to_hash)
|
@client.authorization.update_token!(token_pair.to_hash) if token_pair
|
||||||
end
|
end
|
||||||
if @client.authorization.refresh_token && @client.authorization.expired?
|
if @client.authorization.refresh_token && @client.authorization.expired?
|
||||||
@client.authorization.fetch_access_token!
|
@client.authorization.fetch_access_token!
|
||||||
|
save_token_pair(session, @client)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@prediction = @client.discovered_api('prediction', 'v1.4')
|
||||||
@prediction = @client.discovered_api('prediction', 'v1.3')
|
|
||||||
unless @client.authorization.access_token || request.path_info =~ /^\/oauth2/
|
unless @client.authorization.access_token || request.path_info =~ /^\/oauth2/
|
||||||
redirect to('/oauth2authorize')
|
redirect to('/oauth2authorize')
|
||||||
end
|
end
|
||||||
|
@ -84,144 +103,80 @@ end
|
||||||
|
|
||||||
get '/oauth2callback' do
|
get '/oauth2callback' do
|
||||||
@client.authorization.fetch_access_token!
|
@client.authorization.fetch_access_token!
|
||||||
# Persist the token here
|
save_token_pair(session, @client)
|
||||||
token_pair = if session[:token_id]
|
|
||||||
TokenPair.get(session[:token_id])
|
|
||||||
else
|
|
||||||
TokenPair.new
|
|
||||||
end
|
|
||||||
token_pair.update_token!(@client.authorization)
|
|
||||||
token_pair.save()
|
|
||||||
session[:token_id] = token_pair.id
|
|
||||||
redirect to('/')
|
redirect to('/')
|
||||||
end
|
end
|
||||||
|
|
||||||
get '/' do
|
get '/' do
|
||||||
# FILL IN DATAFILE:
|
erb :index
|
||||||
# ----------------------------------------
|
|
||||||
datafile = "BUCKET/OBJECT"
|
|
||||||
# ----------------------------------------
|
|
||||||
# Train a predictive model.
|
|
||||||
train(datafile)
|
|
||||||
# Check to make sure the training has completed.
|
|
||||||
if (is_done?(datafile))
|
|
||||||
# Do a prediction.
|
|
||||||
# FILL IN DESIRED INPUT:
|
|
||||||
# -------------------------------------------------------------------------------
|
|
||||||
# Note, the input features should match the features of the dataset.
|
|
||||||
prediction,score = get_prediction(datafile, ["Alice noticed with some surprise."])
|
|
||||||
# -------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# We currently just dump the results to output, but you can display them on the page if desired.
|
|
||||||
puts prediction
|
|
||||||
puts score
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
##
|
get '/train' do
|
||||||
# Trains a predictive model.
|
training = @prediction.trainedmodels.insert.request_schema.new
|
||||||
#
|
training.id = 'language-sample'
|
||||||
# @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
|
training.storage_data_location = DATA_OBJECT
|
||||||
# include the gs:// part. If the Google Storage path is gs://bucket/object,
|
result = @client.execute(
|
||||||
# then the correct string is "bucket/object"
|
:api_method => @prediction.trainedmodels.insert,
|
||||||
def train(datafile)
|
:headers => {'Content-Type' => 'application/json'},
|
||||||
input = "{\"id\" : \"#{datafile}\"}"
|
:body_object => training
|
||||||
puts "training input: #{input}"
|
)
|
||||||
result = @client.execute(:api_method => @prediction.training.insert,
|
|
||||||
:merged_body => input,
|
|
||||||
:headers => {'Content-Type' => 'application/json'}
|
|
||||||
)
|
|
||||||
status, headers, body = result.response
|
|
||||||
end
|
end
|
||||||
|
|
||||||
##
|
get '/checkStatus' do
|
||||||
# Returns the current training status
|
result = @client.execute(
|
||||||
#
|
:api_method => @prediction.trainedmodels.get,
|
||||||
# @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
|
:parameters => {'id' => 'language-sample'}
|
||||||
# include the gs:// part. If the Google Storage path is gs://bucket/object,
|
)
|
||||||
# then the correct string is "bucket/object"
|
|
||||||
# @return [Integer] status The HTTP status code of the training job.
|
|
||||||
def get_training_status(datafile)
|
|
||||||
result = @client.execute(:api_method => @prediction.training.get,
|
|
||||||
:parameters => {'data' => datafile})
|
|
||||||
status, headers, body = result.response
|
|
||||||
return status
|
|
||||||
end
|
|
||||||
|
|
||||||
|
# Assemble some JSON our client-side code can work with.
|
||||||
##
|
json = {}
|
||||||
# Checks the training status until a model exists (will loop forever).
|
if result.status != 200
|
||||||
#
|
if result.data["error"]
|
||||||
# @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
|
message = result.data["error"]["errors"].first["message"]
|
||||||
# include the gs:// part. If the Google Storage path is gs://bucket/object,
|
json["message"] = "#{message} [#{result.status}]"
|
||||||
# then the correct string is "bucket/object"
|
else
|
||||||
# @return [Bool] exists True if model exists and can be used for predictions.
|
json["message"] = "Error. [#{result.status}]"
|
||||||
|
end
|
||||||
def is_done?(datafile)
|
json["response"] = ::JSON.parse(result.body)
|
||||||
status = get_training_status(datafile)
|
json["status"] = "error"
|
||||||
# We use an exponential backoff approach here.
|
|
||||||
test_counter = 0
|
|
||||||
while test_counter < 10 do
|
|
||||||
puts "Attempting to check model #{datafile} - Status: #{status} "
|
|
||||||
return true if status == 200
|
|
||||||
sleep 5 * (test_counter + 1)
|
|
||||||
status = get_training_status(datafile)
|
|
||||||
test_counter += 1
|
|
||||||
end
|
|
||||||
return false
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
##
|
|
||||||
# Returns the prediction and most most likely class score if categorization.
|
|
||||||
#
|
|
||||||
# @param [String] filename The name of the file in Google Storage. NOTE: this do *not*
|
|
||||||
# include the gs:// part. If the Google Storage path is gs://bucket/object,
|
|
||||||
# then the correct string is "bucket/object"
|
|
||||||
# @param [List] input_features A list of input features.
|
|
||||||
#
|
|
||||||
# @return [String or Double] prediction The returned prediction, String if categorization,
|
|
||||||
# Double if regression
|
|
||||||
# @return [Double] trueclass_score The numeric score of the most likely label. (Categorical only).
|
|
||||||
|
|
||||||
def get_prediction(datafile,input_features)
|
|
||||||
# We take the input features and put it in the right input (json) format.
|
|
||||||
input="{\"input\" : { \"csvInstance\" : #{input_features}}}"
|
|
||||||
puts "Prediction Input: #{input}"
|
|
||||||
result = @client.execute(:api_method => @prediction.training.predict,
|
|
||||||
:parameters => {'data' => datafile},
|
|
||||||
:merged_body => input,
|
|
||||||
:headers => {'Content-Type' => 'application/json'})
|
|
||||||
status, headers, body = result.response
|
|
||||||
prediction_data = result.data
|
|
||||||
puts status
|
|
||||||
puts body
|
|
||||||
puts prediction_data
|
|
||||||
# Categorical
|
|
||||||
if prediction_data["outputLabel"] != nil
|
|
||||||
# Pull the most likely label.
|
|
||||||
prediction = prediction_data["outputLabel"]
|
|
||||||
# Pull the class probabilities.
|
|
||||||
probs = prediction_data["outputMulti"]
|
|
||||||
puts probs
|
|
||||||
# Verify we are getting a value result.
|
|
||||||
puts ["ERROR", input_features].join("\t") if probs.nil?
|
|
||||||
return "error", -1.0 if probs.nil?
|
|
||||||
|
|
||||||
# Extract the score for the most likely class.
|
|
||||||
trueclass_score = probs.select{|hash|
|
|
||||||
hash["label"] == prediction
|
|
||||||
}[0]["score"]
|
|
||||||
|
|
||||||
# Regression.
|
|
||||||
else
|
else
|
||||||
prediction = prediction_data["outputValue"]
|
json["response"] = ::JSON.parse(result.body)
|
||||||
# Class core unused.
|
json["status"] = "success"
|
||||||
trueclass_score = -1
|
|
||||||
end
|
end
|
||||||
|
return [
|
||||||
puts [prediction,trueclass_score,input_features].join("\t")
|
200,
|
||||||
return prediction,trueclass_score
|
[["Content-Type", "application/json"]],
|
||||||
|
::JSON.generate(json)
|
||||||
|
]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
post '/predict' do
|
||||||
|
input = @prediction.trainedmodels.predict.request_schema.new
|
||||||
|
input.input = {}
|
||||||
|
input.input.csv_instance = [params["input"]]
|
||||||
|
result = @client.execute(
|
||||||
|
:api_method => @prediction.trainedmodels.predict,
|
||||||
|
:parameters => {'id' => 'language-sample'},
|
||||||
|
:headers => {'Content-Type' => 'application/json'},
|
||||||
|
:body_object => input
|
||||||
|
)
|
||||||
|
json = {}
|
||||||
|
if result.status != 200
|
||||||
|
if result.data["error"]
|
||||||
|
message = result.data["error"]["errors"].first["message"]
|
||||||
|
json["message"] = "#{message} [#{result.status}]"
|
||||||
|
else
|
||||||
|
json["message"] = "Error. [#{result.status}]"
|
||||||
|
end
|
||||||
|
json["response"] = ::JSON.parse(result.body)
|
||||||
|
json["status"] = "error"
|
||||||
|
else
|
||||||
|
json["response"] = ::JSON.parse(result.body)
|
||||||
|
json["status"] = "success"
|
||||||
|
end
|
||||||
|
return [
|
||||||
|
200,
|
||||||
|
[["Content-Type", "application/json"]],
|
||||||
|
::JSON.generate(json)
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta http-equiv="Content-Type" content="text/html; charset=utf-8">
|
||||||
|
<title>Prediction API</title>
|
||||||
|
<style type="text/css">
|
||||||
|
body {
|
||||||
|
font-family: Arial, Helvetica, sans-serif;
|
||||||
|
}
|
||||||
|
#log {
|
||||||
|
font-family: monospace;
|
||||||
|
background-color: #eee;
|
||||||
|
padding: 1em;
|
||||||
|
}
|
||||||
|
#log p {
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
#predict {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
#predict label, #predict textarea, #predict button {
|
||||||
|
margin: 1em 0;
|
||||||
|
font-size: 1em;
|
||||||
|
display: block;
|
||||||
|
width: 50%;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Prediction API: Language Sample</h1>
|
||||||
|
<div id="log">
|
||||||
|
</div>
|
||||||
|
<div id="predict">
|
||||||
|
<label for="input">Input</label>
|
||||||
|
<textarea id="input" placeholder="Généralement, les gens qui savant peu parlent beaucoup, et les gens qui savant beaucoup parlent peu."></textarea>
|
||||||
|
<button id="go">Submit</button>
|
||||||
|
</div>
|
||||||
|
<script src="//ajax.googleapis.com/ajax/libs/jquery/1.6.2/jquery.min.js"></script>
|
||||||
|
<script type="text/javascript">
|
||||||
|
function logMessage(message) {
|
||||||
|
$("#log").append("<p>" + message + "</p>");
|
||||||
|
}
|
||||||
|
$(document).ready(function(e) {
|
||||||
|
$.getJSON("/train", function (data) {
|
||||||
|
logMessage("Training started...");
|
||||||
|
var delay = 1000;
|
||||||
|
var checkStatus = function () {
|
||||||
|
logMessage("Checking training status...");
|
||||||
|
$.getJSON("/checkStatus", function(data) {
|
||||||
|
if (data && data.status == 'success') {
|
||||||
|
logMessage("Training complete.");
|
||||||
|
$("#predict").show();
|
||||||
|
$("#go").click(function () {
|
||||||
|
var input = $("#input").val();
|
||||||
|
$.ajax({
|
||||||
|
type: "POST",
|
||||||
|
url: "/predict",
|
||||||
|
data: {"input": input},
|
||||||
|
success: function(data) {
|
||||||
|
if (data && data.status == 'success') {
|
||||||
|
logMessage("Predicted label: " + data.response.outputLabel);
|
||||||
|
} else if (data && data.message) {
|
||||||
|
logMessage(data.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
} else if (data && data.message) {
|
||||||
|
logMessage(data.message);
|
||||||
|
}
|
||||||
|
delay = delay * 2;
|
||||||
|
if (delay > 30000) {
|
||||||
|
// Upper maximum delay.
|
||||||
|
delay = 30000;
|
||||||
|
}
|
||||||
|
logMessage("Checking again in " + (delay / 1000) + " seconds.");
|
||||||
|
setTimeout(checkStatus, delay);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
setTimeout(checkStatus, delay);
|
||||||
|
});
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
Loading…
Reference in New Issue