From 2d17ba2779f792c78213586c5ffd3e7061308267 Mon Sep 17 00:00:00 2001 From: John Halloran Date: Mon, 22 Jun 2026 09:35:07 -0700 Subject: [PATCH 1/4] feat: final parity changes with MATLAB --- news/warnings-fix.rst | 24 ++++ src/diffpy/stretched_nmf/snmf_class.py | 151 +++++++++++++++++++++++-- 2 files changed, 166 insertions(+), 9 deletions(-) create mode 100644 news/warnings-fix.rst diff --git a/news/warnings-fix.rst b/news/warnings-fix.rst new file mode 100644 index 00000000..6f85e12f --- /dev/null +++ b/news/warnings-fix.rst @@ -0,0 +1,24 @@ +**Added:** + +* Use faster iterations when possible +* Allow storing of objective log + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* Remove warning by explicitly interpreting cubic root as complex + +**Security:** + +* diff --git a/src/diffpy/stretched_nmf/snmf_class.py b/src/diffpy/stretched_nmf/snmf_class.py index 3269570d..6a4f26e9 100644 --- a/src/diffpy/stretched_nmf/snmf_class.py +++ b/src/diffpy/stretched_nmf/snmf_class.py @@ -1,3 +1,4 @@ +import csv import time import cvxpy as cp @@ -77,7 +78,7 @@ class SNMFOptimizer: def __init__( self, n_components=None, - max_iter=500, + max_iter=200, min_iter=20, tol=5e-7, rho=0, @@ -85,6 +86,8 @@ def __init__( random_state=None, show_plots=False, verbose=False, + stretch_max_iter=8, + stretch_slow_iter=200, ): """Initialize an instance of sNMF with estimator hyperparameters. @@ -116,10 +119,23 @@ def __init__( created by the decomposition. Optional. show_plots : bool Enables plotting at each step of the decomposition. Optional. + stretch_max_iter : int + Maximum number of projected-gradient stretch steps per outer-loop + iteration. Optional. + stretch_slow_iter : int + Number of initial outer-loop stretch updates to solve with the + slower constrained optimizer before switching to projected-gradient + stretch updates. Optional. """ if n_components is not None and n_components < 1: raise ValueError("n_components must be a positive integer.") + if not isinstance(stretch_max_iter, int) or stretch_max_iter < 1: + raise ValueError("stretch_max_iter must be a positive integer.") + if not isinstance(stretch_slow_iter, int) or stretch_slow_iter < 0: + raise ValueError( + "stretch_slow_iter must be a non-negative integer." + ) self.n_components = n_components self.max_iter = max_iter @@ -130,9 +146,12 @@ def __init__( self.random_state = random_state self.show_plots = show_plots self.verbose = verbose + self.stretch_max_iter = stretch_max_iter + self.stretch_slow_iter = stretch_slow_iter self._rng = np.random.default_rng(self.random_state) self._plotter = SNMFPlotter() if self.show_plots else None + self._stretch_step_size = None def _initialize_factors( self, @@ -212,6 +231,7 @@ def _initialize_factors( self._init_components = self.components_.copy() self._init_weights = self.weights_.copy() self._init_stretch = self.stretch_.copy() + self._stretch_step_size = None # Second-order spline: Tridiagonal (-2 on diags, 1 on sub/superdiags) self._spline_smooth_operator = 0.25 * diags( @@ -388,6 +408,43 @@ def fit( return self + def save_objective_log(self, filename): + """Save the objective log to a tab-delimited text file. + + Parameters + ---------- + filename : str or path-like + Output file path. + """ + if not hasattr(self, "objective_log"): + raise ValueError("Cannot save objective_log before calling fit.") + + fieldnames = ("step", "iteration", "objective", "dt_ms") + with open(filename, "w", newline="") as fp: + writer = csv.DictWriter( + fp, + fieldnames=fieldnames, + delimiter="\t", + ) + writer.writeheader() + previous_timestamp = None + for row in self.objective_log: + timestamp = row["timestamp"] + dt_ms = ( + 0.0 + if previous_timestamp is None + else 1000 * (timestamp - previous_timestamp) + ) + writer.writerow( + { + "step": row["step"], + "iteration": row["iteration"], + "objective": f"{row['objective']:.6E}", + "dt_ms": f"+{dt_ms:.3g}", + } + ) + previous_timestamp = timestamp + def _normalize_results(self): if self.verbose: print("\nNormalizing results after convergence...") @@ -415,7 +472,8 @@ def _normalize_results(self): self._objective_history = [self.objective_function_] self._outer_iter = 0 self._inner_iter = 0 - for outiter in range(self.max_iter): + normalization_max_iter = max(self.max_iter, 100) + for outiter in range(normalization_max_iter): self._outer_iter = outiter if outiter == 1: self._inner_iter = ( @@ -996,7 +1054,22 @@ def _regularize_function(self, stretch=None): return fun, gra - def _update_stretch(self): + @staticmethod + def _project_stretch(stretch, lower_bound=0.1): + return np.maximum(stretch, lower_bound) + + def _initial_stretch_step_size(self, stretch, gradient): + if self._stretch_step_size is not None: + return self._stretch_step_size + + gradient_norm = np.linalg.norm(gradient, "fro") + if gradient_norm == 0 or not np.isfinite(gradient_norm): + return 1.0 + + stretch_norm = max(np.linalg.norm(stretch, "fro"), 1.0) + return 0.05 * stretch_norm / gradient_norm + + def _update_stretch_trust_constr(self): """Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB).""" @@ -1032,6 +1105,68 @@ def objective(stretch_vec): # Update stretch with the optimized values self.stretch_ = result.x.reshape(self.stretch_.shape) + self._stretch_step_size = None + + def _update_stretch_projected_gradient(self): + if self.verbose: + print("Updating stretch factors...") + + for _ in range(self.stretch_max_iter): + stretch = self.stretch_ + current_objective, gradient = self._regularize_function(stretch) + step_size = self._initial_stretch_step_size(stretch, gradient) + best_stretch = stretch + best_objective = current_objective + + for _ in range(20): + candidate_stretch = self._project_stretch( + stretch - step_size * gradient + ) + step = candidate_stretch - stretch + step_norm_sq = np.linalg.norm(step, "fro") ** 2 + if step_norm_sq == 0: + step_size *= 0.5 + continue + + candidate_residuals = self._get_residual_matrix( + stretch=candidate_stretch + ) + candidate_objective = self._get_objective_function( + residuals=candidate_residuals, + stretch=candidate_stretch, + ) + if candidate_objective < best_objective: + best_stretch = candidate_stretch + best_objective = candidate_objective + + sufficient_decrease = ( + current_objective - 1e-4 * step_norm_sq / step_size + ) + if candidate_objective <= sufficient_decrease: + break + + step_size *= 0.5 + + self._stretch_step_size = step_size + if best_objective >= current_objective: + break + + self.stretch_ = best_stretch + + def _update_stretch(self): + """Update stretching factors with a hybrid strategy. + + The first ``stretch_slow_iter`` outer-loop updates use the original + constrained nonlinear optimizer to find a good non-convex basin. Later + updates switch to the lightweight Algorithm 2 style path: compute the + stretch gradient, take linearized proximal steps, then project back + onto the feasible stretch range. + """ + + if getattr(self, "_outer_iter", 0) < self.stretch_slow_iter: + self._update_stretch_trust_constr() + else: + self._update_stretch_projected_gradient() @staticmethod def _compute_objective_function( @@ -1105,9 +1240,9 @@ def _cubic_largest_real_root(p, q): # Compute discriminant delta = (q / 2) ** 2 + (p / 3) ** 3 - # Compute square root of delta safely - d = np.where(delta >= 0, np.sqrt(delta), np.sqrt(np.abs(delta)) * 1j) - # TODO: this line causes a warning but results seem correct + # Match the MATLAB helper's real-root branch without evaluating invalid + # real square roots for entries that are handled in complex arithmetic. + d = np.sqrt(delta.astype(complex)) # Compute cube roots safely a1 = (-q / 2 + d) ** (1 / 3) @@ -1123,9 +1258,7 @@ def _cubic_largest_real_root(p, q): # Take the largest real root element-wise when delta < 0 r_roots = np.stack([np.real(y1), np.real(y2), np.real(y3)], axis=0) - y = np.max(r_roots, axis=0) * ( - delta < 0 - ) # Keep only real roots when delta < 0 + y = np.where(delta < 0, np.max(r_roots, axis=0), 0.0) return y From a845dd7b51efc64a4dbb0e0338e5cb6fdf274696 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jun 2026 18:09:27 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit hooks --- src/diffpy/stretched_nmf/snmf_class.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffpy/stretched_nmf/snmf_class.py b/src/diffpy/stretched_nmf/snmf_class.py index 210d2ac4..9e46f6ea 100644 --- a/src/diffpy/stretched_nmf/snmf_class.py +++ b/src/diffpy/stretched_nmf/snmf_class.py @@ -1155,7 +1155,6 @@ def _update_stretch(self): stretch gradient, take linearized proximal steps, then project back onto the feasible stretch range. """ - if getattr(self, "_outer_iter", 0) < self.stretch_slow_iter: self._update_stretch_trust_constr() else: From af620e2e38fd9b56e6014c7d7a32f40c47788cbc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 27 Jun 2026 02:49:59 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit hooks --- src/diffpy/stretched_nmf/snmf_class.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffpy/stretched_nmf/snmf_class.py b/src/diffpy/stretched_nmf/snmf_class.py index cfebfd57..08d02db7 100644 --- a/src/diffpy/stretched_nmf/snmf_class.py +++ b/src/diffpy/stretched_nmf/snmf_class.py @@ -446,7 +446,6 @@ def save_objective_log(self, filename): ) previous_timestamp = timestamp - def _normalize_results(self): if self.verbose: print("\nNormalizing results after convergence...") From d808d6f82e870cab2bbd11b04ddfdfe3e2fff636 Mon Sep 17 00:00:00 2001 From: John Halloran Date: Sat, 27 Jun 2026 23:00:31 -0700 Subject: [PATCH 4/4] fix: restore docstring that got lost in the merge --- src/diffpy/stretched_nmf/snmf_class.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/diffpy/stretched_nmf/snmf_class.py b/src/diffpy/stretched_nmf/snmf_class.py index a34051cf..0fa169a9 100644 --- a/src/diffpy/stretched_nmf/snmf_class.py +++ b/src/diffpy/stretched_nmf/snmf_class.py @@ -1131,6 +1131,23 @@ def _regularize_function(self, stretch=None): return fun, gra def _regularize_function_hessian(self, stretch): + """Calculate the Hessian for the stretch optimization objective. + + The Hessian combines the Gauss-Newton curvature from the stretched + component derivatives, the residual-weighted second derivatives of + those stretched components, and the quadratic smoothing penalty on + neighboring stretch factors. + + Parameters + ---------- + stretch : ndarray of shape (n_components, n_signals) + Stretching factors at which to evaluate the objective curvature. + + Returns + ------- + ndarray of shape (n_components * n_signals, n_components * n_signals) + Symmetric Hessian matrix for the flattened stretch variables. + """ residuals, d_stretch_comps, dd_stretch_comps = ( self._stretch_residual_and_derivatives(stretch) )