From 11be9a6038b78ee6bcea699f0242696fcf4b86df Mon Sep 17 00:00:00 2001 From: Kai Staats Date: Wed, 14 Sep 2016 19:49:38 -0600 Subject: [PATCH] name change --- tools/karoo_sort.py | 67 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tools/karoo_sort.py diff --git a/tools/karoo_sort.py b/tools/karoo_sort.py new file mode 100644 index 0000000..145f32d --- /dev/null +++ b/tools/karoo_sort.py @@ -0,0 +1,67 @@ +# Karoo Dataset Builder +# by Kai Staats, MSc UCT / AIMS and Arun Kumar, PhD +# version 0.9.2.0 + +import sys +import numpy as np + +np.set_printoptions(linewidth = 320) # set the terminal to print 320 characters before line-wrapping in order to view Trees + +''' +In machine learning, it is often the case that your engaged dataset is derived from a larger parent. In constructing +the subset, if we grab a series of datapoints (rows in a .csv) from the larger dataset in sequential order, only from +the top, middle, or bottom, we will likely bias the new dataset and incorrectly train the machine learning algorithm. +Therefore, it is imperative that we engage a random function, guided only by the number of data points for each class. + +This script can be used *before* karoo_normalise.py, and assumes no header has yet been applied to the .csv. +''' + +### USER INTERACTION ### +if len(sys.argv) == 1: print '\n\t\033[31mERROR! You have not assigned an input file. Try again ...\033[0;0m'; sys.exit() +elif len(sys.argv) > 2: print '\n\t\033[31mERROR! You have assigned too many command line arguments. Try again ...\033[0;0m'; sys.exit() +else: filename = sys.argv[1] + +#n = range(1,101) +#while True: +# try: +# labels = raw_input('\n\tEnter number of unique class labels, or 0 for a regression dataset (default 2): ') +# if labels not in str(n) and labels not in '': raise ValueError() +# # if labels == '0': labels = 1; break +# labels = labels or 2; labels = int(labels); break +# except ValueError: print '\n\t\033[32mEnter a number from 0 including 100. Try again ...\033[0;0m' + +n = range(10,10001) +while True: + try: + samples = raw_input('\n\tEnter number of desired datapoints per class (default 100): ') + if samples not in str(n) and samples not in '': raise ValueError() + if samples == '0': samples = 10; break + samples = samples or 100; samples = int(samples); break + except ValueError: print '\n\t\033[32mEnter a number from 10 including 10000. Try again ...\033[0;0m' + + +### LOAD THE ORIGINAL DATASET ### +print '\n\t\033[36m\n\tLoading dataset:', filename, '\033[0;0m\n' +data = np.loadtxt(filename, delimiter = ',') # load data +data_sort = np.empty(shape = [0, data.shape[1]]) # build an empty array of the proper dimensions + + +### SORT DATA by LABEL ### +labels = len(np.unique(data[:,-1])) + +for label in range(labels): + data_list = np.where(data[:,-1] == label) # build a list of all rows which end in the current label + + data_select = np.random.choice(data_list[0], samples, replace = False) # select user defined 'samples' from list + print data_select + + data_sort = np.append(data_sort, data[data_select], axis = 0) + + +### SAVE THE SORTED DATASET ### +file_tmp = filename.split('.')[0] +np.savetxt(file_tmp + '-SORT.csv', data_sort, delimiter = ',') + +print '\n\t\033[36mThe sorted dataset has been written to the file:', file_tmp + '-SORT.csv', '\033[0;0m' + +