Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions news/warnings-fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
**Added:**

* Use faster iterations when possible
* Allow storing of objective log

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* Remove warning by explicitly interpreting cubic root as complex

**Security:**

* <news item>
154 changes: 144 additions & 10 deletions src/diffpy/stretched_nmf/snmf_class.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import csv
import time

import cvxpy as cp
Expand Down Expand Up @@ -77,14 +78,16 @@ class SNMFOptimizer:
def __init__(
self,
n_components=None,
max_iter=500,
max_iter=200,
min_iter=20,
tol=5e-7,
rho=0,
eta=0,
random_state=None,
show_plots=False,
verbose=False,
stretch_max_iter=8,
stretch_slow_iter=200,
):
"""Initialize an instance of sNMF with estimator
hyperparameters.
Expand Down Expand Up @@ -116,9 +119,22 @@ def __init__(
created by the decomposition. Optional.
show_plots : bool
Enables plotting at each step of the decomposition. Optional.
stretch_max_iter : int
Maximum number of projected-gradient stretch steps per outer-loop
iteration. Optional.
stretch_slow_iter : int
Number of initial outer-loop stretch updates to solve with the

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change any of these, but moving forward, please use the group standard that descriptions start with "The". In most cases "The" can just be added at the beginning of the existing description. But as I said, don't bother fixing this here, just use the pattern moving forward.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder, will do.

slower constrained optimizer before switching to projected-gradient
stretch updates. Optional.
"""
if n_components is not None and n_components < 1:
raise ValueError("n_components must be a positive integer.")
if not isinstance(stretch_max_iter, int) or stretch_max_iter < 1:
raise ValueError("stretch_max_iter must be a positive integer.")
if not isinstance(stretch_slow_iter, int) or stretch_slow_iter < 0:
raise ValueError(
"stretch_slow_iter must be a non-negative integer."
)

