Advertisement
mabruzzo

illustrate_wave.py

Mar 2nd, 2023
658
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.31 KB | None | 0 0
  1. import os.path
  2. import glob
  3.  
  4. import matplotlib.pyplot as plt
  5. from matplotlib.lines import Line2D
  6. import numpy as np
  7. import h5py
  8. import yt
  9. from unyt import unyt_array, dimensions, gauss
  10.  
  11.  
  12. from yt.funcs import mylog
  13. mylog.setLevel(30)
  14.  
  15. def plot_field_ray(ax,ray,field,delta_val=0.0,multiply =1.,**kwargs):
  16.     """
  17.    Plots field values contained by a YTRay or YTOrthoRay
  18.    """
  19.     # first we need to compute distance along the ray
  20.     if hasattr(ray,'axis'):
  21.         axis_id = ray.axis
  22.         for axis in ['x','y','z']:
  23.             if ray.ds.coordinates.axis_id[axis] == axis_id:
  24.                 dists = ray[axis]
  25.                 break
  26.         idx = slice(None)
  27.     else:
  28.         idx = np.argsort(ray["t"])
  29.         start_i = np.argmin(ray["t"])
  30.         final_i = np.argmax(ray["t"])
  31.         delta_comp = [ray[dim][final_i] - ray[dim][start_i] \
  32.                       for dim in ('x','y','z')]
  33.         dists = ray["t"][idx]*np.sqrt(np.sum(np.square(delta_comp)))
  34.     return ax.plot(dists,multiply*(ray[field][idx].value - delta_val),
  35.                    **kwargs)
  36.        
  37. def find_block_boundaries(ds, aligned_ax = 0,
  38.                           transverse_coord = (0.1,0.1),
  39.                           skip_domain_boundaries = True):
  40.     if (aligned_ax == 0):
  41.         map_ind = [1,2]
  42.     else:
  43.         raise NotImplementedError()
  44.        
  45.     if not isinstance(transverse_coord,
  46.                       unyt_array):
  47.         transverse_coord = ds.arr(transverse_coord,
  48.                                   units = 'code_length')
  49.  
  50.     left_edges = ds.index.grid_left_edge
  51.     right_edges = ds.index.grid_right_edge
  52.    
  53.     in_left = np.logical_and(
  54.         transverse_coord[0] >= left_edges[:,map_ind[0]],
  55.         transverse_coord[1] >= left_edges[:,map_ind[1]],
  56.     )
  57.     in_right = np.logical_and(
  58.         transverse_coord[0] < right_edges[:,map_ind[0]],
  59.         transverse_coord[1] < right_edges[:,map_ind[1]],
  60.     )
  61.    
  62.     domain_bounds = (float(ds.domain_left_edge[aligned_ax].v),
  63.                      float(ds.domain_right_edge[aligned_ax].v))
  64.    
  65.     grid_inds = np.argwhere(np.logical_and(in_left, in_right)).flatten()
  66.  
  67.     out = set()
  68.     for grid_ind in grid_inds:
  69.         boundaries = (float(left_edges[ grid_ind,aligned_ax].v),
  70.                       float(right_edges[grid_ind,aligned_ax].v))
  71.         for b in boundaries:
  72.             if skip_domain_boundaries and b in domain_bounds:
  73.                 continue
  74.             elif b not in out:
  75.                 out.add(b)
  76.     return out
  77.  
  78. def field_plot_props(fields, bkg_vals, labels, common_plot_kwargs):
  79.     triples = zip(fields,bkg_vals,labels)
  80.  
  81.     _keys = ('c','mec','mfc','mfcalt')
  82.     skip_color_spec = any(
  83.         ('color' in key) or (key in _keys)
  84.         for key in common_plot_kwargs
  85.     )
  86.     assert 'label' not in common_plot_kwargs
  87.  
  88.     for i,(field,bkg_val,label) in enumerate(triples):
  89.         plot_kwargs = {'label' : label, **common_plot_kwargs}
  90.         if not skip_color_spec:
  91.             plot_kwargs['color'] = f'C{i}'
  92.         yield field, bkg_val, plot_kwargs
  93.  
  94. def plot_aligned_waves(fig, ax_array, fnames, fields, factor,
  95.                        bkg_vals, labels, field_adder = [],
  96.                        aligned_ax = 0,
  97.                        transverse_coord = (0.1,0.1),
  98.                        prec = 3,
  99.                        legend_ax = None,
  100.                        common_plot_kwargs = {}):
  101.     """
  102.    Function that plots oblique waves
  103.    """
  104.  
  105.     if field_adder is None:
  106.         adder_l = []
  107.     elif callable(field_adder):
  108.         adder_l = [field_adder]
  109.     elif isinstance(field_adder, (list,tuple)):
  110.         adder_l = field_adder
  111.     else:
  112.         raise ValueError(
  113.             "field_adder must be callable or list/tuple of callables")
  114.    
  115.     if prec is not None:
  116.         assert isinstance(prec,int) and prec > 0
  117.    
  118.    
  119.     plt_triples = tuple(field_plot_props(
  120.         fields = fields, bkg_vals = bkg_vals, labels = labels,
  121.         common_plot_kwargs = common_plot_kwargs
  122.     ))
  123.  
  124.     for ax,fname in zip(ax_array.flatten(),fnames):
  125.         ds = yt.load(fname)
  126.         time = ds.current_time
  127.         for adder in adder_l:
  128.             adder(ds)
  129.  
  130.         ray = ds.ortho_ray(aligned_ax,transverse_coord)
  131.    
  132.         for field,bkg_val,plot_kwargs in plt_triples:
  133.             plot_field_ray(ax, ray, field, bkg_val, factor,
  134.                            **plot_kwargs)
  135.         if prec is not None:
  136.             ax.text(0.95, 0.01, 't={:.0{prec}f}'.format(time.v,prec=prec),
  137.                     verticalalignment='bottom', horizontalalignment='right',
  138.                     transform=ax.transAxes,
  139.                     fontsize=10)
  140.  
  141.         boundaries = find_block_boundaries(ds, aligned_ax,
  142.                                            transverse_coord)
  143.         for b in boundaries:
  144.             ax.axvline(b, color = 'gray', ls = ':')
  145.    
  146.     if legend_ax is not None:
  147.         legend_elements = [
  148.             Line2D([0], [0], **kw) for _,_,kw in plt_triples
  149.         ]
  150.         legend_ax.legend(loc="center", handles = legend_elements)
  151.         legend_ax.axis('off')
  152.  
  153.     for i in range(ax_array.shape[0]):
  154.         if ax_array[i,0] is None:
  155.             continue
  156.         ax_array[i,0].set_ylabel(r"$\delta U\times 10^6$")
  157.     for i in range(ax_array.shape[1]):
  158.         row_ind = -1
  159.         if (ax_array[row_ind,i] is None):
  160.             if ax_array.shape[1] == 1:
  161.                 continue
  162.             else:
  163.                 row_ind = -2
  164.         ax_array[row_ind,i].set_xlabel(r"$x$")
  165.  
  166. def locate_block_list(prefix):
  167.     """
  168.    Simplified version of collect_files. It just attempts to find the
  169.    block_list_files
  170.    """
  171.     dirs = glob.glob(prefix)
  172.     dirs.sort()
  173.     fnames = []
  174.     for dirname in dirs:
  175.         fname = os.path.join(dirname,
  176.                              '.'.join((os.path.basename(dirname),
  177.                                        'block_list')))
  178.         assert os.path.isfile(fname)
  179.         fnames.append(fname)
  180.     return fnames
  181.  
  182. if __name__ == '__main__':
  183.     fnames = locate_block_list(
  184.         "../method_ppm-xaxis-AMR-soundN32_[0-9].[0257][05][0][0]*"
  185.     )
  186.     fnames = sorted(fnames)
  187.    
  188.     fig,ax_array = plt.subplots(3,2,figsize = (8,12),sharex=True,sharey=True)
  189.     legend_ax = ax_array[-1,-1]
  190.     ax_array[-1,-1] = None
  191.    
  192.  
  193.  
  194.     fields = [("enzoe","density"),
  195.               ("gas","momentum_density_x"),
  196.               ("gas","momentum_density_y"),
  197.               ("gas","momentum_density_z"),
  198.               ("enzoe","total_energy")]
  199.     bkg_vals = [1.0,
  200.                 0.0,0.0,0.0,
  201.                 0.9]
  202.     labels = [r"$\rho$",
  203.               r"$\rho v_x$",r"$\rho v_y$",r"$\rho v_z$",
  204.               r'$E_{\rm dens}$']
  205.     ylims = (-1.6,1.6)
  206.     factor = 1.e6
  207.  
  208.     plot_aligned_waves(fig, ax_array, fnames, fields, factor, bkg_vals,
  209.                        labels, field_adder = [],
  210.                        aligned_ax = 0, transverse_coord = (0.1,0.1),
  211.                        legend_ax = legend_ax,
  212.                        common_plot_kwargs = {"ls" : "none", "marker" : '.'})
  213.     ax_array[0,0].set_ylim(*ylims)
  214.     ax_array[0,0].set_xlim(0,1)
  215.     for ax in ax_array.flatten():
  216.         if ax is not None:
  217.             ax.axhline(0.9, color = 'k', ls = '--')
  218.  
  219.     fig.tight_layout()
  220.     fig.subplots_adjust(wspace = 0.0)
  221.     plt.savefig("example-output.png")
  222.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement