diff --git a/news/hessian-fix.rst b/news/hessian-fix.rst new file mode 100644 index 0000000..65574eb --- /dev/null +++ b/news/hessian-fix.rst @@ -0,0 +1,23 @@ +**Added:** + +* Optimize stretch using a Hessian matrix + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/src/diffpy/stretched_nmf/snmf_class.py b/src/diffpy/stretched_nmf/snmf_class.py index 6c68c28..a2447d6 100644 --- a/src/diffpy/stretched_nmf/snmf_class.py +++ b/src/diffpy/stretched_nmf/snmf_class.py @@ -955,10 +955,7 @@ def _update_weights(self): ) self.weights_[:, signal] = new_weight - def _regularize_function(self, stretch=None): - if stretch is None: - stretch = self.stretch_ - + def _stretch_residual_and_derivatives(self, stretch): stretched_components, d_stretch_comps, dd_stretch_comps = ( self._compute_stretched_components(stretch=stretch) ) @@ -972,6 +969,15 @@ def _regularize_function(self, stretch=None): ) - self._source_matrix ) + return residuals, d_stretch_comps, dd_stretch_comps + + def _regularize_function(self, stretch=None): + if stretch is None: + stretch = self.stretch_ + + residuals, d_stretch_comps, _ = self._stretch_residual_and_derivatives( + stretch + ) fun = self._get_objective_function(residuals, stretch) @@ -986,10 +992,60 @@ def _regularize_function(self, stretch=None): @ (self._spline_smooth_operator.T @ self._spline_smooth_operator) ) - # Hessian would go here - return fun, gra + def _regularize_function_hessian(self, stretch): + """Calculate the Hessian for the stretch optimization objective. + + The Hessian combines the Gauss-Newton curvature from the stretched + component derivatives, the residual-weighted second derivatives of + those stretched components, and the quadratic smoothing penalty on + neighboring stretch factors. + + Parameters + ---------- + stretch : ndarray of shape (n_components, n_signals) + Stretching factors at which to evaluate the objective curvature. + + Returns + ------- + ndarray of shape (n_components * n_signals, n_components * n_signals) + Symmetric Hessian matrix for the flattened stretch variables. + """ + residuals, d_stretch_comps, dd_stretch_comps = ( + self._stretch_residual_and_derivatives(stretch) + ) + n_variables = self.n_components_ * self.n_signals_ + hessian = np.zeros((n_variables, n_variables), dtype=float) + + for signal in range(self.n_signals_): + variable_indices = ( + np.arange(self.n_components_) * self.n_signals_ + signal + ) + d_signal = d_stretch_comps[:, variable_indices] + dd_signal = dd_stretch_comps[:, variable_indices] + hessian[np.ix_(variable_indices, variable_indices)] += ( + d_signal.T @ d_signal + ) + hessian[variable_indices, variable_indices] += np.sum( + dd_signal * residuals[:, signal, None], + axis=0, + ) + + smooth_hessian = ( + self._spline_smooth_operator.T @ self._spline_smooth_operator + ).toarray() + for comp in range(self.n_components_): + component_slice = slice( + comp * self.n_signals_, + (comp + 1) * self.n_signals_, + ) + hessian[component_slice, component_slice] += ( + self.rho * smooth_hessian + ) + + return 0.5 * (hessian + hessian.T) + def _update_stretch(self): """Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB).""" @@ -1009,6 +1065,30 @@ def objective(stretch_vec): gra = gra.flatten() return fun, gra + def hessian(stretch_vec): + stretch_matrix = stretch_vec.reshape(self.stretch_.shape) + return self._regularize_function_hessian(stretch_matrix) + + unconstrained_result = minimize( + fun=lambda stretch_vec: objective(stretch_vec)[0], + x0=stretch_flat_initial, + method="trust-exact", + jac=lambda stretch_vec: objective(stretch_vec)[1], + hess=hessian, + options={"maxiter": 300}, + ) + unconstrained_stretch = unconstrained_result.x.reshape( + self.stretch_.shape + ) + if np.all(unconstrained_stretch >= 0.1): + current_objective = self._regularize_function(self.stretch_)[0] + candidate_objective = self._regularize_function( + unconstrained_stretch + )[0] + if candidate_objective <= current_objective: + self.stretch_ = unconstrained_stretch + return + # Optimization constraints: lower bound 0.1, no upper bound bounds = [ (0.1, None) @@ -1020,6 +1100,7 @@ def objective(stretch_vec): x0=stretch_flat_initial, method="trust-constr", # Substitute for 'trust-region-reflective' jac=lambda stretch_vec: objective(stretch_vec)[1], # Gradient + hess=hessian, bounds=bounds, ) diff --git a/tests/test_snmf_optimizer.py b/tests/test_snmf_optimizer.py index 39406cf..0d2bedf 100644 --- a/tests/test_snmf_optimizer.py +++ b/tests/test_snmf_optimizer.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from scipy.sparse import csr_matrix from diffpy.stretched_nmf.snmf_class import SNMFOptimizer @@ -110,3 +111,47 @@ def test_compute_objective_function(inputs, expected): spline_smooth_operator=operator, ) assert np.isclose(result, expected) + + +def test_regularize_function_hessian_has_expected_structure(): + model = SNMFOptimizer(n_components=2, rho=0.5) + model.n_components_ = 2 + model.n_signals_ = 3 + model._spline_smooth_operator = csr_matrix( + [[1.0, -1.0, 0.0], [0.0, 1.0, -1.0]] + ) + + residuals = np.array([[2.0, -1.0, 4.0], [1.0, 3.0, -2.0]]) + d_stretch_comps = np.array( + [ + [1.0, 0.0, 1.0, 2.0, 0.0, 1.0], + [0.0, 1.0, 1.0, 0.0, 3.0, -1.0], + ] + ) + dd_stretch_comps = np.array( + [ + [0.5, 1.0, 0.0, 1.0, 0.0, -0.5], + [1.0, 0.0, 0.25, 0.0, 1.0, 0.5], + ] + ) + model._stretch_residual_and_derivatives = lambda stretch: ( + residuals, + d_stretch_comps, + dd_stretch_comps, + ) + + hessian = model._regularize_function_hessian(np.ones((2, 3))) + + expected = np.array( + [ + [3.5, -0.5, 0.0, 2.0, 0.0, 0.0], + [-0.5, 1.0, -0.5, 0.0, 3.0, 0.0], + [0.0, -0.5, 2.0, 0.0, 0.0, 0.0], + [2.0, 0.0, 0.0, 6.5, -0.5, 0.0], + [0.0, 3.0, 0.0, -0.5, 13.0, -0.5], + [0.0, 0.0, 0.0, 0.0, -0.5, -0.5], + ] + ) + assert hessian.shape == (6, 6) + assert np.allclose(hessian, hessian.T) + assert np.allclose(hessian, expected)