Skip to content

Commit

Permalink
update for state-handling
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Jan 28, 2022
1 parent 410e78c commit ab03066
Showing 1 changed file with 118 additions and 66 deletions.
184 changes: 118 additions & 66 deletions bozzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from functools import partial
import math

from meshmode.dof_array import thaw
from arraycontext import thaw
from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa
from grudge.dof_desc import DTAG_BOUNDARY
from grudge.eager import EagerDGDiscretization
Expand All @@ -47,7 +47,10 @@

from mirgecom.navierstokes import ns_operator
from mirgecom.fluid import make_conserved
from mirgecom.artificial_viscosity import (av_operator, smoothness_indicator)
from mirgecom.artificial_viscosity import (
av_laplacian_operator,
smoothness_indicator
)
from mirgecom.simutil import (
check_step,
generate_and_distribute_mesh,
Expand All @@ -65,18 +68,24 @@
euler_step)
from mirgecom.steppers import advance_state
from mirgecom.boundary import (
PrescribedInviscidBoundary,
PrescribedFluidBoundary,
IsothermalNoSlipBoundary
)
from mirgecom.initializers import (Uniform, PlanarDiscontinuity)
from mirgecom.eos import IdealSingleGas
from mirgecom.transport import SimpleTransport

from mirgecom.gas_model import (
GasModel,
make_fluid_state
)
from logpyle import IntervalTimer, set_dt
from mirgecom.euler import extract_vars_for_logging, units_for_logging
from mirgecom.logging_quantities import (
initialize_logmgr, logmgr_add_many_discretization_quantities,
logmgr_add_cl_device_info, logmgr_set_time, LogUserQuantity,
logmgr_add_cl_device_info,
logmgr_set_time,
LogUserQuantity,
logmgr_add_device_memory_usage,
set_sim_state
)

Expand All @@ -93,6 +102,7 @@ class MyRuntimeError(RuntimeError):
@mpi_entry_point
def main(ctx_factory=cl.create_some_context, restart_filename=None,
use_profiling=False, use_logmgr=False, user_input_file=None,
use_overintegration=False,
actx_class=PyOpenCLArrayContext, casename=None):
"""Drive the Y0 nozzle example."""
cl_ctx = ctx_factory()
Expand All @@ -119,6 +129,9 @@ def main(ctx_factory=cl.create_some_context, restart_filename=None,
queue,
allocator=cl_tools.MemoryPool(cl_tools.ImmediateAllocator(queue)))

from mirgecom.simutil import global_reduce as _global_reduce
global_reduce = partial(_global_reduce, comm=comm)

# Most of these can be set by the user input file

# default i/o junk frequencies
Expand Down Expand Up @@ -368,9 +381,9 @@ def getIsentropicTemperature(mach, T0, gamma):
transport_model = SimpleTransport(viscosity=mu, thermal_conductivity=kappa)
eos = IdealSingleGas(
gamma=gamma_CO2,
gas_const=R_CO2,
transport_model=transport_model
gas_const=R_CO2
)
gas_model = GasModel(eos=eos, transport=transport_model)
bulk_init = PlanarDiscontinuity(dim=dim, disc_location=-.30, sigma=0.005,
temperature_left=temp_inflow, temperature_right=temp_bkrnd,
pressure_left=pres_inflow, pressure_right=pres_bkrnd,
Expand Down Expand Up @@ -451,14 +464,34 @@ def __call__(self, x_vec, *, time=0, eos, **kwargs):
velocity=vel_outflow
)

inflow = PrescribedInviscidBoundary(fluid_solution_func=inflow_init)
outflow = PrescribedInviscidBoundary(fluid_solution_func=outflow_init)

def _boundary_state_func(discr, state_minus, btag, gas_model,
actx, init_func, **kwargs):
bnd_discr = discr.discr_from_dd(btag)
nodes = thaw(bnd_discr.nodes(), actx)
return make_fluid_state(init_func(x_vec=nodes, eos=gas_model.eos,
**kwargs), gas_model,
temperature_seed=state_minus.temperature)

def _inflow_state_func(discr, btag, gas_model, state_minus, **kwargs):
return _boundary_state_func(discr, state_minus, btag, gas_model,
state_minus.array_context,
inflow_init, **kwargs)

def _outflow_state_func(discr, btag, gas_model, state_minus, **kwargs):
return _boundary_state_func(discr, state_minus, btag, gas_model,
state_minus.array_context,
outflow_init, **kwargs)


inflow = PrescribedFluidBoundary(boundary_state_func=_inflow_state_func)
outflow = PrescribedFluidBoundary(boundary_state_func=_outflow_state_func)
wall = IsothermalNoSlipBoundary()

boundaries = {
DTAG_BOUNDARY("Inflow"): inflow,
DTAG_BOUNDARY("Outflow"): outflow,
DTAG_BOUNDARY("Wall"): wall
DTAG_BOUNDARY("Wall"): wall,
}

