|
26 | 26 |
|
27 | 27 | namespace it_lab_ai { |
28 | 28 |
|
| 29 | +namespace { |
| 30 | +template <typename T> |
| 31 | +std::shared_ptr<Layer> clone_layer_checked( |
| 32 | + const std::shared_ptr<Layer>& layer) { |
| 33 | + const auto* casted = dynamic_cast<const T*>(layer.get()); |
| 34 | + if (casted == nullptr) { |
| 35 | + throw std::invalid_argument("Layer type mismatch while cloning"); |
| 36 | + } |
| 37 | + return std::make_shared<T>(*casted); |
| 38 | +} |
| 39 | +} // namespace |
| 40 | + |
29 | 41 | void Graph::clone(Graph& result, Tensor& out, |
30 | 42 | const RuntimeOptions& options) const { |
31 | 43 | result.arrayE_ = this->arrayE_; |
@@ -61,110 +73,70 @@ std::shared_ptr<Layer> layer_based_shared_copy( |
61 | 73 | const std::shared_ptr<Layer>& layer, const RuntimeOptions& options) { |
62 | 74 | switch (layer->getName()) { |
63 | 75 | case it_lab_ai::kInput: { |
64 | | - auto* tmp_layer = new InputLayer(*dynamic_cast<InputLayer*>(layer.get())); |
65 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 76 | + return clone_layer_checked<InputLayer>(layer); |
66 | 77 | } |
67 | 78 | case it_lab_ai::kPooling: { |
68 | 79 | if (options.backend == Backend::kOneDnn) { |
69 | | - auto* tmp_layer = new PoolingLayerOneDnn( |
70 | | - *dynamic_cast<PoolingLayerOneDnn*>(layer.get())); |
71 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 80 | + return clone_layer_checked<PoolingLayerOneDnn>(layer); |
72 | 81 | } |
73 | | - auto* tmp_layer = |
74 | | - new PoolingLayer(*dynamic_cast<PoolingLayer*>(layer.get())); |
75 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 82 | + return clone_layer_checked<PoolingLayer>(layer); |
76 | 83 | } |
77 | 84 | case it_lab_ai::kElementWise: { |
78 | 85 | if (options.backend == Backend::kOneDnn) { |
79 | | - auto* tmp_layer = |
80 | | - new EwLayerOneDnn(*dynamic_cast<EwLayerOneDnn*>(layer.get())); |
81 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 86 | + return clone_layer_checked<EwLayerOneDnn>(layer); |
82 | 87 | } |
83 | | - auto* tmp_layer = new EWLayer(*dynamic_cast<EWLayer*>(layer.get())); |
84 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 88 | + return clone_layer_checked<EWLayer>(layer); |
85 | 89 | } |
86 | 90 | case it_lab_ai::kConvolution: { |
87 | 91 | if (options.backend == Backend::kOneDnn) { |
88 | | - auto* tmp_layer = |
89 | | - new ConvLayerOneDnn(*dynamic_cast<ConvLayerOneDnn*>(layer.get())); |
90 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 92 | + return clone_layer_checked<ConvLayerOneDnn>(layer); |
91 | 93 | } |
92 | | - auto* tmp_layer = new ConvolutionalLayer( |
93 | | - *dynamic_cast<ConvolutionalLayer*>(layer.get())); |
94 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 94 | + return clone_layer_checked<ConvolutionalLayer>(layer); |
95 | 95 | } |
96 | 96 | case it_lab_ai::kFullyConnected: { |
97 | | - auto* tmp_layer = new FCLayer(*dynamic_cast<FCLayer*>(layer.get())); |
98 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 97 | + return clone_layer_checked<FCLayer>(layer); |
99 | 98 | } |
100 | 99 | case it_lab_ai::kFlatten: { |
101 | | - auto* tmp_layer = |
102 | | - new FlattenLayer(*dynamic_cast<FlattenLayer*>(layer.get())); |
103 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 100 | + return clone_layer_checked<FlattenLayer>(layer); |
104 | 101 | } |
105 | 102 | case it_lab_ai::kConcat: { |
106 | | - auto* tmp_layer = |
107 | | - new ConcatLayer(*dynamic_cast<ConcatLayer*>(layer.get())); |
108 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 103 | + return clone_layer_checked<ConcatLayer>(layer); |
109 | 104 | } |
110 | 105 | case it_lab_ai::kDropout: { |
111 | | - auto* tmp_layer = |
112 | | - new DropOutLayer(*dynamic_cast<DropOutLayer*>(layer.get())); |
113 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 106 | + return clone_layer_checked<DropOutLayer>(layer); |
114 | 107 | } |
115 | 108 | case it_lab_ai::kSplit: { |
116 | | - auto* tmp_layer = new SplitLayer(*dynamic_cast<SplitLayer*>(layer.get())); |
117 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 109 | + return clone_layer_checked<SplitLayer>(layer); |
118 | 110 | } |
119 | 111 | case it_lab_ai::kBinaryOp: { |
120 | 112 | if (options.backend == Backend::kOneDnn) { |
121 | | - auto* tmp_layer = new BinaryOpLayerOneDnn( |
122 | | - *dynamic_cast<BinaryOpLayerOneDnn*>(layer.get())); |
123 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 113 | + return clone_layer_checked<BinaryOpLayerOneDnn>(layer); |
124 | 114 | } |
125 | | - auto* tmp_layer = |
126 | | - new BinaryOpLayer(*dynamic_cast<BinaryOpLayer*>(layer.get())); |
127 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 115 | + return clone_layer_checked<BinaryOpLayer>(layer); |
128 | 116 | } |
129 | 117 | case it_lab_ai::kTranspose: { |
130 | | - auto* tmp_layer = |
131 | | - new TransposeLayer(*dynamic_cast<TransposeLayer*>(layer.get())); |
132 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 118 | + return clone_layer_checked<TransposeLayer>(layer); |
133 | 119 | } |
134 | 120 | case it_lab_ai::kMatmul: { |
135 | | - auto* tmp_layer = |
136 | | - new MatmulLayer(*dynamic_cast<MatmulLayer*>(layer.get())); |
137 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 121 | + return clone_layer_checked<MatmulLayer>(layer); |
138 | 122 | } |
139 | 123 | case it_lab_ai::kReshape: { |
140 | | - auto* tmp_layer = |
141 | | - new ReshapeLayer(*dynamic_cast<ReshapeLayer*>(layer.get())); |
142 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 124 | + return clone_layer_checked<ReshapeLayer>(layer); |
143 | 125 | } |
144 | 126 | case it_lab_ai::kSoftmax: { |
145 | | - auto* tmp_layer = |
146 | | - new SoftmaxLayer(*dynamic_cast<SoftmaxLayer*>(layer.get())); |
147 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 127 | + return clone_layer_checked<SoftmaxLayer>(layer); |
148 | 128 | } |
149 | 129 | case it_lab_ai::kReduce: { |
150 | 130 | if (options.backend == Backend::kOneDnn) { |
151 | | - auto* tmp_layer = new ReduceLayerOneDnn( |
152 | | - *dynamic_cast<ReduceLayerOneDnn*>(layer.get())); |
153 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 131 | + return clone_layer_checked<ReduceLayerOneDnn>(layer); |
154 | 132 | } |
155 | | - auto* tmp_layer = |
156 | | - new ReduceLayer(*dynamic_cast<ReduceLayer*>(layer.get())); |
157 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 133 | + return clone_layer_checked<ReduceLayer>(layer); |
158 | 134 | } |
159 | 135 | case it_lab_ai::kBatchNormalization: { |
160 | | - auto* tmp_layer = new BatchNormalizationLayer( |
161 | | - *dynamic_cast<BatchNormalizationLayer*>(layer.get())); |
162 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 136 | + return clone_layer_checked<BatchNormalizationLayer>(layer); |
163 | 137 | } |
164 | 138 | case it_lab_ai::kOutput: { |
165 | | - auto* tmp_layer = |
166 | | - new OutputLayer(*dynamic_cast<OutputLayer*>(layer.get())); |
167 | | - return std::shared_ptr<Layer>(tmp_layer); |
| 139 | + return clone_layer_checked<OutputLayer>(layer); |
168 | 140 | } |
169 | 141 | default: { |
170 | 142 | throw std::invalid_argument("No such layer type"); |
|
0 commit comments