Empirical Risk Minimization (ERM)
머신러닝 모델은 주어진 데이터를 이용해 예측 같은 임무(task)를 수행한다. 그런데 모델이 ‘잘 했다’라는 것을 어떻게 알 수 있을까??
이를 위해 우리는 보통 학습 시에 손실 함수(loss function) 를 정의한다. 이는 모델이 얼마나 못했는지를 수치화한 함수다. 예측이 완벽하면 손실은 보통 0, 예측이 틀릴수록 손실 값이 커진다. 모델 학습의 목표는 이 손실 함수를 최소화하는 방향으로 파라미터를 조정하는 것이다. 이 손실 함수를 목적 함수(objective function)라고도 부른다.
수식과 이해
우리가 사실 진짜로 최소화하고 싶은 건 데이터가 속한 진짜 분포 또는 내가 앞으로 예측할 테스트 데이터에 대한 기대 손실이다. 하지만 우리는 진짜 분포, 즉 데이터가 생성되는 메커니즘 자체를 알 수는 없다. (신이 아닌 이상) 가지고 있는 건 트레이닝 데이터(관측한 데이터) 들이고, 그래서 보통은 훈련 데이터 에 대해 평균 손실을 계산한다.
이를 경험적 위험(Empirical Risk) 이라고 하고 아래와 같은 수식으로 정의한다.
-
: 파라미터 를 가진 모델
-
: 예측값과 실제 값의 차이를 측정하는 손실 함수
가정
즉, ERM은 모든 트레이닝 데이터에 대해 얻은 손실의 평균을 최소화하는 방법이다. 즉 훈련 데이터에서 얻은 로스가 실제 로스의 좋은 추정량이 된다고 믿는단 건데,
따라서 ERM에는 중요한 가정이 깔려 있다. 바로 훈련 데이터가 진짜 분포를 대표한다고 믿는다는 것이다.
즉 훈련 데이터에 대해 손실을 최소화하면 곧 실제 분포에서도 성능이 좋아질 것이라고 기대한다는 것. 따라서 훈련한 데이터와 실제로 적용할 테스트 데이터의 분포가 다른 상황(Covariate Shift)이 생긴다거나, 데이터가 동일한 분포에서 독립적으로 추출되었다는 i.i.d 가정이 깨지면 ERM은 더 이상 실제 성능을 보장하지 못할 수 있다.
회귀분석 예시
단순 선형 회귀를 생각해보자. 우리는 설명변수 로부터 값을 예측하는 모델 를 세운다.
그리고 이 를 추정하기 위해 최소제곱법(Ordinary Least Squares)을 사용했음을 기억한다면, 이것 역시 ERM과 관련해서 이해해볼 수 있을 것이다.
우리는 OLS 추정을 위해 아래와 같은 식을 최소화했다. 무엇에 대해? 훈련 데이터에 대해
그러면 우리가 최소화하는 값은 결국
과 같다.
즉, 회귀분석에서 배우는 최소제곱법은 주어진 훈련 데이터에 대한 평균 제곱 오차(손실)을 최소화하는 것, ERM의 한 구체적인 예시인 셈이다.
여러 머신러닝, 딥러닝의 학습에서도 훈련데이터에 대해 로스를 최소화하고 있다면, ERM을 이용한다는 기본적인 원리는 같다.