1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
| def redidual_block(x, output_channel): """redidual connection implementation""" input_channel = x.get_shape().as_list()[-1] if input_channel * 2 == output_channel: increase_dim = True strides = (2,2) elif input_channel == output_channel: increase_dim = False strides = (1,1) else: raise Exception("input channel can't match output channel") conv1 = tf.layers.conv2d( x, output_channel, (3,3), strides = strides, padding ='same', activation = tf.nn.relu, name = 'conv1' ) conv2 = tf.layers.conv2d( conv1, output_channel, (3,3), strides = (1,1), padding ='same', activation = tf.nn.relu, name = 'conv2' ) if increase_dim: pooled_x = tf.layers.average_pooling2d( x, (2,2), (2,2), padding = 'valid' ) padded_x = tf.pad( pooled_x, [ [0,0], [0,0], [0,0], [input_channel // 2,input_channel // 2] ]) else: padded_x = x output_x = conv2 + padded_x
return output_x
def res_net(x,num_residual_blocks,num_filter_base,class_num): ''' x: 输入数据 num_residual_blocks: 残差连接块数 eg:[3, 4, 6, 3] num_filter_base: 最初通道数 class_num: 泛化不同数据集 ''' num_subsampling = len(num_residual_blocks) layers = [] with tf.variable_scope('conv0'): conv0 = tf.layers.conv2d( x, num_filter_base, (3,3), strides = (1,1), activation = tf.nn.relu, padding = 'same', name = 'conv0' ) layers.append(conv0) for sample_id in range(num_subsampling): for i in range(num_residual_blocks[sample_id]): with tf.variable_scope('conv%d_%d' % (sample_id, i)): conv = redidual_block( layers[-1], num_filter_base * (2 ** sample_id) ) layers.append(conv) with tf.variable_scope('fc'): global_pool = tf.reduce_mean(layers[-1], [1,2]) logits = tf.layers.dense(global_pool, class_num) layers.append(logits)
return layers[-1]
|