1

import copy

2


3

import numpy

4

import matplotlib

5

import matplotlib.pyplot

6

import scipy.stats

7


8

class Lattice(object):

9

def __init__(self, step_size, n_steps):

10

"""

11

Generalised tracking element. Overload "acceleration" to generate

12

arbitrary symplectic tracking. Overload "transport_one" to do

13

nonsymplectic things.

14

"""

15

self.step_size = step_size

16

self.n_steps = n_steps

17


18

def transport(self, distribution):

19

"""

20

Transport a distribution. This just iterates over the psv_list to do the

21

transport.

22

"""

23

for i, psv in enumerate(distribution.psv_list):

24

distribution.psv_list[i] = self.transport_one(psv)

25

distribution.psv_list = [psv for psv in distribution.psv_list if psv is not None]

26

return distribution

27


28

def transport_one(self, psv):

29

"""

30

Following "leap frog integration" which is a symplectic 2nd order integrator

31

http://www.artcompsci.org/vol_1/v1_web/node34.html

32

It is symplectic, because at each step we perform 2 shears in phase

33

space  a shear in x followed by a shear in v.

34

"""

35

dt = self.step_size

36

x_im1 = psv[0]

37

v_imh = psv[1]+self.acceleration(x_im1, 0)*dt/2.0

38

for i in range(self.n_steps):

39

# calculate the values for the next time step

40

x_i = x_im1+v_imh*dt

41

a_i = self.acceleration(x_i, 0)

42

v_iph = v_imh + a_i*dt

43

# update the values ready for the next step. We can go directly but

44

# I include this step for clarity

45

x_im1 = x_i

46

v_imh = v_iph

47

psv[0] = x_im1

48

psv[1] = v_imh+self.acceleration(x_im1, 0)*dt/2.0

49

return psv

50


51

def acceleration(self, x, z):

52

"""

53

Force is some function of phase space vector. As long as force is a

54

function only of position then area conserving.

55

"""

56

raise NotImplementedError("Abstract base class of Transport")

57


58

class ConstantNonLinearFocusing(Lattice):

59

"""

60

Class to do focussing that is constant polynomial

61

i.e. force(x, z) = a_0 x + a_1 x^2 + a_2 x^3 +...

62

"""

63

def __init__(self, step_size, n_steps, focusing_strength_vector):

64

"""

65

Instantiate with step size, number of steps and focusing strength

66

polynomial coefficients

67

"""

68

super().__init__(step_size, n_steps)

69

self.f = focusing_strength_vector

70


71

def acceleration(self, x, z):

72

"""

73

Acceleration as a function of x.

74

"""

75

xpow = 1 # x, x^2, x^3, etc

76

acc = 0.0

77

for f_i in self.f:

78

xpow *= x

79

acc += f_i*xpow

80

return acc

81


82

class Scraper(Lattice):

83

def __init__(self, aperture):

84

"""

85

Scraper kills particles having x > aperture

86

"""

87

super().__init__(1, 1)

88

self.aperture = aperture

89


90

def transport_one(self, psv):

91

"""

92

Kills the particles. Returns None if position is greater than the

93

aperture or psv.

94

"""

95

if abs(psv[0]) > self.aperture:

96

return None

97

return psv

98


99

class Damping(Lattice):

100

def __init__(self, x_damping, v_damping):

101

"""Damps the particle."""

102

super().__init__(1, 1)

103

self.x_damping = x_damping

104

self.v_damping = v_damping

105


106

def transport_one(self, psv):

107

"""Damps by x_damping, v_damping for each particle."""

108

psv[0] *= self.x_damping

109

psv[1] *= self.v_damping

110

return psv

111


112

class Distribution(object):

113

def __init__(self):

114

"""Distribution just holds a collection of particles."""

115

self.psv_list = []

116


117

def __deepcopy__(self, memo):

118

"""Deepcopy  allocate new memory for the psv_list"""

