본문 바로가기

개발/algorithm

[백준 2096번] 내려가기 - python

문제 링크

https://www.acmicpc.net/problem/2096

 

2096번: 내려가기

첫째 줄에 N(1 ≤ N ≤ 100,000)이 주어진다. 다음 N개의 줄에는 숫자가 세 개씩 주어진다. 숫자는 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 중의 하나가 된다.

www.acmicpc.net

풀이

- dp 문제

- 메모제이션으로 풀었는데 메모리 초과 발생 

- 현재 상태만 배열에 저장해두고 값을 갱신해야 한다고 한다. 

- 사용한 아이디어는

제일 왼쪽 인덱스 일 경우 바로 위 숫자와 위 숫자 바로 오른쪽 숫자 중에 최댓값 또는 최솟값 더하기

가운데 인덱스일 경우, 바로 위 숫자, 위 숫자 바로 왼쪽, 오른쪽 숫자 중에 최댓값 또는 최솟값 더하기

제일 오른쪽 인덱스 일 경우, 바로 위 숫자와 위 숫자 바로 왼쪽 숫자 중에 최댓값 또는 최솟값 더하기 

# 메모리 초과 발생한 코드 
import sys
input = sys.stdin.readline 

n = int(input())

graph = []

for _ in range(n):
  graph.append(list(map(int,input().split())))

# 최댓값, 최솟값 메모제이션 할 배열 
max_dp = [ [0] * n for _ in range(n)]
min_dp =  [ [0] * n for _ in range(n)]

for i in range(n):
  max_dp[0][i] = graph[0][i]
  min_dp[0][i] = graph[0][i]

for i in range(1,n):
  for j in range(n):
    if j == 0 :
      max_dp[i][j] = max(max_dp[i-1][j], max_dp[i-1][j+1]) + graph[i][j]
      min_dp[i][j] = min(min_dp[i-1][j], min_dp[i-1][j+1]) + graph[i][j]
    elif j == n-1:
      max_dp[i][j] = max(max_dp[i-1][j], max_dp[i-1][j-1]) + graph[i][j]
      min_dp[i][j] = min(min_dp[i-1][j], min_dp[i-1][j-1]) + graph[i][j]
    else:
      max_dp[i][j] = max(max_dp[i-1][j], max_dp[i-1][j-1], max_dp[i-1][j+1]) + graph[i][j]
      min_dp[i][j] = min(min_dp[i-1][j], min_dp[i-1][j-1], min_dp[i-1][j+1]) + graph[i][j]

print(max(max_dp[n-1]), min(min_dp[n-1]))

- max_dp와 min_dp에 현재값만 저장 

- dp에서 메모리 초과를 줄일 수 있는 방법을 배웠다. 

# 수정한 코드 
import sys
input = sys.stdin.readline 

n = int(input())

max_dp =  [0] * 3
min_dp =  [0] * 3

# 임시로 저장해둘 배열 
max_tmp =  [0] * 3
min_tmp =  [0] * 3

for i in range(n):
  a, b, c = map(int,input().split())
  for j in range(3):
    if j == 0 :
      max_tmp[j] = a + max(max_dp[j], max_dp[j+1]) 
      min_tmp[j] = a + min(min_dp[j], min_dp[j+1])
    elif j == 1:
      max_tmp[j] = b + max(max_dp[j], max_dp[j-1], max_dp[j+1])
      min_tmp[j] = b + min(min_dp[j], min_dp[j-1], min_dp[j+1]) 
    else:
      max_tmp[j] = c + max(max_dp[j], max_dp[j-1])
      min_tmp[j] = c + min(min_dp[j], min_dp[j-1])

  for j in range(3):
    max_dp[j] = max_tmp[j]
    min_dp[j] = min_tmp[j]

print(max(max_dp), min(min_dp))