Skip to content

Commit 9bac4ca

Browse files
committed
some utility methods added for shape handling
1 parent 65c0c71 commit 9bac4ca

File tree

9 files changed

+246
-54
lines changed

9 files changed

+246
-54
lines changed

src/java/org/tensorics/core/lang/Tensorics.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323
package org.tensorics.core.lang;
2424

25-
import static org.tensorics.core.tensor.operations.PositionFunctions.forSupplier;
26-
2725
import java.util.Map;
2826
import java.util.Map.Entry;
2927
import java.util.Set;
@@ -46,9 +44,9 @@
4644
import org.tensorics.core.tensor.lang.OngoingTensorManipulation;
4745
import org.tensorics.core.tensor.lang.QuantityTensors;
4846
import org.tensorics.core.tensor.lang.TensorStructurals;
49-
import org.tensorics.core.tensor.operations.FunctionTensorCreationOperation;
50-
import org.tensorics.core.tensor.operations.SingleValueTensorCreationOperation;
47+
import org.tensorics.core.tensor.operations.TensorInternals;
5148
import org.tensorics.core.tensor.stream.TensorStreams;
49+
import org.tensorics.core.tensorbacked.OngoingTensorbackedCompletion;
5250
import org.tensorics.core.tensorbacked.Tensorbacked;
5351
import org.tensorics.core.tensorbacked.TensorbackedBuilder;
5452
import org.tensorics.core.tensorbacked.Tensorbackeds;
@@ -303,15 +301,15 @@ public static <S> OngoingFlattening<S> flatten(Tensorbacked<S> tensorbacked) {
303301
}
304302

305303
public static <S> Tensor<S> sameValues(Shape shape, S value) {
306-
return new SingleValueTensorCreationOperation<S>(shape, value).perform();
304+
return TensorInternals.sameValues(shape, value);
307305
}
308306

309307
public static <S> Tensor<S> createFrom(Shape shape, Supplier<S> supplier) {
310-
return new FunctionTensorCreationOperation<>(shape, forSupplier(supplier)).perform();
308+
return TensorInternals.createFrom(shape, supplier);
311309
}
312310

313311
public static <S> Tensor<S> createFrom(Shape shape, Function<Position, S> function) {
314-
return new FunctionTensorCreationOperation<>(shape, function).perform();
312+
return TensorInternals.createFrom(shape, function);
315313
}
316314

