Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os.path
- import glob
- import matplotlib.pyplot as plt
- from matplotlib.lines import Line2D
- import numpy as np
- import h5py
- import yt
- from unyt import unyt_array, dimensions, gauss
- from yt.funcs import mylog
- mylog.setLevel(30)
- def plot_field_ray(ax,ray,field,delta_val=0.0,multiply =1.,**kwargs):
- """
- Plots field values contained by a YTRay or YTOrthoRay
- """
- # first we need to compute distance along the ray
- if hasattr(ray,'axis'):
- axis_id = ray.axis
- for axis in ['x','y','z']:
- if ray.ds.coordinates.axis_id[axis] == axis_id:
- dists = ray[axis]
- break
- idx = slice(None)
- else:
- idx = np.argsort(ray["t"])
- start_i = np.argmin(ray["t"])
- final_i = np.argmax(ray["t"])
- delta_comp = [ray[dim][final_i] - ray[dim][start_i] \
- for dim in ('x','y','z')]
- dists = ray["t"][idx]*np.sqrt(np.sum(np.square(delta_comp)))
- return ax.plot(dists,multiply*(ray[field][idx].value - delta_val),
- **kwargs)
- def find_block_boundaries(ds, aligned_ax = 0,
- transverse_coord = (0.1,0.1),
- skip_domain_boundaries = True):
- if (aligned_ax == 0):
- map_ind = [1,2]
- else:
- raise NotImplementedError()
- if not isinstance(transverse_coord,
- unyt_array):
- transverse_coord = ds.arr(transverse_coord,
- units = 'code_length')
- left_edges = ds.index.grid_left_edge
- right_edges = ds.index.grid_right_edge
- in_left = np.logical_and(
- transverse_coord[0] >= left_edges[:,map_ind[0]],
- transverse_coord[1] >= left_edges[:,map_ind[1]],
- )
- in_right = np.logical_and(
- transverse_coord[0] < right_edges[:,map_ind[0]],
- transverse_coord[1] < right_edges[:,map_ind[1]],
- )
- domain_bounds = (float(ds.domain_left_edge[aligned_ax].v),
- float(ds.domain_right_edge[aligned_ax].v))
- grid_inds = np.argwhere(np.logical_and(in_left, in_right)).flatten()
- out = set()
- for grid_ind in grid_inds:
- boundaries = (float(left_edges[ grid_ind,aligned_ax].v),
- float(right_edges[grid_ind,aligned_ax].v))
- for b in boundaries:
- if skip_domain_boundaries and b in domain_bounds:
- continue
- elif b not in out:
- out.add(b)
- return out
- def field_plot_props(fields, bkg_vals, labels, common_plot_kwargs):
- triples = zip(fields,bkg_vals,labels)
- _keys = ('c','mec','mfc','mfcalt')
- skip_color_spec = any(
- ('color' in key) or (key in _keys)
- for key in common_plot_kwargs
- )
- assert 'label' not in common_plot_kwargs
- for i,(field,bkg_val,label) in enumerate(triples):
- plot_kwargs = {'label' : label, **common_plot_kwargs}
- if not skip_color_spec:
- plot_kwargs['color'] = f'C{i}'
- yield field, bkg_val, plot_kwargs
- def plot_aligned_waves(fig, ax_array, fnames, fields, factor,
- bkg_vals, labels, field_adder = [],
- aligned_ax = 0,
- transverse_coord = (0.1,0.1),
- prec = 3,
- legend_ax = None,
- common_plot_kwargs = {}):
- """
- Function that plots oblique waves
- """
- if field_adder is None:
- adder_l = []
- elif callable(field_adder):
- adder_l = [field_adder]
- elif isinstance(field_adder, (list,tuple)):
- adder_l = field_adder
- else:
- raise ValueError(
- "field_adder must be callable or list/tuple of callables")
- if prec is not None:
- assert isinstance(prec,int) and prec > 0
- plt_triples = tuple(field_plot_props(
- fields = fields, bkg_vals = bkg_vals, labels = labels,
- common_plot_kwargs = common_plot_kwargs
- ))
- for ax,fname in zip(ax_array.flatten(),fnames):
- ds = yt.load(fname)
- time = ds.current_time
- for adder in adder_l:
- adder(ds)
- ray = ds.ortho_ray(aligned_ax,transverse_coord)
- for field,bkg_val,plot_kwargs in plt_triples:
- plot_field_ray(ax, ray, field, bkg_val, factor,
- **plot_kwargs)
- if prec is not None:
- ax.text(0.95, 0.01, 't={:.0{prec}f}'.format(time.v,prec=prec),
- verticalalignment='bottom', horizontalalignment='right',
- transform=ax.transAxes,
- fontsize=10)
- boundaries = find_block_boundaries(ds, aligned_ax,
- transverse_coord)
- for b in boundaries:
- ax.axvline(b, color = 'gray', ls = ':')
- if legend_ax is not None:
- legend_elements = [
- Line2D([0], [0], **kw) for _,_,kw in plt_triples
- ]
- legend_ax.legend(loc="center", handles = legend_elements)
- legend_ax.axis('off')
- for i in range(ax_array.shape[0]):
- if ax_array[i,0] is None:
- continue
- ax_array[i,0].set_ylabel(r"$\delta U\times 10^6$")
- for i in range(ax_array.shape[1]):
- row_ind = -1
- if (ax_array[row_ind,i] is None):
- if ax_array.shape[1] == 1:
- continue
- else:
- row_ind = -2
- ax_array[row_ind,i].set_xlabel(r"$x$")
- def locate_block_list(prefix):
- """
- Simplified version of collect_files. It just attempts to find the
- block_list_files
- """
- dirs = glob.glob(prefix)
- dirs.sort()
- fnames = []
- for dirname in dirs:
- fname = os.path.join(dirname,
- '.'.join((os.path.basename(dirname),
- 'block_list')))
- assert os.path.isfile(fname)
- fnames.append(fname)
- return fnames
- if __name__ == '__main__':
- fnames = locate_block_list(
- "../method_ppm-xaxis-AMR-soundN32_[0-9].[0257][05][0][0]*"
- )
- fnames = sorted(fnames)
- fig,ax_array = plt.subplots(3,2,figsize = (8,12),sharex=True,sharey=True)
- legend_ax = ax_array[-1,-1]
- ax_array[-1,-1] = None
- fields = [("enzoe","density"),
- ("gas","momentum_density_x"),
- ("gas","momentum_density_y"),
- ("gas","momentum_density_z"),
- ("enzoe","total_energy")]
- bkg_vals = [1.0,
- 0.0,0.0,0.0,
- 0.9]
- labels = [r"$\rho$",
- r"$\rho v_x$",r"$\rho v_y$",r"$\rho v_z$",
- r'$E_{\rm dens}$']
- ylims = (-1.6,1.6)
- factor = 1.e6
- plot_aligned_waves(fig, ax_array, fnames, fields, factor, bkg_vals,
- labels, field_adder = [],
- aligned_ax = 0, transverse_coord = (0.1,0.1),
- legend_ax = legend_ax,
- common_plot_kwargs = {"ls" : "none", "marker" : '.'})
- ax_array[0,0].set_ylim(*ylims)
- ax_array[0,0].set_xlim(0,1)
- for ax in ax_array.flatten():
- if ax is not None:
- ax.axhline(0.9, color = 'k', ls = '--')
- fig.tight_layout()
- fig.subplots_adjust(wspace = 0.0)
- plt.savefig("example-output.png")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement