#!/usr/env python

import os
import subprocess

import ROOT

import xboa.Bunch
from xboa.Bunch import Bunch
import xboa.Common as Common

IDEAL_FIELD = None

def run_sim():
    erit_root_dir = os.getenv('ERIT_ROOT')
    sim_path = os.path.join(erit_root_dir, 'maus', 'simulate_erit.py')
    conf_path = os.path.join(erit_root_dir, 'maus', 'tests', 'test_field', 'field_configuration.py')
    print sim_path
    print conf_path
    subproc = subprocess.Popen(['python', sim_path, '-configuration_file', \
                                conf_path])
    subproc.wait()

def plot_output():
    global IDEAL_FIELD
    primary_bunch = Bunch.new_from_read_builtin('maus_primary', 'simulation.out')
    primary_bunch.root_scatter_graph('x', 'z', 'm', 'm')

    virtual_hits = Bunch.new_from_read_builtin('maus_virtual_hit', 'simulation.out')
    virtual_hits.root_scatter_graph('x', 'z', 'm', 'm')

    virtual_hits_by_station = Bunch.new_list_from_read_builtin('maus_virtual_hit', 'simulation.out')
    for bunch in virtual_hits_by_station:
        if abs(bunch.mean(['z'])['z']) < 1e-9:
            (canvas, hist, graph) = bunch.root_scatter_graph('x', 'by', 'm', 'T', ymin=0.5, ymax=1.0)
            graph.Draw('l')
            IDEAL_FIELD = ROOT.TF1("ideal_field","0.727*(2.350/x)**1.92",2.000,2.700)
            IDEAL_FIELD.Draw('lsame')
            canvas.Update()
            for hit in bunch:
                print hit['x'], hit['z'], hit['by']

    #three_d_plot(virtual_hits, 'x', 'z', 'by', 'm', 'm', 'T')

def three_d_plot(bunch, x_axis, y_axis, z_axis, x_units='', y_units='', z_units=''):
    canvas_name = x_axis+'-'+y_axis+'-'+z_axis
    canvas = Common.make_root_canvas(canvas_name)
    x_list, y_list, z_list = [], [], []
    for hit in bunch:
        x_list.append(hit[x_axis]/Common.units[x_units])
        y_list.append(hit[y_axis]/Common.units[y_units])
        z_list.append(hit[z_axis]/Common.units[z_units])
        print x_list[-1], y_list[-1], z_list[-1]
    (hist, graph_3D, graph_2D) = Common.make_root_graph_2d(canvas_name, x_list, x_axis+' ['+x_units+']', y_list, y_axis+' ['+y_units+']', z_list, z_axis+' ['+z_units+']')
    canvas.cd()
    hist.Draw()
    graph_3D.Draw('CONTZ')
    graph_2D.Draw('')
    canvas.Update()
    print 'done'

def main():
    run_sim()
    plot_output()
    raw_input()

if __name__ == "__main__":
    main()


