Skip to content

Commit a031cad

Browse files
committed
Ensure all functions require explicitly providing a prng
1 parent 55e5115 commit a031cad

File tree

7 files changed

+64
-31
lines changed

7 files changed

+64
-31
lines changed

src/heap.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@
5757
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
5858
*/
5959

60+
import { RandomFn } from './umap';
6061
import * as utils from './utils';
62+
6163
export type Heap = number[][][];
6264

6365
/**
@@ -92,7 +94,7 @@ export function makeHeap(nPoints: number, size: number): Heap {
9294
export function rejectionSample(
9395
nSamples: number,
9496
poolSize: number,
95-
random: () => number
97+
random: RandomFn
9698
) {
9799
const result = utils.zeros(nSamples);
98100
for (let i = 0; i < nSamples; i++) {
@@ -227,7 +229,7 @@ export function buildCandidates(
227229
nVertices: number,
228230
nNeighbors: number,
229231
maxCandidates: number,
230-
random: () => number
232+
random: RandomFn
231233
) {
232234
const candidateNeighbors = makeHeap(nVertices, maxCandidates);
233235
for (let i = 0; i < nVertices; i++) {

src/nn_descent.ts

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ import * as heap from './heap';
6161
import * as matrix from './matrix';
6262
import * as tree from './tree';
6363
import * as utils from './utils';
64-
import { Vectors, DistanceFn } from './umap';
64+
import { RandomFn, Vectors, DistanceFn } from './umap';
6565

6666
/**
6767
* Create a version of nearest neighbor descent.
6868
*/
69-
export function makeNNDescent(distanceFn: DistanceFn, random: () => number) {
69+
export function makeNNDescent(distanceFn: DistanceFn, random: RandomFn) {
7070
return function nNDescent(
7171
data: Vectors,
7272
leafArray: Vectors,
@@ -150,25 +150,28 @@ export type InitFromRandomFn = (
150150
nNeighbors: number,
151151
data: Vectors,
152152
queryPoints: Vectors,
153-
_heap: heap.Heap
153+
_heap: heap.Heap,
154+
random: RandomFn
154155
) => void;
155156

156157
export type InitFromTreeFn = (
157158
_tree: tree.FlatTree,
158159
data: Vectors,
159160
queryPoints: Vectors,
160-
_heap: heap.Heap
161+
_heap: heap.Heap,
162+
random: RandomFn
161163
) => void;
162164

163165
export function makeInitializations(distanceFn: DistanceFn) {
164166
function initFromRandom(
165167
nNeighbors: number,
166168
data: Vectors,
167169
queryPoints: Vectors,
168-
_heap: heap.Heap
170+
_heap: heap.Heap,
171+
random: RandomFn
169172
) {
170173
for (let i = 0; i < queryPoints.length; i++) {
171-
const indices = utils.rejectionSample(nNeighbors, data.length);
174+
const indices = utils.rejectionSample(nNeighbors, data.length, random);
172175
for (let j = 0; j < indices.length; j++) {
173176
if (indices[j] < 0) {
174177
continue;
@@ -183,10 +186,11 @@ export function makeInitializations(distanceFn: DistanceFn) {
183186
_tree: tree.FlatTree,
184187
data: Vectors,
185188
queryPoints: Vectors,
186-
_heap: heap.Heap
189+
_heap: heap.Heap,
190+
random: RandomFn
187191
) {
188192
for (let i = 0; i < queryPoints.length; i++) {
189-
const indices = tree.searchFlatTree(queryPoints[i], _tree);
193+
const indices = tree.searchFlatTree(queryPoints[i], _tree, random);
190194

191195
for (let j = 0; j < indices.length; j++) {
192196
if (indices[j] < 0) {
@@ -252,13 +256,14 @@ export function initializeSearch(
252256
queryPoints: Vectors,
253257
nNeighbors: number,
254258
initFromRandom: InitFromRandomFn,
255-
initFromTree: InitFromTreeFn
259+
initFromTree: InitFromTreeFn,
260+
random: RandomFn
256261
) {
257262
const results = heap.makeHeap(queryPoints.length, nNeighbors);
258-
initFromRandom(nNeighbors, data, queryPoints, results);
263+
initFromRandom(nNeighbors, data, queryPoints, results, random);
259264
if (forest) {
260265
for (let tree of forest) {
261-
initFromTree(tree, data, queryPoints, results);
266+
initFromTree(tree, data, queryPoints, results, random);
262267
}
263268
}
264269
return results;

src/tree.ts

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
*/
5959

6060
import * as utils from './utils';
61-
import { Vector, Vectors } from './umap';
61+
import { RandomFn, Vector, Vectors } from './umap';
6262

6363
/**
6464
* Tree functionality for approximating nearest neighbors
@@ -88,7 +88,7 @@ export function makeForest(
8888
data: Vectors,
8989
nNeighbors: number,
9090
nTrees: number,
91-
random: () => number
91+
random: RandomFn
9292
) {
9393
const leafSize = Math.max(10, nNeighbors);
9494

@@ -108,7 +108,7 @@ function makeTree(
108108
data: Vectors,
109109
leafSize = 30,
110110
n: number,
111-
random: () => number
111+
random: RandomFn
112112
): RandomProjectionTreeNode {
113113
const indices = utils.range(data.length);
114114
const tree = makeEuclideanTree(data, indices, leafSize, n, random);
@@ -120,7 +120,7 @@ function makeEuclideanTree(
120120
indices: number[],
121121
leafSize = 30,
122122
q: number,
123-
random: () => number
123+
random: RandomFn
124124
): RandomProjectionTreeNode {
125125
if (indices.length > leafSize) {
126126
const splitResults = euclideanRandomProjectionSplit(data, indices, random);
@@ -160,7 +160,7 @@ function makeEuclideanTree(
160160
function euclideanRandomProjectionSplit(
161161
data: Vectors,
162162
indices: number[],
163-
random: () => number
163+
random: RandomFn
164164
) {
165165
const dim = data[0].length;
166166

@@ -343,14 +343,19 @@ export function makeLeafArray(rpForest: FlatTree[]): number[][] {
343343
/**
344344
* Selects the side of the tree to search during flat tree search.
345345
*/
346-
function selectSide(hyperplane: number[], offset: number, point: Vector) {
346+
function selectSide(
347+
hyperplane: number[],
348+
offset: number,
349+
point: Vector,
350+
random: RandomFn
351+
) {
347352
let margin = offset;
348353
for (let d = 0; d < point.length; d++) {
349354
margin += hyperplane[d] * point[d];
350355
}
351356

352357
if (margin === 0) {
353-
const side = utils.tauRandInt(2);
358+
const side = utils.tauRandInt(2, random);
354359
return side;
355360
} else if (margin > 0) {
356361
return 0;
@@ -362,10 +367,19 @@ function selectSide(hyperplane: number[], offset: number, point: Vector) {
362367
/**
363368
* Searches a flattened rp-tree for a point.
364369
*/
365-
export function searchFlatTree(point: Vector, tree: FlatTree) {
370+
export function searchFlatTree(
371+
point: Vector,
372+
tree: FlatTree,
373+
random: RandomFn
374+
) {
366375
let node = 0;
367376
while (tree.children[node][0] > 0) {
368-
const side = selectSide(tree.hyperplanes[node], tree.offsets[node], point);
377+
const side = selectSide(
378+
tree.hyperplanes[node],
379+
tree.offsets[node],
380+
point,
381+
random
382+
);
369383
if (side === 0) {
370384
node = tree.children[node][0];
371385
} else {

src/umap.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ import * as utils from './utils';
6565
import LM from 'ml-levenberg-marquardt';
6666

6767
export type DistanceFn = (x: Vector, y: Vector) => number;
68+
export type RandomFn = () => number;
6869
export type EpochCallback = (epoch: number) => boolean | void;
6970
export type Vector = number[];
7071
export type Vectors = Vector[];
@@ -142,7 +143,7 @@ export interface UMAPParameters {
142143
* The pseudo-random number generator used by the stochastic parts of the
143144
* algorithm.
144145
*/
145-
random?: () => number;
146+
random?: RandomFn;
146147
/**
147148
* Interpolate between (fuzzy) union and intersection as the set operation
148149
* used to combine local fuzzy simplicial sets to obtain a global fuzzy
@@ -413,7 +414,8 @@ export class UMAP {
413414
toTransform,
414415
nNeighbors,
415416
this.initFromRandom,
416-
this.initFromTree
417+
this.initFromTree,
418+
this.random
417419
);
418420

419421
const result = this.search(rawData, this.searchGraph, init, toTransform);

src/utils.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,19 @@
1717
* ==============================================================================
1818
*/
1919

20+
import { RandomFn } from './umap';
21+
2022
/**
2123
* Simple random integer function
2224
*/
23-
export function tauRandInt(n: number, random = Math.random) {
25+
export function tauRandInt(n: number, random: RandomFn) {
2426
return Math.floor(random() * n);
2527
}
2628

2729
/**
2830
* Simple random float function
2931
*/
30-
export function tauRand(random = Math.random) {
32+
export function tauRand(random: RandomFn) {
3133
return random();
3234
}
3335
/**
@@ -132,12 +134,16 @@ export function max2d(input: number[][]): number {
132134
* integer is selected twice. The duplication constraint is achieved via
133135
* rejection sampling.
134136
*/
135-
export function rejectionSample(nSamples: number, poolSize: number): number[] {
137+
export function rejectionSample(
138+
nSamples: number,
139+
poolSize: number,
140+
random: RandomFn
141+
): number[] {
136142
const result = zeros(nSamples);
137143
for (let i = 0; i < nSamples; i++) {
138144
let rejectSample = true;
139145
while (rejectSample) {
140-
const j = tauRandInt(poolSize);
146+
const j = tauRandInt(poolSize, random);
141147
let broken = false;
142148
for (let k = 0; k < i; k++) {
143149
if (j === result[k]) {

test/umap.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import {
2121
UMAP,
2222
findABParams,
2323
euclidean,
24+
RandomFn,
2425
TargetMetric,
2526
Vector,
2627
} from '../src/umap';
@@ -32,12 +33,11 @@ import {
3233
testLabels,
3334
testResults2D,
3435
testResults3D,
35-
transformResult2d,
3636
} from './test_data';
3737
import Prando from 'prando';
3838

3939
describe('UMAP', () => {
40-
let random: () => number;
40+
let random: RandomFn;
4141

4242
// Expected "clustering" ratios, representing inter-cluster distance vs mean
4343
// distance to other points.

test/utils.test.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818
*/
1919

2020
import * as utils from '../src/utils';
21+
import Prando from 'prando';
2122

2223
describe('umap utils', () => {
24+
const prando = new Prando(42);
25+
const random = () => prando.next();
26+
2327
test('norm function', () => {
2428
const results = utils.norm([1, 2, 3, 4]);
2529
expect(results).toEqual(Math.sqrt(30));
@@ -81,7 +85,7 @@ describe('umap utils', () => {
8185
});
8286

8387
test('rejection sample', () => {
84-
const results = utils.rejectionSample(5, 10);
88+
const results = utils.rejectionSample(5, 10, random);
8589
const entries = new Set<number>();
8690
for (const r of results) {
8791
expect(entries.has(r)).toBe(false);

0 commit comments

Comments
 (0)