diff --git a/news/warnings-fix.rst b/news/warnings-fix.rst new file mode 100644 index 0000000..6f85e12 --- /dev/null +++ b/news/warnings-fix.rst @@ -0,0 +1,24 @@ +**Added:** + +* Use faster iterations when possible +* Allow storing of objective log + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* Remove warning by explicitly interpreting cubic root as complex + +**Security:** + +* diff --git a/src/diffpy/stretched_nmf/snmf_class.py b/src/diffpy/stretched_nmf/snmf_class.py index 7099312..0fa169a 100644 --- a/src/diffpy/stretched_nmf/snmf_class.py +++ b/src/diffpy/stretched_nmf/snmf_class.py @@ -1,3 +1,4 @@ +import csv import time import cvxpy as cp @@ -77,7 +78,7 @@ class SNMFOptimizer: def __init__( self, n_components=None, - max_iter=500, + max_iter=200, min_iter=20, tol=5e-7, rho=0, @@ -85,6 +86,8 @@ def __init__( random_state=None, show_plots=False, verbose=False, + stretch_max_iter=8, + stretch_slow_iter=200, ): """Initialize an instance of sNMF with estimator hyperparameters. @@ -116,9 +119,22 @@ def __init__( created by the decomposition. Optional. show_plots : bool Enables plotting at each step of the decomposition. Optional. + stretch_max_iter : int + Maximum number of projected-gradient stretch steps per outer-loop + iteration. Optional. + stretch_slow_iter : int + Number of initial outer-loop stretch updates to solve with the + slower constrained optimizer before switching to projected-gradient + stretch updates. Optional. """ if n_components is not None and n_components < 1: raise ValueError("n_components must be a positive integer.") + if not isinstance(stretch_max_iter, int) or stretch_max_iter < 1: + raise ValueError("stretch_max_iter must be a positive integer.") + if not isinstance(stretch_slow_iter, int) or stretch_slow_iter < 0: + raise ValueError( + "stretch_slow_iter must be a non-negative integer." + ) self.n_components = n_components self.max_iter = max_iter @@ -129,10 +145,13 @@ def __init__( self.random_state = random_state self.show_plots = show_plots self.verbose = verbose + self.stretch_max_iter = stretch_max_iter + self.stretch_slow_iter = stretch_slow_iter self._rng = np.random.default_rng(self.random_state) self._plotter = SNMFPlotter() if self.show_plots else None self._fill_tail_zero = False + self._stretch_step_size = None def _initialize_factors( self, @@ -213,6 +232,7 @@ def _initialize_factors( self._init_weights = self.weights_.copy() self._init_stretch = self.stretch_.copy() self._fill_tail_zero = False + self._stretch_step_size = None # Second-order spline: Tridiagonal (-2 on diags, 1 on sub/superdiags) self._spline_smooth_operator = 0.25 * diags( @@ -389,6 +409,43 @@ def fit( return self + def save_objective_log(self, filename): + """Save the objective log to a tab-delimited text file. + + Parameters + ---------- + filename : str or path-like + Output file path. + """ + if not hasattr(self, "objective_log"): + raise ValueError("Cannot save objective_log before calling fit.") + + fieldnames = ("step", "iteration", "objective", "dt_ms") + with open(filename, "w", newline="") as fp: + writer = csv.DictWriter( + fp, + fieldnames=fieldnames, + delimiter="\t", + ) + writer.writeheader() + previous_timestamp = None + for row in self.objective_log: + timestamp = row["timestamp"] + dt_ms = ( + 0.0 + if previous_timestamp is None + else 1000 * (timestamp - previous_timestamp) + ) + writer.writerow( + { + "step": row["step"], + "iteration": row["iteration"], + "objective": f"{row['objective']:.6E}", + "dt_ms": f"+{dt_ms:.3g}", + } + ) + previous_timestamp = timestamp + def _normalize_results(self): if self.verbose: print("\nNormalizing results after convergence...") @@ -410,6 +467,7 @@ def _normalize_results(self): self._prev_grad_components = np.zeros_like( self.components_ ) # Previous gradient of X (zeros for now) + self._fill_tail_zero = True try: self.residuals_ = self._get_residual_matrix() @@ -418,7 +476,9 @@ def _normalize_results(self): self._objective_history = [self.objective_function_] self._outer_iter = 0 self._inner_iter = 0 - for outiter in range(self.max_iter): + + normalization_max_iter = max(self.max_iter, 100) + for outiter in range(normalization_max_iter): self._outer_iter = outiter if outiter == 1: self._inner_iter = ( @@ -451,7 +511,7 @@ def _normalize_results(self): print( f"\n--- Iteration {outiter} after normalization---" f"\nTotal Objective : {self.objective_function_:.5e}" - "\nConvergence Check : Δ " + "\nConvergence Check : Delta " f"({self.objective_difference_:.2e})" f" < Threshold ({convergence_threshold:.2e})\n" ) @@ -1122,7 +1182,22 @@ def _regularize_function_hessian(self, stretch): return 0.5 * (hessian + hessian.T) - def _update_stretch(self): + @staticmethod + def _project_stretch(stretch, lower_bound=0.1): + return np.maximum(stretch, lower_bound) + + def _initial_stretch_step_size(self, stretch, gradient): + if self._stretch_step_size is not None: + return self._stretch_step_size + + gradient_norm = np.linalg.norm(gradient, "fro") + if gradient_norm == 0 or not np.isfinite(gradient_norm): + return 1.0 + + stretch_norm = max(np.linalg.norm(stretch, "fro"), 1.0) + return 0.05 * stretch_norm / gradient_norm + + def _update_stretch_trust_constr(self): """Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB).""" if self.verbose: @@ -1182,6 +1257,67 @@ def hessian(stretch_vec): # Update stretch with the optimized values self.stretch_ = result.x.reshape(self.stretch_.shape) + self._stretch_step_size = None + + def _update_stretch_projected_gradient(self): + if self.verbose: + print("Updating stretch factors...") + + for _ in range(self.stretch_max_iter): + stretch = self.stretch_ + current_objective, gradient = self._regularize_function(stretch) + step_size = self._initial_stretch_step_size(stretch, gradient) + best_stretch = stretch + best_objective = current_objective + + for _ in range(20): + candidate_stretch = self._project_stretch( + stretch - step_size * gradient + ) + step = candidate_stretch - stretch + step_norm_sq = np.linalg.norm(step, "fro") ** 2 + if step_norm_sq == 0: + step_size *= 0.5 + continue + + candidate_residuals = self._get_residual_matrix( + stretch=candidate_stretch + ) + candidate_objective = self._get_objective_function( + residuals=candidate_residuals, + stretch=candidate_stretch, + ) + if candidate_objective < best_objective: + best_stretch = candidate_stretch + best_objective = candidate_objective + + sufficient_decrease = ( + current_objective - 1e-4 * step_norm_sq / step_size + ) + if candidate_objective <= sufficient_decrease: + break + + step_size *= 0.5 + + self._stretch_step_size = step_size + if best_objective >= current_objective: + break + + self.stretch_ = best_stretch + + def _update_stretch(self): + """Update stretching factors with a hybrid strategy. + + The first ``stretch_slow_iter`` outer-loop updates use the original + constrained nonlinear optimizer to find a good non-convex basin. Later + updates switch to the lightweight Algorithm 2 style path: compute the + stretch gradient, take linearized proximal steps, then project back + onto the feasible stretch range. + """ + if getattr(self, "_outer_iter", 0) < self.stretch_slow_iter: + self._update_stretch_trust_constr() + else: + self._update_stretch_projected_gradient() @staticmethod def _compute_objective_function( @@ -1255,9 +1391,9 @@ def _cubic_largest_real_root(p, q): # Compute discriminant delta = (q / 2) ** 2 + (p / 3) ** 3 - # Compute square root of delta safely - d = np.where(delta >= 0, np.sqrt(delta), np.sqrt(np.abs(delta)) * 1j) - # TODO: this line causes a warning but results seem correct + # Match the MATLAB helper's real-root branch without evaluating invalid + # real square roots for entries that are handled in complex arithmetic. + d = np.sqrt(delta.astype(complex)) # Compute cube roots safely a1 = (-q / 2 + d) ** (1 / 3) @@ -1273,9 +1409,7 @@ def _cubic_largest_real_root(p, q): # Take the largest real root element-wise when delta < 0 r_roots = np.stack([np.real(y1), np.real(y2), np.real(y3)], axis=0) - y = np.max(r_roots, axis=0) * ( - delta < 0 - ) # Keep only real roots when delta < 0 + y = np.where(delta < 0, np.max(r_roots, axis=0), 0.0) return y