= (M, N) = (30, 30) shape
Direct Optimization
Let’s leave the physics simulator out of it for now…
Data
target
Let’s - for now - just try to find a pre-defined design target (which we will construct with our generator)…
= notched_square_brush(5, 1)
brush_target = new_latent_design(shape, r=42)
latent_target = transform(latent_target, brush_target)
latent_target_t = generate_feasible_design_mask(latent_target_t, brush_target)
mask_target
="Greys")
plt.imshow(mask_target, cmap
plt.colorbar() plt.show()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
input
= notched_square_brush(5, 1)
brush_input = new_latent_design(shape, r=0)
latent_input = transform(latent_input, brush_input)
latent_input_t = generate_feasible_design_mask(latent_input_t, brush_target)
mask_input
="Greys")
plt.imshow(mask_input, cmap
plt.colorbar() plt.show()
Loss Functions
mse
mse (x, y)
huber_loss
huber_loss (x, y, delta=0.5)
Optimization
the loss function defines what we’re optimizing.
def forward(latent, brush):
= transform(latent, brush)
latent_t = generate_feasible_design_mask(latent_t, brush) # differentiable through STE
design_mask return design_mask
def loss_fn(latent, brush, target_mask):
= forward(latent, brush)
design_mask return huber_loss(design_mask, target_mask)
loss_fn(latent_input, brush_input, mask_target)
Array(0.40687042, dtype=float32)
using jax, it’s easy to get the gradient function.
= jax.grad(loss_fn, argnums=0) grad_fn
let’s use an Adam optimizer
= adam(0.1)
init_fn, update_fn, params_fn = init_fn(latent_input) state
this is the optimization step:
def step_fn(step, state, brush, mask_target):
= params_fn(state)
latent = loss_fn(latent, brush, mask_target)
loss = grad_fn(latent, brush, mask_target)
grads = update_fn(step, grads, state)
optim_state return loss, optim_state
we can now loop over the optimization:
= trange(10) # reduced iterations, so that tests run faster
range_ for step in range_:
= step_fn(step, state, brush_input, mask_target)
loss, state =float(loss))
range_.set_postfix(loss= params_fn(state) latent_input
="Greys", vmin=-1, vmax=1)
plt.imshow(mask_target, cmap
plt.colorbar() plt.show()
="Greys", vmin=-1, vmax=1)
plt.imshow(forward(latent_input, brush_input), cmap
plt.colorbar() plt.show()
loss_fn(latent_input, brush_input, mask_target)
Array(0.03643624, dtype=float32)