import numpy as np
import ceviche_challenges
from ceviche_challenges import units as u
import ceviche
from inverse_design.brushes import notched_square_brush, circular_brush
from inverse_design.conditional_generator import (
new_latent_design, transform
)from tqdm.notebook import trange
import autograd
import autograd.numpy as npa
Ceviche Challenges
Integration with ceviche challenges to test the optimization
import jax
import jax.numpy as jnp
from javiche import jaxit
import matplotlib.pylab as plt
import numpy as np
from inverse_design.local_generator import generate_feasible_design_mask
from jax.example_libraries.optimizers import adam
= ceviche_challenges.waveguide_bend.prefabs.waveguide_bend_2umx2um_spec(
spec =400*u.nm, variable_region_size=(1600*u.nm, 1600*u.nm), cladding_permittivity=2.25
wg_width
)= ceviche_challenges.waveguide_bend.prefabs.waveguide_bend_sim_params(resolution = 25 * u.nm)
params = ceviche_challenges.waveguide_bend.model.WaveguideBendModel(params, spec) model
def forward(latent_weights, brush):
= transform(latent_weights, brush) #.reshape((Nx, Ny))
latent_t = generate_feasible_design_mask(latent_t,
design_mask =False)
brush, verbose= (design_mask+1.0)/2.0
design return design
= circular_brush(5)
brush = new_latent_design(model.design_variable_shape, bias=0.1, r=1, r_scale=1e-3) latent
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
@jaxit()
def inner_loss_fn(design):
= model.simulate(design)
s_params, fields = npa.abs(s_params[:, 0, 0])
s11 = npa.abs(s_params[:, 0, 1])
s21
global debug_fields
= fields
debug_fields global debug_design
= design
debug_design
return npa.mean(s11) - npa.mean(s21)
def loss_fn(latent):
= forward(latent, brush)
design return inner_loss_fn(design)
# Number of epochs in the optimization
= 150
Nsteps # Step size for the Adam optimizer
def step_size(idx):
"""reducing the stepsize linearly for Nsteps (stabilize afterwards just in case)"""
= 0.1
start = 5e-3
stop return start*(stop/start)**(idx/Nsteps)
= 0.01 step_size
def visualize_latent(latent):
global debug_design, debug_fields
= forward(latent, brush)
design = model.simulate(design)
s_params, fields = design
debug_design = fields
debug_fields
visualize_debug()
def visualize_debug():
global debug_design, debug_fields
if not isinstance(debug_fields, np.ndarray):
= debug_fields._value
debug_fields = debug_design._value
debug_design abs(np.squeeze(np.asarray(debug_fields)), model.density(np.asarray(debug_design)))
ceviche.viz.
plt.grid() plt.show()
= jax.grad(loss_fn)
grad_fn
= adam(step_size)
init_fn, update_fn, params_fn = init_fn(latent) #.flatten()
state #value_and_grad seems to have a problem. Figure out why!
def step_fn(step, state):
= params_fn(state) # we need autograd arrays here...
latent = grad_fn(latent)
grads = loss_fn(latent)
loss #loss = loss_fn(latent)
= update_fn(step, grads, state)
optim_state # optim_latent = params_fn(optim_state)
# optim_latent = optim_latent/optim_latent.std()
visualize_debug()return loss, optim_state
= params_fn(state)
latent visualize_latent(latent)
= trange(Nsteps)
range_ = np.ndarray(Nsteps)
losses for step in range_:
= step_fn(step, state)
loss, state = loss
losses[step] =float(loss)) range_.set_postfix(loss
= params_fn(state)
latent = forward(latent, brush)
design = model.simulate(design)
s_params, fields = model.epsilon_r(design)
epsr abs(np.squeeze(fields), model.density(design))
ceviche.viz. plt.grid()