| 1 | // This file is part of Eigen, a lightweight C++ template library | 
| 2 | // for linear algebra. | 
| 3 | // | 
| 4 | // Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr> | 
| 5 | // | 
| 6 | // This Source Code Form is subject to the terms of the Mozilla | 
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed | 
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. | 
| 9 |  | 
| 10 | #ifndef EIGEN_SOLVERBASE_H | 
| 11 | #define EIGEN_SOLVERBASE_H | 
| 12 |  | 
| 13 | namespace Eigen { | 
| 14 |  | 
| 15 | namespace internal { | 
| 16 |  | 
| 17 | template<typename Derived> | 
| 18 | struct solve_assertion { | 
| 19 |     template<bool Transpose_, typename Rhs> | 
| 20 |     static void run(const Derived& solver, const Rhs& b) { solver.template _check_solve_assertion<Transpose_>(b); } | 
| 21 | }; | 
| 22 |  | 
| 23 | template<typename Derived> | 
| 24 | struct solve_assertion<Transpose<Derived> > | 
| 25 | { | 
| 26 |     typedef Transpose<Derived> type; | 
| 27 |  | 
| 28 |     template<bool Transpose_, typename Rhs> | 
| 29 |     static void run(const type& transpose, const Rhs& b) | 
| 30 |     { | 
| 31 |         internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<true>(transpose.nestedExpression(), b); | 
| 32 |     } | 
| 33 | }; | 
| 34 |  | 
| 35 | template<typename Scalar, typename Derived> | 
| 36 | struct solve_assertion<CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > > | 
| 37 | { | 
| 38 |     typedef CwiseUnaryOp<Eigen::internal::scalar_conjugate_op<Scalar>, const Transpose<Derived> > type; | 
| 39 |  | 
| 40 |     template<bool Transpose_, typename Rhs> | 
| 41 |     static void run(const type& adjoint, const Rhs& b) | 
| 42 |     { | 
| 43 |         internal::solve_assertion<typename internal::remove_all<Transpose<Derived> >::type>::template run<true>(adjoint.nestedExpression(), b); | 
| 44 |     } | 
| 45 | }; | 
| 46 | } // end namespace internal | 
| 47 |  | 
| 48 | /** \class SolverBase | 
| 49 |   * \brief A base class for matrix decomposition and solvers | 
| 50 |   * | 
| 51 |   * \tparam Derived the actual type of the decomposition/solver. | 
| 52 |   * | 
| 53 |   * Any matrix decomposition inheriting this base class provide the following API: | 
| 54 |   * | 
| 55 |   * \code | 
| 56 |   * MatrixType A, b, x; | 
| 57 |   * DecompositionType dec(A); | 
| 58 |   * x = dec.solve(b);             // solve A   * x = b | 
| 59 |   * x = dec.transpose().solve(b); // solve A^T * x = b | 
| 60 |   * x = dec.adjoint().solve(b);   // solve A'  * x = b | 
| 61 |   * \endcode | 
| 62 |   * | 
| 63 |   * \warning Currently, any other usage of transpose() and adjoint() are not supported and will produce compilation errors. | 
| 64 |   * | 
| 65 |   * \sa class PartialPivLU, class FullPivLU, class HouseholderQR, class ColPivHouseholderQR, class FullPivHouseholderQR, class CompleteOrthogonalDecomposition, class LLT, class LDLT, class SVDBase | 
| 66 |   */ | 
| 67 | template<typename Derived> | 
| 68 | class SolverBase : public EigenBase<Derived> | 
| 69 | { | 
| 70 |   public: | 
| 71 |  | 
| 72 |     typedef EigenBase<Derived> Base; | 
| 73 |     typedef typename internal::traits<Derived>::Scalar Scalar; | 
| 74 |     typedef Scalar CoeffReturnType; | 
| 75 |  | 
| 76 |     template<typename Derived_> | 
| 77 |     friend struct internal::solve_assertion; | 
| 78 |  | 
| 79 |     enum { | 
| 80 |       RowsAtCompileTime = internal::traits<Derived>::RowsAtCompileTime, | 
| 81 |       ColsAtCompileTime = internal::traits<Derived>::ColsAtCompileTime, | 
| 82 |       SizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::RowsAtCompileTime, | 
| 83 |                                                           internal::traits<Derived>::ColsAtCompileTime>::ret), | 
| 84 |       MaxRowsAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime, | 
| 85 |       MaxColsAtCompileTime = internal::traits<Derived>::MaxColsAtCompileTime, | 
| 86 |       MaxSizeAtCompileTime = (internal::size_at_compile_time<internal::traits<Derived>::MaxRowsAtCompileTime, | 
| 87 |                                                              internal::traits<Derived>::MaxColsAtCompileTime>::ret), | 
| 88 |       IsVectorAtCompileTime = internal::traits<Derived>::MaxRowsAtCompileTime == 1 | 
| 89 |                            || internal::traits<Derived>::MaxColsAtCompileTime == 1, | 
| 90 |       NumDimensions = int(MaxSizeAtCompileTime) == 1 ? 0 : bool(IsVectorAtCompileTime) ? 1 : 2 | 
| 91 |     }; | 
| 92 |  | 
| 93 |     /** Default constructor */ | 
| 94 |     SolverBase() | 
| 95 |     {} | 
| 96 |  | 
| 97 |     ~SolverBase() | 
| 98 |     {} | 
| 99 |  | 
| 100 |     using Base::derived; | 
| 101 |  | 
| 102 |     /** \returns an expression of the solution x of \f$ A x = b \f$ using the current decomposition of A. | 
| 103 |       */ | 
| 104 |     template<typename Rhs> | 
| 105 |     inline const Solve<Derived, Rhs> | 
| 106 |     solve(const MatrixBase<Rhs>& b) const | 
| 107 |     { | 
| 108 |       internal::solve_assertion<typename internal::remove_all<Derived>::type>::template run<false>(derived(), b); | 
| 109 |       return Solve<Derived, Rhs>(derived(), b.derived()); | 
| 110 |     } | 
| 111 |  | 
| 112 |     /** \internal the return type of transpose() */ | 
| 113 |     typedef typename internal::add_const<Transpose<const Derived> >::type ConstTransposeReturnType; | 
| 114 |     /** \returns an expression of the transposed of the factored matrix. | 
| 115 |       * | 
| 116 |       * A typical usage is to solve for the transposed problem A^T x = b: | 
| 117 |       * \code x = dec.transpose().solve(b); \endcode | 
| 118 |       * | 
| 119 |       * \sa adjoint(), solve() | 
| 120 |       */ | 
| 121 |     inline ConstTransposeReturnType transpose() const | 
| 122 |     { | 
| 123 |       return ConstTransposeReturnType(derived()); | 
| 124 |     } | 
| 125 |  | 
| 126 |     /** \internal the return type of adjoint() */ | 
| 127 |     typedef typename internal::conditional<NumTraits<Scalar>::IsComplex, | 
| 128 |                         CwiseUnaryOp<internal::scalar_conjugate_op<Scalar>, ConstTransposeReturnType>, | 
| 129 |                         ConstTransposeReturnType | 
| 130 |                      >::type AdjointReturnType; | 
| 131 |     /** \returns an expression of the adjoint of the factored matrix | 
| 132 |       * | 
| 133 |       * A typical usage is to solve for the adjoint problem A' x = b: | 
| 134 |       * \code x = dec.adjoint().solve(b); \endcode | 
| 135 |       * | 
| 136 |       * For real scalar types, this function is equivalent to transpose(). | 
| 137 |       * | 
| 138 |       * \sa transpose(), solve() | 
| 139 |       */ | 
| 140 |     inline AdjointReturnType adjoint() const | 
| 141 |     { | 
| 142 |       return AdjointReturnType(derived().transpose()); | 
| 143 |     } | 
| 144 |  | 
| 145 |   protected: | 
| 146 |  | 
| 147 |     template<bool Transpose_, typename Rhs> | 
| 148 |     void _check_solve_assertion(const Rhs& b) const { | 
| 149 |         EIGEN_ONLY_USED_FOR_DEBUG(b); | 
| 150 |         eigen_assert(derived().m_isInitialized && "Solver is not initialized." ); | 
| 151 |         eigen_assert((Transpose_?derived().cols():derived().rows())==b.rows() && "SolverBase::solve(): invalid number of rows of the right hand side matrix b" ); | 
| 152 |     } | 
| 153 | }; | 
| 154 |  | 
| 155 | namespace internal { | 
| 156 |  | 
| 157 | template<typename Derived> | 
| 158 | struct generic_xpr_base<Derived, MatrixXpr, SolverStorage> | 
| 159 | { | 
| 160 |   typedef SolverBase<Derived> type; | 
| 161 |  | 
| 162 | }; | 
| 163 |  | 
| 164 | } // end namespace internal | 
| 165 |  | 
| 166 | } // end namespace Eigen | 
| 167 |  | 
| 168 | #endif // EIGEN_SOLVERBASE_H | 
| 169 |  |