PRISMS-PF Manual
Loading...
Searching...
No Matches
linear_solver.h
Go to the documentation of this file.
1// SPDX-FileCopyrightText: © 2025 PRISMS Center at the University of Michigan
2// SPDX-License-Identifier: GNU Lesser General Public Version 2.1
3
4#pragma once
5
6#include <deal.II/lac/precondition.h>
7#include <deal.II/lac/solver_cg.h>
8#include <deal.II/lac/solver_control.h>
9
13#include <prismspf/core/timer.h>
15#include <prismspf/core/types.h>
16
19
21
22#include <prismspf/config.h>
23
25
26template <unsigned int dim, unsigned int degree, typename number>
27class SolveContext;
28
32template <unsigned int dim, unsigned int degree, typename number>
33class LinearSolver : public SolverBase<dim, degree, number>
34{
35protected:
36 using SolverBase<dim, degree, number>::solutions;
37 using SolverBase<dim, degree, number>::solve_context;
38 using SolverBase<dim, degree, number>::solve_group;
39
40public:
47 , lin_params(
48 solve_context->get_user_inputs().linear_solve_parameters.linear_solvers.at(
49 solve_group.id))
50 {}
51
55 void
56 init(const std::list<DependencyMap> &all_dependeny_sets) override
57 {
59 unsigned int num_levels = solve_context->get_dof_manager().get_dof_handlers().size();
60 rhs_vector.resize(num_levels);
61 for (unsigned int relative_level = 0; relative_level < num_levels; ++relative_level)
62 {
63 rhs_vector[relative_level].reinit(
64 solutions.get_solution_full_vector(relative_level));
65 }
66 // Initialize rhs_operators
67 rhs_operators.reserve(num_levels);
68 for (unsigned int relative_level = 0; relative_level < num_levels; ++relative_level)
69 {
70 rhs_operators.emplace_back(solve_context->get_pde_operator(),
72 solve_context->get_field_attributes(),
73 solve_context->get_solution_indexer(),
74 relative_level,
76 solve_context->get_simulation_timer());
77 rhs_operators[relative_level].initialize(solutions);
78 rhs_operators[relative_level].set_scaling_diagonal(
80 solve_context->get_invm_manager().get_invm_sqrt(
81 solve_context->get_field_attributes(),
83 relative_level));
84 }
85 // Initialize lhs_operators
86 lhs_operators.reserve(num_levels);
87 for (unsigned int relative_level = 0; relative_level < num_levels; ++relative_level)
88 {
89 lhs_operators.emplace_back(solve_context->get_pde_operator(),
91 solve_context->get_field_attributes(),
92 solve_context->get_solution_indexer(),
93 relative_level,
95 solve_context->get_simulation_timer());
96 lhs_operators[relative_level].initialize(solutions);
97 lhs_operators[relative_level].set_scaling_diagonal(
99 solve_context->get_invm_manager().get_invm_sqrt(
100 solve_context->get_field_attributes(),
102 relative_level));
103 }
106 }
107
111 void
112 reinit() override
113 {
115 const unsigned int num_levels = rhs_vector.size();
116 for (unsigned int relative_level = 0; relative_level < num_levels; ++relative_level)
117 {
118 rhs_vector[relative_level].reinit(
119 solutions.get_solution_full_vector(relative_level));
120 }
121 }
122
126 void
127 solve_level(unsigned int relative_level) override
128 {
129 // Zero out the ghosts
130 Timer::start_section("Zero ghosts");
131 solutions.zero_out_ghosts(relative_level);
132 Timer::end_section("Zero ghosts");
133
134 // Set up linear solver
135 rhs_operators[relative_level].compute_operator(rhs_vector[relative_level]);
136 do_linear_solve(rhs_vector[relative_level],
137 lhs_operators[relative_level],
138 solutions.get_solution_full_vector(relative_level));
139
140 // Apply constraints
141 solutions.apply_constraints(relative_level);
142
143 // Update the ghosts
144 Timer::start_section("Update ghosts");
145 solutions.update_ghosts(relative_level);
146 Timer::end_section("Update ghosts");
147 }
148
149 int
153 {
154 // Linear solve
155 try
156 {
157 dealii::SolverCG<BlockVector<number>> cg_solver(linear_solver_control);
158 cg_solver.solve(lhs_operator, x_vector, b_vector, dealii::PreconditionIdentity());
159 if (solve_context->get_user_inputs().output_parameters.should_output(
160 solve_context->get_simulation_timer().get_increment()))
161 {
163 << " Linear solve final residual : "
165 << " Linear steps: " << linear_solver_control.last_step() << "\n"
166 << std::flush;
167 }
168 }
169 catch (...) // TODO: more specific catch
170 {
172 << "[Increment " << solve_context->get_simulation_timer().get_increment()
173 << "] "
174 << "Warning: linear solver did not converge as per set tolerances before "
175 << lin_params.max_iterations << " iterations.\n";
176 }
177 return linear_solver_control.last_step();
178 }
179
180protected:
184 std::vector<MFOperator<dim, degree, number>> rhs_operators;
185
186 std::vector<MFOperator<dim, degree, number>> lhs_operators;
187 std::vector<BlockVector<number>> rhs_vector;
188
189 double
191 {
193 using std::sqrt;
194 double value = 1.0;
195 if (type == RMSEPerField || type == RMSETotal)
196 {
197 value *= sqrt(solve_context->get_triangulation_manager().get_volume());
198 }
199 if (type == RMSEPerField || type == IntegratedPerField)
200 {
201 value *= sqrt(double(solve_group.field_indices.size()));
202 }
203 return value;
204 }
205
206private:
211
215 dealii::SolverControl linear_solver_control;
216};
217
static dealii::ConditionalOStream & pout_summary()
Log output stream for writing a summary.log file.
Definition conditional_ostreams.cc:34
static dealii::ConditionalOStream & pout_base()
Generic parallel output stream. Used for essential information in release and debug mode.
Definition conditional_ostreams.cc:43
This class handles the explicit solves of all explicit fields.
Definition linear_solver.h:34
std::vector< MFOperator< dim, degree, number > > rhs_operators
Matrix free operators for each level.
Definition linear_solver.h:184
int do_linear_solve(BlockVector< number > &b_vector, MFOperator< dim, degree, number > &lhs_operator, BlockVector< number > &x_vector)
Definition linear_solver.h:150
LinearSolverParameters lin_params
Linear solver parameters.
Definition linear_solver.h:210
double normalization_value()
Definition linear_solver.h:190
void init(const std::list< DependencyMap > &all_dependeny_sets) override
Initialize the solver.
Definition linear_solver.h:56
std::vector< BlockVector< number > > rhs_vector
Definition linear_solver.h:187
LinearSolver(SolveGroup _solve_group, const SolveContext< dim, degree, number > &_solve_context)
Constructor.
Definition linear_solver.h:44
std::vector< MFOperator< dim, degree, number > > lhs_operators
Definition linear_solver.h:186
dealii::SolverControl linear_solver_control
Solver control. Contains max iterations and tolerance.
Definition linear_solver.h:215
void solve_level(unsigned int relative_level) override
Solve for a single update step.
Definition linear_solver.h:127
void reinit() override
Reinitialize the solver.
Definition linear_solver.h:112
This class exists to evaluate a single user-defined operator for the matrix-free implementation of so...
Definition mf_operator.h:50
This class contains the user implementation of each PDE operator.
Definition pde_operator_base.h:24
This class provides context for a solver with ptrs to all the relevant dependencies.
Definition solve_context.h:36
Structure to hold the attributes of a solve-group.
Definition solve_group.h:59
std::set< Types::Index > field_indices
Indices of the fields to be solved in this group.
Definition solve_group.h:98
DependencyMap dependencies_rhs
Dependencies for the rhs equation(s)
Definition solve_group.h:103
DependencyMap dependencies_lhs
Dependencies for the lhs equation(s)
Definition solve_group.h:107
Definition solver_base.h:32
virtual void reinit()
Reinitialize the solution vectors & apply constraints.
Definition solver_base.h:98
SolveGroup solve_group
Information about the solve group this handler is responsible for.
Definition solver_base.h:259
virtual void init(const std::list< DependencyMap > &all_dependeny_sets)
Initialize the solver.
Definition solver_base.h:83
GroupSolutionHandler< dim, number > solutions
Solution vectors for fields handled by this solver.
Definition solver_base.h:269
const SolveContext< dim, degree, number > * solve_context
Solver context provides access to external information.
Definition solver_base.h:264
static void start_section(const char *name)
Start a new timer section.
Definition timer.cc:116
static void end_section(const char *name)
End the timer section.
Definition timer.cc:127
@ Value
Use value of the variable as a criterion for refinement.
Definition grid_refiner_criterion.h:31
dealii::LinearAlgebra::distributed::BlockVector< number > BlockVector
Typedef for solution block vector.
Definition group_solution_handler.h:29
Definition conditional_ostreams.cc:20
Struct that stores relevant linear solve information of a certain field.
Definition linear_solve_parameters.h:22
unsigned int max_iterations
Definition linear_solve_parameters.h:30
double tolerance
Definition linear_solve_parameters.h:24
SolverToleranceType tolerance_type
Definition linear_solve_parameters.h:27
SolverToleranceType
Solver tolerance type.
Definition type_enums.h:69
@ RMSEPerField
The mean local error averaged over each field is lower than the tolerance.
Definition type_enums.h:81
@ AbsoluteResidual
Legacy.
Definition type_enums.h:73
@ RMSETotal
The sum of the average local errors of each field is lower than the tolerance.
Definition type_enums.h:89
@ IntegratedPerField
The integrated error averaged over each field is lower than the tolerance.
Definition type_enums.h:85