karoo_gp/tools/karoo_iris_plot.py

77 lines
2.7 KiB
Python

# Karoo Iris Plot
# by Kai Staats, MSc UCT / AIMS and Arun Kumar, PhD
# version 0.9.1.6
import sys
import numpy as np
import matplotlib.pyplot as mpl
np.set_printoptions(linewidth = 320) # set the terminal to print 320 characters before line-wrapping in order to view Trees
'''
This is a functional, even if rudimentary script designed to help you visualise your 2D or 3D data against a function
generated by Karoo GP. The challenge comes with solving complex equations for a single variable such that you have a
plot-able function. If the algebra required is beyond your skills (or you forgot what you learned in high school),
tools such as Matlab may be of some assistance. If you desire to normalise your data in advance of using this script,
the Karoo GP normalisation script included in the karoo_gp/toos/ directory is very easy to use.
By default, this script plots a Karoo GP derived function against a scatter plot of one of the Iris datasets
included with this package: karoo_gp/files/Iris_dataset/data_IRIS_virginica-vs-setosa_3-col_PLOT.csv
If you are new to plotting, https://www.youtube.com/channel/UCfzlCWGWYyIQ0aLC5w48gBQ for a good plotting tutorial
provides a good, visual tutorial, as do many, many other web and video based guides.
'''
### USER INTERACTION ###
if len(sys.argv) == 1:
filename = '../files/Iris_dataset/data_IRIS_virginica-vs-setosa_3-col_PLOT.csv'
print '\n\t\033[31mYou have not assigned an input file, therefore "IRIS_virginica-vs-setosa_3-col_PLOT" will be used.\033[0;0m'
elif len(sys.argv) > 2: print '\n\t\033[31mERROR! You have assigned too many command line arguments. Try again ...\033[0;0m'; sys.exit()
else: filename = sys.argv[1]
### LOAD THE DATA and PREPARE AN EMPTY ARRAY ###
print '\n\t\033[36mLoading dataset:', filename, '\033[0;0m\n'
data = np.loadtxt(filename, delimiter=',', dtype = str)
data_a, data_b, data_c = [], [], []
tmp = data[:,0]
for n in range(len(tmp)):
data_a.append(float(tmp[n]))
tmp = data[:,1]
for n in range(len(tmp)):
data_b.append(float(tmp[n]))
tmp = data[:,2]
for n in range(len(tmp)):
data_c.append(float(tmp[n]))
### PREP THE FUNCTION ###
b = np.arange(2, 4, 0.25) # plot from n to m in steps o
c = np.arange(2, 4, 0.25) # plot from n to m in steps o
b, c = np.meshgrid(b, c)
# -b*c + c**2 + c - 1 # Karoo GP derived function
# -a/c - b**2 + c**2 # Karoo GP derived function
# -a - b + c**2 # Karoo GP derived function becomes a = -b + c**2
a = -b + c**2
### PLOT THE FUNCTION and DATA###
fig = mpl.figure()
ax = fig.add_subplot(111, projection = '3d')
ax.scatter(data_a, data_b, data_c, c = 'r', marker = 'o') # 3D data
ax.plot_wireframe(a,b,c) # 3D function
ax.set_xlabel('a')
ax.set_ylabel('b')
ax.set_zlabel('c')
mpl.show()