summaryrefslogtreecommitdiff
path: root/scripts/prepare-data.py
blob: 7250bde6f2b6ab1b537c76d83bb82d5e630875ee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python

import csv
import getopt
import os, sys
import scipy.interpolate
import numpy as np
import matplotlib.pyplot as plt

INTERP_METHODS = [ 'cubic', 'linear' ]
DEFAULT_N = 100

def usage():
    print("""usage: {} [OPTION...] CSV-FILE...

    -d DIR     use DIR as output directory (default: same as input file)
    -f         force overwrite of existing files
    -i INTERP  use interpolation method INTERP. Must be one of: {} (default: {})
    -N N       produce N data points in the output (default: {})
    -p         plot result (default: off)
    -h         show this help and exit""".format(os.path.basename(sys.argv[0]),
        ','.join(INTERP_METHODS), INTERP_METHODS[0], DEFAULT_N))

def main():
    try:
        opts, args = getopt.getopt(sys.argv[1:], "d:fi:N:ph")
    except getopt.GetoptError, err:
        print(str(err))
        usage()
        sys.exit(-1)

    if len(args) < 1:
        usage()
        sys.exit(-1)

    output_dir = None
    force = False
    interp = INTERP_METHODS[0]
    N = DEFAULT_N
    do_plot = False
    x_max = 100
    y_max = 50
    for o, a in opts:
        if o == '-d':
            if os.path.exists(a):
                output_dir = a
            else:
                print("Error: output directory {} is not a valid path".format(a))
                sys.exit(-1)
        elif o == '-f':
            force = True
        elif o == '-i':
            if a in INTERP_METHODS:
                interp = a
            else:
                print("Error: invalid interpolation method: {}".format(a))
                usage()
                sys.exit(-1)
        elif o == '-N':
            try:
                N = int(a)
            except ValueError:
                print("Error: invalid number of data points: {}".format(N))
                usage()
                sys.exit(-1)
        elif o == '-p':
            do_plot = True
        elif o == '-h':
            usage()
            sys.exit(0)
        else:
            assert False, "unhandled option"

    for csv_file in args:
        if not os.path.exists(csv_file):
            print("Error: File {} not found, skipping".format(csv_file))
            continue

        with open(csv_file, 'r') as cf:
            dialect = csv.Sniffer().sniff(cf.read(1024))
            cf.seek(0)
            csv_reader = csv.reader(cf, dialect)
            x_name, y_name = csv_reader.next()  # header line

            X = np.array([[float(_x), float(_y)] for _x, _y in csv_reader ])
            x = X[:,0]
            y = X[:,1]

            # TODO: maybe smoothen data first?

            # interpolate data points
            xnew = np.linspace(min(x), max(x), N)
            f = scipy.interpolate.interp1d(x, y, kind=interp)
            ynew = f(xnew)

            if do_plot:
                plt.plot(x, y, 'x')
                plt.plot(xnew, ynew, '-')
                plt.legend([ 'data', "{} interpolation".format(interp) ], loc='best')
                plt.axis([ 0, x_max, 0, y_max ], 'equal')
                plt.xlabel(x_name)
                plt.ylabel(y_name)
                plt.title(csv_file)
                plt.grid(True)

                plt.show()

            Y = np.transpose(np.array([ xnew, ynew ]))

            if output_dir:
                od = output_dir
            else:
                od = os.path.dirname(csv_file)
            out_file = os.path.join(od, os.path.basename(csv_file) + '.interp')
            if os.path.exists(out_file) and not force:
                print("Error: file {} already exists, not overwriting".format(out_file))
            else:
                csv_writer = csv.writer(open(out_file, 'w'), dialect)
                csv_writer.writerow([ x_name, y_name ])
                csv_writer.writerows(Y)
if __name__ == '__main__':
    main()