viz_path = "viz_data/"
Expand Down Expand Up @@ -511,12 +544,20 @@ def __call__(self, x_vec, *, time=0, eos, **kwargs):
if grid_only:
return 0

discr = EagerDGDiscretization(actx,
local_mesh,
order=order,
mpi_communicator=comm)

nodes = thaw(actx, discr.nodes())
from grudge.dof_desc import DISCR_TAG_BASE, DISCR_TAG_QUAD
from meshmode.discretization.poly_element import \
default_simplex_group_factory, QuadratureSimplexGroupFactory

discr = EagerDGDiscretization(
actx, local_mesh,
discr_tag_to_group_factory={
DISCR_TAG_BASE: default_simplex_group_factory(
base_dim=local_mesh.dim, order=order),
DISCR_TAG_QUAD: QuadratureSimplexGroupFactory(2*order + 1)
},
mpi_communicator=comm
)
nodes = thaw(discr.nodes(), actx)

if discr_only:
return 0
Expand All @@ -525,6 +566,11 @@ def __call__(self, x_vec, *, time=0, eos, **kwargs):
from mirgecom.simutil import boundary_report
boundary_report(discr, boundaries, f"{casename}_boundaries_np{nparts}.yaml")

if use_overintegration:
quadrature_tag = DISCR_TAG_QUAD
else:
quadrature_tag = None

# initialize the sponge field
def gen_sponge():
thickness = 0.15
Expand All @@ -539,12 +585,12 @@ def gen_sponge():

zeros = 0 * nodes[0]
sponge_sigma = gen_sponge()
ref_state = bulk_init(x_vec=nodes, eos=eos, time=0.0)
ref_cv = bulk_init(x_vec=nodes, eos=eos, time=0.0)

if restart_filename:
if rank == 0:
logging.info("Restarting soln.")
current_state = restart_data["state"]
current_cv = restart_data["cv"]
if restart_order != order:
restart_discr = EagerDGDiscretization(
actx,
Expand All @@ -557,15 +603,14 @@ def gen_sponge():
discr.discr_from_dd("vol"),
restart_discr.discr_from_dd("vol")
)
restart_state = restart_data["state"]
current_state = connection(restart_state)
current_cv = connection(restart_data["cv"])
else:
if rank == 0:
logging.info("Initializing soln.")
# for Discontinuity initial conditions
current_state = bulk_init(x_vec=nodes, eos=eos, time=0.0)
# for uniform background initial condition
#current_state = bulk_init(nodes, eos=eos)
current_cv = bulk_init(x_vec=nodes, eos=eos, time=0.0)

current_state = make_fluid_state(current_cv, gas_model)

vis_timer = None
log_cfl = None
Expand Down Expand Up @@ -626,42 +671,46 @@ def gen_sponge():
if rank == 0:
logger.info(init_message)

def _mk_flu_state(cv):
return make_fluid_state(cv=cv, gas_model=gas_model)

make_fluid_state_cmp = actx.compile(_mk_flu_state)

def sponge(cv, cv_ref, sigma):
return (sigma*(cv_ref - cv))

