Direct Optimization

Let’s leave the physics simulator out of it for now…

Data

shape = (M, N) = (30, 30)

target

Let’s - for now - just try to find a pre-defined design target (which we will construct with our generator)…

brush_target = notched_square_brush(5, 1)
latent_target = new_latent_design(shape, r=42)
latent_target_t = transform(latent_target, brush_target)
mask_target = generate_feasible_design_mask(latent_target_t, brush_target)

plt.imshow(mask_target, cmap="Greys")
plt.colorbar()
plt.show()
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

input

brush_input = notched_square_brush(5, 1)
latent_input = new_latent_design(shape, r=0)
latent_input_t = transform(latent_input, brush_input)
mask_input = generate_feasible_design_mask(latent_input_t, brush_target)

plt.imshow(mask_input, cmap="Greys")
plt.colorbar()
plt.show()

Loss Functions


source

mse

 mse (x, y)

source

huber_loss

 huber_loss (x, y, delta=0.5)

Optimization

the loss function defines what we’re optimizing.

def forward(latent, brush):
    latent_t = transform(latent, brush)
    design_mask = generate_feasible_design_mask(latent_t, brush) # differentiable through STE
    return design_mask
    
def loss_fn(latent, brush, target_mask):
    design_mask = forward(latent, brush)
    return huber_loss(design_mask, target_mask)

loss_fn(latent_input, brush_input, mask_target)
Array(0.40687042, dtype=float32)

using jax, it’s easy to get the gradient function.

grad_fn = jax.grad(loss_fn, argnums=0)

let’s use an Adam optimizer

init_fn, update_fn, params_fn = adam(0.1)
state = init_fn(latent_input)

this is the optimization step:

def step_fn(step, state, brush, mask_target):
    latent = params_fn(state)
    loss = loss_fn(latent, brush, mask_target)
    grads = grad_fn(latent, brush, mask_target)
    optim_state = update_fn(step, grads, state)
    return loss, optim_state

we can now loop over the optimization:

range_ = trange(10) # reduced iterations, so that tests run faster
for step in range_:
    loss, state = step_fn(step, state, brush_input, mask_target)
    range_.set_postfix(loss=float(loss))
latent_input = params_fn(state)
plt.imshow(mask_target, cmap="Greys", vmin=-1, vmax=1)
plt.colorbar()
plt.show()

plt.imshow(forward(latent_input, brush_input), cmap="Greys", vmin=-1, vmax=1)
plt.colorbar()
plt.show()

loss_fn(latent_input, brush_input, mask_target)
Array(0.03643624, dtype=float32)