Skip to content

Commit d37c3d8

Browse files
Kajetan FuchsbergerKajetan Fuchsberger
authored andcommitted
operations on interface backed objects should also work now
1 parent e396d41 commit d37c3d8

File tree

9 files changed

+91
-4
lines changed

9 files changed

+91
-4
lines changed

src/java/org/tensorics/core/tensor/lang/OngoingTensorManipulation.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.Collection;
3030
import java.util.HashSet;
3131
import java.util.List;
32+
import java.util.Map;
3233
import java.util.Set;
3334

3435
import org.tensorics.core.tensor.Coordinates;
@@ -38,8 +39,10 @@
3839
import org.tensorics.core.tensor.Positions;
3940
import org.tensorics.core.tensor.Tensor;
4041
import org.tensorics.core.tensor.operations.TensorInternals;
42+
import org.tensorics.core.tensor.stream.TensorStreams;
4143

4244
import com.google.common.collect.ImmutableList;
45+
import com.google.common.collect.ImmutableMap;
4346

4447
/**
4548
* Part of the tensoric fluent API which provides methods to describe misc manipulations on a given tensor.
@@ -96,17 +99,27 @@ public V get(Object... coordinates) {
9699
return tensor.get(coordinates);
97100
}
98101

99-
public <C> List<V> list(List<C> listCoordinateValues, Position remainingCoordinates) {
102+
public <C> List<V> list(List<C> listCoordinateValues, Position otherCoordinates) {
100103
return listCoordinateValues.stream().map(Position::of)//
101-
.map(p -> Positions.union(remainingCoordinates, p)) //
104+
.map(p -> Positions.union(otherCoordinates, p)) //
102105
.map(p -> get(p))//
103106
.collect(toImmutableList());
104107
}
105108

106-
public <C> List<V> list(List<C> listCoordinateValues, Object ...otherCoordinates) {
109+
public <C> List<V> list(List<C> listCoordinateValues, Object... otherCoordinates) {
107110
return list(listCoordinateValues, Position.of(otherCoordinates));
108111
}
109112

113+
public <C> Map<C, V> map(Class<C> mapKeyType, Position otherCoordinates) {
114+
Tensor<V> remaining = extract(otherCoordinates);
115+
return TensorStreams.tensorEntryStream(remaining)
116+
.collect(ImmutableMap.toImmutableMap(e -> e.getKey().coordinateFor(mapKeyType), e-> e.getValue())); //
117+
}
118+
119+
public <C> Map<C, V> map(Class<C> mapKeyType, Object... otherCoordinates) {
120+
return map(mapKeyType, Position.of(otherCoordinates));
121+
}
122+
110123
public Tensor<V> extract(Position position) {
111124
return extractTensor(position.coordinates());
112125
}

src/java/org/tensorics/core/tensorbacked/ProxiedInterfaceTensorbackeds.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.lang.reflect.Method;
1212
import java.lang.reflect.Proxy;
1313
import java.util.Objects;
14+
import java.util.Optional;
1415

1516
import org.tensorics.core.tensor.Tensor;
1617
import org.tensorics.core.util.JavaVersions;
@@ -30,6 +31,19 @@ public static <V, T extends Tensorbacked<V>> T create(Class<T> tensorbackedType,
3031
new DelegatingInvocationHandler<>(tensor, tensorbackedType));
3132
}
3233

34+
public static < T extends Tensorbacked<?>> Optional<Class<T>> tensorbackedInterfaceFrom(T object) {
35+
if (!(object instanceof Proxy)) {
36+
return Optional.empty();
37+
}
38+
39+
InvocationHandler handler = Proxy.getInvocationHandler(object);
40+
if (!(handler instanceof DelegatingInvocationHandler)) {
41+
return Optional.empty();
42+
}
43+
44+
return Optional.of((DelegatingInvocationHandler<?, T>) handler).map(h -> h.intfc);
45+
}
46+
3347
private final static class DelegatingInvocationHandler<V, T extends Tensorbacked<V>> implements InvocationHandler {
3448
private final Tensor<V> delegate;
3549
private final Class<T> intfc;

src/java/org/tensorics/core/tensorbacked/TensorbackedInternals.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import java.util.ArrayList;
3030
import java.util.List;
31+
import java.util.Optional;
3132
import java.util.Set;
3233

3334
import org.tensorics.core.lang.Tensorics;
@@ -58,7 +59,7 @@ private TensorbackedInternals() {
5859
* {@link Dimensions} annotation.
5960
*
6061
* @param tensorBackedClass the class for which to determine the dimensions
61-
* @return the set of dimentions (classses of coordinates) which are required to create an instance of the given
62+
* @return the set of dimensions (classses of coordinates) which are required to create an instance of the given
6263
* class.
6364
*/
6465
public static <T extends Tensorbacked<?>> Set<Class<?>> dimensionsOf(Class<T> tensorBackedClass) {
@@ -123,6 +124,10 @@ public static final <TB extends Tensorbacked<?>> Iterable<Shape> shapesOf(Iterab
123124

124125
@SuppressWarnings("unchecked")
125126
public static final <TB extends Tensorbacked<?>> Class<TB> classOf(TB tensorBacked) {
127+
Optional<Class<TB>> proxiedInterface = ProxiedInterfaceTensorbackeds.tensorbackedInterfaceFrom(tensorBacked);
128+
if (proxiedInterface.isPresent()) {
129+
return proxiedInterface.get();
130+
}
126131
return (Class<TB>) tensorBacked.getClass();
127132
}
128133

src/java/org/tensorics/core/tensorbacked/dimtyped/Tensorbacked1d.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,9 @@ public interface Tensorbacked1d<C1, V> extends DimtypedTensorbacked<V> {
55
default V get(C1 c1) {
66
return tensor().get(c1);
77
}
8+
9+
default boolean contains(C1 c1) {
10+
return tensor().contains(c1);
11+
}
812

913
}

src/java/org/tensorics/core/tensorbacked/dimtyped/Tensorbacked2d.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,8 @@ default V get(C1 c1, C2 c2) {
66
return tensor().get(c1, c2);
77
}
88

9+
10+
default boolean contains(C1 c1, C2 c2) {
11+
return tensor().contains(c1, c2);
12+
}
913
}

src/java/org/tensorics/core/tensorbacked/dimtyped/Tensorbacked3d.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,8 @@ default V get(C1 c1, C2 c2, C3 c3) {
66
return tensor().get(c1, c2, c3);
77
}
88

9+
default boolean contains(C1 c1, C2 c2, C3 c3) {
10+
return tensor().contains(c1, c2, c3);
11+
}
12+
913
}