317315
public static <S> OngoingCompletion<S> complete(Tensor<S> tensor) {
@@ -407,4 +405,19 @@ public static <S> Stream<Map.Entry<Position, S>> stream(Tensor<S> tensor) {
407405
public static <S> Stream<Map.Entry<Position, S>> stream(Tensorbacked<S> tensorBacked) {
408406
return TensorStreams.tensorEntryStream(tensorBacked.tensor());
409407
}
408+
409+
/**
410+
* @see Tensorbackeds#shapesOf(Tensorbacked)
411+
*/
412+
public static <TB extends Tensorbacked<?>> Iterable<Shape> shapesOf(Iterable<TB> tensorbackeds) {
413+
return Tensorbackeds.shapesOf(tensorbackeds);
414+
}
415+
416+
/**
417+
* @see Tensorbackeds#complete(Tensorbacked)
418+
*/
419+
public static <S, TB extends Tensorbacked<S>> OngoingTensorbackedCompletion<TB, S> complete(TB tensorbacked) {
420+
return Tensorbackeds.complete(tensorbacked);
421+
}
422+
410423
}

src/java/org/tensorics/core/tensor/Shapes.java

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@
2525
import static com.google.common.base.Preconditions.checkArgument;
2626
import static com.google.common.base.Preconditions.checkNotNull;
2727
import static com.google.common.collect.Collections2.transform;
28-
import static com.google.common.collect.Sets.union;
28+
import static java.util.Objects.requireNonNull;
2929

30+
import java.util.NoSuchElementException;
3031
import java.util.Set;
32+
import java.util.function.BiFunction;
3133
import java.util.function.Function;
3234

35+
import com.google.common.collect.Iterables;
3336
import com.google.common.collect.Sets;
3437

3538
/**
@@ -48,20 +51,59 @@ private Shapes() {
4851

4952
/**
5053
* Creates a shape, containing all the positions that are contained in both given shapes. This only makes sense, if
51-
* the dimensions of the two sets are the same. If they are not, then an {@link IllegalArgumentException} is thrown.
54+
* the dimensions of the two shapes are the same. If they are not, then an {@link IllegalArgumentException} is
55+
* thrown.
5256
*
5357
* @param left the first shape from which to take positions
5458
* @param right the second shape from which to take positions
5559
* @return a shape containing all positions, which are contained in both given shapes
5660
*/
5761
public static Shape intersection(Shape left, Shape right) {
58-
checkLeftRightNotNull(left, right);
59-
if (!left.hasSameDimensionsAs(right)) {
60-
throw new IllegalArgumentException("The two shapes do not have the same dimension, "
61-
+ "therefore the intersection of coordinates cannot be determined. Left dimensions: "
62-
+ left.dimensionSet() + "; Right dimensions: " + right.dimensionSet());
63-
}
64-
return Shape.of(Sets.intersection(left.positionSet(), right.positionSet()));
62+
return combineLeftRightBy(left, right, Sets::intersection);
63+
}
64+
65+
/**
66+
* Creates a shape, containing all the positions that are either contained in the left or the right shape.This only
67+
* makes sense, if the dimensions of the two shapes are the same. If they are not, then an
68+
* {@link IllegalArgumentException} is thrown.
69+
*
70+
* @param left the first shape from which to take positions
71+
* @param right the second shape from which to take positions
72+
* @return a shape containing all positions, which are contained in at least one of the two shapes
73+
*/
74+
public static Shape union(Shape left, Shape right) {
75+
return combineLeftRightBy(left, right, Sets::union);
76+
}
77+
78+
/**
79+
* Creates a shape, containing all the positions that are contained at least in one of the given shapes. This only
80+
* makes sense, if the dimensions of the shapes are the same. If they are not, then an
81+
* {@link IllegalArgumentException} is thrown. Further, it is required that at least one element is contained in the
82+
* iterable.
83+
*
84+
* @param shapes the shapes for which the union shall be found
85+
* @return a shape which represents the union of all the shapes
86+
* @throws IllegalArgumentException if the shapes are not of the same dimension
87+
* @throws NoSuchElementException in case the iterable is empty
88+
* @throws NullPointerException if the given iterable is {@code null}
89+
*/
90+
public static final Shape union(Iterable<Shape> shapes) {
91+
return combineBy(shapes, Shapes::union);
92+
}
93+
94+
/**
95+
* Creates a shape, containing the positions which are contained in each of the given shapes. This only makes sense,
96+
* if the dimensions of the shapes are the same. If they are not, then an {@link IllegalArgumentException} is
97+
* thrown. Further, it is required that at least one element is contained in the iterable.
98+
*
99+
* @param shapes the shapes for which the intersection shall be found
100+
* @return a shape which represents the intersection of all the shapes
101+
* @throws IllegalArgumentException if the shapes are not of the same dimension
102+
* @throws NoSuchElementException in case the iterable is empty
103+
* @throws NullPointerException if the given iterable is {@code null}
104+
*/
105+
public static final Shape intersection(Iterable<Shape> shapes) {
106+
return combineBy(shapes, Shapes::intersection);
65107
}
66108

67109
/**
@@ -89,7 +131,8 @@ public static Set<Class<?>> dimensionalIntersection(Shape left, Shape right) {
89131
public static Shape dimensionStripped(Shape shape, Set<? extends Class<?>> dimensionsToStrip) {
90132
checkNotNull(shape, "shape must not be null");
91133
checkNotNull(dimensionsToStrip, "dimensions must not be null");
92-
return Shape.of(Positions.unique(transform(shape.positionSet(), toGuavaFunction(Positions.stripping(dimensionsToStrip)))));
134+
return Shape.of(Positions
135+
.unique(transform(shape.positionSet(), toGuavaFunction(Positions.stripping(dimensionsToStrip)))));
93136
}
94137

95138
/**
@@ -107,7 +150,7 @@ public static Shape dimensionStripped(Shape shape, Set<? extends Class<?>> dimen
107150
public static Shape outerProduct(Shape left, Shape right) {
108151
checkArgument(dimensionalIntersection(left, right).isEmpty(), "The two shapes have "
109152
+ "overlapping dimensions. The outer product is not foreseen to be used in this situation.");
110-
Shape.Builder builder = Shape.builder(union(left.dimensionSet(), right.dimensionSet()));
153+
Shape.Builder builder = Shape.builder(Sets.union(left.dimensionSet(), right.dimensionSet()));
111154
for (Position leftPosition : left.positionSet()) {
112155
for (Position rightPosition : right.positionSet()) {
113156
builder.add(Positions.union(leftPosition, rightPosition));
@@ -137,4 +180,36 @@ public R apply(T input) {
137180
}
138181
};
139182
}
183+
184+
private static void checkLeftRightSameDimensions(Shape left, Shape right) {
185+
if (!left.hasSameDimensionsAs(right)) {
186+
throw new IllegalArgumentException("The two shapes do not have the same dimension, "
187+
+ "therefore combining of positions does not make sense. Left dimensions: " + left.dimensionSet()
188+
+ "; Right dimensions: " + right.dimensionSet());
189+
}
190+
}
191+
192+
private static Shape combineLeftRightBy(Shape left, Shape right,
193+
BiFunction<Set<Position>, Set<Position>, Set<Position>> combiner) {
194+
checkLeftRightNotNull(left, right);
195+
checkLeftRightSameDimensions(left, right);
196+
return Shape.of(combiner.apply(left.positionSet(), right.positionSet()));
197+
}
198+
199+
private static Shape combineBy(Iterable<Shape> shapes, BiFunction<Shape, Shape, Shape> combiner) {
200+
requireNonNull(shapes, "shapes must not be null");
201+
if (Iterables.isEmpty(shapes)) {
202+
throw new NoSuchElementException("At least one shape is required.");
203+
}
204+
Shape resultingShape = null;
205+
for (Shape shape : shapes) {
206+
if (shape == null) {
207+
resultingShape = shape;
208+
} else {
209+
resultingShape = combiner.apply(resultingShape, shape);
210+
}
211+
}
212+
return resultingShape;
213+
}
214+
140215
}

src/java/org/tensorics/core/tensor/lang/OngoingCompletion.java

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,9 @@
2222

2323
package org.tensorics.core.tensor.lang;
2424

25-
import static com.google.common.base.Preconditions.checkArgument;
26-
import static com.google.common.base.Preconditions.checkNotNull;
27-
28-
import java.util.Map.Entry;
29-
import java.util.Set;
30-
31-
import org.tensorics.core.tensor.ImmutableTensor;
32-
import org.tensorics.core.tensor.ImmutableTensor.Builder;
33-
import org.tensorics.core.tensor.Position;
25+
import org.tensorics.core.tensor.Shape;
3426
import org.tensorics.core.tensor.Tensor;
27+
import org.tensorics.core.tensor.operations.TensorInternals;
3528

3629
import com.google.common.base.Preconditions;
3730

@@ -44,24 +37,11 @@ public class OngoingCompletion<S> {
4437
}
4538

4639
public Tensor<S> with(Tensor<S> second) {
47-
checkNotNull(second, "second tensor must not be null");
48-
checkArgument(second.shape().dimensionSet().equals(dimensions()),
49-
"Tensors do not have the same dimensions! Completion not supported!");
50-
Builder<S> builder = ImmutableTensor.builder(dimensions());
51-
builder.context(tensor.context());
52-
for (Entry<Position, S> entry: second.asMap().entrySet()) {
53-
Position position = entry.getKey();
54-
if (tensor.shape().contains(position)) {
55-
builder.putAt(tensor.get(position), position);
56-
} else {
57-
builder.put(entry);
58-
}
59-
}
60-
return builder.build();
40+
return TensorStructurals.completeWith(tensor, second);
6141
}
6242

63-
private Set<Class<?>> dimensions() {
64-
return tensor.shape().dimensionSet();
43+
public Tensor<S> with(Shape shape, S value) {
44+
return with(TensorInternals.sameValues(shape, value));
6545
}
6646

6747
}

src/java/org/tensorics/core/tensor/lang/TensorStructurals.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222

2323
package org.tensorics.core.tensor.lang;
2424

25+
import static com.google.common.base.Preconditions.checkArgument;
26+
import static com.google.common.base.Preconditions.checkNotNull;
27+
2528
import java.util.HashSet;
2629
import java.util.Map.Entry;
2730
import java.util.Set;
@@ -30,6 +33,8 @@
3033
import org.tensorics.core.tensor.ImmutableTensor;
3134
import org.tensorics.core.tensor.ImmutableTensor.Builder;
3235
import org.tensorics.core.tensor.Position;
36+
import org.tensorics.core.tensor.Shape;
37+
import org.tensorics.core.tensor.Shapes;
3338
import org.tensorics.core.tensor.Tensor;
3439

3540
import com.google.common.collect.Iterables;
@@ -163,4 +168,22 @@ public static final <S> OngoingTensorFiltering<S> filter(Tensor<S> tensor) {
163168
return new OngoingTensorFiltering<>(tensor);
164169
}
165170

171+
public static final <S> Tensor<S> completeWith(Tensor<S> tensor, Tensor<S> second) {
172+
checkNotNull(second, "second tensor must not be null");
173+
checkArgument(second.shape().dimensionSet().equals(tensor.shape().dimensionSet()),
174+
"Tensors do not have the same dimensions! Completion not supported!");
175+
Builder<S> builder = ImmutableTensor.builder(tensor.shape().dimensionSet());
176+
builder.context(tensor.context());
177+
178+
Shape shape = Shapes.union(tensor.shape(), second.shape());
179+
for (Position position : shape.positionSet()) {
180+
if (tensor.shape().contains(position)) {
181+
builder.putAt(tensor.get(position), position);
182+
} else {
183+
builder.putAt(second.get(position), position);
184+
}
185+
}
186+
return builder.build();
187+
}
188+
166189
}

src/java/org/tensorics/core/tensor/operations/TensorInternals.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@
2222

2323
package org.tensorics.core.tensor.operations;
2424

25+
import static org.tensorics.core.tensor.operations.PositionFunctions.forSupplier;
26+
2527
import java.util.Map.Entry;
2628
import java.util.Set;
29+
import java.util.function.Function;
30+
import java.util.function.Supplier;
2731

2832
import org.tensorics.core.tensor.Position;
33+
import org.tensorics.core.tensor.Shape;
2934
import org.tensorics.core.tensor.Tensor;
3035

3136
/**
@@ -47,4 +52,16 @@ public static <T> Set<Entry<Position, T>> entrySetOf(Tensor<T> tensor) {
4752
return tensor.asMap().entrySet();
4853
}
4954

55+
public static <S> Tensor<S> sameValues(Shape shape, S value) {
56+
return new SingleValueTensorCreationOperation<S>(shape, value).perform();
57+
}
58+
59+
public static <S> Tensor<S> createFrom(Shape shape, Supplier<S> supplier) {
60+
return new FunctionTensorCreationOperation<>(shape, forSupplier(supplier)).perform();
61+
}
62+
63+
public static <S> Tensor<S> createFrom(Shape shape, Function<Position, S> function) {
64+
return new FunctionTensorCreationOperation<>(shape, function).perform();
65+
}
66+
5067
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// @formatter:off
2+
/*******************************************************************************
3+
*
4+
* This file is part of tensorics.
5+
*
6+
* Copyright (c) 2008-2011, CERN. All rights reserved.
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this file except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*
20+
******************************************************************************/
21+
// @formatter:on
22+
23+
package org.tensorics.core.tensorbacked;
24+
25+
import static com.google.common.base.Preconditions.checkNotNull;
26+
import static org.tensorics.core.tensor.lang.TensorStructurals.completeWith;
27+
import static org.tensorics.core.tensorbacked.TensorbackedInternals.createBackedByTensor;
28+
29+
import org.tensorics.core.tensor.Shape;
30+
import org.tensorics.core.tensor.Tensor;
31+
import org.tensorics.core.tensor.operations.TensorInternals;
32+
33+
public class OngoingTensorbackedCompletion<TB extends Tensorbacked<S>, S> {
34+
35+
private final TB tensorbacked;
36+
37+
OngoingTensorbackedCompletion(TB tensorbacked) {
38+
this.tensorbacked = checkNotNull(tensorbacked, "tensorbacked must not be null");
39+
}
40+
41+
public TB with(Tensor<S> second) {
42+
Tensor<S> tensor = completeWith(tensorbacked.tensor(), second);
43+
return createBackedByTensor(TensorbackedInternals.classOf(tensorbacked), tensor);
44+
}
45+
46+
public TB with(TB second) {
47+
return with(second.tensor());
48+
}
49+
50+
public TB with(Shape shape, S value) {
51+
return with(TensorInternals.sameValues(shape, value));
52+
}
53+
54+
}

0 commit comments

Comments
 (0)