function [s, net, tnet] = update(s0, net0, tnet0, grad)

sdim = prod(size(s0));
ndim = prod(size(net0));
tdim = prod(size(tnet0));

sstep = grad(1:sdim);
nstep = grad(sdim+(1:ndim));
tstep = grad(sdim+ndim+(1:tdim));

s = s0 + reshape(sstep, size(s0));
net = sum_structs(net0, reshape(nstep, size(net0)));
tnet = sum_structs(tnet0, reshape(tstep, size(tnet0)));
