Skip to content

Commit 85054df

Browse files
authored
Merge pull request PAIR-code#23 from ProductiveRage/optimize
Performance improvement for getting the data out of SparseMatrix
2 parents 6109ba8 + 6e54c7c commit 85054df

File tree

3 files changed

+68
-46
lines changed

3 files changed

+68
-46
lines changed

src/matrix.ts

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,13 @@
1919

2020
import * as utils from './utils';
2121

22+
type Entry = { value: number; row: number; col: number };
23+
2224
/**
2325
* Internal 2-dimensional sparse matrix class
2426
*/
2527
export class SparseMatrix {
26-
private rows: number[];
27-
private cols: number[];
28-
private values: number[];
29-
30-
private entries = new Map<string, number>();
28+
private entries = new Map<string, Entry>();
3129

3230
readonly nRows: number = 0;
3331
readonly nCols: number = 0;
@@ -38,19 +36,20 @@ export class SparseMatrix {
3836
values: number[],
3937
dims: number[]
4038
) {
41-
// TODO: Assert that rows / cols / vals are the same length.
42-
this.rows = [...rows];
43-
this.cols = [...cols];
44-
this.values = [...values];
45-
46-
for (let i = 0; i < values.length; i++) {
47-
const key = this.makeKey(this.rows[i], this.cols[i]);
48-
this.entries.set(key, i);
39+
if ((rows.length !== cols.length) || (rows.length !== values.length)) {
40+
throw new Error("rows, cols and values arrays must all have the same length");
4941
}
5042

5143
// TODO: Assert that dims are legit.
5244
this.nRows = dims[0];
5345
this.nCols = dims[1];
46+
for (let i = 0; i < values.length; i++) {
47+
const row = rows[i];
48+
const col = cols[i];
49+
this.checkDims(row, col);
50+
const key = this.makeKey(row, col);
51+
this.entries.set(key, { value: values[i], row, col });
52+
}
5453
}
5554

5655
private makeKey(row: number, col: number): string {
@@ -60,74 +59,84 @@ export class SparseMatrix {
6059
private checkDims(row: number, col: number) {
6160
const withinBounds = row < this.nRows && col < this.nCols;
6261
if (!withinBounds) {
63-
throw new Error('array index out of bounds');
62+
throw new Error('row and/or col specified outside of matrix dimensions');
6463
}
6564
}
6665

6766
set(row: number, col: number, value: number) {
6867
this.checkDims(row, col);
6968
const key = this.makeKey(row, col);
7069
if (!this.entries.has(key)) {
71-
this.rows.push(row);
72-
this.cols.push(col);
73-
this.values.push(value);
74-
this.entries.set(key, this.values.length - 1);
70+
this.entries.set(key, { value, row, col });
7571
} else {
76-
const index = this.entries.get(key)!;
77-
this.values[index] = value;
72+
this.entries.get(key)!.value = value;
7873
}
7974
}
8075

8176
get(row: number, col: number, defaultValue = 0) {
8277
this.checkDims(row, col);
8378
const key = this.makeKey(row, col);
8479
if (this.entries.has(key)) {
85-
const index = this.entries.get(key)!;
86-
return this.values[index];
80+
return this.entries.get(key)!.value;
8781
} else {
8882
return defaultValue;
8983
}
9084
}
9185

86+
getAll(ordered = true): { value: number; row: number; col: number }[] {
87+
const rowColValues: Entry[] = [];
88+
this.entries.forEach((value) => {
89+
rowColValues.push(value);
90+
});
91+
if (ordered) { // Ordering the result isn't required for processing but it does make it easier to write tests
92+
rowColValues.sort((a, b) => {
93+
if (a.row === b.row) {
94+
return a.col - b.col;
95+
} else {
96+
return a.row - b.row;
97+
}
98+
});
99+
}
100+
return rowColValues;
101+
}
102+
92103
getDims(): number[] {
93104
return [this.nRows, this.nCols];
94105
}
95106

96107
getRows(): number[] {
97-
return [...this.rows];
108+
return Array.from(this.entries, ([key, value]) => value.row);
98109
}
99110

100111
getCols(): number[] {
101-
return [...this.cols];
112+
return Array.from(this.entries, ([key, value]) => value.col);
102113
}
103114

104115
getValues(): number[] {
105-
return [...this.values];
116+
return Array.from(this.entries, ([key, value]) => value.value);
106117
}
107118

108119
forEach(fn: (value: number, row: number, col: number) => void): void {
109-
for (let i = 0; i < this.values.length; i++) {
110-
fn(this.values[i], this.rows[i], this.cols[i]);
111-
}
120+
this.entries.forEach((value) => fn(value.value, value.row, value.col));
112121
}
113122

114123
map(fn: (value: number, row: number, col: number) => number): SparseMatrix {
115124
let vals: number[] = [];
116-
for (let i = 0; i < this.values.length; i++) {
117-
vals.push(fn(this.values[i], this.rows[i], this.cols[i]));
118-
}
125+
this.entries.forEach((value) => {
126+
vals.push(fn(value.value, value.row, value.col));
127+
});
119128
const dims = [this.nRows, this.nCols];
120-
return new SparseMatrix(this.rows, this.cols, vals, dims);
129+
return new SparseMatrix(this.getRows(), this.getCols(), vals, dims);
121130
}
122131

123132
toArray() {
124133
const rows: undefined[] = utils.empty(this.nRows);
125134
const output = rows.map(() => {
126135
return utils.zeros(this.nCols);
127136
});
128-
for (let i = 0; i < this.values.length; i++) {
129-
output[this.rows[i]][this.cols[i]] = this.values[i];
130-
}
137+
this.entries.forEach((value) => {
138+
output[value.row][value.col] = value.value;
139+
});
131140
return output;
132141
}
133142
}
@@ -338,7 +347,6 @@ function elementWise(
338347
* search logic depends on this data format.
339348
*/
340349
export function getCSR(x: SparseMatrix) {
341-
type Entry = { value: number; row: number; col: number };
342350
const entries: Entry[] = [];
343351

344352
x.forEach((value, row, col) => {

src/umap.ts

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -809,16 +809,15 @@ export class UMAP {
809809
const weights: number[] = [];
810810
const head: number[] = [];
811811
const tail: number[] = [];
812-
for (let i = 0; i < graph.nRows; i++) {
813-
for (let j = 0; j < graph.nCols; j++) {
814-
const value = graph.get(i, j);
815-
if (value) {
816-
weights.push(value);
817-
tail.push(i);
818-
head.push(j);
819-
}
812+
const rowColValues = graph.getAll();
813+
for (let i = 0; i < rowColValues.length; i++) {
814+
const entry = rowColValues[i];
815+
if (entry.value) {
816+
weights.push(entry.value);
817+
tail.push(entry.row);
818+
head.push(entry.col);
820819
}
821-
}
820+
}
822821
const epochsPerSample = this.makeEpochsPerSample(weights, nEpochs);
823822

824823
return { head, tail, epochsPerSample };

test/matrix.test.ts

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ describe('sparse matrix', () => {
3636
test('constructs a sparse matrix from rows/cols/vals ', () => {
3737
const rows = [0, 0, 1, 1];
3838
const cols = [0, 1, 0, 1];
39-
const vals = [1, 2];
39+
const vals = [1, 2, 3, 4];
4040
const dims = [2, 2];
4141
const matrix = new SparseMatrix(rows, cols, vals, dims);
4242
expect(matrix.getRows()).toEqual(rows);
@@ -58,6 +58,21 @@ describe('sparse matrix', () => {
5858
expect(matrix.get(0, 1)).toEqual(9);
5959
});
6060

61+
test('sparse matrix has getAll method', () => {
62+
const rows = [0, 0, 1, 1];
63+
const cols = [0, 1, 0, 1];
64+
const vals = [1, 2, 3, 4];
65+
const dims = [2, 2];
66+
const matrix = new SparseMatrix(rows, cols, vals, dims);
67+
68+
expect(matrix.getAll()).toEqual([
69+
{ row: 0, col: 0, value: 1 },
70+
{ row: 0, col: 1, value: 2 },
71+
{ row: 1, col: 0, value: 3 },
72+
{ row: 1, col: 1, value: 4 }
73+
]);
74+
});
75+
6176
test('sparse matrix has toArray method', () => {
6277
const rows = [0, 0, 1, 1];
6378
const cols = [0, 1, 0, 1];

0 commit comments

Comments
 (0)