重慶二級(jí)建造師證書查詢廣西seo經(jīng)理
使用tf2模型進(jìn)行推理的過(guò)程中,發(fā)現(xiàn)模型的內(nèi)存占用在逐步增加,甚至?xí)驗(yàn)镺OM被kill掉進(jìn)程,有時(shí)候模型只加載不運(yùn)行,搜索得到很多五花八門的答案,有些認(rèn)為是tf2本身的問(wèn)題,但在使用內(nèi)存追蹤的時(shí)候發(fā)現(xiàn),是模型的動(dòng)態(tài)圖沒(méi)有得到釋放,而導(dǎo)致這個(gè)問(wèn)題出現(xiàn)的原因,是數(shù)據(jù)的加載方式存在問(wèn)題!!!
mhc_a_batches = list(chunks(mhc_seqs_a, self.batch_size))mhc_b_batches = list(chunks(mhc_seqs_b, self.batch_size))pep_batches = list(chunks(pep_seqs, self.batch_size))assert len(mhc_a_batches) == len(mhc_b_batches)assert len(mhc_a_batches) == len(pep_batches)size = len(mhc_a_batches)# 開(kāi)始預(yù)測(cè)preds = []for i in range(size):_preds = self.model([mhc_a_batches[i], mhc_b_batches[i], pep_batches[i]], training = False)preds.extend(_preds.numpy().tolist())return preds
如這段代碼,直接使用了list作為模型的輸入,盡管tf2也支持numpy的輸入格式,但卻存在隱患,會(huì)產(chǎn)生大量的空tensor!!!
將其改為這樣的形式,問(wèn)題得到解決:
mhc_seqs_a = tf.convert_to_tensor(mhc_seqs_a, dtype=tf.float32)mhc_seqs_b = tf.convert_to_tensor(mhc_seqs_b, dtype=tf.float32)pep_seqs = tf.convert_to_tensor(pep_seqs, dtype=tf.float32)assert len(mhc_seqs_a) == len(mhc_seqs_b)assert len(mhc_seqs_a) == len(pep_seqs)ds = tf.data.Dataset.from_tensor_slices((mhc_seqs_a, mhc_seqs_b, pep_seqs)).batch(self.batch_size).prefetch(1)preds = []for x, y, z in ds:_preds = self.model([x,y,z], training=False)preds.extend(_preds.numpy().tolist())return preds
現(xiàn)在可以愉快的進(jìn)行模型推理了,而且速度比之前要快幾倍不止,實(shí)測(cè)在GPU上提速近30倍,可想而知對(duì)于上億級(jí)別的數(shù)據(jù),節(jié)省的時(shí)間多么可觀!