def my_rhs(t, state):
return (
ns_operator(discr, cv=state, t=t, boundaries=boundaries, eos=eos)
+ make_conserved(
dim, q=av_operator(discr, q=state.join(), boundaries=boundaries,
boundary_kwargs={"time": t, "eos": eos},
alpha=alpha_sc, s0=s0_sc, kappa=kappa_sc)
) + sponge(cv=state, cv_ref=ref_state, sigma=sponge_sigma)
cv = state
fluid_state = make_fluid_state(cv=cv, gas_model=gas_model)
cv_rhs = (
ns_operator(discr, state=fluid_state, time=t, boundaries=boundaries,
gas_model=gas_model)
+ av_laplacian_operator(discr, fluid_state=fluid_state,
boundaries=boundaries,
boundary_kwargs={"time": t,
"gas_model": gas_model},
alpha=alpha_sc, s0=s0_sc, kappa=kappa_sc)
+ sponge(cv=fluid_state.cv, cv_ref=ref_cv, sigma=sponge_sigma)
)
return cv_rhs

def my_write_viz(step, t, dt, state, dv=None, tagged_cells=None, ts_field=None):
if dv is None:
dv = eos.dependent_vars(state)
if tagged_cells is None:
tagged_cells = smoothness_indicator(discr, state.mass, s0=s0_sc,
kappa=kappa_sc)
if ts_field is None:
ts_field, cfl, dt = my_get_timestep(t, dt, state, True)

viz_fields = [("cv", state),
def my_write_viz(step, t, dt, cv, dv, ts_field):
tagged_cells = smoothness_indicator(discr, cv.mass, s0=s0_sc,
kappa=kappa_sc)
viz_fields = [("cv", cv),
("dv", dv),
("sponge_sigma", gen_sponge()),
("tagged_cells", tagged_cells),
("dt" if constant_cfl else "cfl", ts_field)]
write_visfile(discr, viz_fields, visualizer, vizname=vizname,
step=step, t=t, overwrite=True)

def my_write_restart(step, t, state):
def my_write_restart(step, t, cv):
restart_fname = restart_pattern.format(cname=casename, step=step, rank=rank)
if restart_fname != restart_filename:
restart_data = {
"local_mesh": local_mesh,
"state": state,
"cv": cv,
"t": t,
"step": step,
"order": order,
Expand All @@ -683,33 +732,37 @@ def my_health_check(dv):

return health_error

def my_get_timestep(t, dt, state, force=0):
def my_get_timestep(t, dt, state):
t_remaining = max(0, t_final - t)
cfl = current_cfl
dt = dt
ts_field = None

if constant_cfl:
from mirgecom.viscous import get_viscous_timestep
ts_field = current_cfl * get_viscous_timestep(discr, eos=eos, cv=state)
ts_field = current_cfl * get_viscous_timestep(discr, state=state)
from grudge.op import nodal_min
dt = nodal_min(discr, "vol", ts_field)
elif force:
else:
from mirgecom.viscous import get_viscous_cfl
ts_field = get_viscous_cfl(discr, eos=eos, dt=dt, cv=state)
ts_field = get_viscous_cfl(discr, dt=dt, state=state)
from grudge.op import nodal_max
cfl = nodal_max(discr, "vol", ts_field)

return ts_field, cfl, min(t_remaining, dt)

def my_pre_step(step, t, dt, state):
cv = state
fluid_state = make_fluid_state_cmp(cv=cv)
dv = fluid_state.dv

try:
dv = None

if logmgr:
logmgr.tick_before()

ts_field, cfl, dt = my_get_timestep(t, dt, state, logDependent)
ts_field, cfl, dt = my_get_timestep(t, dt, fluid_state)

if logDependent:
log_cfl.set_quantity(cfl)

Expand All @@ -718,28 +771,24 @@ def my_pre_step(step, t, dt, state):
do_health = check_step(step=step, interval=nhealth)

if do_health:
dv = eos.dependent_vars(state)
from mirgecom.simutil import allsync
health_errors = allsync(my_health_check(dv), comm,
op=MPI.LOR)
health_errors = global_reduce(my_health_check(dv), op="lor")

if health_errors:
if rank == 0:
logger.info("Fluid solution failed health check.")
raise MyRuntimeError("Failed simulation health check.")

if do_restart:
my_write_restart(step=step, t=t, state=state)
my_write_restart(step=step, t=t, cv=cv)

if do_viz:
if dv is None:
dv = eos.dependent_vars(state)
my_write_viz(step=step, t=t, dt=dt, state=state, dv=dv)
my_write_viz(step=step, t=t, dt=dt, cv=cv, dv=dv, ts_field=ts_field)

except MyRuntimeError:
if rank == 0:
logger.info("Errors detected; attempting graceful exit.")
my_write_viz(step=step, t=t, dt=dt, state=state)
my_write_restart(step=step, t=t, state=state)
my_write_viz(step=step, t=t, dt=dt, cv=cv, dv=dv, ts_field=ts_field)
my_write_restart(step=step, t=t, cv=cv)
raise

return state, dt
Expand All @@ -749,7 +798,7 @@ def my_post_step(step, t, dt, state):
# imo this is a design/scope flaw
if logmgr:
set_dt(logmgr, dt)
set_sim_state(logmgr, dim, state, eos)
set_sim_state(logmgr, dim, state, gas_model.eos)
logmgr.tick_after()
return state, dt

Expand All @@ -760,22 +809,25 @@ def my_post_step(step, t, dt, state):
logging.info("Stepping.")

current_dt = get_sim_timestep(discr, current_state, current_t, current_dt,
current_cfl, eos, t_final, constant_cfl)
current_cfl, t_final, constant_cfl)

(current_step, current_t, current_state) = \
(current_step, current_t, current_cv) = \
advance_state(rhs=my_rhs, timestepper=timestepper,
pre_step_callback=my_pre_step,
post_step_callback=my_post_step,
state=current_state, dt=current_dt,
state=current_cv, dt=current_dt,
t_final=t_final, t=current_t, istep=current_step)

# Dump the final data
if rank == 0:
logger.info("Checkpointing final state ...")
final_dv = eos.dependent_vars(current_state)
my_write_viz(step=current_step, t=current_t, dt=current_dt, state=current_state,
dv=final_dv)
my_write_restart(step=current_step, t=current_t, state=current_state)
final_state = make_fluid_state(current_cv, gas_model)
final_dv = final_state.dv
ts_field, cfl, dt = my_get_timestep(current_t, current_dt, final_state)

my_write_viz(step=current_step, t=current_t, dt=current_dt, cv=current_cv,
dv=final_dv, ts_field=ts_field)
my_write_restart(step=current_step, t=current_t, cv=current_cv)

if logmgr:
logmgr.close()
Expand Down

0 comments on commit ab03066

Please sign in to comment.