Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ci:
submodules: false
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v6.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
Expand All @@ -21,45 +21,45 @@ repos:
- id: check-toml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 26.5.1
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
rev: 7.3.0
hooks:
- id: flake8
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 9.0.0a3
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
rev: 0.9.1
hooks:
- id: nbstripout
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v6.0.0
hooks:
- id: no-commit-to-branch
name: Prevent Commit to Main Branch
args: ["--branch", "main"]
stages: [pre-commit]
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.2
hooks:
- id: codespell
additional_dependencies:
- tomli
# prettier - multi formatter for .json, .yml, and .md files
- repo: https://github.com/pre-commit/mirrors-prettier
rev: f12edd9c7be1c20cfa42420fd0e6df71e42b51ea # frozen: v4.0.0-alpha.8
rev: v4.0.0-alpha.8
hooks:
- id: prettier
additional_dependencies:
- "prettier@^3.2.4"
# docformatter - PEP 257 compliant docstring formatter
- repo: https://github.com/s-weigand/docformatter
rev: 5757c5190d95e5449f102ace83df92e7d3b06c6c
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.8
hooks:
- id: docformatter
additional_dependencies: [tomli]
Expand Down
25 changes: 25 additions & 0 deletions news/forge-cleanup.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
**Added:**

* <news item>

**Changed:**

* Rename cli entrypoint to 'snmf' from 'diffpy.stretched-nmf'

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* Produce an error if test files are missing
* Use matplotlib-base when installing with conda-forge
* Update to docformatter 1.7.8

**Security:**

* <news item>
8 changes: 0 additions & 8 deletions src/diffpy/stretched_nmf/snmf_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def __init__(
show_plots : bool
Enables plotting at each step of the decomposition. Optional.
"""

if n_components is not None and n_components < 1:
raise ValueError("n_components must be a positive integer.")

Expand Down Expand Up @@ -583,7 +582,6 @@ def _get_residual_matrix(
-------
residuals : (signal_len, n_signals) array
"""

if components is None:
components = self.components_
if weights is None:
Expand Down Expand Up @@ -651,7 +649,6 @@ def _compute_stretched_components(
dd_stretched_components : array, shape (signal_len, n_comps * n_sigs)
Second derivatives with respect to stretch.
"""

# --- Defaults ---
if components is None:
components = self.components_
Expand Down Expand Up @@ -724,7 +721,6 @@ def _apply_transformation_matrix(
"""Computes the transformation matrix `stretch_transformed` for
residuals, using scaling matrix `stretch` and weight
coefficients `weights`."""

if stretch is None:
stretch = self.stretch_
if weights is None:
Expand Down Expand Up @@ -851,7 +847,6 @@ def _solve_quadratic_program(self, t, m):
def _update_components(self):
"""Updates `components` using gradient-based optimization with
adaptive step size."""

# Compute stretched components using the interpolation function
stretched_components, _, _ = (
self._compute_stretched_components()
Expand Down Expand Up @@ -934,7 +929,6 @@ def _update_weights(self):
"""Updates weights by building the stretched component matrix
`stretched_comps` with np.interp and solving a quadratic program
for each signal."""

sample_indices = np.arange(self.signal_length_)
for signal in range(self.n_signals_):
# Stretch factors for this signal across components:
Expand Down Expand Up @@ -999,7 +993,6 @@ def _regularize_function(self, stretch=None):
def _update_stretch(self):
"""Updates stretching matrix using constrained optimization
(equivalent to fmincon in MATLAB)."""

if self.verbose:
print("Updating stretch factors...")

Expand Down Expand Up @@ -1147,7 +1140,6 @@ def _reconstruct_matrix(components, weights, stretch):
-------
reconstructed_matrix : (signal_len, n_signals) array
"""

signal_len = components.shape[0]
n_components = components.shape[1]
n_signals = weights.shape[1]
Expand Down
72 changes: 26 additions & 46 deletions tests/test_snmf_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,42 @@
from pathlib import Path

import numpy as np
import pytest

from diffpy.stretched_nmf.snmf_class import SNMFOptimizer

DATA_DIR = (
Path(__file__).resolve().parents[1]
/ "docs/examples/data/XRD-MgMnO-YCl-real"
)

_required = [
"init-components.txt",
"source-matrix.txt",
"init-stretch.txt",
"init-weights.txt",
]
_missing = [f for f in _required if not (DATA_DIR / f).exists()]


@pytest.fixture(scope="module")
def inputs():
if _missing:
pytest.fail(
f"Missing required test data files in {DATA_DIR}: {_missing}"
)
return {
"components": np.loadtxt(
DATA_DIR / "init-components.txt", dtype=float
),
"source": np.loadtxt(
DATA_DIR / "source-matrix.txt", dtype=float, skiprows=4
),
"stretch": np.loadtxt(DATA_DIR / "init-stretch.txt", dtype=float),
"weights": np.loadtxt(DATA_DIR / "init-weights.txt", dtype=float),
}

def test_fit_recovers_rank_one_factors():
expected_components = np.array(
[
[0.20],
[0.75],
[1.20],
[0.80],
[0.30],
]
)
expected_weights = np.array(
[
[0.20, 0.60, 1.00, 0.40],
]
)
source = expected_components @ expected_weights

@pytest.mark.slow
def test_final_objective_below_threshold(inputs):
model = SNMFOptimizer(
n_components=1,
show_plots=False,
random_state=1,
min_iter=5,
max_iter=5,
rho=1e12,
eta=610,
)
model.fit(
source_matrix=inputs["source"],
init_weights=inputs["weights"],
init_components=inputs["components"],
init_stretch=inputs["stretch"],
min_iter=0,
max_iter=2,
rho=0.0,
eta=0.0,
)
model.fit(source_matrix=source)

# Basic sanity check and the actual assertion
assert np.isfinite(model.objective_function_)
assert model.objective_function_ < 5e6
assert np.allclose(
model.components_, expected_components, rtol=0.2, atol=0.1
)
assert np.allclose(model.weights_, expected_weights, rtol=0.2, atol=0.1)


@pytest.mark.parametrize(
Expand Down
Loading