主要记录concat,stack,unstack和split相关操作的作用
import tensorflow as tf import numpy as np tf.__version__ #concat对某个维度进行连接 #假设下面的tensor0和tensor1分别表示4个班级35名同学的8门成绩和两个班级35个同学8门成绩 tensor0 = tf.ones([4,35,8]) tensor1 = tf.ones([2,35,8]) #用concat将第0个维度(班级,axis=0)连接起来,结果是一个[6,35,8]的tensor #表示6个班级35名同学8门成绩的数据 tensor = tf.concat([tensor0, tensor1], axis=0) print("=========>tf.concat([tensor0, tensor1], axis=0).shape:", tensor.shape) #在同学维度进行合并,第1个维度,axis=1 #假设下面的tensor0和tensor1分别表示4个班级32名同学的8门成绩和4个班级3个同学8门成绩 tensor0 = tf.ones([4,32,8]) tensor1 = tf.ones([4,3,8]) #concat合并第一个维度,可以理解为,tensor0先收集到了32名同学的8门成绩 #然后补考的3名同学成绩放到了tensor1上,通过concat进行汇总 tensor = tf.concat([tensor0, tensor1], axis=1) print("=========>tf.concat([tensor0, tensor1], axis=1).shape:", tensor.shape) #concat对于维度有要求,对于不是指定axis的维度要相等才能concat #一个[4,35,8]的tensor和一个[3,15,8]的tensor无法进行concat #concat对某个维度进行连接 #假设下面的tensor0和tensor1分别表示4个班级35名同学的8门成绩和两个班级35个同学8门成绩 tensor0 = tf.ones([4,35,8]) tensor1 = tf.ones([2,35,8]) #用concat将第0个维度(班级,axis=0)连接起来,结果是一个[6,35,8]的tensor #表示6个班级35名同学8门成绩的数据 tensor = tf.concat([tensor0, tensor1], axis=0) print("=========>tf.concat([tensor0, tensor1], axis=0).shape:", tensor.shape) #在同学维度进行合并,第1个维度,axis=1 #假设下面的tensor0和tensor1分别表示4个班级32名同学的8门成绩和4个班级3个同学8门成绩 tensor0 = tf.ones([4,32,8]) tensor1 = tf.ones([4,3,8]) #concat合并第一个维度,可以理解为,tensor0先收集到了32名同学的8门成绩 #然后补考的3名同学成绩放到了tensor1上,通过concat进行汇总 tensor = tf.concat([tensor0, tensor1], axis=1) print("=========>tf.concat([tensor0, tensor1], axis=1).shape:", tensor.shape) #concat对于维度有要求,对于不是指定axis的维度要相等才能concat #一个[4,35,8]的tensor和一个[3,15,8]的tensor无法进行concat #unstack和stack操作相反,会对指定维度进行拆分 tensor = tf.ones([3,4,35,8]) #拆分出3个[4,35,8]的tensor splited = tf.unstack(tensor, axis=0) print("==========>tf.unstack(tensor, axis=0).shape:", splited[0].shape, splited[1].shape, splited[2].shape) #拆分出8个[3,4,35]的tensor splited = tf.unstack(tensor, axis=3) print("==========>tf.unstack(tensor, axis=3).shape:", splited[0].shape, splited[1].shape, splited[2].shape, splited[3].shape, splited[4].shape, splited[5].shape, splited[5].shape, splited[6].shape, splited[7].shape) #拆分出4个[3,35,8]的tensor splited = tf.unstack(tensor, axis=1) print("==========>tf.unstack(tensor, axis=1).shape:", splited[0].shape, splited[1].shape, splited[2].shape, splited[3].shape) #unstack会固定打散指定维度为1 #split则可以指定这个维度划分的比例,通过num_or_size_splits指定 #看个例子就明白了 tensor = tf.ones([2,4,35,8]) #第3个维度划分为2个4维的两个tensor([2,4,35,4]) --- 8 / 2(num_of_size_splits) = 4 splited = tf.split(tensor, axis=3, num_or_size_splits=2) print("==========>split(tensor, axis=3, num_or_size_splits=2).shape:", splited[0].shape, splited[1].shape) #将第3个维度按照2,2,4的比例划分,得到3个tensor splited = tf.split(tensor, axis=3, num_or_size_splits=[2,2,4]) print("==========>split(tensor, axis=3, num_or_size_splits=2).shape:", splited[0].shape, splited[1].shape, splited[2].shape)
运行结果:
猜你喜欢
网友评论
- 搜索
- 最新文章
- 热门文章