From 38c998f5f35d262ee3a26b6adcafb66edda81400 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Wed, 10 Sep 2025 14:30:15 -0400 Subject: [PATCH 1/5] Minimal concatenated input example --- example/concatenate.f90 | 63 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 example/concatenate.f90 diff --git a/example/concatenate.f90 b/example/concatenate.f90 new file mode 100644 index 00000000..2bf492c9 --- /dev/null +++ b/example/concatenate.f90 @@ -0,0 +1,63 @@ +program concatenate + use nf, only: dense, input, network, sgd + implicit none + + type(network) :: net1, net2, net3 + real, allocatable :: x1(:), y1(:) + real, allocatable :: x2(:), y2(:) + real, allocatable :: x3(:), y3(:) + integer, parameter :: num_iterations = 500 + integer :: n + + ! Network 1 + net1 = network([ & + input(3), & + dense(2) & + ]) + + x1 = [0.2, 0.4, 0.6] + y1 = [0.123456, 0.246802] + + do n = 1, num_iterations + call net1 % forward(x1) + call net1 % backward(y1) + call net1 % update(optimizer=sgd(learning_rate=1.)) + end do + + print *, "net1 output: ", net1 % predict(x1) + + ! Network 2 + net2 = network([ & + input(3), & + dense(3) & + ]) + + x2 = [0.7, 0.5, 0.3] + y2 = [0.369258, 0.482604, 0.505050] + + do n = 1, num_iterations + call net2 % forward(x2) + call net2 % backward(y2) + call net2 % update(optimizer=sgd(learning_rate=1.)) + end do + + print *, "net2 output: ", net2 % predict(x2) + + ! Network 3 + net3 = network([ & + input(size(net1 % predict(x1)) + size(net2 % predict(x2))), & + dense(5) & + ]) + + x3 = [net1 % predict(x1), net2 % predict(x2)] + y3 = [0.111111, 0.222222, 0.333333, 0.444444, 0.555555] + + do n = 1, num_iterations + call net3 % forward(x3) + call net3 % backward(y3) + call net3 % update(optimizer=sgd(learning_rate=1.)) + end do + + print *, "net3 output: ", net3 % predict(x3) + +end program concatenate \ No newline at end of file From 165a6c40126487d91dd09022b910735961553f99 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 11 Sep 2025 11:37:16 -0400 Subject: [PATCH 2/5] Update example of merging 2 networks to feed into a 3rd network --- example/concatenate.f90 | 63 ------------------- example/merge_networks.f90 | 122 +++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 63 deletions(-) delete mode 100644 example/concatenate.f90 create mode 100644 example/merge_networks.f90 diff --git a/example/concatenate.f90 b/example/concatenate.f90 deleted file mode 100644 index 2bf492c9..00000000 --- a/example/concatenate.f90 +++ /dev/null @@ -1,63 +0,0 @@ -program concatenate - use nf, only: dense, input, network, sgd - implicit none - - type(network) :: net1, net2, net3 - real, allocatable :: x1(:), y1(:) - real, allocatable :: x2(:), y2(:) - real, allocatable :: x3(:), y3(:) - integer, parameter :: num_iterations = 500 - integer :: n - - ! Network 1 - net1 = network([ & - input(3), & - dense(2) & - ]) - - x1 = [0.2, 0.4, 0.6] - y1 = [0.123456, 0.246802] - - do n = 1, num_iterations - call net1 % forward(x1) - call net1 % backward(y1) - call net1 % update(optimizer=sgd(learning_rate=1.)) - end do - - print *, "net1 output: ", net1 % predict(x1) - - ! Network 2 - net2 = network([ & - input(3), & - dense(3) & - ]) - - x2 = [0.7, 0.5, 0.3] - y2 = [0.369258, 0.482604, 0.505050] - - do n = 1, num_iterations - call net2 % forward(x2) - call net2 % backward(y2) - call net2 % update(optimizer=sgd(learning_rate=1.)) - end do - - print *, "net2 output: ", net2 % predict(x2) - - ! Network 3 - net3 = network([ & - input(size(net1 % predict(x1)) + size(net2 % predict(x2))), & - dense(5) & - ]) - - x3 = [net1 % predict(x1), net2 % predict(x2)] - y3 = [0.111111, 0.222222, 0.333333, 0.444444, 0.555555] - - do n = 1, num_iterations - call net3 % forward(x3) - call net3 % backward(y3) - call net3 % update(optimizer=sgd(learning_rate=1.)) - end do - - print *, "net3 output: ", net3 % predict(x3) - -end program concatenate \ No newline at end of file diff --git a/example/merge_networks.f90 b/example/merge_networks.f90 new file mode 100644 index 00000000..f69283c3 --- /dev/null +++ b/example/merge_networks.f90 @@ -0,0 +1,122 @@ +program merge_networks + use nf, only: dense, input, network, sgd + use nf_dense_layer, only: dense_layer + implicit none + + type(network) :: net1, net2, net3 + real, allocatable :: x1(:), x2(:) + real, allocatable :: y1(:), y2(:) + real, allocatable :: y(:) + integer, parameter :: num_iterations = 500 + integer :: n, nn + integer :: net1_output_size, net2_output_size + + x1 = [0.1, 0.3, 0.5] + x2 = [0.2, 0.4] + y = [0.123456, 0.246802, 0.369258, 0.482604, 0.505050, 0.628406, 0.741852] + + net1 = network([ & + input(3), & + dense(2), & + dense(3), & + dense(2) & + ]) + + net2 = network([ & + input(2), & + dense(5), & + dense(3) & + ]) + + net1_output_size = product(net1 % layers(size(net1 % layers)) % layer_shape) + net2_output_size = product(net2 % layers(size(net2 % layers)) % layer_shape) + + ! Network 3 + net3 = network([ & + input(net1_output_size + net2_output_size), & + dense(7) & + ]) + + do n = 1, num_iterations + + ! Forward propagate two network branches + call net1 % forward(x1) + call net2 % forward(x2) + + ! Get outputs of net1 and net2, concatenate, and pass to net3 + ! A helper function could be made to take any number of networks + ! and return the concatenated output. Such function would turn the following + ! block into a one-liner. + select type (net1_output_layer => net1 % layers(size(net1 % layers)) % p) + type is (dense_layer) + y1 = net1_output_layer % output + end select + + select type (net2_output_layer => net2 % layers(size(net2 % layers)) % p) + type is (dense_layer) + y2 = net2_output_layer % output + end select + + call net3 % forward([y1, y2]) + + ! Compute the gradients on the 3rd network + call net3 % backward(y) + + ! net3 % update() will clear the gradients immediately after updating + ! the weights, so we need to pass the gradients to net1 and net2 first + + ! For net1 and net2, we can't use the existing net % backward() because + ! it currently assumes that the output layer gradients are computed based + ! on the loss function and not the gradient from the next layer. + ! For now, we need to manually pass the gradient from the first hidden layer + ! of net3 to the output layers of net1 and net2. + select type (next_layer => net3 % layers(2) % p) + ! Assume net3's first hidden layer is dense; + ! would need to be generalized to others. + type is (dense_layer) + + nn = size(net1 % layers) + call net1 % layers(nn) % backward( & + net1 % layers(nn - 1), next_layer % gradient(1:net1_output_size) & + ) + + nn = size(net2 % layers) + call net2 % layers(nn) % backward( & + net2 % layers(nn - 1), next_layer % gradient(net1_output_size+1:size(next_layer % gradient)) & + ) + + end select + + ! Compute the gradients on hidden layers of net1, if any + do nn = size(net1 % layers)-1, 2, -1 + select type (next_layer => net1 % layers(nn + 1) % p) + type is (dense_layer) + call net1 % layers(nn) % backward( & + net1 % layers(nn - 1), next_layer % gradient & + ) + end select + end do + + ! Compute the gradients on hidden layers of net2, if any + do nn = size(net2 % layers)-1, 2, -1 + select type (next_layer => net2 % layers(nn + 1) % p) + type is (dense_layer) + call net2 % layers(nn) % backward( & + net2 % layers(nn - 1), next_layer % gradient & + ) + end select + end do + + ! Gradients are now computed on all networks and we can update the weights + call net1 % update(optimizer=sgd(learning_rate=1.)) + call net2 % update(optimizer=sgd(learning_rate=1.)) + call net3 % update(optimizer=sgd(learning_rate=1.)) + + if (mod(n, 50) == 0) then + print *, "Iteration ", n, ", output RMSE = ", & + sqrt(sum((net3 % predict([net1 % predict(x1), net2 % predict(x2)]) - y)**2) / size(y)) + end if + + end do + +end program merge_networks \ No newline at end of file From f6767805b4d9a781bdb93adc47a3d0b0bea411b4 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Mon, 15 Sep 2025 14:06:41 -0400 Subject: [PATCH 3/5] Allow passing gradient to network % backward() to bypass loss function --- example/merge_networks.f90 | 48 ++----------- src/nf/nf_network.f90 | 8 ++- src/nf/nf_network_submodule.f90 | 118 ++++++++++++++++++-------------- 3 files changed, 78 insertions(+), 96 deletions(-) diff --git a/example/merge_networks.f90 b/example/merge_networks.f90 index f69283c3..c1a55008 100644 --- a/example/merge_networks.f90 +++ b/example/merge_networks.f90 @@ -34,7 +34,7 @@ program merge_networks ! Network 3 net3 = network([ & input(net1_output_size + net2_output_size), & - dense(7) & + dense(7) & ]) do n = 1, num_iterations @@ -59,54 +59,16 @@ program merge_networks call net3 % forward([y1, y2]) - ! Compute the gradients on the 3rd network + ! First compute the gradients on net3, then pass the gradients from the first + ! hidden layer on net3 to net1 and net2, and compute their gradients. call net3 % backward(y) - ! net3 % update() will clear the gradients immediately after updating - ! the weights, so we need to pass the gradients to net1 and net2 first - - ! For net1 and net2, we can't use the existing net % backward() because - ! it currently assumes that the output layer gradients are computed based - ! on the loss function and not the gradient from the next layer. - ! For now, we need to manually pass the gradient from the first hidden layer - ! of net3 to the output layers of net1 and net2. select type (next_layer => net3 % layers(2) % p) - ! Assume net3's first hidden layer is dense; - ! would need to be generalized to others. type is (dense_layer) - - nn = size(net1 % layers) - call net1 % layers(nn) % backward( & - net1 % layers(nn - 1), next_layer % gradient(1:net1_output_size) & - ) - - nn = size(net2 % layers) - call net2 % layers(nn) % backward( & - net2 % layers(nn - 1), next_layer % gradient(net1_output_size+1:size(next_layer % gradient)) & - ) - + call net1 % backward(y, gradient=next_layer % gradient(1:net1_output_size)) + call net2 % backward(y, gradient=next_layer % gradient(net1_output_size+1:size(next_layer % gradient))) end select - ! Compute the gradients on hidden layers of net1, if any - do nn = size(net1 % layers)-1, 2, -1 - select type (next_layer => net1 % layers(nn + 1) % p) - type is (dense_layer) - call net1 % layers(nn) % backward( & - net1 % layers(nn - 1), next_layer % gradient & - ) - end select - end do - - ! Compute the gradients on hidden layers of net2, if any - do nn = size(net2 % layers)-1, 2, -1 - select type (next_layer => net2 % layers(nn + 1) % p) - type is (dense_layer) - call net2 % layers(nn) % backward( & - net2 % layers(nn - 1), next_layer % gradient & - ) - end select - end do - ! Gradients are now computed on all networks and we can update the weights call net1 % update(optimizer=sgd(learning_rate=1.)) call net2 % update(optimizer=sgd(learning_rate=1.)) diff --git a/src/nf/nf_network.f90 b/src/nf/nf_network.f90 index 2743ff5b..cf335aa6 100644 --- a/src/nf/nf_network.f90 +++ b/src/nf/nf_network.f90 @@ -195,7 +195,7 @@ end function predict_batch_3d interface - module subroutine backward(self, output, loss) + module subroutine backward(self, output, loss, gradient) !! Apply one backward pass through the network. !! This changes the state of layers on the network. !! Typically used only internally from the `train` method, @@ -206,6 +206,12 @@ module subroutine backward(self, output, loss) !! Output data class(loss_type), intent(in), optional :: loss !! Loss instance to use. If not provided, the default is quadratic(). + real, intent(in), optional :: gradient(:) + !! Gradient to use for the output layer. + !! If not provided, the gradient in the last layer is computed using + !! the loss function. + !! Passing the gradient is useful for merging/concatenating multiple + !! networks. end subroutine backward module integer function get_num_params(self) diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index df95963a..2d03c8e7 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -115,10 +115,11 @@ module function network_from_layers(layers) result(res) end function network_from_layers - module subroutine backward(self, output, loss) + module subroutine backward(self, output, loss, gradient) class(network), intent(in out) :: self real, intent(in) :: output(:) class(loss_type), intent(in), optional :: loss + real, intent(in), optional :: gradient(:) integer :: n, num_layers ! Passing the loss instance is optional. If not provided, and if the @@ -140,58 +141,71 @@ module subroutine backward(self, output, loss) ! Iterate backward over layers, from the output layer ! to the first non-input layer - do n = num_layers, 2, -1 - - if (n == num_layers) then - ! Output layer; apply the loss function - select type(this_layer => self % layers(n) % p) - type is(dense_layer) - call self % layers(n) % backward( & - self % layers(n - 1), & - self % loss % derivative(output, this_layer % output) & - ) - type is(flatten_layer) - call self % layers(n) % backward( & - self % layers(n - 1), & - self % loss % derivative(output, this_layer % output) & - ) - end select - else - ! Hidden layer; take the gradient from the next layer - select type(next_layer => self % layers(n + 1) % p) - type is(dense_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(dropout_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(conv2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(flatten_layer) - if (size(self % layers(n) % layer_shape) == 2) then - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d) - else - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d) - end if - type is(maxpool2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(reshape3d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(linear2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(self_attention_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(maxpool1d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(reshape2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(conv1d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(locally_connected2d_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - type is(layernorm_layer) - call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) - end select - end if + ! Output layer first + n = num_layers + if (present(gradient)) then + + ! If the gradient is passed, use it directly for the output layer + select type(this_layer => self % layers(n) % p) + type is(dense_layer) + call self % layers(n) % backward(self % layers(n - 1), gradient) + type is(flatten_layer) + call self % layers(n) % backward(self % layers(n - 1), gradient) + end select + + else + + ! Apply the loss function + select type(this_layer => self % layers(n) % p) + type is(dense_layer) + call self % layers(n) % backward( & + self % layers(n - 1), & + self % loss % derivative(output, this_layer % output) & + ) + type is(flatten_layer) + call self % layers(n) % backward( & + self % layers(n - 1), & + self % loss % derivative(output, this_layer % output) & + ) + end select + + end if + + ! Hidden layers; take the gradient from the next layer + do n = num_layers - 1, 2, -1 + select type(next_layer => self % layers(n + 1) % p) + type is(dense_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(dropout_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(conv2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(flatten_layer) + if (size(self % layers(n) % layer_shape) == 2) then + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d) + else + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d) + end if + type is(maxpool2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(reshape3d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(linear2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(self_attention_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(maxpool1d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(reshape2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(conv1d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(locally_connected2d_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + type is(layernorm_layer) + call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient) + end select end do end subroutine backward From 4aea615272916d3ab72355bb59ef8e2254d76000 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Thu, 25 Sep 2025 12:08:09 -0400 Subject: [PATCH 4/5] Add network % get_output() subroutine that returns a pointer to the outputs --- example/merge_networks.f90 | 17 +++-------------- src/nf/nf_network.f90 | 18 ++++++++++++++---- src/nf/nf_network_submodule.f90 | 22 ++++++++++++++++++++++ 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/example/merge_networks.f90 b/example/merge_networks.f90 index c1a55008..590deb26 100644 --- a/example/merge_networks.f90 +++ b/example/merge_networks.f90 @@ -5,7 +5,7 @@ program merge_networks type(network) :: net1, net2, net3 real, allocatable :: x1(:), x2(:) - real, allocatable :: y1(:), y2(:) + real, pointer :: y1(:), y2(:) real, allocatable :: y(:) integer, parameter :: num_iterations = 500 integer :: n, nn @@ -44,19 +44,8 @@ program merge_networks call net2 % forward(x2) ! Get outputs of net1 and net2, concatenate, and pass to net3 - ! A helper function could be made to take any number of networks - ! and return the concatenated output. Such function would turn the following - ! block into a one-liner. - select type (net1_output_layer => net1 % layers(size(net1 % layers)) % p) - type is (dense_layer) - y1 = net1_output_layer % output - end select - - select type (net2_output_layer => net2 % layers(size(net2 % layers)) % p) - type is (dense_layer) - y2 = net2_output_layer % output - end select - + call net1 % get_output(y1) + call net2 % get_output(y2) call net3 % forward([y1, y2]) ! First compute the gradients on net3, then pass the gradients from the first diff --git a/src/nf/nf_network.f90 b/src/nf/nf_network.f90 index cf335aa6..3cfeb521 100644 --- a/src/nf/nf_network.f90 +++ b/src/nf/nf_network.f90 @@ -33,6 +33,7 @@ module nf_network procedure, private :: forward_1d_int procedure, private :: forward_2d procedure, private :: forward_3d + procedure, private :: get_output_1d procedure, private :: predict_1d procedure, private :: predict_1d_int procedure, private :: predict_2d @@ -42,6 +43,7 @@ module nf_network generic :: evaluate => evaluate_batch_1d generic :: forward => forward_1d, forward_1d_int, forward_2d, forward_3d + generic :: get_output => get_output_1d generic :: predict => predict_1d, predict_1d_int, predict_2d, predict_3d generic :: predict_batch => predict_batch_1d, predict_batch_3d @@ -131,7 +133,7 @@ end subroutine forward_3d end interface forward - interface output + interface predict module function predict_1d(self, input) result(res) !! Return the output of the network given the input 1-d array. @@ -169,9 +171,10 @@ module function predict_3d(self, input) result(res) real, allocatable :: res(:) !! Output of the network end function predict_3d - end interface output - interface output_batch + end interface predict + + interface predict_batch module function predict_batch_1d(self, input) result(res) !! Return the output of the network given an input batch of 3-d data. class(network), intent(in out) :: self @@ -191,7 +194,14 @@ module function predict_batch_3d(self, input) result(res) real, allocatable :: res(:,:) !! Output of the network; the last dimension is the batch end function predict_batch_3d - end interface output_batch + end interface predict_batch + + interface get_output + module subroutine get_output_1d(self, output) + class(network), intent(in), target :: self + real, pointer, intent(out) :: output(:) + end subroutine get_output_1d + end interface get_output interface diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 2d03c8e7..e44b016b 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -511,6 +511,28 @@ module subroutine print_info(self) end subroutine print_info + module subroutine get_output_1d(self, output) + class(network), intent(in), target :: self + real, pointer, intent(out) :: output(:) + integer :: last + + last = size(self % layers) + + select type(output_layer => self % layers(last) % p) + type is(dense_layer) + output => output_layer % output + type is(dropout_layer) + output => output_layer % output + type is(flatten_layer) + output => output_layer % output + class default + error stop 'network % get_output not implemented for ' // & + trim(self % layers(last) % name) // ' layer' + end select + + end subroutine get_output_1d + + module function get_num_params(self) class(network), intent(in) :: self integer :: get_num_params From f0d5ca24e3d93db118b2a69608755f8b856b58d0 Mon Sep 17 00:00:00 2001 From: milancurcic Date: Mon, 13 Oct 2025 13:06:45 -0400 Subject: [PATCH 5/5] Allow getting output pointer for all layers --- src/nf/nf_network_submodule.f90 | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index e44b016b..13df77c0 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -519,12 +519,26 @@ module subroutine get_output_1d(self, output) last = size(self % layers) select type(output_layer => self % layers(last) % p) - type is(dense_layer) + type is (conv1d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is(conv2d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (dense_layer) output => output_layer % output - type is(dropout_layer) + type is (dropout_layer) output => output_layer % output - type is(flatten_layer) + type is (flatten_layer) output => output_layer % output + type is (layernorm_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (linear2d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (locally_connected2d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (maxpool1d_layer) + output(1:size(output_layer % output)) => output_layer % output + type is (maxpool2d_layer) + output(1:size(output_layer % output)) => output_layer % output class default error stop 'network % get_output not implemented for ' // & trim(self % layers(last) % name) // ' layer'