@@ -9,8 +9,26 @@ Base.axes(x::ComponentArray) = axes(getdata(x))
99
1010Base. reinterpret (:: Type{T} , x:: ComponentArray , args... ) where T = ComponentArray (reinterpret (T, getdata (x), args... ), getaxes (x))
1111
12- Base. hcat (x:: CV... ) where {CV<: ComponentVector } = ComponentArray (reduce (hcat, getdata .(x)), getaxes (x[1 ])[1 ], FlatAxis ())
12+ # Cats
13+ # TODO : Make this a little less copy-pastey
14+ function Base. hcat (x:: AbstractComponentVecOrMat , y:: AbstractComponentVecOrMat )
15+ ax_x, ax_y = second_axis .((x,y))
16+ if reduce ((accum, key) -> accum || (key in keys (ax_x)), keys (ax_y); init= false ) || getaxes (x)[1 ] != getaxes (y)[1 ]
17+ return hcat (getdata (x), getdata (y))
18+ else
19+ data_x, data_y = getdata .((x, y))
20+ ax_y = reindex (ax_y, size (x,2 ))
21+ idxmap_x, idxmap_y = indexmap .((ax_x, ax_y))
22+ axs = getaxes (x)
23+ return ComponentArray (hcat (data_x, data_y), axs[1 ], Axis ((;idxmap_x... , idxmap_y... )), axs[3 : end ]. .. )
24+ end
25+ end
1326
27+ second_axis (ca:: AbstractComponentVecOrMat ) = getaxes (ca)[2 ]
28+ second_axis (:: ComponentVector ) = FlatAxis ()
29+
30+ # Are all these methods necessary?
31+ # TODO : See what we can reduce down to without getting ambiguity errors
1432Base. vcat (x:: ComponentVector , y:: AbstractVector ) = vcat (getdata (x), y)
1533Base. vcat (x:: AbstractVector , y:: ComponentVector ) = vcat (x, getdata (y))
1634function Base. vcat (x:: ComponentVector , y:: ComponentVector )
@@ -24,11 +42,33 @@ function Base.vcat(x::ComponentVector, y::ComponentVector)
2442 return ComponentArray (vcat (data_x, data_y), Axis ((;idxmap_x... , idxmap_y... )))
2543 end
2644end
45+ function Base. vcat (x:: AbstractComponentVecOrMat , y:: AbstractComponentVecOrMat )
46+ ax_x, ax_y = getindex .(getaxes .((x, y)), 1 )
47+ if reduce ((accum, key) -> accum || (key in keys (ax_x)), keys (ax_y); init= false ) || getaxes (x)[2 : end ] != getaxes (y)[2 : end ]
48+ return vcat (getdata (x), getdata (y))
49+ else
50+ data_x, data_y = getdata .((x, y))
51+ ax_y = reindex (ax_y, size (x,1 ))
52+ idxmap_x, idxmap_y = indexmap .((ax_x, ax_y))
53+ return ComponentArray (vcat (data_x, data_y), Axis ((;idxmap_x... , idxmap_y... )), getaxes (x)[2 : end ]. .. )
54+ end
55+ end
2756Base. vcat (x:: CV... ) where {CV<: AdjOrTransComponentArray } = ComponentArray (reduce (vcat, map (y-> getdata (y. parent)' , x)), getaxes (x[1 ]))
28- Base. vcat (x:: ComponentVector... ) = reduce (vcat, x)
2957Base. vcat (x:: ComponentVector , args... ) = vcat (getdata (x), getdata .(args)... )
3058Base. vcat (x:: ComponentVector , args:: Vararg{AbstractVector{T}, N} ) where {T,N} = vcat (getdata (x), getdata .(args)... )
3159
60+ function Base. hvcat (row_lengths:: Tuple{Vararg{Int}} , xs:: AbstractComponentVecOrMat... )
61+ i = 1
62+ idxs = UnitRange{Int}[]
63+ for row_length in row_lengths
64+ i_last = i + row_length - 1
65+ push! (idxs, i: i_last)
66+ i = i_last + 1
67+ end
68+ rows = [reduce (hcat, xs[idx]) for idx in idxs]
69+ return vcat (rows... )
70+ end
71+
3272function Base. permutedims (x:: ComponentArray , dims)
3373 axs = getaxes (x)
3474 return ComponentArray (permutedims (getdata (x), dims), map (i-> axs[i], dims)... )
@@ -59,14 +99,14 @@ Base.@propagate_inbounds function Base.getindex(x::ComponentArray, idx::FlatOrCo
5999 return ComponentArray (getdata (x)[idx... ], axs... )
60100end
61101Base. @propagate_inbounds Base. getindex (x:: ComponentArray , :: Colon ) = getdata (x)[:]
62- @inline Base. getindex (x:: ComponentArray , :: Colon... ) = x
63- Base . @propagate_inbounds Base. getindex (x:: ComponentArray , idx... ) = getindex (x, toval .(idx)... )
102+ Base . @propagate_inbounds Base. getindex (x:: ComponentArray , :: Colon... ) = x
103+ @inline Base. getindex (x:: ComponentArray , idx... ) = getindex (x, toval .(idx)... )
64104@inline Base. getindex (x:: ComponentArray , idx:: Val... ) = _getindex (x, idx... )
65105
66106# Set ComponentArray index
67- @inline Base. setindex! (x:: ComponentArray , v, idx:: FlatIdx ... ) = setindex! (getdata (x), v, idx... )
107+ Base . @propagate_inbounds Base. setindex! (x:: ComponentArray , v, idx:: FlatOrColonIdx ... ) = setindex! (getdata (x), v, idx... )
68108Base. @propagate_inbounds Base. setindex! (x:: ComponentArray , v, :: Colon ) = setindex! (getdata (x), v, :)
69- Base . @propagate_inbounds Base. setindex! (x:: ComponentArray , v, idx... ) = setindex! (x, v, toval .(idx)... )
109+ @inline Base. setindex! (x:: ComponentArray , v, idx... ) = setindex! (x, v, toval .(idx)... )
70110@inline Base. setindex! (x:: ComponentArray , v, idx:: Val... ) = _setindex! (x, v, idx... )
71111
72112# Explicitly view
@@ -108,6 +148,56 @@ ArrayInterface.lu_instance(jac_prototype::ComponentArray) = ArrayInterface.lu_in
108148
109149ArrayInterface. parent_type (:: Type{ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = A
110150
151+ for f in [* , \ , / ]
152+ op = nameof (f)
153+ @eval begin
154+ function Base. $op (A:: ComponentVecOrMat , B:: ComponentMatrix )
155+ C = $ op (getdata (A), getdata (B))
156+ ax1 = getaxes (A)[1 ]
157+ ax2 = getaxes (B)[2 ]
158+ return ComponentArray (C, (ax1, ax2))
159+ end
160+ function Base. $op (A:: ComponentVecOrMat , b:: ComponentVector )
161+ c = $ op (getdata (A), getdata (b))
162+ ax1 = getaxes (A)[1 ]
163+ return ComponentArray (c, ax1)
164+ end
165+ function Base. $op (Aᵀ:: ComponentMatrix{T} , B:: AdjOrTransComponentVecOrMat{T} ) where {T}
166+ Cᵀ = $ op (getdata (Aᵀ), getdata (B))
167+ ax1 = getaxes (Aᵀ)[1 ]
168+ ax2 = getaxes (B)[2 ]
169+ return ComponentArray (Cᵀ, ax1, ax2)
170+ end
171+ end
172+ for (adjfun, AdjType) in zip ([adjoint, transpose], [Adjoint, Transpose])
173+ adj = nameof (adjfun)
174+ Adj = nameof (AdjType)
175+ @eval begin
176+ function Base. $op (aᵀ:: $Adj{T,<:ComponentVector} , B:: ComponentMatrix ) where {T}
177+ cᵀ = parent ($ op (getdata (aᵀ), getdata (B)))
178+ ax2 = getaxes (B)[2 ]
179+ return $ adj (ComponentArray (cᵀ, ax2))
180+ end
181+ function Base. $op (Aᵀ:: $Adj{T,<:ComponentMatrix} , B:: AdjOrTransComponentVecOrMat ) where {T}
182+ Cᵀ = $ op (getdata (Aᵀ), getdata (B))
183+ ax1 = getaxes (Aᵀ)[1 ]
184+ ax2 = getaxes (B)[2 ]
185+ return ComponentArray (Cᵀ, ax1, ax2)
186+ end
187+ function Base. $op (Aᵀ:: $Adj{T,<:ComponentMatrix} , B:: ComponentMatrix{T} ) where {T}
188+ Cᵀ = $ op (getdata (Aᵀ), getdata (B))
189+ ax1 = getaxes (Aᵀ)[1 ]
190+ ax2 = getaxes (B)[2 ]
191+ return ComponentArray (Cᵀ, ax1, ax2)
192+ end
193+ function Base. $op (Aᵀ:: $Adj{T,<:ComponentMatrix} , b:: ComponentVector ) where {T}
194+ cᵀ = $ op (getdata (Aᵀ), getdata (b))
195+ ax1 = getaxes (Aᵀ)[1 ]
196+ return ComponentArray (cᵀ, ax1)
197+ end
198+ end
199+ end
200+ end
111201
112202
113203# While there are some cases where these were faster, it is going to be almost impossible to
0 commit comments