119

my_copy = Distribution()

120

my_copy.psv_list = copy.deepcopy(self.psv_list, memo)

121

return my_copy

122


123

def __str__(self):

124

"""Return a string representation of the distribution  x v one psv per line"""

125

my_str = ""

126

for psv in self.psv_list:

127

my_str += format(psv[0], "+8.4f")+format(psv[1], "+9.4f")+"\n"

128

return my_str

129


130

def plot(self, axes, vmin, vmax):

131

"""

132

Make a 2d histogram of the particles. axes is the matplotlib axes to draw on

133

vmin, vmax are the minimum/maximum vertical axes (set to None to automatically calculate)

134

"""

135

x_y = list(zip(*self.psv_list))

136

clr = axes.hist2d(x_y[0], x_y[1], 50, range=[[2, 2],[2, 2]], vmin=vmin, vmax=vmax)[3]

137

axes.get_figure().colorbar(clr)

138


139

@classmethod

140

def gen_gaussian_distribution(cls, sample_size, covariance):

141

"""

142

Generate a multivariatae gaussian distirbution

143

"""

144

dist = Distribution()

145

dist.psv_list = numpy.random.multivariate_normal([0.0, 0.0],

146

covariance,

147

sample_size)

148

return dist

149


150

def get_amplitude_list(self, psv_list, cov = None):

151

"""

152

Calculate a set of amplitudes from the psv_list. If cov is defined,

153

use this covariance matrix to calculate the amplitudes.

154


155

Returns a tuple of the (list of amplitudes, covariance matrix).

156

"""

157

if cov is None:

158

cov = numpy.cov(psv_list, rowvar=False)

159

cov /= numpy.linalg.det(cov)**(1.0/cov.shape[0])

160

cov_inv = numpy.linalg.inv(cov)

161

amp_list = []

162

for psv in psv_list:

163

amplitude = numpy.dot(psv, cov_inv)

164

amplitude = numpy.dot(amplitude, numpy.transpose(psv))

165

amp_list.append(amplitude)

166

return amp_list, cov

167


168

def amplitude(self, amp_cut):

169

"""

170

Calculate amplitudes. amp_cut is a float that sets the maximum amplitude

171

for applying the cut. We iterate on amp_cut until no events are included

172

in covariance matrix calculation having amplitude > amp_cut.

173

"""

174

test_list = numpy.array(self.psv_list)

175

sample_size = len(test_list)

176

cut_list = [0]

177

cov = None

178

while amp_cut and len(cut_list) > 0:

179

amp_list, cov = self.get_amplitude_list(test_list)

180

cut_list = [i for i, amp in enumerate(amp_list) if amp > amp_cut]

181

test_list = numpy.delete(test_list, cut_list, 0)

182

amp_list, cov = self.get_amplitude_list(numpy.array(self.psv_list), cov)

183

return amp_list, cov

184


185


186

def plot_phase_space_trajectory(psv, lattice, n_steps, axes):

187

"""

188

Plot the trajectory in phase space for a particle traversing *lattice*. Plot

189

n_steps number of iterations of lattice transport. Axes is the matplotlib axes

190

used for tracking.

191

"""

192

x_list = [psv[0]]+[None]*n_steps

193

y_list = [psv[1]]+[None]*n_steps

194

for i in range(n_steps):

195

lattice.transport_one(psv)

196

x_list[i+1] = psv[0]

197

y_list[i+1] = psv[1]

198

axes.scatter(x_list, y_list, s=0.1)

199


200

def plot_distribution(lattice_list_1, contour_element, dist_in, n_turns, title, amplitude_cut):

201

"""

202

Transport a distribution a number of turns and then plot it. Make a 2D

203

histogram of x v on the left and an amplitude distribution on the right.

204


205

 lattice_list is used for tracking the distribution

206

 contour_element is a single element that is used for plotting contours

207

 dist_in is the distribution

208

 n_turns is the number of turns to track distribution before plotting

209

 title is a string title for the plot

210

 amplitude_cut is the amplitude cut to use when calculating amplitudes

211

"""

