FileDocCategorySizeDatePackage
HistogramPredictor.javaAPI DocAndroid 5.1 API14400Thu Mar 12 22:22:48 GMT 2015android.bordeaux.learning

HistogramPredictor.java

/*
 * Copyright (C) 2011 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.bordeaux.learning;

import android.util.Log;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
/**
 * A histogram based predictor which records co-occurrences of applations with a speficic
 * feature, for example, location, * time of day, etc. The histogram is kept in a two level
 * hash table. The first level key is the feature value and the second level key is the app
 * id.
 */
// TODOS:
// 1. Use forgetting factor to downweight istances propotional to the time
// 2. Different features could have different weights on prediction scores.
// 3. Add function to remove sampleid (i.e. remove apps that are uninstalled).


public class HistogramPredictor {
    final static String TAG = "HistogramPredictor";

    private HashMap<String, HistogramCounter> mPredictor =
            new HashMap<String, HistogramCounter>();

    private HashMap<String, Integer> mClassCounts = new HashMap<String, Integer>();
    private HashSet<String> mBlacklist = new HashSet<String>();

    private static final int MINIMAL_FEATURE_VALUE_COUNTS = 5;
    private static final int MINIMAL_APP_APPEARANCE_COUNTS = 5;

    // This parameter ranges from 0 to 1 which determines the effect of app prior.
    // When it is set to 0, app prior means completely neglected. When it is set to 1
    // the predictor is a standard naive bayes model.
    private static final int PRIOR_K_VALUE = 1;

    private static final String[] APP_BLACKLIST = {
        "com.android.contacts",
        "com.android.chrome",
        "com.android.providers.downloads.ui",
        "com.android.settings",
        "com.android.vending",
        "com.android.mms",
        "com.google.android.gm",
        "com.google.android.gallery3d",
        "com.google.android.apps.googlevoice",
    };

    public HistogramPredictor(String[] blackList) {
        for (String appName : blackList) {
            mBlacklist.add(appName);
        }
    }

    /*
     * This class keeps the histogram counts for each feature and provide the
     * joint probabilities of <feature, class>.
     */
    private class HistogramCounter {
        private HashMap<String, HashMap<String, Integer> > mCounter =
                new HashMap<String, HashMap<String, Integer> >();

        public HistogramCounter() {
            mCounter.clear();
        }

        public void setCounter(HashMap<String, HashMap<String, Integer> > counter) {
            resetCounter();
            mCounter.putAll(counter);
        }

        public void resetCounter() {
            mCounter.clear();
        }

        public void addSample(String className, String featureValue) {
            HashMap<String, Integer> classCounts;

            if (!mCounter.containsKey(featureValue)) {
                classCounts = new HashMap<String, Integer>();
                mCounter.put(featureValue, classCounts);
            } else {
                classCounts = mCounter.get(featureValue);
            }
            int count = (classCounts.containsKey(className)) ?
                    classCounts.get(className) + 1 : 1;
            classCounts.put(className, count);
        }

        public HashMap<String, Double> getClassScores(String featureValue) {
            HashMap<String, Double> classScores = new HashMap<String, Double>();

            if (mCounter.containsKey(featureValue)) {
                int totalCount = 0;
                for(Map.Entry<String, Integer> entry :
                        mCounter.get(featureValue).entrySet()) {
                    String app = entry.getKey();
                    int count = entry.getValue();

                    // For apps with counts less than or equal to one, we treated
                    // those as having count one. Hence their score, i.e. log(count)
                    // would be zero. classScroes stores only apps with non-zero scores.
                    // Note that totalCount also neglect app with single occurrence.
                    if (count > 1) {
                        double score = Math.log((double) count);
                        classScores.put(app, score);
                        totalCount += count;
                    }
                }
                if (totalCount < MINIMAL_FEATURE_VALUE_COUNTS) {
                    classScores.clear();
                }
            }
            return classScores;
        }

        public byte[] getModel() {
            try {
                ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
                ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
                synchronized(mCounter) {
                    objStream.writeObject(mCounter);
                }
                byte[] bytes = byteStream.toByteArray();
                return bytes;
            } catch (IOException e) {
                throw new RuntimeException("Can't get model");
            }
        }

        public boolean setModel(final byte[] modelData) {
            mCounter.clear();
            HashMap<String, HashMap<String, Integer> > model;

            try {
                ByteArrayInputStream input = new ByteArrayInputStream(modelData);
                ObjectInputStream objStream = new ObjectInputStream(input);
                model = (HashMap<String, HashMap<String, Integer> >) objStream.readObject();
            } catch (IOException e) {
                throw new RuntimeException("Can't load model");
            } catch (ClassNotFoundException e) {
                throw new RuntimeException("Learning class not found");
            }

            synchronized(mCounter) {
                mCounter.putAll(model);
            }

            return true;
        }


        public HashMap<String, HashMap<String, Integer> > getCounter() {
            return mCounter;
        }

        public String toString() {
            String result = "";
            for (Map.Entry<String, HashMap<String, Integer> > entry :
                     mCounter.entrySet()) {
                result += "{ " + entry.getKey() + " : " +
                    entry.getValue().toString() + " }";
            }
            return result;
        }
    }

