Skip to content

Commit b94d23c

Browse files
authored
Merge pull request PAIR-code#25 from PAIR-code/fix-csr
Fix incorrect sort method in getCSR
2 parents 0977e70 + f6ba61a commit b94d23c

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

src/matrix.ts

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ export class SparseMatrix {
3636
values: number[],
3737
dims: number[]
3838
) {
39-
if ((rows.length !== cols.length) || (rows.length !== values.length)) {
40-
throw new Error("rows, cols and values arrays must all have the same length");
39+
if (rows.length !== cols.length || rows.length !== values.length) {
40+
throw new Error(
41+
'rows, cols and values arrays must all have the same length'
42+
);
4143
}
4244

4345
// TODO: Assert that dims are legit.
@@ -85,10 +87,11 @@ export class SparseMatrix {
8587

8688
getAll(ordered = true): { value: number; row: number; col: number }[] {
8789
const rowColValues: Entry[] = [];
88-
this.entries.forEach((value) => {
90+
this.entries.forEach(value => {
8991
rowColValues.push(value);
9092
});
91-
if (ordered) { // Ordering the result isn't required for processing but it does make it easier to write tests
93+
if (ordered) {
94+
// Ordering the result isn't required for processing but it does make it easier to write tests
9295
rowColValues.sort((a, b) => {
9396
if (a.row === b.row) {
9497
return a.col - b.col;
@@ -117,12 +120,12 @@ export class SparseMatrix {
117120
}
118121

119122
forEach(fn: (value: number, row: number, col: number) => void): void {
120-
this.entries.forEach((value) => fn(value.value, value.row, value.col));
123+
this.entries.forEach(value => fn(value.value, value.row, value.col));
121124
}
122125

123126
map(fn: (value: number, row: number, col: number) => number): SparseMatrix {
124127
let vals: number[] = [];
125-
this.entries.forEach((value) => {
128+
this.entries.forEach(value => {
126129
vals.push(fn(value.value, value.row, value.col));
127130
});
128131
const dims = [this.nRows, this.nCols];
@@ -134,7 +137,7 @@ export class SparseMatrix {
134137
const output = rows.map(() => {
135138
return utils.zeros(this.nCols);
136139
});
137-
this.entries.forEach((value) => {
140+
this.entries.forEach(value => {
138141
output[value.row][value.col] = value.value;
139142
});
140143
return output;
@@ -357,7 +360,7 @@ export function getCSR(x: SparseMatrix) {
357360
if (a.row === b.row) {
358361
return a.col - b.col;
359362
} else {
360-
return a.row - b.col;
363+
return a.row - b.row;
361364
}
362365
});
363366

test/matrix.test.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ describe('sparse matrix', () => {
6969
{ row: 0, col: 0, value: 1 },
7070
{ row: 0, col: 1, value: 2 },
7171
{ row: 1, col: 0, value: 3 },
72-
{ row: 1, col: 1, value: 4 }
72+
{ row: 1, col: 1, value: 4 },
7373
]);
7474
});
7575

@@ -223,8 +223,8 @@ describe('normalize method', () => {
223223

224224
test('getCSR function', () => {
225225
const { indices, values, indptr } = getCSR(A);
226-
expect(indices).toEqual([0, 1, 2, 0, 0, 1, 2, 1, 2]);
227-
expect(values).toEqual([1, 2, 3, 7, 4, 5, 6, 8, 9]);
228-
expect(indptr).toEqual([0, 3, 4, 7]);
226+
expect(indices).toEqual([0, 1, 2, 0, 1, 2, 0, 1, 2]);
227+
expect(values).toEqual([1, 2, 3, 4, 5, 6, 7, 8, 9]);
228+
expect(indptr).toEqual([0, 3, 6]);
229229
});
230230
});

0 commit comments

Comments
 (0)