19
19
20
20
import * as utils from './utils' ;
21
21
22
+ type Entry = { value : number ; row : number ; col : number } ;
23
+
22
24
/**
23
25
* Internal 2-dimensional sparse matrix class
24
26
*/
25
27
export class SparseMatrix {
26
- private rows : number [ ] ;
27
- private cols : number [ ] ;
28
- private values : number [ ] ;
29
-
30
- private entries = new Map < string , number > ( ) ;
28
+ private entries = new Map < string , Entry > ( ) ;
31
29
32
30
readonly nRows : number = 0 ;
33
31
readonly nCols : number = 0 ;
@@ -38,19 +36,20 @@ export class SparseMatrix {
38
36
values : number [ ] ,
39
37
dims : number [ ]
40
38
) {
41
- // TODO: Assert that rows / cols / vals are the same length.
42
- this . rows = [ ...rows ] ;
43
- this . cols = [ ...cols ] ;
44
- this . values = [ ...values ] ;
45
-
46
- for ( let i = 0 ; i < values . length ; i ++ ) {
47
- const key = this . makeKey ( this . rows [ i ] , this . cols [ i ] ) ;
48
- this . entries . set ( key , i ) ;
39
+ if ( ( rows . length !== cols . length ) || ( rows . length !== values . length ) ) {
40
+ throw new Error ( "rows, cols and values arrays must all have the same length" ) ;
49
41
}
50
42
51
43
// TODO: Assert that dims are legit.
52
44
this . nRows = dims [ 0 ] ;
53
45
this . nCols = dims [ 1 ] ;
46
+ for ( let i = 0 ; i < values . length ; i ++ ) {
47
+ const row = rows [ i ] ;
48
+ const col = cols [ i ] ;
49
+ this . checkDims ( row , col ) ;
50
+ const key = this . makeKey ( row , col ) ;
51
+ this . entries . set ( key , { value : values [ i ] , row, col } ) ;
52
+ }
54
53
}
55
54
56
55
private makeKey ( row : number , col : number ) : string {
@@ -60,74 +59,84 @@ export class SparseMatrix {
60
59
private checkDims ( row : number , col : number ) {
61
60
const withinBounds = row < this . nRows && col < this . nCols ;
62
61
if ( ! withinBounds ) {
63
- throw new Error ( 'array index out of bounds ' ) ;
62
+ throw new Error ( 'row and/or col specified outside of matrix dimensions ' ) ;
64
63
}
65
64
}
66
65
67
66
set ( row : number , col : number , value : number ) {
68
67
this . checkDims ( row , col ) ;
69
68
const key = this . makeKey ( row , col ) ;
70
69
if ( ! this . entries . has ( key ) ) {
71
- this . rows . push ( row ) ;
72
- this . cols . push ( col ) ;
73
- this . values . push ( value ) ;
74
- this . entries . set ( key , this . values . length - 1 ) ;
70
+ this . entries . set ( key , { value, row, col } ) ;
75
71
} else {
76
- const index = this . entries . get ( key ) ! ;
77
- this . values [ index ] = value ;
72
+ this . entries . get ( key ) ! . value = value ;
78
73
}
79
74
}
80
75
81
76
get ( row : number , col : number , defaultValue = 0 ) {
82
77
this . checkDims ( row , col ) ;
83
78
const key = this . makeKey ( row , col ) ;
84
79
if ( this . entries . has ( key ) ) {
85
- const index = this . entries . get ( key ) ! ;
86
- return this . values [ index ] ;
80
+ return this . entries . get ( key ) ! . value ;
87
81
} else {
88
82
return defaultValue ;
89
83
}
90
84
}
91
85
86
+ getAll ( ordered = true ) : { value : number ; row : number ; col : number } [ ] {
87
+ const rowColValues : Entry [ ] = [ ] ;
88
+ this . entries . forEach ( ( value ) => {
89
+ rowColValues . push ( value ) ;
90
+ } ) ;
91
+ if ( ordered ) { // Ordering the result isn't required for processing but it does make it easier to write tests
92
+ rowColValues . sort ( ( a , b ) => {
93
+ if ( a . row === b . row ) {
94
+ return a . col - b . col ;
95
+ } else {
96
+ return a . row - b . row ;
97
+ }
98
+ } ) ;
99
+ }
100
+ return rowColValues ;
101
+ }
102
+
92
103
getDims ( ) : number [ ] {
93
104
return [ this . nRows , this . nCols ] ;
94
105
}
95
106
96
107
getRows ( ) : number [ ] {
97
- return [ ... this . rows ] ;
108
+ return Array . from ( this . entries , ( [ key , value ] ) => value . row ) ;
98
109
}
99
110
100
111
getCols ( ) : number [ ] {
101
- return [ ... this . cols ] ;
112
+ return Array . from ( this . entries , ( [ key , value ] ) => value . col ) ;
102
113
}
103
114
104
115
getValues ( ) : number [ ] {
105
- return [ ... this . values ] ;
116
+ return Array . from ( this . entries , ( [ key , value ] ) => value . value ) ;
106
117
}
107
118
108
119
forEach ( fn : ( value : number , row : number , col : number ) => void ) : void {
109
- for ( let i = 0 ; i < this . values . length ; i ++ ) {
110
- fn ( this . values [ i ] , this . rows [ i ] , this . cols [ i ] ) ;
111
- }
120
+ this . entries . forEach ( ( value ) => fn ( value . value , value . row , value . col ) ) ;
112
121
}
113
122
114
123
map ( fn : ( value : number , row : number , col : number ) => number ) : SparseMatrix {
115
124
let vals : number [ ] = [ ] ;
116
- for ( let i = 0 ; i < this . values . length ; i ++ ) {
117
- vals . push ( fn ( this . values [ i ] , this . rows [ i ] , this . cols [ i ] ) ) ;
118
- }
125
+ this . entries . forEach ( ( value ) => {
126
+ vals . push ( fn ( value . value , value . row , value . col ) ) ;
127
+ } ) ;
119
128
const dims = [ this . nRows , this . nCols ] ;
120
- return new SparseMatrix ( this . rows , this . cols , vals , dims ) ;
129
+ return new SparseMatrix ( this . getRows ( ) , this . getCols ( ) , vals , dims ) ;
121
130
}
122
131
123
132
toArray ( ) {
124
133
const rows : undefined [ ] = utils . empty ( this . nRows ) ;
125
134
const output = rows . map ( ( ) => {
126
135
return utils . zeros ( this . nCols ) ;
127
136
} ) ;
128
- for ( let i = 0 ; i < this . values . length ; i ++ ) {
129
- output [ this . rows [ i ] ] [ this . cols [ i ] ] = this . values [ i ] ;
130
- }
137
+ this . entries . forEach ( ( value ) => {
138
+ output [ value . row ] [ value . col ] = value . value ;
139
+ } ) ;
131
140
return output ;
132
141
}
133
142
}
@@ -338,7 +347,6 @@ function elementWise(
338
347
* search logic depends on this data format.
339
348
*/
340
349
export function getCSR ( x : SparseMatrix ) {
341
- type Entry = { value : number ; row : number ; col : number } ;
342
350
const entries : Entry [ ] = [ ] ;
343
351
344
352
x . forEach ( ( value , row , col ) => {
0 commit comments