@@ -14,6 +14,7 @@ type batchQueryRunner struct {
1414 q Query
1515 oneToOneRels []Relationship
1616 oneToManyRels []Relationship
17+ throughRels []Relationship
1718 db squirrel.DBProxy
1819 builder squirrel.SelectBuilder
1920 total int
@@ -29,6 +30,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
2930 var (
3031 oneToOneRels []Relationship
3132 oneToManyRels []Relationship
33+ throughRels []Relationship
3234 )
3335
3436 for _ , rel := range q .getRelationships () {
@@ -37,6 +39,8 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
3739 oneToOneRels = append (oneToOneRels , rel )
3840 case OneToMany :
3941 oneToManyRels = append (oneToManyRels , rel )
42+ case Through :
43+ throughRels = append (throughRels , rel )
4044 }
4145 }
4246
@@ -46,6 +50,7 @@ func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQuer
4650 q : q ,
4751 oneToOneRels : oneToOneRels ,
4852 oneToManyRels : oneToManyRels ,
53+ throughRels : throughRels ,
4954 db : db ,
5055 builder : builder ,
5156 }
@@ -125,8 +130,14 @@ func (r *batchQueryRunner) processBatch(rows *sql.Rows) ([]Record, error) {
125130 return nil , err
126131 }
127132
133+ if len (records ) == 0 {
134+ return nil , nil
135+ }
136+
128137 var ids = make ([]interface {}, len (records ))
138+ var identType Identifier
129139 for i , r := range records {
140+ identType = r .GetID ()
130141 ids [i ] = r .GetID ().Raw ()
131142 }
132143
@@ -136,63 +147,142 @@ func (r *batchQueryRunner) processBatch(rows *sql.Rows) ([]Record, error) {
136147 return nil , err
137148 }
138149
139- for _ , r := range records {
140- err := r . SetRelationship ( rel . Field , indexedResults [ r . GetID (). Raw ()])
141- if err != nil {
142- return nil , err
143- }
150+ err = setIndexedResults ( records , rel , indexedResults )
151+ if err != nil {
152+ return nil , err
153+ }
154+ }
144155
145- // If the relationship is partial, we can not ensure the results
146- // in the field reflect the truth of the database.
147- // In this case, the parent is marked as non-writable.
148- if rel .Filter != nil {
149- r .setWritable (false )
150- }
156+ for _ , rel := range r .throughRels {
157+ indexedResults , err := r .getRecordThroughRelationships (ids , rel , identType )
158+ if err != nil {
159+ return nil , err
160+ }
161+
162+ err = setIndexedResults (records , rel , indexedResults )
163+ if err != nil {
164+ return nil , err
151165 }
152166 }
153167
154168 return records , nil
155169}
156170
171+ func setIndexedResults (records []Record , rel Relationship , indexedResults indexedRecords ) error {
172+ for _ , r := range records {
173+ err := r .SetRelationship (rel .Field , indexedResults [r .GetID ().Raw ()])
174+ if err != nil {
175+ return err
176+ }
177+
178+ // If the relationship is partial, we can not ensure the results
179+ // in the field reflect the truth of the database.
180+ // In this case, the parent is marked as non-writable.
181+ if rel .Filter != nil {
182+ r .setWritable (false )
183+ }
184+ }
185+
186+ return nil
187+ }
188+
157189type indexedRecords map [interface {}][]Record
158190
159191func (r * batchQueryRunner ) getRecordRelationships (ids []interface {}, rel Relationship ) (indexedRecords , error ) {
160192 fk , ok := r .schema .ForeignKey (rel .Field )
161193 if ! ok {
162- return nil , fmt .Errorf ("kallax: cannot find foreign key on field %s for table %s" , rel .Field , r .schema .Table ())
194+ return nil , fmt .Errorf ("kallax: cannot find foreign key on field %s of table %s" , rel .Field , r .schema .Table ())
163195 }
164196
165197 filter := In (fk , ids ... )
166198 if rel .Filter != nil {
167- And (rel .Filter , filter )
168- } else {
169- rel .Filter = filter
199+ filter = And (rel .Filter , filter )
170200 }
171201
172202 q := NewBaseQuery (rel .Schema )
173- q .Where (rel . Filter )
203+ q .Where (filter )
174204 cols , builder := q .compile ()
175205 rows , err := builder .RunWith (r .db ).Query ()
176206 if err != nil {
177207 return nil , err
178208 }
179209
210+ return indexedResultsFromRows (rows , cols , rel .Schema , fk , nil )
211+ }
212+
213+ func (r * batchQueryRunner ) getRecordThroughRelationships (ids []interface {}, rel Relationship , identType Identifier ) (indexedRecords , error ) {
214+ lfk , rfk , ok := r .schema .ForeignKeys (rel .Field )
215+ if ! ok {
216+ return nil , fmt .Errorf ("kallax: cannot find foreign keys for through relationship on field %s of table %s" , rel .Field , r .schema .Table ())
217+ }
218+
219+ filter := In (r .schema .ID (), ids ... )
220+ if rel .Filter != nil {
221+ filter = And (rel .Filter , filter )
222+ }
223+
224+ if rel .IntermediateFilter != nil {
225+ filter = And (rel .IntermediateFilter , filter )
226+ }
227+
228+ q := NewBaseQuery (rel .Schema )
229+ lschema := r .schema .WithAlias (rel .Schema .Alias ())
230+ intSchema := rel .IntermediateSchema .WithAlias (rel .Schema .Alias ())
231+ q .joinThrough (lschema , intSchema , rel .Schema , lfk , rfk )
232+ q .Where (filter )
233+ cols , builder := q .compile ()
234+ // manually add the extra column to also select the parent id
235+ builder = builder .Column (lschema .ID ().QualifiedName (lschema ))
236+ rows , err := builder .RunWith (r .db ).Query ()
237+ if err != nil {
238+ return nil , err
239+ }
240+
241+ // we need to pass a new pointer of the parent identifier type so the
242+ // resultset can fill it and we can know to which record it belongs when
243+ // indexing by parent id.
244+ return indexedResultsFromRows (rows , cols , rel .Schema , rfk , identType .newPtr ())
245+ }
246+
247+ // indexedResultsFromRows returns the results in the given rows indexed by the
248+ // parent id. In the case of many to many relationships, the record odes not
249+ // have a specific field with the ID of the parent to index by it,
250+ // that's why parentIDPtr is passed for these cases. parentIDPtr is a pointer
251+ // to an ID of the type required by the parent to be filled by the result set.
252+ func indexedResultsFromRows (rows * sql.Rows , cols []string , schema Schema , fk SchemaField , parentIDPtr interface {}) (indexedRecords , error ) {
180253 relRs := NewResultSet (rows , false , nil , cols ... )
181254 var indexedResults = make (indexedRecords )
182255 for relRs .Next () {
183- rec , err := relRs .Get (rel .Schema )
184- if err != nil {
185- return nil , err
256+ var (
257+ rec Record
258+ err error
259+ )
260+
261+ if parentIDPtr != nil {
262+ rec , err = relRs .customGet (schema , parentIDPtr )
263+ } else {
264+ rec , err = relRs .Get (schema )
186265 }
187266
188- val , err := rec .Value (fk .String ())
189267 if err != nil {
190268 return nil , err
191269 }
192270
193271 rec .setPersisted ()
194272 rec .setWritable (true )
195- id := val .(Identifier ).Raw ()
273+
274+ var id interface {}
275+ if parentIDPtr != nil {
276+ id = parentIDPtr .(Identifier ).Raw ()
277+ } else {
278+ val , err := rec .Value (fk .String ())
279+ if err != nil {
280+ return nil , err
281+ }
282+
283+ id = val .(Identifier ).Raw ()
284+ }
285+
196286 indexedResults [id ] = append (indexedResults [id ], rec )
197287 }
198288
0 commit comments