add prediction sample
This commit is contained in:
		
							parent
							
								
									f336ab34a7
								
							
						
					
					
						commit
						9682ef3971
					
				|  | @ -0,0 +1,56 @@ | ||||||
|  | Data Setup: | ||||||
|  | ---------- | ||||||
|  | Before you can run the prediction sample prediction.rb, you must load some csv | ||||||
|  | formatted data into Google Storage. You can do this by running setup.sh with a  | ||||||
|  | bucket/object name of your choice. You must first create the bucket you want to  | ||||||
|  | use. This can be done with the gsutil function or via the web UI (Storage  | ||||||
|  | Access) in the Google APIs Console. i.e.: | ||||||
|  | # chmod 744 setup.sh | ||||||
|  | # ./setup.sh BUCKET/OBJECT | ||||||
|  | Note you need gsutil in your path for this to work. | ||||||
|  | 
 | ||||||
|  | In the script, you must then modify the datafile string. This must correspond with the | ||||||
|  | bucket/object of your dataset (if you are using your own dataset). We have | ||||||
|  | provided a setup.sh which will upload some basic sample data. The section is | ||||||
|  | near the bottom of the script, under 'FILL IN DATAFILE' | ||||||
|  | 
 | ||||||
|  | API setup: | ||||||
|  | --------- | ||||||
|  | We need to allow the application to use your API access. Go to APIs Console | ||||||
|  | https://code.google.com/apis/console, and select the project you want, go to API | ||||||
|  | Access, and create an OAuth2 client if you have not yet already. You should | ||||||
|  | generate a client ID and secret.  | ||||||
|  | 
 | ||||||
|  | This example will run through the server-side example, where the application | ||||||
|  | gets authorization ahead of time, which is the normal use case for Prediction | ||||||
|  | API. You can also set it up so the user can grant access. | ||||||
|  | 
 | ||||||
|  | First, run the google-api script to generate access and refresh tokens. Ex. | ||||||
|  | 
 | ||||||
|  | # cd google-api-ruby-client | ||||||
|  | # ruby-1.9.2-p290  bin/google-api oauth-2-login --scope=https://www.googleapis.com/auth/prediction --client-id=NUMBER.apps.googleusercontent.com --client-secret=CLIENT_SECRET | ||||||
|  | 
 | ||||||