src/java/org/tensorics/core/tensorbacked/dimtyped/Tensorbacked4d.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,8 @@ default V get(C1 c1, C2 c2, C3 c3, C4 c4) {
66
return tensor().get(c1, c2, c3, c4);
77
}
88

9+
default boolean contains(C1 c1, C2 c2, C3 c3, C4 c4) {
10+
return tensor().contains(c1, c2, c3, c4);
11+
}
12+
913
}

src/java/org/tensorics/core/tensorbacked/dimtyped/Tensorbacked5d.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,10 @@ public interface Tensorbacked5d<C1, C2, C3, C4, C5, V> extends DimtypedTensorbac
55
default V get(C1 c1, C2 c2, C3 c3, C4 c4, C5 c5) {
66
return tensor().get(c1, c2, c3, c4, c5);
77
}
8+
9+
default boolean contains(C1 c1, C2 c2, C3 c3, C4 c4, C5 c5) {
10+
return tensor().contains(c1, c2, c3, c4, c5);
11+
}
12+
813

914
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package org.tensorics.core.tensorbacked.dimtyped;
2+
3+
import org.assertj.core.api.Assertions;
4+
import org.junit.Rule;
5+
import org.junit.Test;
6+
import org.junit.rules.ExpectedException;
7+
import org.tensorics.core.lang.TensoricDoubles;
8+
import org.tensorics.core.lang.Tensorics;
9+
10+
import java.util.NoSuchElementException;
11+
import java.util.Set;
12+
13+
import static org.assertj.core.api.Assertions.assertThat;
14+
15+
public class DimtypedOperationsTest {
16+
17+
@Rule
18+
public ExpectedException thrown = ExpectedException.none();
19+
20+
21+
@Test
22+
public void tensorbacked1dElementMultiplication() {
23+
ADoubleBackedVector vector = Tensorics.builderFor1D(ADoubleBackedVector.class).put("a", 1.2).build();
24+
25+
ADoubleBackedVector newVector = TensoricDoubles.calculate(vector).elementTimesV(2.0);
26+
Assertions.assertThat(newVector.get("a")).isEqualTo(2.4);
27+
}
28+
29+
30+
public interface ADoubleBackedVector extends Tensorbacked1d<String, Double> {
31+
32+
}
33+
34+
}

0 commit comments

Comments
 (0)