PRISMS-PF Manual v3.0-pre
All Classes Functions Variables Enumerations Pages
checkpoint_parameters.h
1// SPDX-FileCopyrightText: © 2025 PRISMS Center at the University of Michigan
2// SPDX-License-Identifier: GNU Lesser General Public Version 2.1
3
4#ifndef checkpoint_parameters_h
5#define checkpoint_parameters_h
6
7#include <prismspf/config.h>
8#include <prismspf/core/conditional_ostreams.h>
9#include <prismspf/core/exceptions.h>
10#include <prismspf/user_inputs/temporal_discretization.h>
11#include <prismspf/utilities/utilities.h>
12
13#include <climits>
14#include <set>
15#include <string>
16
17PRISMS_PF_BEGIN_NAMESPACE
18
23{
24public:
28 [[nodiscard]] bool
29 should_checkpoint(unsigned int increment) const;
30
34 void
35 postprocess_and_validate(const temporalDiscretization &temporal_discretization);
36
40 void
42
43 // Whether to load from a checkpoint
44 bool load_from_checkpoint = false;
45
46 // Checkpoint condition type
47 std::string condition;
48
49 // Number of checkpoints
50 unsigned int n_checkpoints = 0;
51
52 // User given checkpoint list
53 std::vector<int> user_checkpoint_list;
54
55 // List of increments for checkpoints
56 std::set<unsigned int> checkpoint_list;
57};
58
59inline bool
60checkpointParameters::should_checkpoint(unsigned int increment) const
61{
62 return checkpoint_list.find(increment) != checkpoint_list.end();
63}
64
65inline void
67 const temporalDiscretization &temporal_discretization)
68{
69 // If the user has specified a list and we have list checkpoint use that and return
70 // early
71 if (condition == "LIST")
72 {
73 for (const auto &increment : user_checkpoint_list)
74 {
75 checkpoint_list.insert(static_cast<unsigned int>(increment));
76 }
77 return;
78 }
79
80 // If the number of checkpoints is 0 return early
81 if (n_checkpoints == 0)
82 {
83 return;
84 }
85
86 // If the number of outputs is greater than the number of increments, force them to be
87 // equivalent
88 n_checkpoints = std::min(n_checkpoints, temporal_discretization.total_increments);
89
90 // Determine the output list from the other criteria
91 if (condition == "EQUAL_SPACING")
92 {
93 for (unsigned int iteration = 0;
94 iteration <= temporal_discretization.total_increments;
95 iteration += temporal_discretization.total_increments / n_checkpoints)
96 {
97 checkpoint_list.insert(iteration);
98 }
99 }
100 else if (condition == "LOG_SPACING")
101 {
102 checkpoint_list.insert(0);
103 for (unsigned int output = 1; output <= n_checkpoints; output++)
104 {
105 checkpoint_list.insert(static_cast<unsigned int>(std::round(
106 std::pow(static_cast<double>(temporal_discretization.total_increments),
107 static_cast<double>(output) / static_cast<double>(n_checkpoints)))));
108 }
109 }
110 else if (condition == "N_PER_DECADE")
111 {
112 AssertThrow(temporal_discretization.total_increments > 1,
113 dealii::ExcMessage("For n per decaded spaced outputs, the number of "
114 "increments must be greater than 1."));
115
116 checkpoint_list.insert(0);
117 checkpoint_list.insert(1);
118 for (unsigned int iteration = 2;
119 iteration <= temporal_discretization.total_increments;
120 iteration++)
121 {
122 const auto decade = static_cast<unsigned int>(std::ceil(std::log10(iteration)));
123 const auto step_size =
124 static_cast<unsigned int>(std::pow(10, decade) / n_checkpoints);
125 if (iteration % step_size == 0)
126 {
127 checkpoint_list.insert(iteration);
128 }
129 }
130 }
131 else
132 {
133 AssertThrow(false, UnreachableCode());
134 }
135}
136
137inline void
139{
141 << "================================================\n"
142 << " Checkpoint Parameters\n"
143 << "================================================\n"
144 << "Checkpoint condition: " << condition << "\n"
145 << "Number of checkpoints: " << n_checkpoints << "\n";
146
147 conditionalOStreams::pout_summary() << "Checkpoint iteration list: ";
148 for (const auto &iteration : checkpoint_list)
149 {
150 conditionalOStreams::pout_summary() << iteration << " ";
151 }
152 conditionalOStreams::pout_summary() << "\n\n" << std::flush;
153}
154
155PRISMS_PF_END_NAMESPACE
156
157#endif
static dealii::ConditionalOStream & pout_summary()
Log output stream for writing a summary.log file.
Definition conditional_ostreams.cc:22
Struct that holds checkpoint parameters.
Definition checkpoint_parameters.h:23
void print_parameter_summary() const
Print parameters to summary.log.
Definition checkpoint_parameters.h:138
bool should_checkpoint(unsigned int increment) const
Return is the current increment should be checkpointed.
Definition checkpoint_parameters.h:60
void postprocess_and_validate(const temporalDiscretization &temporal_discretization)
Postprocess and validate parameters.
Definition checkpoint_parameters.h:66
Struct that holds temporal discretization parameters.
Definition temporal_discretization.h:17