57 resizeI<std::vector<int>>({i});
62 resizeI<std::vector<int>>({i, j});
65 Tensor(
int i,
int j,
int k)
67 resizeI<std::vector<int>>({i, j, k});
70 Tensor(
int i,
int j,
int k,
int l)
72 resizeI<std::vector<int>>({i, j, k, l});
75 template <
typename Sizes>
76 void resizeI(
const Sizes& sizes)
78 if (sizes.size() == 1)
79 dims_ = {(int)sizes[0]};
80 if (sizes.size() == 2)
81 dims_ = {(int)sizes[0], (
int)sizes[1]};
82 if (sizes.size() == 3)
83 dims_ = {(int)sizes[0], (
int)sizes[1], (int)sizes[2]};
84 if (sizes.size() == 4)
85 dims_ = {(int)sizes[0], (
int)sizes[1], (int)sizes[2], (
int)sizes[3]};
87 data_.resize(std::accumulate(begin(dims_), end(dims_), 1.0, std::multiplies<>()));
92 OPM_ERROR_IF(dims_.size() == 0,
"Invalid tensor");
94 int elements = dims_[0];
95 for (
unsigned int i = 1; i < dims_.size(); i++) {
103 OPM_ERROR_IF(dims_.size() != 1,
"Invalid indexing for tensor");
105 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
106 fmt::format(
" Invalid i: "
116 T& operator()(
int i,
int j)
118 OPM_ERROR_IF(dims_.size() != 2,
"Invalid indexing for tensor");
119 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
120 fmt::format(
" Invalid i: "
126 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
127 fmt::format(
" Invalid j: "
134 return data_[dims_[1] * i + j];
137 const T& operator()(
int i,
int j)
const
139 OPM_ERROR_IF(dims_.size() != 2,
"Invalid indexing for tensor");
140 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
141 fmt::format(
" Invalid i: "
147 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
148 fmt::format(
" Invalid j: "
154 return data_[dims_[1] * i + j];
157 T& operator()(
int i,
int j,
int k)
159 OPM_ERROR_IF(dims_.size() != 3,
"Invalid indexing for tensor");
160 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
161 fmt::format(
" Invalid i: "
167 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
168 fmt::format(
" Invalid j: "
174 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
175 fmt::format(
" Invalid k: "
182 return data_[dims_[2] * (dims_[1] * i + j) + k];
185 const T& operator()(
int i,
int j,
int k)
const
187 OPM_ERROR_IF(dims_.size() != 3,
"Invalid indexing for tensor");
188 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
189 fmt::format(
" Invalid i: "
195 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
196 fmt::format(
" Invalid j: "
202 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
203 fmt::format(
" Invalid k: "
210 return data_[dims_[2] * (dims_[1] * i + j) + k];
213 T& operator()(
int i,
int j,
int k,
int l)
215 OPM_ERROR_IF(dims_.size() != 4,
"Invalid indexing for tensor");
216 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
217 fmt::format(
" Invalid i: "
223 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
224 fmt::format(
" Invalid j: "
230 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
231 fmt::format(
" Invalid k: "
237 OPM_ERROR_IF(!(l < dims_[3] && l >= 0),
238 fmt::format(
" Invalid l: "
245 return data_[dims_[3] * (dims_[2] * (dims_[1] * i + j) + k) + l];
248 const T& operator()(
int i,
int j,
int k,
int l)
const
250 OPM_ERROR_IF(dims_.size() != 4,
"Invalid indexing for tensor");
251 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
252 fmt::format(
" Invalid i: "
258 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
259 fmt::format(
" Invalid j: "
265 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
266 fmt::format(
" Invalid k: "
272 OPM_ERROR_IF(!(l < dims_[3] && l >= 0),
273 fmt::format(
" Invalid l: "
280 return data_[dims_[3] * (dims_[2] * (dims_[1] * i + j) + k) + l];
283 void fill(
const T& value)
285 std::fill(data_.begin(), data_.end(), value);
291 OPM_ERROR_IF(dims_.size() != other.dims_.size(),
292 "Cannot add tensors with different dimensions");
294 result.dims_ = dims_;
295 result.data_.resize(data_.size());
297 std::transform(data_.begin(),
300 result.data_.begin(),
301 [](
const T& x,
const T& y) { return x + y; });
309 OPM_ERROR_IF(dims_.size() != other.dims_.size(),
310 "Cannot multiply elements with different dimensions");
313 result.dims_ = dims_;
314 result.data_.resize(data_.size());
316 std::transform(data_.begin(),
319 result.data_.begin(),
320 [](
const T& x,
const T& y) { return x * y; });
328 OPM_ERROR_IF(dims_.size() != 2,
"Invalid tensor dimensions");
329 OPM_ERROR_IF(other.dims_.size() != 2,
"Invalid tensor dimensions");
331 OPM_ERROR_IF(dims_[1] != other.dims_[0],
332 "Cannot multiply with different inner dimensions");
334 Tensor tmp(dims_[0], other.dims_[1]);
336 for (
int i = 0; i < dims_[0]; i++) {
337 for (
int j = 0; j < other.dims_[1]; j++) {
338 for (
int k = 0; k < dims_[1]; k++) {
339 tmp(i, j) += (*this)(i, k) * other(k, j);
349 dims_.swap(other.dims_);
350 data_.swap(other.data_);
353 std::vector<int> dims_;
354 std::vector<T> data_;