diff --git a/scripts/builtin/independentSubnetTrain.dml b/scripts/builtin/independentSubnetTrain.dml new file mode 100644 index 00000000000..2009ab05157 --- /dev/null +++ b/scripts/builtin/independentSubnetTrain.dml @@ -0,0 +1,504 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +# Independent Subnet Training (IST) +# +# This builtin implements independent subnet training as a +# second-order function. It orchestrates distributed / parallel +# training over disjoint subnets using parfor, while delegating +# architecture-specific logic to user-provided functions. +# ------------------------------------------------------------ +# INPUT: +# model : initial model parameters. A list of matrices for NN. +# features : X +# labels : Y +# val_features : validation X +# val_labels : validation Y +# upd : computes gradients and performs optimizer step +# agg : aggregation logic to combine updates of shared parameters (across subnets) +# epochs : number of epochs +# batchsize : batchsize for training +# j : number of gradient steps until aggregation -> determines length of the IST round (aggregation frequency) +# numSubnets : number of independent subnets/workers +# hyperparams : list of hyperparameters (e.g. lr, reg, mask params, etc.) +# verbose : print progress (boolean) +# paramsPerLayer : amount of parameters each layer consists of +# fullyConnectedLayers : list of all FC layer indices (starting at idx=1) +# +# OUTPUT: +# model_out : trained model parameters (IST: W) +# +# ASSUMPTION: +# - the last layer is the output layer +# ------------------------------------------------------------ + +m_independentSubnetTrain = function( + list[unknown] model, + matrix[double] features, + matrix[double] labels, + matrix[double] val_features, + matrix[double] val_labels, + string upd, + string agg, + int epochs, + int batchsize, + int j, + int numSubnets, + list[unknown] hyperparams, + boolean verbose, + int paramsPerLayer, + list[int] fullyConnectedLayers +) +return (list[unknown] trained_model) +{ + # ------------------------------------------------------------ + # Setup + # ------------------------------------------------------------ + model_out = model + + P = length(model) + N = nrow(features) + if (P %% paramsPerLayer != 0) stop("Model length not divisible by paramsPerLayer") + L = as.integer(P / paramsPerLayer) # total layers + + # I. determine shared parameters + isSharedParam = matrix(0, 1, P) + + # - create mask for all FC layers + fcLayers = fullyConnectedLayers + isFC = matrix(0, rows=1, cols=L) + for (i in 1:length(fcLayers)) { + idx = as.integer(as.scalar(fcLayers[i])) + isFC[1, idx] = 1 # TODO vectorize + } + + # - expand layer mask across all parameters + isFC_rep = isFC + for (r in 2:paramsPerLayer) { + isFC_rep = cbind(isFC_rep, isFC) + } + if (ncol(isFC_rep)!=P) stop("Dimension mismatch for FC layer mask.") + + # - all non-FC layers are shared + isSharedParam = 1 - isFC_rep + + # - edge case: FC bias parameters are shared in: output layer or at the end of a FC block + for (paramId in seq(2, paramsPerLayer, 2)) { # iterate bias blocks only + for (l in 1:L) { + if (as.scalar(isFC[1,l])==1 & l==L) { + p_out_bias = (paramId - 1) * L + L # output bias is shared across subnets + isSharedParam[1, p_out_bias] = 1 # TODO vectorize + } + else if (as.scalar(isFC[1,l])==1 & lN) { + stop("Batch size is out of bounds!") + } else { + stepsPerEpoch = ceil(N / batchsize) + } + + # III. training loop + for (epoch in 1:epochs) { + if (verbose) print("Entered epoch: " + epoch) + + # A.) reshuffle indices each epoch + allSampleIndicesRandom = order(target=rand(rows=N, cols=1), by=1, decreasing=FALSE, index.return=TRUE) + batchIndices = allSampleIndicesRandom[, 1] + b = nrow(batchIndices) + I = seq(1, b, 1) + V = matrix(1, rows=b, cols=1) + S = table(I, batchIndices, V, b, N) + + features_shuffled = S %*% features + labels_shuffled = S %*% labels + + # B.) iterate IST rounds + for (step in seq(1, stepsPerEpoch, j)) { + if (verbose) print("Starting new IST round at step: " + step) + round_model = model_out # prevent accidental mutation of model_out + + # 1.) create masks for all subnets + [masks, masks_meta_info] = ist_create_disjoint_masks(round_model, numSubnets, L, fcLayers, paramsPerLayer, isFC, verbose) + + # 2.) preallocate list to store all subnets TODO move outside epoch loop? to prevent constantly allocating... + updatedSubnets = list() + updatedSubnetsMasks = list() + for (s in 1:numSubnets) { + updatedSubnets = append(updatedSubnets, list()) + updatedSubnetsMasks = append(updatedSubnetsMasks, list()) + } + + # 3) create a template for each subnet based on input model (allows indexing in subsequent parfor-loop) TODO move outside epoch loop? to prevent constantly allocating... + subnetModelTemplate = list() + subnetModelMaskTemplate = list() + for (pIdx in 1:P) { + subnetModelTemplate = append(subnetModelTemplate, as.matrix(model[pIdx])) + subnetModelMaskTemplate = append(subnetModelMaskTemplate, as.matrix(model[pIdx])) + } + + # local optimization steps / IST round + localSteps = min(j, (stepsPerEpoch-step+1)) + + # 4.) obtain all minibatches for this IST round (doing it once prevents parfor confusion) + shuffled_features = list() + shuffled_labels = list() + for (localStep in 1:localSteps) { + mb = (step-1) + localStep + mb_local = mb-1 + start = mb_local*batchsize + 1 + end = min(mb*batchsize, N) + + Xb = features_shuffled[start:end, 1:ncol(features_shuffled)] + yb = labels_shuffled[start:end, 1:ncol(labels_shuffled)] + shuffled_features = append(shuffled_features, Xb) + shuffled_labels = append(shuffled_labels, yb) + } + + # 5.) perform 'j' local gradient steps for each subnet + parfor (subnet in 1:numSubnets) { + + # a.) obtain masked subnet + subnet_model = subnetModelTemplate + subnet_model_mask = subnetModelMaskTemplate + for (subnet_p in 1:length(round_model)) { + param_start_idx = as.integer(as.scalar(masks_meta_info[subnet_p,1])) + param_end_idx = as.integer(as.scalar(masks_meta_info[subnet_p,2])) + param_rows = as.integer(as.scalar(masks_meta_info[subnet_p,3])) + param_cols = as.integer(as.scalar(masks_meta_info[subnet_p,4])) + + vec = masks[subnet, param_start_idx:param_end_idx] + param_mask = matrix(vec, rows=param_rows, cols=param_cols, byrow=TRUE) + param = as.matrix(round_model[subnet_p]) + + subnet_model[subnet_p] = list(param * param_mask) # TODO sparse masking! dense masking will probably increase computational efficiency + subnet_model_mask[subnet_p] = list(param_mask) + } + + # b.) local optimization steps / IST round + for (localStep in 1:localSteps) { + feat = as.matrix(shuffled_features[localStep]) + lab = as.matrix(shuffled_labels[localStep]) + + # compute gradients for subnet s + apply update step on owned params (only) + subnet_model = as.list(evalList(upd, list(model=subnet_model, mask=subnet_model_mask, features=feat, labels=lab, hyperparams=hyperparams))) + } + + # c.) save updated subnet and mask + updatedSubnets[subnet] = list(subnet_model) + updatedSubnetsMasks[subnet] = list(subnet_model_mask) + } + if (verbose) print("All subnets have run successfully.") + + # 6.) aggregate updates into global model (i.e. model_out) + for (p in 1:P) { + if (as.scalar(isSharedParam[1, p])==1) { + # construct full model update by aggregating shared parameter updates from all subnets + subnetParams = list() + subnetMasks = list() + for (s in 1:numSubnets) { + subnet = as.list(updatedSubnets[s]) + subnetMask = as.list(updatedSubnetsMasks[s]) + + subnetParams = append(subnetParams, as.matrix(subnet[p])) + subnetMasks = append(subnetMasks, as.matrix(subnetMask[p])) + } + + # aggregate shared parameters based on provided function + averagedUpdatedParam = eval(agg, list(initialParam=as.matrix(round_model[p]), allSubnetsParam=subnetParams, allSubnetsMasks=subnetMasks)) + round_model[p] = averagedUpdatedParam + } + else { + # construct full model update by filling with disjointly partitioned parameter updates from all subnets + initialParam = as.matrix(round_model[p]) + updatedParam = matrix(0, nrow(initialParam), ncol(initialParam)) + owned = matrix(0, nrow(initialParam), ncol(initialParam)) + + for (s in 1:numSubnets) { + subnet = as.list(updatedSubnets[s]) + subnetMask = as.list(updatedSubnetsMasks[s]) + + owned = owned + as.matrix(subnetMask[p]) + updatedParam = updatedParam + as.matrix(subnet[p]) + } + + # SANITY CHECK: + max_freq = max(owned) + if (max_freq > 1) stop("Overlap detected") + + round_model[p] = updatedParam + } + } + if (verbose) print("Aggregation of subnets finished. IST round has been successfully executed!") + + # 7.) update global model (end of the IST round) + model_out = round_model + } + # TODO (potentially): add validation for early stopping etc. + } + trained_model = model_out +} + + +# ---------------------------------------------------------------------------------------------------------------------- +# Independent Subnet Masking +# +# This helper function creates two matrices: one contains all flattened masks, the other contains the info on +# how to reconstruct the mask matrices. Each mask is a binary vector indicating which parameters belong to that subnet. +# ---------------------------------------------------------------------------------------------------------------------- +# INPUT: +# model : list of parameter tensors grouped by parameter type i.e. blocks +# numSubnets : number of independent subnets/workers (K) +# L : total number of layers INCLUDING the output layer (layer indices are assumed to be 1..L) +# fullyConnectedLayers : list of all FC layer indices (starting at idx=1) +# paramsPerFCLayer : number of parameters / neurons to be partitioned per FC layer +# isFC : indicator matrix encoding which layers are FC => isFC[l] ∈ {0,1} +# verbose : print progress (boolean) +# +# OUTPUT: +# masks_new : mask matrix defining disjoint neuron ownership across subnets +# masks_new_meta : metadata matrix describing the mask layout and ownership mapping +# +# ASSUMPTIONS: +# - neuron ownership is defined via bias vectors +# - model is a list of parameter tensors +# - trainable parameters are grouped by parameter type i.e. param blocks like (W_l1, W_l2, ..., b_l1, b_l2, ...) +# - assumes W and b are always the first two param blocks +# - the pattern of optional optimizer state tensors (e.g., vW_l, vb_l) follow the same grouping and always W followed by b +# - (output layer & end of FC block) biases are shared -> gradients collide; must be handled by aggregation logic +# ---------------------------------------------------------------------------------------------------------------------- + +ist_create_disjoint_masks = function( + list[unknown] model, + int numSubnets, + int L, + list[int] fullyConnectedLayers, + int paramsPerFCLayer, + Matrix[Double] isFC, + boolean verbose +) + return ( + Matrix[Double] masks_new, + Matrix[Double] masks_new_meta + ) +{ + P = length(model) + + # SANITY CHECKS: ensure provided model can be masked correctly + if (as.integer(P / paramsPerFCLayer) != L) { + stop("Layer/parameter mismatch. Please make sure each layer has the same amount of parameters.") + }; + if (paramsPerFCLayer < 2 | paramsPerFCLayer %% 2 != 0) { + stop("At least 1 pair of W and b needs to be present, as well as parameters need to be W&b pairs.") + } + + # I.) initialize and preallocate masks + masks_new_meta = matrix(0, rows=length(model), cols=4) # columns=[start,end,rows,cols] + current_position = 1 + for (p in 1:length(model)) { + M = as.matrix(model[p]) + param_length = ncol(M) * nrow(M) # as.scalar(ncol(M)) * as.scalar(nrow(M)) + + masks_new_meta[p,1] = current_position + masks_new_meta[p,2] = current_position + param_length -1 + masks_new_meta[p,3] = nrow(M) + masks_new_meta[p,4] = ncol(M) + + current_position = current_position + param_length + } + mask_size = current_position-1 + masks_new = matrix(0, rows=numSubnets, cols=mask_size) # all subnets in one matrix + + # II.) iterate all layers + for (l in 1:L) { + + # FC layer: create #{numSubnets} disjoint partitions for this layer across all parameters + if (as.scalar(isFC[1,l]) == 1) { + W = as.matrix(model[l]) + b = as.matrix(model[l+L]) + H = ncol(W); # bias neurons in layer l + + # SANITY CHECKS: + if (nrow(b) != 1 | ncol(b) != H) { + if (verbose) print("Bias shape mismatch!") + if (verbose) print("b:", nrow(b), "x", ncol(b)) + if (verbose) print("expected: 1 x", H) + stop("Invalid bias shape") + } + if (l!=L & numSubnets>ncol(b)) { # TODO change to next layer is non-FC logic + if (verbose) print("More subnets than available neurons in layer:") + if (verbose) print(l) + stop("Please use a wider model or decrease the amount of subnets.") + } + + # A.) shuffle all neuron indices + allNeuronIndicesRandom = order(target=rand(rows=H, cols=1), by=1, decreasing=FALSE, index.return=TRUE) + + # B.) determine neuron ownership + chunk_size = floor(H/numSubnets) + remaining_neurons = H - chunk_size * numSubnets + amount_active_neurons = matrix(chunk_size, rows=numSubnets, cols=1) + if (remaining_neurons > 0) { + randomSubnetIndices = order(target=rand(rows=numSubnets, cols=1, seed=-1), by=1, decreasing=FALSE, index.return=TRUE) # TODO replace seed for experiments + for (i in 1:remaining_neurons) { + sid = as.integer(as.scalar(randomSubnetIndices[i,1])) + amount_active_neurons[sid,1] = as.scalar(amount_active_neurons[sid,1]) + 1 # TODO VECTORIZE + } + } + neuron_end_indices = cumsum(amount_active_neurons) + neuron_start_indices = neuron_end_indices - amount_active_neurons + 1 + + # C.) obtain masks for all subnets + for(s in 1:numSubnets) { + + # 1. obtain owned neurons for this layer + start = as.integer(as.scalar(neuron_start_indices[s,1])) + end = as.integer(as.scalar(neuron_end_indices[s,1])) + current_b_indices = allNeuronIndicesRandom[start:end, 1] + + # 2. create masked bias + if(l==L) { # output layer + masked_b = matrix(1, rows=1, cols=ncol(b)) + } + else if (l1 & as.scalar(isFC[1, l-1])==0) { # previous layer is not FC + for (i in 1:nrow(current_b_indices)) { # TODO VECTORIZE + idx = as.integer(as.scalar(current_b_indices[i,1])) + masked_W[1:nrow(W), idx] = matrix(1, rows=nrow(W), cols=1) + } + } + else { + # obtain active neurons of previous layer + p = L + (l-1) + start = as.integer(as.scalar(masks_new_meta[p,1])) + end = as.integer(as.scalar(masks_new_meta[p,2])) + r = as.integer(as.scalar(masks_new_meta[p,3])) + c = as.integer(as.scalar(masks_new_meta[p,4])) + vec = masks_new[s, start:end] + previous_masked_b = matrix(vec, rows=r, cols=c, byrow=TRUE) + + # SANITY CHECK: dimensions with layers of previous layer match + if (l > 1 & ncol(previous_masked_b) != nrow(W)) { + if (verbose) print("W/prev layer mismatch in layer l=", l) + if (verbose) print("prev_b:", nrow(previous_masked_b), "x", ncol(previous_masked_b)) + if (verbose) print("W:", nrow(W), "x", ncol(W)) + stop("Invalid W shape wrt previous layer") + } + + if (nrow(previous_masked_b)==1) previous_masked_b = t(previous_masked_b) + if (ncol(masked_b) == 1) masked_b = t(masked_b) + + if(l==L) { # output layer + masked_W = previous_masked_b %*% matrix(1, 1, ncol(masked_W)) + } + else if (l