|  | Fill in your client-id and client-secret from the API Access page. You will | ||||||
|  | probably have to set a redirect URI in your client ID | ||||||
|  | (ex. http://localhost:12736/). You can do this by hitting 'Edit settings' in the | ||||||
|  | API Access / Client ID section, and adding it to Authorized Redirect URIs. Not | ||||||
|  | that this has to be exactly the same URI, http://localhost:12736 and | ||||||
|  | http://localhost:12736/ are not the same in this case. | ||||||
|  | 
 | ||||||
|  | This should pop up a browser window, where you grant access. This will then | ||||||
|  | generate a ~/.google-api.yaml file. You have two options here, you can either | ||||||
|  | copy the the information directly in your code, or you can store this as a file | ||||||
|  | and load it in the sample as a yaml. In this example we do the latter. NOTE: if | ||||||
|  | you are loading it as a yaml, ensure you rename/move the file, as the | ||||||
|  | ~/.google-api.yaml file can get overwritten. The script will work as is if you | ||||||
|  | move the .google-api.yaml file to the sample directory. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | This sample currently does not cover some newer features of Prediction API such | ||||||
|  | as streaming training, hosted models or class weights. If there are any | ||||||
|  | questions or suggestions to improve the script please email us at | ||||||
|  | prediction-api-discuss@googlegroups.com. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -0,0 +1,219 @@ | ||||||
|  | #!/usr/bin/ruby1.8  | ||||||
|  | # -*- coding: utf-8 -*- | ||||||
|  | 
 | ||||||
|  | # Copyright:: Copyright 2011 Google Inc. | ||||||
|  | # License:: All Rights Reserved. | ||||||
|  | # Original Author:: Bob Aman, Winton Davies, Robert Kaplow | ||||||
|  | # Maintainer:: Robert Kaplow (mailto:rkaplow@google.com) | ||||||
|  | 
 | ||||||
|  | $:.unshift('lib') | ||||||
|  | require 'rubygems' | ||||||
|  | require 'sinatra' | ||||||
|  | require 'datamapper' | ||||||
|  | require 'google/api_client' | ||||||
|  | require 'yaml' | ||||||
|  | 
 | ||||||
|  | use Rack::Session::Pool, :expire_after => 86400 # 1 day | ||||||
|  | 
 | ||||||
|  | # Set up our token store | ||||||
|  | DataMapper.setup(:default, 'sqlite::memory:') | ||||||
|  | class TokenPair | ||||||
|  |   include DataMapper::Resource | ||||||
|  | 
 | ||||||
|  |   property :id, Serial | ||||||
|  |   property :refresh_token, String | ||||||
|  |   property :access_token, String | ||||||
|  |   property :expires_in, Integer | ||||||
|  |   property :issued_at, Integer | ||||||
|  | 
 | ||||||
|  |   def update_token!(object) | ||||||
|  |     self.refresh_token = object.refresh_token | ||||||
|  |     self.access_token = object.access_token | ||||||
|  |     self.expires_in = object.expires_in | ||||||
|  |     self.issued_at = object.issued_at | ||||||
|  |   end | ||||||
|  | 
 | ||||||
|  |   def to_hash | ||||||
|  |     return { | ||||||
|  |       :refresh_token => refresh_token, | ||||||
|  |       :access_token => access_token, | ||||||
|  |       :expires_in => expires_in, | ||||||
|  |       :issued_at => Time.at(issued_at) | ||||||
|  |     } | ||||||
|  |   end | ||||||
|  | end | ||||||
|  | TokenPair.auto_migrate! | ||||||
|  | 
 | ||||||
|  | before do | ||||||
|  | 
 | ||||||
|  |   # FILL IN THIS SECTION | ||||||
|  |   # This will work if your yaml file is stored as ./google-api.yaml | ||||||
|  |   # ------------------------ | ||||||
|  |   oauth_yaml = YAML.load_file('.google-api.yaml') | ||||||
|  |   @client = Google::APIClient.new | ||||||
|  |   @client.authorization.client_id = oauth_yaml["client_id"] | ||||||
|  |   @client.authorization.client_secret = oauth_yaml["client_secret"] | ||||||
|  |   @client.authorization.scope = oauth_yaml["scope"] | ||||||
|  |   @client.authorization.refresh_token = oauth_yaml["refresh_token"] | ||||||
|  |   @client.authorization.access_token = oauth_yaml["access_token"] | ||||||
|  |   # ----------------------- | ||||||
|  | 
 | ||||||
|  |   @client.authorization.redirect_uri = to('/oauth2callback') | ||||||
|  | 
 | ||||||
|  |   # Workaround for now as expires_in may be nil, but when converted to int it becomes 0. | ||||||
|  |   @client.authorization.expires_in = Time.now + 1800 if @client.authorization.expires_in.to_i == 0 | ||||||
|  | 
 | ||||||
|  |   if session[:token_id] | ||||||
|  |     # Load the access token here if it's available | ||||||
|  |     token_pair = TokenPair.get(session[:token_id]) | ||||||
|  |     @client.authorization.update_token!(token_pair.to_hash) | ||||||
|  |   end | ||||||
|  |   if @client.authorization.refresh_token && @client.authorization.expired? | ||||||
|  |     @client.authorization.fetch_access_token! | ||||||
|  |   end | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |   @prediction = @client.discovered_api('prediction', 'v1.3') | ||||||
|  |   unless @client.authorization.access_token || request.path_info =~ /^\/oauth2/ | ||||||
|  |     redirect to('/oauth2authorize') | ||||||
|  |   end | ||||||
|  | end | ||||||
|  | 
 | ||||||
|  | get '/oauth2authorize' do | ||||||
|  |   redirect @client.authorization.authorization_uri.to_s, 303 | ||||||
|  | end | ||||||
|  | 
 | ||||||
|  | get '/oauth2callback' do | ||||||
|  |   @client.authorization.fetch_access_token! | ||||||
|  |   # Persist the token here | ||||||
|  |   token_pair = if session[:token_id] | ||||||
|  |     TokenPair.get(session[:token_id]) | ||||||
|  |   else | ||||||
|  |     TokenPair.new | ||||||
|  |   end | ||||||
|  |   token_pair.update_token!(@client.authorization) | ||||||
|  |   token_pair.save() | ||||||
|  |   session[:token_id] = token_pair.id | ||||||
|  |   redirect to('/') | ||||||
|  | end | ||||||
|  | 
 | ||||||
|  | get '/' do | ||||||
|  |   # FILL IN DATAFILE: | ||||||
|  |   # ---------------------------------------- | ||||||
|  |   datafile = "BUCKET/OBJECT" | ||||||
|  |   # ---------------------------------------- | ||||||
|  |   # Train a predictive model. | ||||||
|  |   train(datafile) | ||||||
|  |   # Check to make sure the training has completed. | ||||||
|  |   if (is_done?(datafile)) | ||||||
|  |     # Do a prediction. | ||||||
|  |     # FILL IN DESIRED INPUT: | ||||||
|  |     # ------------------------------------------------------------------------------- | ||||||
|  |     prediction,score = get_prediction(datafile, ["Alice noticed with some surprise."]) | ||||||
|  |     # ------------------------------------------------------------------------------- | ||||||
|  | 
 | ||||||
|  |     # We currently just dump the results to output, but you can display them on the page if desired. | ||||||
|  |     puts prediction | ||||||
|  |     puts score     | ||||||
|  |   end | ||||||
|  | end | ||||||
|  | 
 | ||||||
|  | ## | ||||||
|  | # Trains a predictive model. | ||||||
|  | # | ||||||
|  | # @param [String] filename The name of the file in Google Storage. NOTE: this do *not* | ||||||
|  | #                 include the gs:// part. If the Google Storage path is gs://bucket/object, | ||||||
|  | #                 then the correct string is "bucket/object" | ||||||
|  | def train(datafile) | ||||||
|  |   input = "{\"id\" : \"#{datafile}\"}" | ||||||
|  |   puts "training input: #{input}" | ||||||
|  |   status, headers, body = @client.execute(@prediction.training.insert, | ||||||
|  |                                           {}, | ||||||
|  |                                           input, | ||||||
|  |                                           {'Content-Type' => 'application/json'}) | ||||||
|  | end | ||||||
|  | 
 | ||||||
|  | ## | ||||||
|  | # Returns the current training status | ||||||
|  | # | ||||||
|  | # @param [String] filename The name of the file in Google Storage. NOTE: this do *not* | ||||||
|  | #                 include the gs:// part. If the Google Storage path is gs://bucket/object, | ||||||
|  | #                 then the correct string is "bucket/object" | ||||||
|  | # @return [Integer] status The HTTP status code of the training job. | ||||||
|  | def get_training_status(datafile) | ||||||
|  |   status, headers, body = @client.execute(@prediction.training.get, | ||||||
|  |                                           {'data' => datafile}) | ||||||
|  |   return status | ||||||
|  | end | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ## | ||||||
|  | # Checks the training status until a model exists (will loop forever). | ||||||
|  | # | ||||||
|  | # @param [String] filename The name of the file in Google Storage. NOTE: this do *not* | ||||||
|  | #                 include the gs:// part. If the Google Storage path is gs://bucket/object, | ||||||
|  | #                 then the correct string is "bucket/object" | ||||||
|  | # @return [Bool] exists True if model exists and can be used for predictions. | ||||||
|  | 
 | ||||||
|  | def is_done?(datafile) | ||||||
|  |   status = get_training_status(datafile) | ||||||
|  |   while true do | ||||||
|  |     puts "Attempting to check model #{datafile} - Status: #{status} " | ||||||
|  |     return true if status == 200 | ||||||
|  |     sleep 10 | ||||||
|  |     status = get_training_status(datafile) | ||||||
|  |   end | ||||||
|  |   return false | ||||||
|  | end | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ## | ||||||
|  | # Returns the prediction and most most likely class score if categorization. | ||||||
|  | # | ||||||
|  | # @param [String] filename The name of the file in Google Storage. NOTE: this do *not* | ||||||
|  | #                 include the gs:// part. If the Google Storage path is gs://bucket/object, | ||||||
|  | #                 then the correct string is "bucket/object" | ||||||
|  | # @param [List] input_features A list of input features. | ||||||
|  | # | ||||||
|  | # @return [String or Double] prediction The returned prediction, String if categorization, | ||||||
|  | #                            Double if regression | ||||||
|  | # @return [Double] trueclass_score The numeric score of the most likely label. (Categorical only). | ||||||
|  | 
 | ||||||
|  | def get_prediction(datafile,input_features) | ||||||
|  |   # We take the input features and put it in the right input (json) format. | ||||||
|  |   input="{\"input\" : { \"csvInstance\" :  #{input_features}}}" | ||||||
|  |   puts "Prediction Input: #{input}" | ||||||
|  |   status, headers, body = @client.execute(@prediction.training.predict, | ||||||
|  |                                                      {'data' => datafile}, | ||||||
|  |                                                      input, | ||||||
|  |                                                      {'Content-Type' => 'application/json'}) | ||||||
|  |   prediction_data = JSON.parse(body[0]) | ||||||
|  |    | ||||||
|  |   # Categorical | ||||||
|  |   if prediction_data["outputLabel"] != nil | ||||||
|  |     # Pull the most likely label. | ||||||
|  |     prediction = prediction_data["outputLabel"] | ||||||
|  |     # Pull the class probabilities. | ||||||
|  |     probs = prediction_data["outputMulti"] | ||||||
|  |     puts probs | ||||||
|  |     # Verify we are getting a value result. | ||||||
|  |     puts ["ERROR", input_features].join("\t")  if probs.nil? | ||||||
|  |     return "error", -1.0 if probs.nil? | ||||||
|  | 
 | ||||||
|  |     # Extract the score for the most likely class. | ||||||
|  |     trueclass_score = probs.select{|hash| | ||||||
|  |       hash["label"] ==  prediction | ||||||
|  |     }[0]["score"] | ||||||
|  | 
 | ||||||
|  |     # Regression. | ||||||
|  |   else | ||||||
|  |     prediction = prediction_data["outputValue"] | ||||||
|  |     # Class core unused. | ||||||
|  |     trueclass_score = -1 | ||||||
|  |   end | ||||||
|  | 
 | ||||||
|  |   puts [prediction,trueclass_score,input_features].join("\t")  | ||||||
|  |   return prediction,trueclass_score | ||||||
|  | end | ||||||
|  | 
 | ||||||
|  | @ -0,0 +1,16 @@ | ||||||
|  | #!/bin/bash | ||||||
|  | # | ||||||
|  | # Copyright 2011 Google Inc. All Rights Reserved. | ||||||
|  | # Author: rkaplow@google.com (Robert Kaplow) | ||||||
|  | # | ||||||
|  | # Uploads a training data set to Google Storage to be used by this sample | ||||||
|  | # application.  | ||||||
|  | # | ||||||
|  | # Usage: | ||||||
|  | # setup.sh bucket/object  | ||||||
|  | # | ||||||
|  | # Requirements: | ||||||
|  | #   gsutil - a client application for interacting with Google Storage. It | ||||||
|  | #     can be downloaded from https://code.google.com/apis/storage/docs/gsutil.html | ||||||
|  | OBJECT_NAME=$1 | ||||||
|  | gsutil cp language_id.txt gs://$OBJECT_NAME | ||||||
		Loading…
	
		Reference in New Issue