FileDocCategorySizeDatePackage
StochasticLinearRanker.javaAPI DocAndroid 5.1 API6932Thu Mar 12 22:22:48 GMT 2015android.bordeaux.learning

StochasticLinearRanker.java

/*
 * Copyright (C) 2011 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */


package android.bordeaux.learning;

import android.util.Log;

import java.io.Serializable;
import java.util.List;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score
 * can be used to compare samples.
 * This java class wraps the native StochasticLinearRanker class.
 * To update the ranker, call updateClassifier with two samples, with the first
 * one having higher rank than the second one.
 * To get the rank score of the sample call scoreSample.
 *  TODO: adding more interfaces for changing the learning parameters
 */
public class StochasticLinearRanker {
    String TAG = "StochasticLinearRanker";
    public static int VAR_NUM = 14;
    static public class Model implements Serializable {
        public HashMap<String, Float> weights = new HashMap<String, Float>();
        public float weightNormalizer = 1;
        public HashMap<String, String> parameters = new HashMap<String, String>();
    }

    /**
     * Initializing a ranker
     */
    public StochasticLinearRanker() {
        mNativeClassifier = initNativeClassifier();
    }

    /**
     * Reset the ranker
     */
    public void resetRanker(){
        deleteNativeClassifier(mNativeClassifier);
        mNativeClassifier = initNativeClassifier();
    }

    /**
     * Train the ranker with a pair of samples. A sample,  a pair of arrays of
     * keys and values. The first sample should have higher rank than the second
     * one.
     */
    public boolean updateClassifier(String[] keys_positive,
                                    float[] values_positive,
                                    String[] keys_negative,
                                    float[] values_negative) {
        return nativeUpdateClassifier(keys_positive, values_positive,
                                      keys_negative, values_negative,
                                      mNativeClassifier);
    }

    /**
     * Get the rank score of the sample, a sample is a list of key, value pairs.
     */
    public float scoreSample(String[] keys, float[] values) {
        return nativeScoreSample(keys, values, mNativeClassifier);
    }

    /**
     * Get the current model and parameters of ranker
     */
    public Model getUModel(){
        Model slrModel = new Model();
        int len = nativeGetLengthClassifier(mNativeClassifier);
        String[] wKeys = new String[len];
        float[] wValues = new float[len];
        float wNormalizer = 1;
        nativeGetWeightClassifier(wKeys, wValues, wNormalizer, mNativeClassifier);
        slrModel.weightNormalizer = wNormalizer;
        for (int  i=0; i< wKeys.length ; i++)
            slrModel.weights.put(wKeys[i], wValues[i]);

        String[] paramKeys = new String[VAR_NUM];
        String[] paramValues = new String[VAR_NUM];
        nativeGetParameterClassifier(paramKeys, paramValues, mNativeClassifier);
        for (int  i=0; i< paramKeys.length ; i++)
            slrModel.parameters.put(paramKeys[i], paramValues[i]);
        return slrModel;
    }

    /**
     * load the given model and parameters to the ranker
     */
    public boolean loadModel(Model model) {
        String[] wKeys = new String[model.weights.size()];
        float[] wValues = new float[model.weights.size()];
        int i = 0 ;
        for (Map.Entry<String, Float> e : model.weights.entrySet()){
            wKeys[i] = e.getKey();
            wValues[i] = e.getValue();
            i++;
        }
        boolean res = setModelWeights(wKeys, wValues, model.weightNormalizer);
        if (!res)
            return false;

        for (Map.Entry<String, String> e : model.parameters.entrySet()){
            res = setModelParameter(e.getKey(), e.getValue());
            if (!res)
                return false;
        }
        return res;
    }

    public boolean setModelWeights(String[] keys, float [] values, float normalizer){
        return nativeSetWeightClassifier(keys, values, normalizer, mNativeClassifier);
    }

    public boolean setModelParameter(String key, String value){
        boolean res = nativeSetParameterClassifier(key, value, mNativeClassifier);
        return res;
    }

    /**
     * Print a model for debugging
     */
    public void print(Model model){
        String Sw = "";
        String Sp = "";
        for (Map.Entry<String, Float> e : model.weights.entrySet())
            Sw = Sw + "<" + e.getKey() + "," + e.getValue() + "> ";
        for (Map.Entry<String, String> e : model.parameters.entrySet())
            Sp = Sp + "<" + e.getKey() + "," + e.getValue() + "> ";
        Log.i(TAG, "Weights are " + Sw);
        Log.i(TAG, "Normalizer is " + model.weightNormalizer);
        Log.i(TAG, "Parameters are " + Sp);
    }

    @Override
    protected void finalize() throws Throwable {
        deleteNativeClassifier(mNativeClassifier);
    }

    static {
        System.loadLibrary("bordeaux");
    }

    private long mNativeClassifier;

    /*
     * The following methods are the java stubs for the jni implementations.
     */
    private native long initNativeClassifier();

    private native void deleteNativeClassifier(long classifierPtr);

    private native boolean nativeUpdateClassifier(
            String[] keys_positive,
            float[] values_positive,
            String[] keys_negative,
            float[] values_negative,
            long classifierPtr);

    private native float nativeScoreSample(String[] keys, float[] values, long classifierPtr);

    private native void nativeGetWeightClassifier(String [] keys, float[] values, float normalizer,
                                                  long classifierPtr);

    private native void nativeGetParameterClassifier(String [] keys, String[] values,
                                                  long classifierPtr);

    private native int nativeGetLengthClassifier(long classifierPtr);

    private native boolean nativeSetWeightClassifier(String [] keys, float[] values,
                                                     float normalizer, long classifierPtr);

    private native boolean nativeSetParameterClassifier(String key, String value,
                                                        long classifierPtr);
}