-
Notifications
You must be signed in to change notification settings - Fork 10
feat: optimize stretch using a Hessian matrix #204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| **Added:** | ||
|
|
||
| * Optimize stretch using a Hessian matrix | ||
|
|
||
| **Changed:** | ||
|
|
||
| * <news item> | ||
|
|
||
| **Deprecated:** | ||
|
|
||
| * <news item> | ||
|
|
||
| **Removed:** | ||
|
|
||
| * <news item> | ||
|
|
||
| **Fixed:** | ||
|
|
||
| * <news item> | ||
|
|
||
| **Security:** | ||
|
|
||
| * <news item> |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -955,10 +955,7 @@ def _update_weights(self): | |
| ) | ||
| self.weights_[:, signal] = new_weight | ||
|
|
||
| def _regularize_function(self, stretch=None): | ||
| if stretch is None: | ||
| stretch = self.stretch_ | ||
|
|
||
| def _stretch_residual_and_derivatives(self, stretch): | ||
| stretched_components, d_stretch_comps, dd_stretch_comps = ( | ||
| self._compute_stretched_components(stretch=stretch) | ||
| ) | ||
|
|
@@ -972,6 +969,15 @@ def _regularize_function(self, stretch=None): | |
| ) | ||
| - self._source_matrix | ||
| ) | ||
| return residuals, d_stretch_comps, dd_stretch_comps | ||
|
|
||
| def _regularize_function(self, stretch=None): | ||
| if stretch is None: | ||
| stretch = self.stretch_ | ||
|
|
||
| residuals, d_stretch_comps, _ = self._stretch_residual_and_derivatives( | ||
| stretch | ||
| ) | ||
|
|
||
| fun = self._get_objective_function(residuals, stretch) | ||
|
|
||
|
|
@@ -986,10 +992,60 @@ def _regularize_function(self, stretch=None): | |
| @ (self._spline_smooth_operator.T @ self._spline_smooth_operator) | ||
| ) | ||
|
|
||
| # Hessian would go here | ||
|
|
||
| return fun, gra | ||
|
|
||
| def _regularize_function_hessian(self, stretch): | ||
| """Calculate the Hessian for the stretch optimization objective. | ||
|
|
||
| The Hessian combines the Gauss-Newton curvature from the stretched | ||
| component derivatives, the residual-weighted second derivatives of | ||
| those stretched components, and the quadratic smoothing penalty on | ||
| neighboring stretch factors. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| stretch : ndarray of shape (n_components, n_signals) | ||
| Stretching factors at which to evaluate the objective curvature. | ||
|
|
||
| Returns | ||
| ------- | ||
| ndarray of shape (n_components * n_signals, n_components * n_signals) | ||
| Symmetric Hessian matrix for the flattened stretch variables. | ||
| """ | ||
| residuals, d_stretch_comps, dd_stretch_comps = ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this could use some dicumentation and also maybe a more descriptive variable name. Stretch is a verb usually and I would expect a noun here. We were pretty lazy about documenting code in the early part of the project because we wanted it to get going from a low place, but we are at the point of the project where we can take the time to make the code more maintainable moving forward.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I have added a numpy style docstring as used for some of the other functions. As for renaming
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. The code is really not very readable ATM which would make it harder to maintain in the future, but I am ok to put this off for now. |
||
| self._stretch_residual_and_derivatives(stretch) | ||
| ) | ||
| n_variables = self.n_components_ * self.n_signals_ | ||
| hessian = np.zeros((n_variables, n_variables), dtype=float) | ||
|
|
||
| for signal in range(self.n_signals_): | ||
| variable_indices = ( | ||
| np.arange(self.n_components_) * self.n_signals_ + signal | ||
| ) | ||
| d_signal = d_stretch_comps[:, variable_indices] | ||
| dd_signal = dd_stretch_comps[:, variable_indices] | ||
| hessian[np.ix_(variable_indices, variable_indices)] += ( | ||
| d_signal.T @ d_signal | ||
| ) | ||
| hessian[variable_indices, variable_indices] += np.sum( | ||
| dd_signal * residuals[:, signal, None], | ||
| axis=0, | ||
| ) | ||
|
|
||
| smooth_hessian = ( | ||
| self._spline_smooth_operator.T @ self._spline_smooth_operator | ||
| ).toarray() | ||
| for comp in range(self.n_components_): | ||
| component_slice = slice( | ||
| comp * self.n_signals_, | ||
| (comp + 1) * self.n_signals_, | ||
| ) | ||
| hessian[component_slice, component_slice] += ( | ||
| self.rho * smooth_hessian | ||
| ) | ||
|
|
||
| return 0.5 * (hessian + hessian.T) | ||
|
|
||
| def _update_stretch(self): | ||
| """Updates stretching matrix using constrained optimization | ||
| (equivalent to fmincon in MATLAB).""" | ||
|
|
@@ -1009,6 +1065,30 @@ def objective(stretch_vec): | |
| gra = gra.flatten() | ||
| return fun, gra | ||
|
|
||
| def hessian(stretch_vec): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need a test for this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. I have added a test checking for some of the expected properties (right shape, symmetric, correct ordering/cross-coupling). |
||
| stretch_matrix = stretch_vec.reshape(self.stretch_.shape) | ||
| return self._regularize_function_hessian(stretch_matrix) | ||
|
|
||
| unconstrained_result = minimize( | ||
| fun=lambda stretch_vec: objective(stretch_vec)[0], | ||
| x0=stretch_flat_initial, | ||
| method="trust-exact", | ||
| jac=lambda stretch_vec: objective(stretch_vec)[1], | ||
| hess=hessian, | ||
| options={"maxiter": 300}, | ||
| ) | ||
| unconstrained_stretch = unconstrained_result.x.reshape( | ||
| self.stretch_.shape | ||
| ) | ||
| if np.all(unconstrained_stretch >= 0.1): | ||
| current_objective = self._regularize_function(self.stretch_)[0] | ||
| candidate_objective = self._regularize_function( | ||
| unconstrained_stretch | ||
| )[0] | ||
| if candidate_objective <= current_objective: | ||
| self.stretch_ = unconstrained_stretch | ||
| return | ||
|
|
||
| # Optimization constraints: lower bound 0.1, no upper bound | ||
| bounds = [ | ||
| (0.1, None) | ||
|
|
@@ -1020,6 +1100,7 @@ def objective(stretch_vec): | |
| x0=stretch_flat_initial, | ||
| method="trust-constr", # Substitute for 'trust-region-reflective' | ||
| jac=lambda stretch_vec: objective(stretch_vec)[1], # Gradient | ||
| hess=hessian, | ||
| bounds=bounds, | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always start descriptions of parameters with "The"