1. 모델 분석/Segmentation

[2D Multi-Class Semantic Segmentation] 1. Data Preprocessing

M_AI 2021. 5. 31. 20:49

 

2021.05.11 (화) 수정 내역

 

1. 데이터 추가

- 검증 데이터가 적어 모델에 대한 신뢰성이 떨어져, DIARETDB1 데이터셋을 추가.

 

2. 클래스 변경

- 기존에는 optic disc를 추가하여 함께 segmentation을 했으나 새로 추가하는 DIARETDB1에는 optic disc 마스크가 없기에, 여기에 맞춰서 optic disc를 제외하여 클래스 수를 4 -> 3으로 변경.

 

Class

0 - Background

1 - MA + HE

2 - EX + SE

 

3. 데이터 증식(Data Augmentation)

- 기존에는 회전은 5도 마다 해주었지만, 너무 많은 관계로 20도로 변경.

 

 

4. 데이터셋 생성 코드 수정

- 모델 검증을 K-Fold Cross Validation으로 하기 위해 tf.data.Dataset.from_generator을 사용하지 못하고, 데이터셋 전처리를 별도로 실행해야 하기에 코드를 분할 및 수정하였다. 하지만 전처리 코드 방식은 거의 동일하기 때문에 기존의 코드는 삭제하지 않고 그대로 두기로 하고, 수정한 전처리 코드를 업로드. ( 기존의 코드와 구분하기 위해 밝은색 배경에 코드를 업로드한다.)

 


 

안녕하세요! M_AI 입니다!

 

이번에는 2D 데이터에서 세그멘테이션하는 것을 진행하도록 합니다. 이번 데이터셋은 당뇨망막병증(Diabetic Retinopathy, DR)가 있는 안저 영상(fundus image)을 사용하며, 여기서 병변에 대한 분할(세그멘테이션, segmentation)을 하고자 합니다.

 

사실 이 주제로 진행하는 이유는,

본인 학사 학위 논문 주제를 당뇨망막병증의 삼출물(Exudates)을 효과적으로 추출하는 방법으로 연구했으나,

그 당시에는 딥러닝에 완전 초짜였고(사실 지금도 초짜지만), 촉박한 시간 내에 연구 주제 선정 및 계획을 세우고....

 

테스트하고 수정하고, 또 테스트하고....를 반복하다보니 너무나 아쉬운 성능이 보였습니다.

 

그래서 conclusion에서 부족하고 아쉬웠던 부분을 조금 만회하고자 이 주제로 선정하였습니다.

 

본인 학사 학위 논문 연구와 위의 세 단계를 취합한 결과 또한 project 카테고리 항목에 정리하여 올리도록 하겠습니다.

 

늘 그렇듯이, 여기서는 상세한 이론을 설명하기 보다는 코드 분석에 초점을 두기 때문에 상세한 이론을 원하시는 분들은 다른 곳에서 구글링을 하시거나, 제가 걸어둔 링크에서 찾아보시는 것을 추천드립니다.


PC 화면으로 보는 것을 추천!!!!!

 

이 글을 읽으면 좋을 것 같은 사람들!

- 기본적인 분류(classfication) 모델을 Tensorflow 2.x 버전으로 구현할 줄 아는 사람

- Clssification을 넘어서 Object detection을 하길 원하는데 오픈 소스를 봐도 이해가 가질 않는 사람.

- 오픈 소스를 이해하더라도 입맛대로 수정하는 방법을 모르는 사람.

 

 

본 게시글의 목적

※ 본 게시글은 딥러닝 모델의 상세한 이론적 설명보다는 모델 구조를 바탕으로 코드 분석하고 이해하는 데에 목적에 있습니다!

※ 상세한 이론을 원하면 다른 곳에서 찾아보시길 바랍니다!

 

본 게시글의 특징

① Framework : Tensorflow 2.x

② 학습 환경

- Google Colab

③ Batch size = 1

 IDRiD Diabetic Retinopathy, DIARETDB1 데이터셋 사용하여 Fundus Image에서 병변 분할(Lesion Segmentation) 목적

