00001
00009
00010
00012
00013 #include "SparseMatrix.h"
00014 #include "Matlab.h"
00015 #include <assert.h>
00016 #include <math.h>
00017
00018 using namespace MeshLib;
00019
00024 inline bool isZero(double x)
00025 {
00026 return (fabs(x) < 1e-9);
00027 }
00028
00030
00032
00039 CSparseMatrix::CSparseMatrix(int nRows, int nCols, int nnz )
00040 {
00041 m_nRows = nRows;
00042 m_nCols = nCols;
00043 assert(nRows > 0);
00044 assert(nCols > 0);
00045 m_pEntries = new vector<_Entry>();
00046 if (nnz > 0)
00047 m_pEntries->reserve(nnz);
00048
00049 if (nnz > 1024)
00050 {
00051 m_nHashCode = nnz;
00052 }
00053 else
00054 {
00055 m_nHashCode = (nRows > nCols)? nRows * 6: nCols * 6;
00056 }
00057
00058
00059 if (m_nHashCode % 2 == 0)
00060 m_nHashCode++;
00061
00062 m_bAccessed = new bool[m_nHashCode];
00063 memset(m_bAccessed, 0, sizeof(bool) * m_nHashCode);
00064 }
00065
00070 CSparseMatrix::~CSparseMatrix()
00071 {
00072 if (m_pEntries != NULL)
00073 delete m_pEntries;
00074
00075 delete [] m_bAccessed;
00076 }
00084 int CSparseMatrix::CalcHashCode(int nRow, int nCol)
00085 {
00086 return (nRow * 37 + nCol * 103) % m_nHashCode;
00087 }
00088
00095 void CSparseMatrix::AddElement(int nRow, int nCol, double dVal)
00096 {
00097 _Entry entry;
00098 assert(nRow >= 0 && nRow < m_nRows);
00099 assert(nCol >= 0 && nCol < m_nCols);
00100 bool bFound = false;
00101
00102 int hash = CalcHashCode(nRow, nCol);
00103 if (m_bAccessed[hash])
00104 {
00105 for (int i = 0; i < (int)m_pEntries->size(); i++)
00106 {
00107 _Entry& e = m_pEntries->at(i);
00108 if (e.i == nRow && e.j == nCol)
00109 {
00110 e.val += dVal;
00111 if (isZero(e.val))
00112 {
00113 if (m_pEntries->size() > 0)
00114 {
00115 m_pEntries->at(i) = m_pEntries->at(m_pEntries->size() - 1);
00116 }
00117 m_pEntries->resize(m_pEntries->size() - 1);
00118 }
00119 bFound = true;
00120 break;
00121 }
00122 }
00123 }
00124 if (!bFound)
00125 {
00126 m_bAccessed[hash] = true;
00127 entry.i = nRow;
00128 entry.j = nCol;
00129 entry.val = dVal;
00130 m_pEntries->push_back(entry);
00131 }
00132 }
00133
00134
00141 void CSparseMatrix::AddElementTail(int nRow, int nCol, double dVal)
00142 {
00143 _Entry entry;
00144 assert(nRow >= 0 && nRow < m_nRows);
00145 assert(nCol >= 0 && nCol < m_nCols);
00146
00147 entry.i = nRow;
00148 entry.j = nCol;
00149 entry.val = dVal;
00150 m_bAccessed[CalcHashCode(nRow, nCol)] = true;
00151
00152 m_pEntries->push_back(entry);
00153
00154 }
00155
00162 bool CSparseMatrix::CGSolver(double b[], double x[], double eps, int& itrs )
00163 {
00164 if( GetRows()!=GetCols() )
00165 return false;
00166
00167 int size = this->GetRows();
00168
00169 double alpha, beta;
00170 double* r = new double[size];
00171 double* p = new double[size];
00172 double* temp = new double[size];
00173 double* multi = new double[size];
00174 assert( r != NULL && p != NULL && temp != NULL && multi != NULL );
00175
00176
00177 for( int i=0; i<size; i++ )
00178 x[i] = b[i];
00179
00180
00181 Multiply( x, multi );
00182
00183
00184
00185 for( int i=0; i<size; i++ )
00186 r[i] = b[i] - multi[i];
00187
00188
00189
00190
00191 for( int i=0; i<size; i++ )
00192 p[i] = r[i];
00193
00194 double numerator, denominator;
00195
00196 int step = 0;
00197
00198 double norm_b = 0.0;
00199 for( int i=0; i<size; i++ )
00200 norm_b += b[i]*b[i];
00201 norm_b = sqrt( norm_b );
00202
00203 double norm_r = 0.0;
00204
00205 while( true )
00206 {
00207 if( itrs>=0 )
00208 if( step>=itrs )
00209 break;
00210
00211 norm_r = 0.0;
00212 for( int i=0; i<size; i++ )
00213 norm_r += r[i]*r[i];
00214 norm_r = sqrt( norm_r );
00215
00216 if( norm_r/norm_b<eps )
00217 break;
00218
00219
00220 numerator = 0.0;
00221 for( int i=0; i<size; i++ )
00222 numerator += r[i]*r[i];
00223
00224
00225
00226 this->Multiply( p, multi );
00227
00228 denominator = 0.0;
00229 for( int i=0; i<size; i++ )
00230 denominator += p[i]*multi[i];
00231
00232 alpha = numerator/denominator;
00233
00234
00235 for( int i=0; i<size; i++ )
00236 x[i] = x[i] + alpha*p[i];
00237
00238 for( int i=0; i<size; i++ )
00239 temp[i] = r[i];
00240
00241
00242 for( int i=0; i<size; i++ )
00243 r[i] = temp[i] - alpha*multi[i];
00244
00245
00246 numerator = 0.0;
00247 for( int i=0; i<size; i++ )
00248 numerator += r[i]*r[i];
00249 denominator = 0.0;
00250 for( int i=0; i<size; i++ )
00251 denominator += temp[i]*temp[i];
00252 beta = numerator/denominator;
00253
00254
00255 for( int i=0; i<size; i++ )
00256 p[i] = r[i]+beta*p[i];
00257
00258 step++;
00259 }
00260
00261
00262 this->Multiply( x, multi );
00263
00264 double sum = 0.0;
00265 for( int i=0; i<size; i++ )
00266 sum += (multi[i]-b[i])*(multi[i]-b[i]);
00267 sum = sqrt( sum );
00268
00269
00270 printf( "CG iter_num: %d error: %.10f\n", step, norm_r/norm_b );
00271
00272 delete []r;
00273 delete []p;
00274 delete []multi;
00275 delete []temp;
00276
00277 return true;
00278 }
00279
00280
00286 void CSparseMatrix::Multiply(double iVector[], double oVector[])
00287 {
00288 memset(oVector, 0, sizeof(double) * m_nRows);
00289 for (int i = 0; i < (int)m_pEntries->size(); i++)
00290 {
00291 _Entry& entry = m_pEntries->at(i);
00292 oVector[entry.i] += iVector[entry.j] * entry.val;
00293 }
00294 }
00300 void CSparseMatrix::TransMul(double iVector[], double oVector[])
00301 {
00302 memset(oVector, 0, sizeof(double) * m_nCols);
00303 for (int i = 0; i < (int)m_pEntries->size(); i++)
00304 {
00305 _Entry& entry = m_pEntries->at(i);
00306 oVector[entry.j] += iVector[entry.i] * entry.val;
00307 }
00308 }
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346
00347
00348
00349
00350
00351
00352
00353
00354
00355
00356
00357
00358
00359
00360
00361
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00393 bool CSparseMatrix::CGSolverStable(double b[], double x[], double eps, int& itrs )
00394 {
00395 if( GetRows()!=GetCols() )
00396 return false;
00397
00398 int size = this->GetRows();
00399
00400 double alpha, beta;
00401 double sum;
00402 double* r = new double[size];
00403 double* p = new double[size];
00404 double* multi = new double[size];
00405 double* temp = new double[size];
00406
00407
00408 for( int i=0; i<size; i++ )
00409 x[i] = b[i];
00410
00411
00412 sum = 0.0;
00413 for( int i=0; i<size; i++)
00414 sum += x[i];
00415 for( int i=0; i<size; i++ )
00416 {
00417 temp[i] = -(sum-x[i]);
00418 }
00419
00420 this->Multiply( temp, multi );
00421
00422
00423 sum = 0.0;
00424 for( int i=0; i<size; i++)
00425 sum += multi[i];
00426 for( int i=0; i<size; i++ )
00427 {
00428 temp[i] = -(sum-multi[i]);
00429 }
00430
00431 for( int i=0; i<size; i++ )
00432 r[i] = b[i] - temp[i];
00433
00434
00435 for( int i=0; i<size; i++ )
00436 p[i] = r[i];
00437
00438 double numerator, denominator;
00439
00440 int step = 0;
00441
00442 double norm_b = 0.0;
00443 for( int i=0; i<size; i++ )
00444 norm_b += multi[i]*multi[i];
00445 norm_b = sqrt( norm_b );
00446
00447 double norm_r = 0.0;
00448
00449 while( true )
00450 {
00451 if( itrs>=0 )
00452 if( step>=itrs )
00453 break;
00454
00455 norm_r = 0.0;
00456 for( int i=0; i<size; i++ )
00457 norm_r += r[i]*r[i];
00458 norm_r = sqrt( norm_r );
00459
00460 if( norm_r/norm_b<eps )
00461 break;
00462
00463
00464
00465
00466
00467 numerator = 0.0;
00468 for( int i=0; i<size; i++ )
00469 numerator += r[i]*r[i];
00470
00471 sum = 0.0;
00472 for( int i=0; i<size; i++ )
00473 sum += p[i];
00474 for( int i=0; i<size; i++ )
00475 {
00476 temp[i] = -(sum-p[i]);
00477 }
00478
00479 this->Multiply( temp, multi );
00480 sum = 0.0;
00481 for( int i=0; i<size; i++ )
00482 sum += multi[i];
00483 for( int i=0; i<size; i++ )
00484 {
00485 temp[i] = -(sum-multi[i]);
00486 }
00487
00488 denominator = 0.0;
00489 for( int i=0; i<size; i++ )
00490 denominator += p[i]*temp[i];
00491
00492 alpha = numerator/denominator;
00493
00494
00495 for( int i=0; i<size; i++ )
00496 x[i] = x[i] + alpha*p[i];
00497
00498 for( int i=0; i<size; i++ )
00499 multi[i] = r[i];
00500
00501
00502 for( int i=0; i<size; i++ )
00503 r[i] = multi[i] - alpha*temp[i];
00504
00505
00506 numerator = 0.0;
00507 for( int i=0; i<size; i++ )
00508 numerator += r[i]*r[i];
00509 denominator = 0.0;
00510 for( int i=0; i<size; i++ )
00511 denominator += multi[i]*multi[i];
00512 beta = numerator/denominator;
00513
00514
00515 for( int i=0; i<size; i++ )
00516 p[i] = r[i]+beta*p[i];
00517
00518 step++;
00519 }
00520
00521
00522
00523
00524 delete []r;
00525 delete []p;
00526 delete []temp;
00527 delete []multi;
00528
00529 return true;
00530 }
00531
00535 bool Predicate(const _Entry & d1, const _Entry & d2)
00536 {
00537 return ( d1.j < d2.j || (d1.j==d2.j && d1.i < d2.i) );
00538 }
00539
00544 bool CSparseMatrix::SolverUMF(double b[], double x[])
00545 {
00546 std::sort( (*m_pEntries).begin(), (*m_pEntries).end(), Predicate );
00547
00548
00549 int n = m_pEntries->size();
00550
00551
00552
00553
00554
00555
00556
00557
00558 int *Ai = new int[n];
00559 double *Ax = new double[n];
00560
00561 for( int i = 0; i < n ; i ++ )
00562 {
00563 _Entry & e = m_pEntries->at(i);
00564 Ai[i] = e.i;
00565 Ax[i] = e.val;
00566 }
00567
00568 int row = GetRows();
00569 int col = GetCols();
00570
00571
00572 int *Ap = new int[col+1];
00573
00574 for( int i = 0; i < col + 1; i ++ )
00575 {
00576 Ap[i] = 0;
00577 }
00578
00579 for( int i = 0; i < n ; i ++ )
00580 {
00581 _Entry & e = m_pEntries->at(i);
00582 Ap[e.j+1]++;
00583 }
00584
00585 for( int i = 1; i < col + 1; i ++ )
00586 {
00587 Ap[i] = Ap[i] + Ap[i-1];
00588 }
00589
00590 printf("%d - %d\n", Ap[col], n );
00591
00592 void *Symbolic, *Numeric ;
00593 double *null = (double *) NULL ;
00594 (void) umfpack_di_symbolic (row, col, Ap, Ai, Ax, &Symbolic, null, null) ;
00595 (void) umfpack_di_numeric (Ap, Ai, Ax, Symbolic, &Numeric, null, null) ;
00596 umfpack_di_free_symbolic (&Symbolic) ;
00597 (void) umfpack_di_solve (UMFPACK_A, Ap, Ai, Ax, x, b, Numeric, null, null) ;
00598 umfpack_di_free_numeric (&Numeric) ;
00599
00600 delete []Ap;
00601 delete []Ax;
00602 delete []Ai;
00603
00604 return true;
00605 }