self.n_components = n_components
self.max_iter = max_iter
Expand All @@ -129,10 +145,13 @@ def __init__(
self.random_state = random_state
self.show_plots = show_plots
self.verbose = verbose
self.stretch_max_iter = stretch_max_iter
self.stretch_slow_iter = stretch_slow_iter

self._rng = np.random.default_rng(self.random_state)
self._plotter = SNMFPlotter() if self.show_plots else None
self._fill_tail_zero = False
self._stretch_step_size = None

def _initialize_factors(
self,
Expand Down Expand Up @@ -213,6 +232,7 @@ def _initialize_factors(
self._init_weights = self.weights_.copy()
self._init_stretch = self.stretch_.copy()
self._fill_tail_zero = False
self._stretch_step_size = None

# Second-order spline: Tridiagonal (-2 on diags, 1 on sub/superdiags)
self._spline_smooth_operator = 0.25 * diags(
Expand Down Expand Up @@ -389,6 +409,43 @@ def fit(

return self

def save_objective_log(self, filename):
Comment thread
sbillinge marked this conversation as resolved.
"""Save the objective log to a tab-delimited text file.

Parameters
----------
filename : str or path-like
Output file path.
"""
if not hasattr(self, "objective_log"):
raise ValueError("Cannot save objective_log before calling fit.")

fieldnames = ("step", "iteration", "objective", "dt_ms")
with open(filename, "w", newline="") as fp:
writer = csv.DictWriter(
fp,
fieldnames=fieldnames,
delimiter="\t",
)
writer.writeheader()
previous_timestamp = None
for row in self.objective_log:
timestamp = row["timestamp"]
dt_ms = (
0.0
if previous_timestamp is None
else 1000 * (timestamp - previous_timestamp)
)
writer.writerow(
{
"step": row["step"],
"iteration": row["iteration"],
"objective": f"{row['objective']:.6E}",
"dt_ms": f"+{dt_ms:.3g}",
}
)
previous_timestamp = timestamp

def _normalize_results(self):
if self.verbose:
print("\nNormalizing results after convergence...")
Expand All @@ -410,6 +467,7 @@ def _normalize_results(self):
self._prev_grad_components = np.zeros_like(
self.components_
) # Previous gradient of X (zeros for now)

self._fill_tail_zero = True
try:
self.residuals_ = self._get_residual_matrix()
Expand All @@ -418,7 +476,9 @@ def _normalize_results(self):
self._objective_history = [self.objective_function_]
self._outer_iter = 0
self._inner_iter = 0
for outiter in range(self.max_iter):

normalization_max_iter = max(self.max_iter, 100)
for outiter in range(normalization_max_iter):
self._outer_iter = outiter
if outiter == 1:
self._inner_iter = (
Expand Down Expand Up @@ -451,7 +511,7 @@ def _normalize_results(self):
print(
f"\n--- Iteration {outiter} after normalization---"
f"\nTotal Objective : {self.objective_function_:.5e}"
"\nConvergence Check : Δ "
"\nConvergence Check : Delta "
f"({self.objective_difference_:.2e})"
f" < Threshold ({convergence_threshold:.2e})\n"
)
Expand Down Expand Up @@ -1122,7 +1182,22 @@ def _regularize_function_hessian(self, stretch):

return 0.5 * (hessian + hessian.T)

def _update_stretch(self):
@staticmethod
def _project_stretch(stretch, lower_bound=0.1):
return np.maximum(stretch, lower_bound)

def _initial_stretch_step_size(self, stretch, gradient):
if self._stretch_step_size is not None:
return self._stretch_step_size

gradient_norm = np.linalg.norm(gradient, "fro")
if gradient_norm == 0 or not np.isfinite(gradient_norm):
return 1.0

stretch_norm = max(np.linalg.norm(stretch, "fro"), 1.0)
return 0.05 * stretch_norm / gradient_norm

def _update_stretch_trust_constr(self):
"""Updates stretching matrix using constrained optimization
(equivalent to fmincon in MATLAB)."""
if self.verbose:
Expand Down Expand Up @@ -1182,6 +1257,67 @@ def hessian(stretch_vec):

# Update stretch with the optimized values
self.stretch_ = result.x.reshape(self.stretch_.shape)
self._stretch_step_size = None

def _update_stretch_projected_gradient(self):
if self.verbose:
print("Updating stretch factors...")

for _ in range(self.stretch_max_iter):
stretch = self.stretch_
current_objective, gradient = self._regularize_function(stretch)
step_size = self._initial_stretch_step_size(stretch, gradient)
best_stretch = stretch
best_objective = current_objective

for _ in range(20):
candidate_stretch = self._project_stretch(
stretch - step_size * gradient
)
step = candidate_stretch - stretch
step_norm_sq = np.linalg.norm(step, "fro") ** 2
if step_norm_sq == 0:
step_size *= 0.5
continue

candidate_residuals = self._get_residual_matrix(
stretch=candidate_stretch
)
candidate_objective = self._get_objective_function(
residuals=candidate_residuals,
stretch=candidate_stretch,
)
if candidate_objective < best_objective:
best_stretch = candidate_stretch
best_objective = candidate_objective

sufficient_decrease = (
current_objective - 1e-4 * step_norm_sq / step_size
)
if candidate_objective <= sufficient_decrease:
break

step_size *= 0.5

self._stretch_step_size = step_size
if best_objective >= current_objective:
break

self.stretch_ = best_stretch

def _update_stretch(self):
"""Update stretching factors with a hybrid strategy.

The first ``stretch_slow_iter`` outer-loop updates use the original
constrained nonlinear optimizer to find a good non-convex basin. Later
updates switch to the lightweight Algorithm 2 style path: compute the
stretch gradient, take linearized proximal steps, then project back
onto the feasible stretch range.
"""
if getattr(self, "_outer_iter", 0) < self.stretch_slow_iter:
self._update_stretch_trust_constr()
else:
self._update_stretch_projected_gradient()

@staticmethod
def _compute_objective_function(
Expand Down Expand Up @@ -1255,9 +1391,9 @@ def _cubic_largest_real_root(p, q):
# Compute discriminant
delta = (q / 2) ** 2 + (p / 3) ** 3

# Compute square root of delta safely
d = np.where(delta >= 0, np.sqrt(delta), np.sqrt(np.abs(delta)) * 1j)
# TODO: this line causes a warning but results seem correct
# Match the MATLAB helper's real-root branch without evaluating invalid
# real square roots for entries that are handled in complex arithmetic.
d = np.sqrt(delta.astype(complex))

# Compute cube roots safely
a1 = (-q / 2 + d) ** (1 / 3)
Expand All @@ -1273,9 +1409,7 @@ def _cubic_largest_real_root(p, q):

# Take the largest real root element-wise when delta < 0
r_roots = np.stack([np.real(y1), np.real(y2), np.real(y3)], axis=0)
y = np.max(r_roots, axis=0) * (
delta < 0
) # Keep only real roots when delta < 0
y = np.where(delta < 0, np.max(r_roots, axis=0), 0.0)

return y

Expand Down
Loading