summaryrefslogtreecommitdiff
path: root/scripts/csv2sto.py
blob: 338b7737513922bc49c10413d1e0d80df7470fb2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2012-2013 Tobias Klauser <tklauser@distanz.ch>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation.

import os, sys
import getopt
import csv
import numpy as np

DEFAULT_NAME = 'control'
DEFAULT_MAXTIME = 100.0
DEFAULT_MAXY = 50.0

def usage():
    print("""usage: {} [OPTION...] CSV-FILE... STO-FILE

  -f        force overwrite of existing files
  -m NAMES  use names from comma-separated list NAMES to name muscles instead
            of CSV file name (# of names in list must match # of CSV files)
  -n        name of the activation data in sto file (default: {})
  -T N      scale time into range 0.0 to N
  -Y N      scale y values into range 0.0 to N
  -h        show this help and exit""".format(os.path.basename(sys.argv[0]), DEFAULT_NAME))

def main():
    try:
        opts, args = getopt.getopt(sys.argv[1:], "fm:n:hT:Y:")
    except getopt.GetoptError as err:
        print(str(err))
        usage()
        sys.exit(-1)

    if len(args) < 2:
        usage()
        sys.exit(-1)

    overwrite = False
    muscles = None
    name = DEFAULT_NAME
    maxtime = DEFAULT_MAXTIME
    maxy = DEFAULT_MAXY
    for o, a in opts:
        if o == '-f':
            overwrite = True
        elif o == '-m':
            muscles = [ x.strip() for x in a.split(',') ]
        elif o == '-n':
            name = a
        elif o == '-T':
            maxtime = float(a)
        elif o == '-Y':
            maxy = float(a)
        elif o == '-h':
            usage()
            sys.exit(0)
        else:
            assert False, "unhandled option"

    fs_csv = args[0:-1]
    for f_csv in fs_csv:
        if not os.path.exists(f_csv):
            print("Error: CSV file {} does not exist".format(f_csv))
            sys.exit(-1)

    if muscles and len(muscles) != len(fs_csv):
        print("Error: number of names ({}) provided does not match number of CSV files ({})".format(len(muscles), len(fs_csv)))
        sys.exit(-1)

    f_sto = args[-1]
    if not overwrite and os.path.exists(f_sto):
        print("Error: STO file {} already exists".format(f_sto))
        sys.exit(-1)

    # determine number of columns and rows
    tot_nRows = 0
    tot_nCols = 1   # initial time column (as we omit it when taking data from each CSV)
    for f_csv in fs_csv:
        fd = open(f_csv, 'r')
        dialect = csv.Sniffer().sniff(fd.read(1024))
        fd.seek(0)
        csv_reader = csv.reader(fd, dialect)

        nRows = nCols = 0
        for line in csv_reader:
            if nRows == 0:
                nCols = len(line)
            nRows += 1

        if tot_nRows == 0:
            tot_nRows = nRows
        elif tot_nRows != nRows:
            print("Error: Number of rows in CSV files do not match")
            sys.exit(-1)

        # add up to the total number of colums but omit time column
        tot_nCols += nCols - 1

        fd.close()

    # don't count the header line
    tot_nRows -= 1

    print("Writing sto file {} with name {} ({} muscles)".format(f_sto, name, len(fs_csv)))
    if maxtime != DEFAULT_MAXTIME:
        print("Scaling to max. time {}".format(maxtime))
    if maxy != DEFAULT_MAXY:
        print("Scaling to max. y {}".format(maxy))

    fd_sto = open(f_sto, 'w')
    fd_sto.write(name + "\n")
    fd_sto.write("version=1\n")
    fd_sto.write("nRows=" + str(tot_nRows) + "\n")
    fd_sto.write("nColumns=" + str(tot_nCols) + "\n")
    fd_sto.write("endheader\n")

    fd_sto.write("time")
    i = 0
    for f_csv in fs_csv:
        if muscles:
            colname = muscles[i]
            i += 1
        else:
            colname = os.path.splitext(os.path.basename(f_csv))[0]
        print("[+] Adding muscle {}...".format(colname))
        fd_sto.write("\t" + colname)
    fd_sto.write("\n")

    fds_csv = [ open(f_csv, 'r') for f_csv in fs_csv ]
    crs = []
    for fd_csv in fds_csv:
        dialect = csv.Sniffer().sniff(fd_csv.read(1024))
        fd_csv.seek(0)
        crs.append(csv.reader(fd_csv, dialect))

    times = np.zeros((tot_nRows, len(fds_csv)))
    vals = np.zeros((tot_nRows, len(fds_csv)))
    for i, cr in enumerate(crs):
        cr.next()   # skip header line
        for j, row in enumerate(cr):
            t,y = map(float, row)
            if maxtime != DEFAULT_MAXTIME:
                # Assume default 0-100 time in all input files for now
                t = (t / 100.0) * maxtime
            if maxy != DEFAULT_MAXY:
                # Assume default 0-50 time in all input files for now
                y = (y / 50.0) * maxy
            times[j,i] = t
            vals[j,i] = y

    # XXX: For now assume same time spacing in all CSV files
    for i, row in enumerate(times):
        fd_sto.write(str(row[0]))
        for val in vals[i]:
            fd_sto.write("\t" + str(val))
        fd_sto.write("\n");

    for fd_csv in fds_csv:
        fd_csv.close()
    fd_sto.close()

if __name__ == '__main__':
    main()