diff --git a/sustain.py b/sustain.py index 1eb50b3..31c703d 100644 --- a/sustain.py +++ b/sustain.py @@ -376,9 +376,9 @@ def shepard_six(data, env): # this is clunky but it's just to organize the data to compute RMSError # output total RMS Error (have to average back the data to avg. error every two blocks) - hd = resize(humandata,(1,len(humandata)*len(humandata[0])))[[0]] + hd = resize(humandata,(1,len(humandata)*len(humandata[0])))[0] avgbacklc = map(lambda x: map(mean,resize(x,(len(x)/2,2))),lc) - md = resize(avgbacklc,(1,len(avgbacklc)*len(avgbacklc[0])))[[0]] + md = resize(avgbacklc,(1,len(avgbacklc)*len(avgbacklc[0])))[0] md = map(lambda x: x/(len(trainingitems)*ntimes), md) print "Total RMSError = ", sqrt(sum(map(lambda x,y: pow(x-y,2.0),md,hd))/(len(md)-1)) @@ -404,4 +404,4 @@ def main(): ########################################################### if __name__ == '__main__': - main() \ No newline at end of file + main()