Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions arrayfire.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ test-suite test
ArrayFire.SignalSpec
ArrayFire.SparseSpec
ArrayFire.StatisticsSpec
ArrayFire.TestHelper
ArrayFire.UtilSpec
ArrayFire.VisionSpec

Expand Down
53 changes: 43 additions & 10 deletions cbits/eigsh.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,24 @@ static int jacobi_d(int n, double *a, double *evals)
memset(v, 0, (size_t)n * n * sizeof(double));
for (int i = 0; i < n; i++) ELEM(v, i, i, n) = 1.0;

/* Up to 50 full sweeps; typical convergence is << 10 for moderate n. */
for (int sweep = 0; sweep < 50 * n; sweep++) {
/* Scale-invariant convergence threshold. */
double amax = 0.0;
for (int c = 0; c < n; c++)
for (int r = 0; r < n; r++) {
double val = fabs(ELEM(a, r, c, n));
if (val > amax) amax = val;
}
double tol = 1e-14 * (amax > 0.0 ? amax : 1.0);

/* Classical Jacobi performs one rotation per iteration; a sweep is
* ~n^2/2 rotations and convergence typically needs O(log) sweeps, so
* 10*n*n rotations is a generous budget. Hitting it means we failed to
* converge and must report an error rather than silently return
* inaccurate results (the old cap of 50*n was routinely exhausted for
* n in the low hundreds). */
long max_rot = 10L * n * n + 100;
int converged = (n <= 1);
for (long rot = 0; rot < max_rot; rot++) {
/* Locate largest off-diagonal element */
int p = 0, q = 1;
double max_off = 0.0;
Expand All @@ -58,7 +74,7 @@ static int jacobi_d(int n, double *a, double *evals)
if (val > max_off) { max_off = val; p = r; q = c; }
}
}
if (max_off < 1e-14) break;
if (max_off < tol) { converged = 1; break; }

double apq = ELEM(a, p, q, n);
double tau = (ELEM(a, q, q, n) - ELEM(a, p, p, n)) / (2.0 * apq);
Expand Down Expand Up @@ -88,7 +104,7 @@ static int jacobi_d(int n, double *a, double *evals)
for (int i = 0; i < n; i++) evals[i] = ELEM(a, i, i, n);
memcpy(a, v, (size_t)n * n * sizeof(double));
free(v);
return 0;
return converged ? 0 : 2;
}

static int jacobi_f(int n, float *a, float *evals)
Expand All @@ -99,7 +115,18 @@ static int jacobi_f(int n, float *a, float *evals)
memset(v, 0, (size_t)n * n * sizeof(float));
for (int i = 0; i < n; i++) ELEM(v, i, i, n) = 1.0f;

for (int sweep = 0; sweep < 50 * n; sweep++) {
/* Scale-invariant convergence threshold. */
float amax = 0.0f;
for (int c = 0; c < n; c++)
for (int r = 0; r < n; r++) {
float val = fabsf(ELEM(a, r, c, n));
if (val > amax) amax = val;
}
float tol = 1e-6f * (amax > 0.0f ? amax : 1.0f);

long max_rot = 10L * n * n + 100;
int converged = (n <= 1);
for (long rot = 0; rot < max_rot; rot++) {
int p = 0, q = 1;
float max_off = 0.0f;
for (int c = 1; c < n; c++) {
Expand All @@ -108,7 +135,7 @@ static int jacobi_f(int n, float *a, float *evals)
if (val > max_off) { max_off = val; p = r; q = c; }
}
}
if (max_off < 1e-6f) break;
if (max_off < tol) { converged = 1; break; }

float apq = ELEM(a, p, q, n);
float tau = (ELEM(a, q, q, n) - ELEM(a, p, p, n)) / (2.0f * apq);
Expand Down Expand Up @@ -136,7 +163,7 @@ static int jacobi_f(int n, float *a, float *evals)
for (int i = 0; i < n; i++) evals[i] = ELEM(a, i, i, n);
memcpy(a, v, (size_t)n * n * sizeof(float));
free(v);
return 0;
return converged ? 0 : 2;
}

/* Selection sort on eigenvalues, mirroring the column swaps in evecs. */
Expand Down Expand Up @@ -199,7 +226,10 @@ static af_err eigsh_cpu(af_array *evals_out, af_array *evecs_out,

int ret = (dtype == f64) ? jacobi_d(n, (double *)A, (double *)W)
: jacobi_f(n, (float *)A, (float *)W);
if (ret != 0) { free(A); free(W); return AF_ERR_NO_MEM; }
if (ret != 0) {
free(A); free(W);
return (ret == 1) ? AF_ERR_NO_MEM : AF_ERR_RUNTIME;
}

if (dtype == f64) sort_eigs_d(n, (double *)W, (double *)A);
else sort_eigs_f(n, (float *)W, (float *)A);
Expand Down Expand Up @@ -368,6 +398,11 @@ af_err af_eigsh(af_array *evals_out, af_array *evecs_out, const af_array input)
if ((err = af_get_type(&dtype, input)) != AF_SUCCESS) return err;
if (dtype != f64 && dtype != f32) return AF_ERR_TYPE;

dim_t d0, d1, d2, d3;
if ((err = af_get_dims(&d0, &d1, &d2, &d3, input)) != AF_SUCCESS) return err;
if (d0 < 1 || d0 != d1 || d2 != 1 || d3 != 1 || d0 > 0x7fffffff)
return AF_ERR_SIZE;

af_backend backend;
if ((err = af_get_active_backend(&backend)) != AF_SUCCESS) return err;

Expand All @@ -377,8 +412,6 @@ af_err af_eigsh(af_array *evals_out, af_array *evecs_out, const af_array input)
if (ensure_init() != AF_SUCCESS)
return eigsh_cpu(evals_out, evecs_out, input);

dim_t d0, d1, d2, d3;
if ((err = af_get_dims(&d0, &d1, &d2, &d3, input)) != AF_SUCCESS) return err;
int n = (int)d0;

af_array evecs;
Expand Down
7 changes: 6 additions & 1 deletion src/ArrayFire/Algorithm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ module ArrayFire.Algorithm where
import Data.Word (Word32)
import Foreign.C.Types (CBool)

import ArrayFire.Arith (cast)
import ArrayFire.FFI
import ArrayFire.Internal.Algorithm
import ArrayFire.Internal.Types
Expand Down Expand Up @@ -196,7 +197,11 @@ count
-- ^ Dimension along which to count
-> Array Int
-- ^ Count of all elements along dimension
count x (fromIntegral -> n) = x `op1` (\p a -> af_count p a n)
count x (fromIntegral -> n) =
-- af_count produces a u32 array; cast to s64 so the data matches the
-- declared element type (otherwise host reads via toVector/toList would
-- read 8 bytes per element from a 4-byte-per-element buffer).
cast (x `op1` (\p a -> af_count p a n) :: Array Word32)

-- | Sum all elements in an 'Array' along all dimensions
--
Expand Down
Loading
Loading