    /*
     * Given a map of feature name -value pairs returns topK mostly likely apps to
     * be launched with corresponding likelihoods. If topK is set zero, it will return
     * the whole list.
     */
    public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) {
        // Most sophisticated function in this class
        HashMap<String, Double> appScores = new HashMap<String, Double>();
        int validFeatureCount = 0;

        // compute all app scores
        for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
            String featureName = entry.getKey();
            HistogramCounter counter = entry.getValue();

            if (features.containsKey(featureName)) {
                String featureValue = features.get(featureName);
                HashMap<String, Double> scoreMap = counter.getClassScores(featureValue);

                if (scoreMap.isEmpty()) {
                  continue;
                }
                validFeatureCount++;

                for (Map.Entry<String, Double> item : scoreMap.entrySet()) {
                    String appName = item.getKey();
                    double appScore = item.getValue();
                    if (appScores.containsKey(appName)) {
                        appScore += appScores.get(appName);
                    }
                    appScores.put(appName, appScore);
                }
            }
        }

        HashMap<String, Double> appCandidates = new HashMap<String, Double>();
        for (Map.Entry<String, Double> entry : appScores.entrySet()) {
            String appName = entry.getKey();
            if (mBlacklist.contains(appName)) {
                Log.i(TAG, appName + " is in blacklist");
                continue;
            }
            if (!mClassCounts.containsKey(appName)) {
                throw new RuntimeException("class count error!");
            }
            int appCount = mClassCounts.get(appName);
            if (appCount < MINIMAL_APP_APPEARANCE_COUNTS) {
                Log.i(TAG, appName + " doesn't have enough counts");
                continue;
            }

            double appScore = entry.getValue();
            double appPrior = Math.log((double) appCount);
            appCandidates.put(appName,
                              appScore - appPrior * (validFeatureCount - PRIOR_K_VALUE));
        }

        // sort app scores
        List<Map.Entry<String, Double> > appList =
               new ArrayList<Map.Entry<String, Double> >(appCandidates.size());
        appList.addAll(appCandidates.entrySet());
        Collections.sort(appList, new  Comparator<Map.Entry<String, Double> >() {
            public int compare(Map.Entry<String, Double> o1,
                               Map.Entry<String, Double> o2) {
                return o2.getValue().compareTo(o1.getValue());
            }
        });

        if (topK == 0) {
            topK = appList.size();
        }
        return appList.subList(0, Math.min(topK, appList.size()));
    }

    /*
     * Add a new observation of given sample id and features to the histograms
     */
    public void addSample(String sampleId, Map<String, String> features) {
        for (Map.Entry<String, String> entry : features.entrySet()) {
            String featureName = entry.getKey();
            String featureValue = entry.getValue();

            useFeature(featureName);
            HistogramCounter counter = mPredictor.get(featureName);
            counter.addSample(sampleId, featureValue);
        }

        int sampleCount = (mClassCounts.containsKey(sampleId)) ?
            mClassCounts.get(sampleId) + 1 : 1;
        mClassCounts.put(sampleId, sampleCount);
    }

    /*
     * reset predictor to a empty model
     */
    public void resetPredictor() {
        // TODO: not sure this step would reduce memory waste
        for (HistogramCounter counter : mPredictor.values()) {
            counter.resetCounter();
        }
        mPredictor.clear();
        mClassCounts.clear();
    }

    /*
     * convert the prediction model into a byte array
     */
    public byte[] getModel() {
        // TODO: convert model to a more memory efficient data structure.
        HashMap<String, HashMap<String, HashMap<String, Integer > > > model =
                new HashMap<String, HashMap<String, HashMap<String, Integer > > >();
        for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
            model.put(entry.getKey(), entry.getValue().getCounter());
        }

        try {
            ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
            ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
            objStream.writeObject(model);
            byte[] bytes = byteStream.toByteArray();
            return bytes;
        } catch (IOException e) {
            throw new RuntimeException("Can't get model");
        }
    }

    /*
     * set the prediction model from a model data in the format of byte array
     */
    public boolean setModel(final byte[] modelData) {
        HashMap<String, HashMap<String, HashMap<String, Integer > > > model;

        try {
            ByteArrayInputStream input = new ByteArrayInputStream(modelData);
            ObjectInputStream objStream = new ObjectInputStream(input);
            model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >)
                    objStream.readObject();
        } catch (IOException e) {
            throw new RuntimeException("Can't load model");
        } catch (ClassNotFoundException e) {
            throw new RuntimeException("Learning class not found");
        }

        resetPredictor();
        for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry :
                model.entrySet()) {
            useFeature(entry.getKey());
            mPredictor.get(entry.getKey()).setCounter(entry.getValue());
        }

        // TODO: this is a temporary fix for now
        loadClassCounter();

        return true;
    }

    private void loadClassCounter() {
        String TIME_OF_WEEK = "Time of Week";

        if (!mPredictor.containsKey(TIME_OF_WEEK)) {
            throw new RuntimeException("Precition model error: missing Time of Week!");
        }

        HashMap<String, HashMap<String, Integer> > counter =
            mPredictor.get(TIME_OF_WEEK).getCounter();

        mClassCounts.clear();
        for (HashMap<String, Integer> map : counter.values()) {
            for (Map.Entry<String, Integer> entry : map.entrySet()) {
                int classCount = entry.getValue();
                String className = entry.getKey();
                // mTotalClassCount += classCount;

                if (mClassCounts.containsKey(className)) {
                    classCount += mClassCounts.get(className);
                }
                mClassCounts.put(className, classCount);
            }
        }
        Log.i(TAG, "class counts: " + mClassCounts);
    }

    private void useFeature(String featureName) {
        if (!mPredictor.containsKey(featureName)) {
            mPredictor.put(featureName, new HistogramCounter());
        }
    }
}