Reproducing the observation¶
What we want is to have a simulator that is able to reproduce the outputs of Observation 5 with different values of the turbulent power spectrum. We keep the emissivity field, binning and everything else fixed.
We start with the imports
# Classics
import matplotlib.pyplot as plt
from astropy.io import fits
import jax.numpy as jnp
import cmasher as cmr
import jax
import xarray as xr
# Astropy
import astropy.units as units
from astropy.cosmology import LambdaCDM
cosmo = LambdaCDM(H0 = 70, Om0 = 0.3, Ode0 = 0.7)
# fast_turbulence_sim
# I do this because I don't want to deal with making environments for this code
import sys
sys.path.append('../../src/')
from simulation import rng_key, Simulation
from binning import LoadBinning
from grid import SpatialGrid3D
from structure_function import StructureFunction
from projection import VCubeProjection_v2
# haiku
import haiku as hk
Define the spatial grid
# Pixel size
pixsize_m = 317e-6 # 317 um
focal_length = 12 # 12 m
pixsize_arcmin = pixsize_m/focal_length * 180 / jnp.pi * 60
pixsize_kpc = pixsize_arcmin * cosmo.kpc_proper_per_arcmin(0.1).value
# Spatial grid
spatial_grid = SpatialGrid3D(pixsize=pixsize_kpc,
shape=(232,232),
los_factor=5)
print('The pixel size in kiloparsecs is {:.2f} kpc'.format(pixsize_kpc))
The pixel size in kiloparsecs is 10.05 kpc
Load the emissivity cube used for emission-weighting (see corresponding notebook for the creation of said cube)
# Emissivity cube
em = jnp.load('/xifu/home/mola/SBI_Turbulence_newXIFU/data/emissivity_cube_19p_newxifu_novign.npy')
Load the binning used for Observation 5 and define SF
# Load binning
binning = LoadBinning(shape = (232,232),
binning_file = '../../data/XIFU_Observation5/region_files/19p_region_dict.p',
count_map_file = '../../data/XIFU_Observation5/19p_count_image.fits')
# Get values from binning
X_pixels, Y_pixels, bin_num_pix, nb_bins, xBar_bins, yBar_bins, bin_nb_map = binning()
# Load structure function
sf = StructureFunction(bins = jnp.geomspace(3,200,20))
Load the fits file defining the instrument PSF and to normalize it to one
PSF_kernel = jnp.array(fits.getdata('../../data/XIFU_Observation5/PSF_image.fits'), dtype = 'float64')
PSF_kernel *= 1/jnp.sum(PSF_kernel)
PSF_kernel_zero_padded = jnp.pad(PSF_kernel, pad_width = int((232-58)/2), mode = 'edge')
Load the measurement error
data_mes_err = jnp.load('../../data/XIFU_Observation5/mes_err_stats.npz')
radial_bins_mes_errors = jnp.array(data_mes_err['radial_bins_mes_errors'], dtype = 'float32')
censhift_offsets = jnp.array(data_mes_err['censhift_offsets'], dtype = 'float32')
censhift_errors = jnp.array(data_mes_err['censhift_errors'], dtype = 'float32')
broad_offsets = jnp.array(data_mes_err['broad_offsets'], dtype = 'float32')
broad_errors = jnp.array(data_mes_err['broad_errors'], dtype = 'float32')
Define the simulator¶
#Projection
projection = VCubeProjection_v2(binning, em, PSF_kernel_zero_padded)
# Simulation
sim = hk.transform(lambda : Simulation(spatial_grid,
sf,
binning,
projection,
radial_bins_mes_errors,
censhift_offsets,
censhift_errors,
broad_offsets,
broad_errors)()
)
Now that the simulator is defined, we need to initalize it (this is a procedure required by haiku, which allows parameter inheritance between classes, which is why I use it)
# Init
pars = sim.init(rng_key())
for i in pars['simulation/~/fluctuation_cube/~/kolmogorov_power_spectrum']:
print(i, pars['simulation/~/fluctuation_cube/~/kolmogorov_power_spectrum'][i])
# JIT
sim_jit = jax.jit(sim.apply)
sigma 250.0 log_inj 2.4771214 alpha 3.6666667
Check one realization
%%time
dist, sf, sf_std, v_vec, std_vec = sim_jit(pars, rng_key())
CPU times: user 6.52 s, sys: 388 ms, total: 6.91 s Wall time: 6.91 s
Look at the binned velocity field
binned_censhift_output = jnp.zeros((232,232))
binned_censhift_output = binned_censhift_output.at[X_pixels, Y_pixels].set(v_vec[bin_num_pix])
fig,ax = plt.subplots(1,1, figsize = (6,6))
im =ax.imshow(binned_censhift_output, cmap = cmr.iceburn, vmin = -150, vmax = 150)
plt.colorbar(im, pad = 0.04, fraction = 0.046, label = 'Centroid shift [km/s]')
<matplotlib.colorbar.Colorbar at 0x149c78d964a0>
Look at the structure function
fig, ax = plt.subplots(1,1, figsize = (6,4))
ax.loglog(dist, sf, marker ='.')
#ax[0].loglog(dists, sf_censhift_inp+2*std_diff_censhift**2, label = r'Output + $2\sigma^2$', marker ='.', color = 'tab:blue', alpha = 0.5, ls = '--')
ax.set_xlabel('Separation [pixels]')
ax.set_ylabel('Centroid shift SF [km2/s2]')
Text(0, 0.5, 'Centroid shift SF [km2/s2]')
Check that on average, the simulator outputs something around the observation used as a reference
Nsims = 100
SFs_censhift = jnp.zeros((Nsims,19))
SFs_broad = jnp.zeros((Nsims,19))
vecs_censhift = jnp.zeros((Nsims,nb_bins))
vecs_broad = jnp.zeros((Nsims,nb_bins))
for k in range(Nsims):
dist, sf, sf_std, v_vec, std_vec = sim_jit(pars, rng_key())
SFs_censhift = SFs_censhift.at[k].set(sf)
SFs_broad = SFs_broad.at[k].set(sf_std)
vecs_censhift = vecs_censhift.at[k].set(v_vec)
vecs_broad = vecs_broad.at[k].set(std_vec)
SFs_real = xr.open_dataset("../../data/XIFU_Observation5/outputs_SFs.nc")
fig, ax = plt.subplots(2,1, figsize = (6,12))
ax[0].loglog(dist, jnp.mean(SFs_censhift, axis =0),
label = 'Average predicted output SF')
ax[0].fill_between(dist,
jnp.percentile(SFs_censhift, 2, axis = 0),
jnp.percentile(SFs_censhift, 98, axis = 0),
alpha = 0.2,
label = '2-98%'
)
ax[0].fill_between(dist,
jnp.percentile(SFs_censhift, 15, axis = 0),
jnp.percentile(SFs_censhift, 85, axis = 0),
alpha = 0.5,
color = 'tab:blue',
label = '15-85%')
ax[1].loglog(dist, jnp.mean(SFs_broad, axis =0),
label = 'Average predicted output SF')
ax[1].fill_between(dist,
jnp.percentile(SFs_broad, 2, axis = 0),
jnp.percentile(SFs_broad, 98, axis = 0),
alpha = 0.2,
label = '2-98%')
ax[1].fill_between(dist,
jnp.percentile(SFs_broad, 15, axis = 0),
jnp.percentile(SFs_broad, 85, axis = 0),
alpha = 0.5,
color = 'tab:blue',
label = '15-85%')
#ax[0].loglog(dist, sf, label = 'Input ', marker ='.')
ax[0].loglog(dist,
SFs_real.sel(SF = 'CentroidShift').to_array()[0],
label = 'Output SF, Obs5', marker ='.')
#ax[0].loglog(dists_obs2, sf_censhift_inp_obs2, label = 'Input Obs 2', marker ='.')
#ax[0].loglog(dists, sf_censhift_out, label = 'Output', marker ='.')
ax[0].set_xlabel('Separation [pixels]')
ax[0].set_ylabel('Centroid shift SF [km2/s2]')
ax[0].legend()
#ax[0].set_ylim(2e2, 5e4)
ax[1].loglog(dist,
SFs_real.sel(SF = 'Broadening').to_array()[0],
label = 'Output SF, Obs5', marker ='.')
#ax[1].loglog(dists, sf_broad_inp_obs2, label = 'Input Obs 2', marker ='.')
ax[1].set_xlabel('Separation [pixels]')
ax[1].set_ylabel('Broadening SF [km2/s2]')
ax[1].legend()
#ax[1].set_ylim(1e2, 5e3)
<matplotlib.legend.Legend at 0x149c70f7eec0>