212


213

figure = matplotlib.pyplot.figure(figsize=(20,10))

214

axes1 = figure.add_subplot(1, 2, 1)

215

n_events = len(dist_in.psv_list)

216

dist_out = copy.deepcopy(dist_in)

217

for i in range(n_turns):

218

for element in lattice_list_1:

219

element.transport(dist_out)

220

dist_out.plot(axes1, None, None)

221

plot_phase_space_trajectory([0.1, 0], contour_element, 1000, axes1)

222

plot_phase_space_trajectory([0.75, 0], contour_element, 1000, axes1)

223

plot_phase_space_trajectory([1.249, 0], contour_element, 1000, axes1)

224

axes1.get_figure().suptitle(title+"\n"+str(n_turns)+" turns")

225


226

axes1.set_xlabel("Position [au]")

227

axes1.set_ylabel("Momentum [au]")

228


229

axes2 = axes1.get_figure().add_subplot(1, 2, 2)

230

amplitude_list, cov = dist_out.amplitude(amplitude_cut)

231

axes2.hist(amplitude_list, bins=50, range=[0.0, 2.0])

232

axes2.set_title(str(len(amplitude_list))+" events")

233

axes2.set_xlabel("Amplitude [au]")

234

x_list = [i/100. for i in range(401)]

235

chi2_list = [2*chi2*n_events*2.0/50 for chi2 in scipy.stats.chi2.pdf(x_list, 2)]

236

x_list = [x/2.0 for x in x_list]

237

axes2.plot(x_list, chi2_list)

238


239

return figure

240


241

def savefig(figure, title, n_turns):

242

"""

243

Save figure, using title and n_turns with special characters removed

244

"""

245

fname = title.replace(",", "")

246

fname = fname.replace("\n", "")

247

fname = fname.replace(" ", "_")

248

fname += "_"+str(n_turns)+"_turns"

249

figure.savefig(fname+".png")

250


251

def distributions(title, focusing, damping, aperture, amplitude_cut):

252

"""

253

Generate a distribution and lattice. Then track the distribution and plot it

254

using plot_distribution.

255

"""

256

suptitle = title+"Focusing: "+str(focusing)+" Damping: "+str(damping)+" Aperture: "+str(aperture)+" Amplitude cut: "+str(amplitude_cut)

257

lattice = ConstantNonLinearFocusing(1, 1, focusing)#, 0.0, 0.06])

258

damping = Damping(*damping)

259

scraper = Scraper(aperture)

260

distribution = Distribution.gen_gaussian_distribution(100000, [[0.5, 0.0], [0.0, 0.5]])

261


262

for n_turns in [0, 1, 10]:

263

figure = plot_distribution([lattice, scraper, damping], lattice, distribution, n_turns, suptitle, amplitude_cut)

264

savefig(figure, title, n_turns)

265


266

def main():

267

"""

268

Main function

269

"""

270

distributions("Basic lattice\n", [0.1], (1.0, 1.0), 100.0, 0.5)

271

distributions("Scraping\n", [0.1], (1.0, 1.0), 1.8, 0.5)

272

distributions("Damping\n", [0.1], (1.0, 0.9), 100.0, 0.5)

273

distributions("Nonlinear\n", [0.1, 0.0, 0.03], (1.0, 1.0), 100.0, 0.05)

274

distributions("Scraping, damping, nonlinear\n", [0.1, 0.0, 0.01], (1.0, 0.9), 1.0, 0.5)

275

distributions("Scraping, nonlinear\n", [0.1, 0.0, 0.01], (1.0, 1.0), 1.0, 0.5)

276


277

if __name__ == "__main__":

278

main()

279

matplotlib.pyplot.show(block=False)

280

input("Press <CR> to finish")

281