( IDRiD 데이터셋 출처 : https://idrid.grand-challenge.org/Data/)

 

IDRiD - Grand Challenge

This challenge evaluates automated techniques for analysis of fundus photographs. We target segmentation of retinal lesions like exudates, microaneurysms, and hemorrhages and detection of the optic disc and fovea. Also, we seek grading of fundus images acc

idrid.grand-challenge.org

 

( DIARETDB1 데이터셋 출처 : https://www.it.lut.fi/project/imageret/diaretdb1/#DOWNLOAD)

 

DIARETDB1 - STANDARD DIABETIC RETINOPATHY DATABASE

DIARETDB1 - Standard Diabetic Retinopathy Database Calibration level 1 This is a public database for benchmarking diabetic retinopathy detection from digital images. The main objective of the design has been to unambiguously define a database and a testing

www.it.lut.fi

⑤ 게시글에서는 반말

 


 

개 요

우선 DR은 많은 실명 원인 중 하나인 당뇨망막병증은 당뇨 합병증으로 인한 망막의 혈관 손상  출혈(hemorrhages), 비정상적인 미세혈관 생성, 동맥류(aneurysm) 그리고 삼출물(exudates) 등이 발생한다. (Figure 1)

본 게시글에서 사용되는 데이터는 당뇨망막병증이 있는 안저 영상이 사용되며, 아래 (Figure 1)은 당뇨망막병증으로 망막에 생기는 병변 5종류를 보기 쉽게 그림으로 표현된 것이다. 하지만 이 중에서 비정상적인 혈관(abnormal growth of blood vessels)은 매우 작아 실제 fundus 영상에서 찾기가 힘들며, 세그멘테이션을 위한 마스크(mask)도 따로 없어 제외하도록 한다. 그리고 cotton wool spots은 자주 포착되지만 본 게시글에서 사용하는 데이터셋에서는 없으므로 이 또한 제외하도록 한다.

 

Figure 1. 당뇨망막병증으로 망막에 생기는 현상 ( 출처 :  https://www.eyeops.com/contents/our-services/eye-diseases/diabetic-retinopathy)

 

그래서 본 게시글에서는 남은 3종류의 병변을 크게 두 클래스로 나눈다. 하나는 혈액(blood)과 관련된 출혈(hemorrhages)과 동맥류(aneurysm or microaneurysms), 또 다른 하나는 삼출물(exudates)이다.

또한, 세그멘테이션을 하기 위해 사용될 딥러닝 모델은 U-Net을 사용한다. U-Net은 Encoder와 Decoder 구조로 구성되어, Encoder에서는 입력 영상에서의 특징(feature)를 추출하고, Decoder에서는 label인 병변 마스크로 복구하는 과정이 이루어진다. 본 게시글에서는 Encoder 부분을 ImageNet으로 사전 학습된(pre-trained) 분류(classification)모델인 VGG16 모델을 백본(backbone)으로 둔 세 가지 U-Net 모델의 성능을 비교할 것이다.

 

1. 데이터 설명 및 전처리

2. U-Net, 백본 모델과 손실 함수 설명

3. 학습과 결과 및 성능 비교

 

로 진행될 것이다.

 


 

1. Dataset and Preprocessing

위에서도 언급했듯이 본 게시글에서는 출혈, 동맥류와 삼출물을 분할한다고 하였다.

본 게시글에서 사용될 데이터셋은 IDRiD으로 "ISBI-2018: Diabetic Retinopathy: Segmentation and Grading Challenge Workshop"라는 데이터 챌린지를 위해 공개한 당뇨망막병증 안저 영상 데이터셋이다.

 

해당 데이터셋은 당뇨망막병증 환자 81명의 4288 × 2848 크기 안저 영상이 있으며, 마스크의 종류는 5가지가 존재한다.

 

- 미세동맥류 microaneurysms (MA)

- 출혈 hemorrhages (HE)

- 작은 삼출물 soft exudates (SE)

- 심한 삼출물 hard exudates (EX)

- 시신경유두 optic disc(OD)

 

어떤 영상 데이터에서는 크기가 작은 SE가 없기도 해서 이에 대한 마스크가 없는 경우가 있다.

 

새로 추가한 DIARETDB1 는 환자 89명의 1500 x 1152 크기 영상으로, 마스크의 종류는 4가지가 존재한다.

 

- Reddot

- 출혈 hemorrhages (HE)

- 작은 삼출물 soft exudates (SE)

- 심한 삼출물 hard exudates (EX)

 

DIARETDB1에는 IDRiD과 다르게 reddot으로 되어있고, 시신경 유두(optic disc)가 존재하지 않는다. 자세한 구성은 아래에서 다루도록 하겠다.

 

이어서, 위에서 언급했듯이 본 게시글에서는 동맥류와 출혈을 하나의 클래스로 한다고 하였다. 추가적으로 해당 데이터에서는 삼출물이 soft와 hard로 나뉘는데 여기서는 하나로 합쳐 하나의 클래스로 하겠다. Optic disc는 병변이 아니지만, 해당 데이터에 마스크로 포함되어있으므로 이 또한 세그멘테이션을 진행하도록 하겠다.(새로운 데이터 추가로 인해 제거)

 

즉, 클래스는 총 4개로 다음과 같다.

 

0 : 배경(Background)

1 : 혈관 혈액 관련 MA + HE + Reddot

2 : 삼출물 SE + EX

3 : 시신경유두 OD(새로운 데이터 추가로 인해 제거)

 

우리는 이렇게 따로 분리된 마스크를 하나로 통합해야만 한다. 이는 아래 1.1절에서 진행하도록 한다.

 

 

1.1. Mask

1.1.1. Mask integration

IDRiD 하나의 영상 데이터에 대해서 마스크는 (Figure 2)와 같이 존재한다.

 

Figure 2-1. IDRiD에서 하나의 원본 영상과 그에 대한 마스크 데이터 ( 출처 : IDRiD )

 

 

Figure 2-2. DIARETDB에서 하나의 원본 영상과 그에 대한 마스크 데이터 ( 출처 :DIARETDB )

 

Figure 2-1은 IDRiD 데이터셋에서의 일부 데이터로 (a)부터 영상 데이터, MA, HE, EX, OD 마스크이며 해당 데이에는 SE가 없다.

 

Figure 2-2은 DIARETDB 에서의 일부 데이터로 (a)부터 영상 데이터, Red dot, HE, EX, SE 마스크이며, 위에서도 언급했듯이 해당 데이터셋에는 시신경 유두(Optic disc, OD)가 없어 OD 마스크가 없다.

 

IDRiD과 DIARETDB의 마스크 형식이 다른데,

 

IDRiD의 마스크 주변과 정확하게 구분되는 곳을 marking을 했다면,

 

DIARETDB의 마스크 해당 병변이 위치하는 영역 자체 marking하여 조금의 차이가 있다. 또한 하나의 마스크에 대해서도 밝기에 따라 다르게 영역이 잡혀있는 것을 볼 수 있는데, 이러한 이유는 데이터셋에 첨부된 논문[1] "DIARETDB1 diabetic retinopathy database and evaluation protocol" 에 따르면 해당 데이터 제작 과정에서 병변 영역에 대해 의료진 간의 큰 차이가 있어 이들의 정보를 버리지 않기 위해 모두 포함하였다고 한다. 그로 인해 픽셀값으로 다르게 표기했다.

 

Figure 2-3. DIARETDB의 일부 데이터에서 다르게 marking된 Hemorrhages 마스크( 출처 :"DIARETDB1 diabetic retinopathy database and evaluation protocol" )

 

Figure 2-3에 따르면, (b), (c), (d), (e)의 픽셀값을 0.25, 0.50, 0.75, 1.00 이라 한다고 하면(실제 픽셀값은 [63, 127, 189, 252]에 맞춰져 있다.), 논문에서 사용한 신뢰값(역주 아마 threshold 의미)을 0.75로 하였다고 한다. 이에 따라서 본 게시글에서도 동일하게 0.75 신뢰성에 대해 이진화(binaryzation)한 후에 마스크를 통합한다.

 

주의할 점은 여기서 0.75에 대해 이진화하게 되면, 0.75 이상 값이 없는 마스크에 대해서는 모두 0이 되어버리기 때문에 상당수의 데이터를 사용하지 못한다. 그래서 원래는 89개 였지만 이 과정을 거치면 34개 데이터를 잃어 55개만 남게 된다.

 

데이터 부족으로 인해 새로운 데이터를 추가했는데, 이 과정에서 새로 추가할 데이터 절반이 날라가버린 상황이 발생하였지만 데이터의 신뢰성을 위해 그대로 진행하기로 한다.

 

위에서 언급했듯이 시신경 유두(optic disc)를 제외한 분리된 마스크를 하나의 마스크로 통합해야 한다. 통합방식은 다음과 같다.

 

➀ 0으로 구성된 (2848, 4288) 크기의 배열을 생성한다. 이는 배경을 의미한다.

➁ MA or Red dot과 HE를 비트 연산 or을 통해 병변이 있는 픽셀값을 1로 할당한다.

➂ SE와 EX를 비트 연산 or을 통해 병변이 있는 픽셀값을 2로 할당한다.

➃ OD가 있는 픽셀값은 3으로 할당한다.

➄ (➀ ~ ➃)에서 생성한 분리된 마스크를 모두 더하여, 1차원의 마스크를 생성한다.

 

 

1.1.2. Crop, Padding and Resize

(Figure 2-1-(a))에서 보다시피 IDRiD 원본 영상에서는 좌우로 여백이 크다는 것을 알 수 있다. 그래서 이를 의미 있는 부분만 잘라주도록 한다.

(이를 Crop 한다고 하는데, Cropping이라고 표현하는 것을 보지 못한 듯 하다. 그리고 Crop을 안해줘도 되지만, 쓸데없는 연산량을 줄이기 위해 진행한다.)

 

Figure 2-2-(a)를 보면 DIARETDB 데이터에서의 여백은 그리 크지 않기 때문에, Crop 과정은 IDRiD 데이터셋만 진행한다.

 

그래서 눈이 있는 위치의 가장 왼쪽 좌표와 오른쪽 좌표를 구하여 해당 구간만 Crop을 해야 한다. Crop한 영상은 일반적으로 대략 (2848, 3400) 정도의 직사각형이기에 정사각형을 만들어주기 위해 위, 아래, 좌우로 제로 패딩(zero-padding)을 진행해주어야 한다. 과정은 다음과 같다.

 

➀ 이진화(binarization)로 눈과 여백을 구분한다.

➁ 윤곽선(contour)을 따서 x좌표에 대해서 최소와 최댓값 찾기. (최대 - 최소 = 너비 길이)

➂ 해당 구간만큼 원본 영상 데이터와 통합한 마스크 데이터 crop하기.

➃ Crop한 영상 데이터와 마스크 데이터 위아래에 각각 500씩, 좌우로 200씩 zero-padding

➄ 영상 데이터와 마스크 데이터를 (1024, 1024)로 Resize

 

자, ➃를 보면 위아래와 좌우까지 각각 패딩 해준다고 하였다. 눈치가 빠른 사람들은 알겠지만, “위아래만 해도 되는데 좌우까지 패딩을 해주면서 정사각형을 만들어줄까?”라는 의문이 들 것이다.

이러한 이유는 본인이 위에서 ‘Crop한 영상은 일반적으로 (2848, 3390)’라 하였다. 이는 다른 말로 일반적이지 않은 데이터도 있다는 것이다. 이는 (Figure 3)에서 비교를 해보았다.

 

Figure 3. 영상 크기 비교 ( 출처 : IDRiD )

 

 

일부 데이터 (Figure 3-(b))에서는 눈이 있는 부분이 다른 데이터에 비해 상당히 커서 Crop한 영상 크기는 대략 (2848, 3800) 정도가 된다. 그래서 일반적이지 않은 데이터에 맞춰주기 위해서 (2848, 3400)을 (3848, 3800)으로, (2848, 3800)을 (3848, 3800) 크기로 만들고 (1024, 1024)로 리사이즈해준다. (Figure 4-(a)는 Crop, padding, resize한 입력 영상이고, (b)-(e)는 각 클래스 별를 나누어 각기 나타낸 마스크 영상이다. ( 데이터 추가로 인해 클래스를 변경하여 (e) 시신경 유두(optic disc) 마스크는 제외한다. )

 

Figure 4. 전처리한 데이터와 클래스 별로 분리한 마스크 데이터

 

아래 (Figure 5)는 다른 입력 영상과 그에 대한 하나로 합친 마스크이다. 통합한 마스크를 자세히 보면 각 클래스 별로 픽셀값이 달라 밝기가 다른 것을 알 수 있다. 가장 밝은 것이 optic disc며, 그 다음 순서는 삼출물(exudates), 그 다음은 출혈과 동맥류, 마지막으로 가장 어두운 부분이 배경이다.

 

 

Figure 5. 전처리한 데이터와 통합한 마스크 데이터

 

 

1.1.3. Data Augmentation

해당 데이터는 136개(81 + 55)로 아주 적은 데이터로 구성되어 데이터 증식(augmentation)은 필수이다. 여기서 본인은 훈련 데이터와 테스트 데이터를 7:3인 100(60 + 40) / 36(21 + 15)개로 나누었다.

데이터 증식 방법으로는 좌우 반전(flip) 회전(rotation)을 사용했다. 훈련 데이터에 대해서 좌우 반전한 데이터를 20도씩 회전하였다. 즉, 100개를 100 * 2 * 18 = 3,600개로 증식하였다.

이후 데이터셋은 tf.data.Dataset.from_generator를 사용하여 생성하였다.

 


1.2. Code

"""

이하 수정한 코드 설명

`21.05.11 수정한 코드

"""

해당 코드는 데이터 추가하고, 수정한 Dataset_Generator()이며, 데이터셋이 두 개이다 보니, 모든 데이터셋을 하나로 모았다.

 

image 폴더에는 원본 fundus 이미

지를,

mask 폴더에는 MA, HE, EX, SE라는 폴더를 또 생성하여 각 폴더에 다음 마스크를 넣었다.

 

- MA : Microaneurysms (IDRiD), RedDot (DIARETDB)

- HE : Haemorrhages

- EX : Hard Exudates

- SE : Soft Exudates

 

그래서 전처리 전 디렉토리 구조는 다음과 같다.

/base_dir
    ├─image
    │   ├─ IDRiD_01.jpg
    │   ├─ IDRiD_02.jpg
    │   ├─ IDRiD_03.jpg
    │   ....
    │   ├─ image088.png
    │   └─ image089.png
    │
    └─ mask
        ├─ MA
        │   ├─ IDRiD_01_MA.jpg
        │   ├─ IDRiD_02_MA.jpg
        │   ....
        │   ├─ image088.png
        │   └─ image089.png
        │ 
        ├─ HE
        │   ├─ IDRiD_01_HE.jpg
        │   ├─ IDRiD_02_HE.jpg
        │   ....
        │   ├─ image088.png
        │   └─ image089.png
        │ 
        ├─ EX
        │   ├─ IDRiD_01_EX.jpg
        │   ├─ IDRiD_02_EX.jpg
        │   ....
        │   ├─ image088.png
        │   └─ image089.png
        │ 
        └─ SE
             ├─ IDRiD_01_SE.jpg
             ├─ IDRiD_02_SE.jpg
            ....
             ├─ image088.png
             └─ image089.png 

 

 

전처리 후 디렉토리 구조는 다음과 같다.

편의상, 폴더 내에 파일 표현은 생략했다.

/base_dir
    ├─image
    ├─mask
    │   ├─ MA
    │   ├─ HE
    │   ├─ EX
    │   └─ SE
    │ 
    ├─ Training
    │   ├─ images
    │   └─ masks
    │
    └─ Test
         ├─ images
         └─ masks

코드가 길어 분할하도록 한다.

 

Dataset_Generator()는 Class로 구성되어 있으며, 내부의 멤버 메소드는 총 4개가 있다.

class Dataset_Generator():
    def __init__(self,
                 base_dir = BASE_DIR,
                 num_classes = NUM_CLASSES,
                 batch_size = BATCH_SIZE,
                 height = HEIGHT,
                 width = WIDTH
                ):
        
        self.base_dir = BASE_DIR
        self.num_classes = float(num_classes)
        self.batch_size = batch_size
        self.height = HEIGHT
        self.width = WIDTH
        #self.images_list = []
        self.images_list = os.listdir(self.base_dir + "Training/images/")
        random.shuffle(self.images_list) # 데이터셋 셔플하기 위해
        
    def __del__(self):
        print("Dataset Generator is destructed")

 

1.2.1. Member Method

1. _preprocessor():

해당 함수로 DIARETDB 마스크를 정리하고, 마스크들을 통합한다. 또한 Crop(IDRiD dataset에 한하여), Padding, Resize 후, Training / Test dataset으로 분할하여 저장한다. 이는 아래 train_generator 함수에서 실행되도록 한다.

 

# 6. Padding 에서는

 

if (x_max-x_min)/2848 >= 1.25:

 

가 이해가 안 갈수 도 있으니 설명하도록 하겠다.

 

위에서 일반적인 데이터를 Crop하면 영상 크기가 (2848, 3390)가 된다고 하였는데, 이 때 높이에 대한 너비 비율이 대략 1.19정도이다. 하지만 (Figure3 -(b))와 같은 일부 데이터를 Crop한 크기는 (2848, 3800)으로 비율이 대략 1.33이다. 이들을 구분하기 위해서 그 사이값인 1.25를 경곗값(threshold)으로 잡아 패딩을 다르게 하였다.

    def _preprocessor(self):
        # 전처리한 파일을 저장할 폴더 생성
        try:
            os.mkdir(self.base_dir+"Training")
            os.mkdir(self.base_dir+"Test")
            os.mkdir(self.base_dir+"Training/images")
            os.mkdir(self.base_dir+"Test/images")
            os.mkdir(self.base_dir+"Training/masks")
            os.mkdir(self.base_dir+"Test/masks")
        except FileExistsError:
            pass
        
        idrid_cnt = diaretdb_cnt = 0 # 훈련 set, 테스트 set 분할

        # 파일명 정리
        image_list = os.listdir(self.base_dir + "image/")
        for i, file_name in enumerate(image_list):
            image_list[i] = file_name.split(".")[0]
        image_list.sort()

        # 마스크 클래스 별
        mask_class_dir = ["MA", "HE", "EX", "SE"]
        mask_file_list = []

        for cls in mask_class_dir:
            mask_file_list.append(os.listdir(self.base_dir + f"mask/{cls}"))

        zero_1 = np.zeros([2848, 4288], dtype = np.uint8)
        zero_2 = np.zeros([1152, 1500], dtype = np.uint8)

        loss_cnt = 0
        # 전처리
        for i, file_name in enumerate(image_list):
            if "IDRiD" in file_name:
                zero = zero_1
                thres = 1
            elif "image" in file_name:
                zero = zero_2
                # [63, 127, 189, 252]
                thres = 127

            mask_list = []
            # 1. mask  파일 찾기
            for cls in range(4):
                # 1.1. 파일명 확정
                if "IDRiD" in file_name:
                    mask_file_name = f"{file_name}_{mask_class_dir[cls]}.tif"
                elif "image" in file_name:
                    mask_file_name = f"{file_name}.png"

                # 1.2. 마스크 유무 확인
                if mask_file_name in mask_file_list[cls]:
                    mask = cv2.imread(f"{self.base_dir}mask/{mask_class_dir[cls]}/{mask_file_name}", 0)
                    _, mask = cv2.threshold(mask, thres, 1, cv2.THRESH_BINARY)
                else:
                    mask = zero
                mask_list.append(mask)

            # 2. 마스크 통합
            Class_1 = cv2.bitwise_or(mask_list[0], mask_list[1]) * 100
            Class_2 = cv2.bitwise_or(mask_list[2], mask_list[3]) * 200
            mask = Class_1 + Class_2
            del Class_1, Class_2, mask_list


            # 빈 마스크 확인
            if np.all(mask == zero):
                loss_cnt += 1
                print(f"{file_name} has no mask")
            else:
                # 5.0. Binaryzation
                # 5.1. 파일명 확정
                if "IDRiD" in file_name:
                    file_name = f"{file_name}.jpg"
                elif "image" in file_name:
                    file_name = f"{file_name}.png"
                img = cv2.imread(f"{self.base_dir}image/{file_name}")

                # 5.2 IDRiD 데이터셋에서만 crop 과정
                if "IDRiD" in file_name:
                    gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
                    if i==3 or i == 10:
                        thres = 10
                    else:
                        thres = 30

                    _, binary_img = cv2.threshold(gray_img, thres, 255, cv2.THRESH_BINARY)
                    del gray_img

                    # 5.2.1. contours
                    contours, hierachy = cv2.findContours(binary_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

                    # 5.2.2 x_min, x_max 찾기
                    x_min = np.min(contours[-1], axis = 0)
                    x_max = np.max(contours[-1], axis = 0)
                    x_min, x_max = x_min[0][0], x_max[0][0]
                    del contours, hierachy

                    # 5.2.3. Crop
                    img = img[:, x_min:x_max+1]
                    mask = mask[:, x_min:x_max+1]

                    # 6. Padding
                    if (x_max-x_min)/2848 >= 1.25:
                        pad_left, pad_right = 0, 0
                    else:
                        pad_left, pad_right = 200, 200

                    img = cv2.copyMakeBorder(img, 500, 500, pad_left, pad_right, cv2.BORDER_CONSTANT,value=0)
                    mask = cv2.copyMakeBorder(mask, 500, 500, pad_left, pad_right, cv2.BORDER_CONSTANT,value=0)

                # 5.3. diaretdb에서는 Padding만
                elif "image" in file_name:
                    img = cv2.copyMakeBorder(img, 174, 174, 0, 0, cv2.BORDER_CONSTANT,value=0)
                    mask = cv2.copyMakeBorder(mask, 174, 174, 0, 0, cv2.BORDER_CONSTANT,value=0)

                # 7. Resize
                img = cv2.resize(img, dsize=(self.height, self.width), interpolation=cv2.INTER_AREA)
                mask = cv2.resize(mask, dsize=(self.height, self.width), interpolation=cv2.INTER_AREA)

                # 8. 파일 저장
                # 훈련, 테스트셋 별도 저장
                if "IDRiD" in file_name and idrid_cnt < 60:
                    cv2.imwrite(f'{self.base_dir}Training/images/{file_name}', img)
                    cv2.imwrite(f'{self.base_dir}Training/masks/{file_name}', mask)
                    idrid_cnt += 1

                elif "IDRiD" in file_name and idrid_cnt >= 60:
                    cv2.imwrite(f'{self.base_dir}Test/images/{file_name}', img)
                    cv2.imwrite(f'{self.base_dir}Test/masks/{file_name}', mask)
                    idrid_cnt += 1

                elif "image" in file_name and diaretdb_cnt < 40:
                    cv2.imwrite(f'{self.base_dir}Training/images/{file_name}', img)
                    cv2.imwrite(f'{self.base_dir}Training/masks/{file_name}', mask)
                    diaretdb_cnt += 1

                elif "image" in file_name and diaretdb_cnt >= 40:
                    cv2.imwrite(f'{self.base_dir}Test/images/{file_name}', img)
                    cv2.imwrite(f'{self.base_dir}Test/masks/{file_name}', mask)
                    diaretdb_cnt += 1

                print(f"{file_name} completed!")
        self.images_list = os.listdir(self.base_dir + "Training/images/")
        random.shuffle(self.images_list)
        print(f"Preprocessing completed!. Number of no mask data : {loss_cnt}")

 

2. _Image_Reshaper():

- 영상 데이터 : (1024, 1024, 3) -> (1, 1024, 1024, 3) 차원으로 변환

- 마스크 데이터 : (1024, 1024, 1) -> (1, 1024, 1024, 1) 차원으로 변환

 

3. train_generator():

_preprocessor()로 전처리한 파일을 만든 후, Training 데이터 셋을 augmentation하고 yield한다.

 

4. valid_generator():

_preprocessor()로 전처리한 파일을 만든 후, Validatation 데이터 셋을 augmentation하고 yield한다.

3번 train_generator()와 함께 K-fold Cross validation하기 위해서, 입력 K를 설정하였다.

본 게시글에서는 5-fold를 진행할 것이므로, K의 범위는 1~5이다. 이때, 각 1,2,3,4,5일 때 Training 데이터셋에서 validation 데이터셋을 다르게 split해주어야하기에, 입력 k에 따라 범위를 다르게 주어 split하도록 하였다.

generator를 사용하지 않는다면 (1024, 1024, 3) 크기의 영상 데이터가 몇 천장으로 존재하면 몇 십 GB는 우습게 잡아먹기에 RAM에서 Out-Of-Memory가 발생한다. 그래서 이를 방지하기 위해 하나씩 빼주는 generator 사용은 필수이다.

 

train_generator와 valid_generator 코드를 보면 아래와 같은 코드가 있다.

for _ in range(self.epochs):
    for i, file_name in enumerate(self.images_list):

 

이것이 의미하는 것은 다음 설명과 같다. tensorflow에서 모델을 학습 시에 model.fit 함수를 사용하는데 이때, 모델에 들어가는 데이터셋이 epoch까지 고려해서 넣어주어야하기에 반복문으로 정해진 epoch만큼 설정 해줘야한다.

만약 for _ in range(self.epochs): 가 존재하지 않으면, 설령 모델에서 1 epochs 초과로 학습하도록 설정하더라도 1 epoch가 끝나면 더 이상 generator로 yield를 할 수 없어 StopIteration가 발생하고 실행한 코드가 정지한다.

만약 1 epoch만 학습하고 싶다면 굳이 for _ in range(self.epochs):를 사용하지 않아도 된다.

 

 

5. test_generator():

Test 데이터 셋을 yield한다.

    def _Image_Reshape(self, image, mask):
        image = np.reshape(image, ((self.batch_size,) + image.shape))
        mask = np.reshape(mask, ((self.batch_size,) + mask.shape))
        return (image/255, mask/200.)
    
    def train_generator(self, k):
        """
        Training Data Augmentation
        """
        # 전처리 했으면 skip, 전처리 안했으면 Go!
        if self.images_list:
            pass
        else:
            self._preprocessor()
        x_center, y_center = self.width/2, self.height/2

        for _ in range(self.epochs):
            for i, file_name in enumerate(self.images_list):
                if 20*k-20 <= i < 20*k:
                    pass
                else:
                    # 원본 이미지
                    img = cv2.imread(f"{self.base_dir}Training/images/{file_name}")
                    mask = cv2.imread(f"{self.base_dir}Training/masks/{file_name}", 0)
                    yield self._Image_Reshape(img, mask)

                    # 좌우 반전
                    flip_img = cv2.flip(img, 1)
                    flip_mask = cv2.flip(mask, 1)
                    yield self._Image_Reshape(flip_img, flip_mask)

                    for degree in range(20, 360, 20):
                        matrix = cv2.getRotationMatrix2D((x_center, y_center), degree, 1)
                            
                        # 원본 이미지에 대한 회전
                        rot_img = cv2.warpAffine(img, matrix, (self.width, self.height))
                        rot_mask = cv2.warpAffine(mask, matrix, (self.width, self.height))
                        yield self._Image_Reshape(rot_img, rot_mask)

                        # filp 이미지에 대한 회전
                        rot_flip_img = cv2.warpAffine(flip_img, matrix, (self.width, self.height))
                        rot_flip_mask = cv2.warpAffine(flip_mask, matrix, (self.width, self.height))
                        yield self._Image_Reshape(rot_flip_img, rot_flip_mask)

    def valid_generator(self, k):
        """
        Validataion Data Augmentation
        """
        x_center, y_center = self.width/2, self.height/2
        for _ in range(self.epochs):
            for i, file_name in enumerate(self.images_list):
                if (20*k-20) <= i < 20*k:
                    # 원본 이미지
                    img = cv2.imread(f"{self.base_dir}Training/images/{file_name}")
                    mask = cv2.imread(f"{self.base_dir}Training/masks/{file_name}", 0)
                    yield self._Image_Reshape(img, mask)

                    # 좌우 반전
                    flip_img = cv2.flip(img, 1)
                    flip_mask = cv2.flip(mask, 1)
                    yield self._Image_Reshape(flip_img, flip_mask)

                    for degree in range(20, 360, 20):
                        matrix = cv2.getRotationMatrix2D((x_center, y_center), degree, 1)
                            
                        # 원본 이미지에 대한 회전
                        rot_img = cv2.warpAffine(img, matrix, (self.width, self.height))
                        rot_mask = cv2.warpAffine(mask, matrix, (self.width, self.height))
                        yield self._Image_Reshape(rot_img, rot_mask)

                        # filp 이미지에 대한 회전
                        rot_flip_img = cv2.warpAffine(flip_img, matrix, (self.width, self.height))
                        rot_flip_mask = cv2.warpAffine(flip_mask, matrix, (self.width, self.height))
                        yield self._Image_Reshape(rot_flip_img, rot_flip_mask)

                
    def test_generator(self):
        images_list = os.listdir(self.base_dir + "Test/images/")
        for i, file_name in enumerate(images_list):
            # 원본 이미지
            img = cv2.imread(f"{self.base_dir}Test/images/{file_name}")
            mask = cv2.imread(f"{self.base_dir}Test/masks/{file_name}", 0)
            yield self._Image_Reshape(img, mask)

 

아래는 테스트를 위한 실행 코드이다. 어디까지나 테스트 코드임을 감안하자.

 

실제로 학습시에 사용되는 코드는 다음 게시글에서 설명하도록 하겠다.

## Training dataset check
dataset = Dataset_Generator()

gen = dataset.train_generator(1)
for i in range(720 * 4):
    result = next(gen)
    img, mask = result[0][0], result[1][0]

    fig = plt.figure(i, figsize = (10,10))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.imshow(img)
    ax1.set_title('Image')
    ax1.axis("off")

    ax2 = fig.add_subplot(1, 2, 2)
    ax2.imshow(mask, cmap = "bone")
    ax2.set_title('Ground Truth Mask')
    ax2.axis("off")
    if i == 3:
        break
## Validation dataset check
gen = dataset.valid_generator(1)
for i in range(720):
    result = next(gen)
    img, mask = result[0][0], result[1][0]
    fig = plt.figure(i, figsize = (10,10))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.imshow(img)
    ax1.set_title('Image')
    ax1.axis("off")

    ax2 = fig.add_subplot(1, 2, 2)
    ax2.imshow(mask, cmap = "bone")
    ax2.set_title('Ground Truth Mask')
    ax2.axis("off")
    if i == 3:
        break
## Test dataset check
test_dataset = tf.data.Dataset.from_generator(
                    dataset.test_generator,
                    (tf.float32, tf.int32),
                    (tf.TensorShape([1, HEIGHT, WIDTH, 3]), tf.TensorShape([1,  HEIGHT, WIDTH])),
                    )

del dataset

### Check Test set test
for i, element in enumerate(test_dataset):
    img = element[0][0].numpy()
    mask = element[1][0].numpy()
    fig = plt.figure(i, figsize = (10,10))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.imshow(img)
    ax1.set_title('Image')
    ax1.axis("off")

    ax2 = fig.add_subplot(1, 2, 2)
    ax2.imshow(mask, cmap = "bone")
    ax2.set_title('Ground Truth Mask')
    ax2.axis("off")
    
    if i == 3:
        break
    plt.show()

del gen, test_dataset

 

 

확인 결과 스크린샷

Figure 6은 전체 Training 데이터셋에서 5-fold 중 4-fold에 대한 데이터 샘플이다. train_generator에서 _preprocessor가 실행되어 출력된 영상 위에 데이터 전처리한 로그가 출력된다.

Figure 7은 5-fold 중에서 validation으로 사용할 나머지 1-fold 데이터 샘플이다.

Figure 8은 테스트 데이터셋을 출력한 것이다.

 

영상이 파란색인 이유는 opencv로 데이터를 read할 때는 RGB가 아닌 BGR로 하기에 파란색으로 나온다.

 

Figure 6. 4-fold of 5-fold Training dataset sample

 

Figure 7. 1-fold of 5-fold validation dataset sample

 

Figure 8. Test dataset sample

 

다음 장에서는 U-Net과 VGG16, ResNet50, EfficientNet B0, 손실 함수에 대해 간략한 설명과 구현한 코드를 설명하도록 하겠다.

 


 

Reference

[1] “DIARETDB1 diabetic retinopathy database and evaluation protocol", “Kauppi, Tomi, Kalesnykiene, Valentin”, 2007, Proc. Medical Image Understanding and Analysis (MIUA)


 

 

 

 

 

 

 


"""

수정 전 코드

'21.05.11 수정 전 코드

"""

# 해당 아래 코드는 데이터셋을 새로 추가하고, K-fold 교차 검증하기 의 데이터셋 생성 코드

# IDRiD dataset만 해당.

 

해당 코드는 class로 구현했으며, 멤버 변수에 대한 설명은 읽어보면 이해되므로 생략하도록 한다.

(사실 이렇게 지저분하게 코딩하면 안 되는데, 테스트하고 그대로 그냥 복붙하다보니 지저분하게 되었다...)

 

코드 길이가 매우 길어 멤버 변수 / 마스크 통합 메소드 / 데이터 증식 메소드로 분할 하였다.

 

1.2.1. Member Variable

멤버 변수는 주로 파일 경로 설정과 파일 이름을 list로 저장한 것이 대부분이다.

class Dataset_Generator():
    def __init__(self,
                 base_dir = BASE_DIR,
                 num_classes = NUM_CLASSES,
                 batch_size = BATCH_SIZE,
                 height = HEIGHT,
                 width = WIDTH
                ):
        
        self.base_dir = BASE_DIR
        self.num_classes = float(num_classes)
        self.batch_size = batch_size
        self.height = HEIGHT
        self.width = WIDTH
        
        self.train_list = os.listdir(self.base_dir + "1. Original Images/a. Training Set/")
        for i, file_name in enumerate(self.train_list):
            self.train_list[i] = file_name.split(".")[0]
        self.train_list.sort()

        self.test_list = os.listdir(self.base_dir + "1. Original Images/b. Testing Set/")
        for i, file_name in enumerate(self.test_list):
            self.test_list[i] = file_name.split(".")[0]
        self.test_list.sort()

        self.Train_GT_BASE_DIR = "2. All Segmentation Groundtruths/a. Training Set/"
        self.Test_GT_BASE_DIR = "2. All Segmentation Groundtruths/b. Testing Set/"
        self.class_1_1 = "1. Microaneurysms/"
        self.class_1_2 = "2. Haemorrhages/"
        self.class_2_1 = "3. Hard Exudates/"
        self.class_2_2 = "4. Soft Exudates/"
        self.class_3 = "5. Optic Disc/"

        # Train
        self.Train_MA_list = os.listdir(self.base_dir + self.Train_GT_BASE_DIR + self.class_1_1)
        self.Train_HE_list = os.listdir(self.base_dir + self.Train_GT_BASE_DIR + self.class_1_2)
        self.Train_EX_list = os.listdir(self.base_dir + self.Train_GT_BASE_DIR + self.class_2_1)
        self.Train_SE_list = os.listdir(self.base_dir + self.Train_GT_BASE_DIR + self.class_2_2)
        self.Train_OD_list = os.listdir(self.base_dir + self.Train_GT_BASE_DIR + self.class_3)

        # Test
        self.Test_MA_list = os.listdir(self.base_dir + self.Test_GT_BASE_DIR + self.class_1_1)
        self.Test_HE_list = os.listdir(self.base_dir + self.Test_GT_BASE_DIR + self.class_1_2)
        self.Test_EX_list = os.listdir(self.base_dir + self.Test_GT_BASE_DIR + self.class_2_1)
        self.Test_SE_list = os.listdir(self.base_dir + self.Test_GT_BASE_DIR + self.class_2_2)
        self.Test_OD_list = os.listdir(self.base_dir + self.Test_GT_BASE_DIR + self.class_3)
        
    def __del__(self):
        print("Dataset Generator is destructed")

 

 

1.2.2. Mask Code

해당 메소드는 1.1.1 ~ 1.1.2에서 설명한 하나의 마스크로 통합하는 과정이다. 읽어보면 이해가 되겠지만,

# 6. Padding 에서는

 

if (x_max-x_min)/2848 >= 1.25:

 

가 이해가 안 갈수 도 있으니 설명하도록 하겠다.

 

위에서 일반적인 데이터를 Crop하면 영상 크기가 (2848, 3390)가 된다고 하였는데, 이 때 높이에 대한 너비 비율이 대략 1.19정도이다. 하지만 (Figure3 -(b))와 같은 일부 데이터를 Crop한 크기는 (2848, 3800)으로 비율이 대략 1.33이다. 이들을 구분하기 위해서 그 사이값인 1.25를 경곗값(threshold)으로 잡아 패딩을 다르게 하였다.

    # 1. Mask 하나로 합치기
    ## # 0. Background, 1, Exudatas(Hard + Soft), 2. Hemorrhages + Microaneurysms, 3. Optic disc
    def _generator_mask(self, Training_or_Test):
        if Training_or_Test == "Training":
            MA_list = self.Train_MA_list
            HE_list = self.Train_HE_list
            EX_list = self.Train_EX_list
            SE_list = self.Train_SE_list
            OD_list = self.Train_OD_list
            GT_BASE_DIR = self.Train_GT_BASE_DIR
            path = "a. Training Set"
            data_list = self.train_list

        elif Training_or_Test == "Test":
            MA_list = self.Test_MA_list
            HE_list = self.Test_HE_list
            EX_list = self.Test_EX_list
            SE_list = self.Test_SE_list
            OD_list = self.Test_OD_list
            GT_BASE_DIR = self.Test_GT_BASE_DIR
            path = "b. Testing Set"
            data_list = self.test_list
        
        for i, file_name in enumerate(data_list):
            # 0. Class 0 배경 공간 생성 -> 0
            Class_0 = np.zeros([2848, 4288], dtype = np.uint8)
            
            
            
            # 1. Class 1 찾기 -> 1
            # 1.1. MA 찾기
            # 동일 이름에 대한 마스크가 있을 때
            if file_name + "_MA.tif" in MA_list:
                MA = cv2.imread(self.base_dir + GT_BASE_DIR + self.class_1_1 + file_name + "_MA.tif", 0)
                _, MA = cv2.threshold(MA, 1, 1, cv2.THRESH_BINARY)
            # 동일 이름에 대한 마스크가 없을 때 -> 0으로 된 빈 공간 생성
            else:
                MA = np.zeros([2848, 4288], dtype = np.uint8)

            # 1.2. HE 찾기
            # 동일 이름에 대한 마스크가 있을 때
            if file_name + "_HE.tif" in HE_list:
                HE = cv2.imread(self.base_dir + GT_BASE_DIR + self.class_1_2 + file_name + "_HE.tif", 0)
                _, HE = cv2.threshold(HE, 1, 1, cv2.THRESH_BINARY)
            # 동일 이름에 대한 마스크가 없을 때 -> 0으로 된 빈 공간 생성
            else:
                HE = np.zeros([2848, 4288], dtype = np.uint8)

            # 1.3. 합치기
            Class_1 = cv2.bitwise_or(MA, HE)


            
            # 2. Class 2 -> 2
            # 2.1. EX 찾기
            if file_name + "_EX.tif" in EX_list:
                EX = cv2.imread(self.base_dir + GT_BASE_DIR + self.class_2_1 + file_name + "_EX.tif", 0)
                _, EX = cv2.threshold(EX, 1, 1, cv2.THRESH_BINARY)
            else:
                EX = np.zeros([2848, 4288], dtype = np.uint8)

            # 2.2. SE 찾기
            if file_name + "_SE.tif" in SE_list:
                SE = cv2.imread(self.base_dir + GT_BASE_DIR + self.class_2_2 + file_name + "_SE.tif", 0)
                _, SE = cv2.threshold(SE, 1, 1, cv2.THRESH_BINARY)
            else:
                SE = np.zeros([2848, 4288], dtype = np.uint8)

            # 2.3. 합치기
            Class_2 = cv2.bitwise_or(EX, SE) * 2


            
            # 3. Class 3 찾기 -> 3
            if file_name + "_OD.tif" in OD_list:
                Class_3 = cv2.imread(self.base_dir + GT_BASE_DIR + self.class_3 + file_name + "_OD.tif", 0)
                _, Class_3 = cv2.threshold(Class_3, 1, 1, cv2.THRESH_BINARY)
                Class_3 = Class_3 * 3
            else:
                Class_3 = np.zeros([2848, 4288], dtype = np.uint8)
            
            
            # 4. Mask 합치기
            mask = Class_0 + Class_1 + Class_2 + Class_3
            
            del Class_0
            del Class_1
            del Class_2
            del Class_3
            
            # 5. Crop
            if Training_or_Test == "Training" and (i==3 or i == 10):
                thres = 10
            else:
                thres = 30
            
            # 5.0. Binaryzation
            img = cv2.imread(self.base_dir + f"1. Original Images/{path}/" + file_name + ".jpg")
            gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            _, binary_img = cv2.threshold(gray_img, thres, 255, cv2.THRESH_BINARY)

            # 5.1. contours
            contours, hierachy = cv2.findContours(binary_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

            # 5.2. x_min, x_max 찾기
            x_min = np.min(contours[-1], axis = 0)
            x_max = np.max(contours[-1], axis = 0)
            x_min, x_max = x_min[0][0], x_max[0][0]

            # 5.3. Crop
            img = img[:, x_min:x_max+1]
            mask = mask[:, x_min:x_max+1]
            
            # 6. Padding
            if (x_max-x_min)/2848 >= 1.25:
                pad_left, pad_right = 0, 0
            else:
                pad_left, pad_right = 200, 200

            img = cv2.copyMakeBorder(img, 500, 500, pad_left, pad_right, cv2.BORDER_CONSTANT,value=0)
            mask = cv2.copyMakeBorder(mask, 500, 500, pad_left, pad_right, cv2.BORDER_CONSTANT,value=0)

            # 7. Resize
            img = cv2.resize(img, dsize=(self.height, self.width), interpolation=cv2.INTER_AREA)
            mask = cv2.resize(mask, dsize=(self.height, self.width), interpolation=cv2.INTER_AREA)
            yield img, mask
            

 

 

1.2.3. Data Augmentation

해당 메소드는 1.1.3.에서 설명한 데이터 증식하는 코드이며, 추가적으로 학습 데이터셋 생성으로 위해 데이터를 reshape 하는 과정이 있다. def _Image_Reshape이 여기에 해당되는 함수이며, 이는

 

- 영상 데이터 : (1024, 1024, 3) -> (1, 1024, 1024, 3) 차원으로 변환

- 마스크 데이터 : (1024, 1024, 1) -> (1, 1024, 1024, 1) 차원으로 변환

 

def train_preprocessor(self):와 def test_preprocessor(self):는 사실상 같은 기능이며, 알아보기 쉽게 그냥 구분 해놓았다. 이는 데이터를 증식하는 함수이다.

   
    def _Image_Reshape(self, image, mask):
        image = np.reshape(image, ((self.batch_size,) + image.shape))
        #mask = mask[..., np.newaxis]
        mask = np.reshape(mask, ((self.batch_size,) + mask.shape))
        return (image/255, mask)
    
    
    def train_preprocessor(self):
        x_center, y_center = self.width/2, self.height/2
        gen_mask = self._generator_mask("Training")
        while True:
            try:
                # 원본 이미지
                img, mask = next(gen_mask)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

                # 좌우 반전
                flip_img = cv2.flip(img, 1)
                flip_mask = cv2.flip(mask, 1)

                for degree in range(5, 360, 5):
                    matrix = cv2.getRotationMatrix2D((x_center, y_center), degree, 1)
                    
                    # 원본 이미지에 대한 회전
                    rot_img = cv2.warpAffine(img, matrix, (self.width, self.height))
                    rot_mask = cv2.warpAffine(mask, matrix, (self.width, self.height))
                    yield self._Image_Reshape(rot_img, rot_mask)

                    # filp 이미지에 대한 회전
                    rot_flip_img = cv2.warpAffine(flip_img, matrix, (self.width, self.height))
                    rot_flip_mask = cv2.warpAffine(flip_mask, matrix, (self.width, self.height))
                    yield self._Image_Reshape(rot_flip_img, rot_flip_mask)

                # 원본 이미지
                yield self._Image_Reshape(img, mask)

                # filp 이미지
                yield self._Image_Reshape(flip_img, flip_mask)
                
            except StopIteration:
                break

    def test_preprocessor(self):
        x_center, y_center = self.width/2, self.height/2
        gen_mask = self._generator_mask("Test")
        while True:
            try:
                # 원본 이미지
                img, mask = next(gen_mask)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                
                # 좌우 반전
                flip_img = cv2.flip(img, 1)
                flip_mask = cv2.flip(mask, 1)

                for degree in range(5, 360, 5):
                    matrix = cv2.getRotationMatrix2D((x_center, y_center), degree, 1)
                    
                    # 원본 이미지에 대한 회전
                    rot_img = cv2.warpAffine(img, matrix, (self.width, self.height))
                    rot_mask = cv2.warpAffine(mask, matrix, (self.width, self.height))
                    yield self._Image_Reshape(rot_img, rot_mask)

                    # filp 이미지에 대한 회전
                    rot_flip_img = cv2.warpAffine(flip_img, matrix, (self.width, self.height))
                    rot_flip_mask = cv2.warpAffine(flip_mask, matrix, (self.width, self.height))
                    yield self._Image_Reshape(rot_flip_img, rot_flip_mask)

                # 원본 이미지
                yield self._Image_Reshape(img, mask)

                # filp 이미지
                yield self._Image_Reshape(flip_img, flip_mask)
            
            except StopIteration:
                break

 

아래 코드는 데이터셋으로 생성하는 코드이다.

위의 두 코드와 별도임을 기억하자.

dataset = Dataset_Generator()

train_dataset = tf.data.Dataset.from_generator(
                    dataset.train_preprocessor,
                    (tf.float32, tf.int32),
                    (tf.TensorShape([1, HEIGHT, WIDTH, 3]), tf.TensorShape([1,  HEIGHT, WIDTH])),
                    )

test_dataset = tf.data.Dataset.from_generator(
                    dataset.test_preprocessor,
                    (tf.float32, tf.int32),
                    (tf.TensorShape([1, HEIGHT, WIDTH, 3]), tf.TensorShape([1,  HEIGHT, WIDTH])),
                    )

del dataset