본문 바로가기

개발/algorithm

[코드트리] 포탑 부수기 - Python

문제 링크 

https://www.codetree.ai/training-field/frequent-problems/problems/destroy-the-turret/description?page=1&pageSize=20 

 

코드트리 | 코딩테스트 준비를 위한 알고리즘 정석

국가대표가 만든 코딩 공부의 가이드북 코딩 왕초보부터 꿈의 직장 코테 합격까지, 국가대표가 엄선한 커리큘럼으로 준비해보세요.

www.codetree.ai

 

풀이 

아래 과정을 부서지지않은 포탑이 1개 이상 남았을 때까지 K번 반복한다. 

1. 공격자 선정 

조건에 맞게 tower 배열 대입한 후 정렬해서 구해주었다. 

공격력이 작은 순, 최근 공격한 순, 행과 열의 합이 작은 순, 열이 작은 순 

최근 공격한 순은 몇 번째 turn 인지 graph에 함께 넣어주어 turn이 큰 게 최근에 공격한 포탑이 된다. 

def find_weak():
    s_x, s_y = -1,-1
    tower = []
    for i in range(n):
        for j in range(m):
            if len(graph[i][j]) > 0 :
            	# 공격력, 최근 공격한 turn, 행과 열의 합, 열 
                tower.append((graph[i][j][0], -graph[i][j][1], -(i+j), -j))

    tower.sort()
    s_x = -(tower[0][2] - tower[0][3])
    s_y = - tower[0][3]
    
    return s_x, s_y

2. 공격자의 공격 

2-1) 공격할 포탑 찾기 

1번이랑 비슷하게 조건에 맞게 tower 배열에 대입한 후 정렬해서 구해주었다. 

여기서 공격자가 되면 안되므로, check 배열로 확인해준다. 

def find_strong():
    s_x, s_y = -1,-1
    tower = []
    for i in range(n):
        for j in range(m):
            if len(graph[i][j]) and not check[i][j] :
                tower.append((-graph[i][j][0], graph[i][j][1], (i+j), j))
    tower.sort()
    s_x = tower[0][2] - tower[0][3]
    s_y = tower[0][3]
    
    return s_x, s_y

2-2) 레이저 공격

route 배열에 경로와 걸린 시간 넣어주고, 정렬해서 이동할 수 있다면 레이저 공격 적용 

이동할 수 없다면 2-3) 포탄 공격 함수 호출 

dx = [0,1,0,-1]
dy = [1,0,-1,0]

def laser(w_x,w_y, s_x,s_y):
    q = deque()
    visited = [[False] * m for _ in range(n)]
    
    q.append((w_x,w_y,[]))
    visited[w_x][w_y] = True 
    route = []
    
    while q :
        x, y, r= q.popleft()
        
        for i in range(4):
        	# 반대쪽으로 나옴 
            nx = (x + dx[i])%n
            ny = (y + dy[i])%m
            
            if len(graph[nx][ny]) == 0  :
                continue 
        
            if not visited[nx][ny] :
                tmp = copy.deepcopy(r)
                # 도착했으면 route 배열에 넣어줌 
                if nx == s_x and ny == s_y :
                    route.append((len(tmp), tmp))
                tmp.append(i)
                visited[nx][ny] = True 
                q.append((nx,ny,tmp))
    
    # 이동할 수 없는 경우 
    if len(route) == 0 :
        bomb(w_x,w_y,s_x,s_y)
        return 

    route.sort()
    tmp_x = w_x
    tmp_y = w_y
    # 최단 거리 경로에 있는 포탑에 공격 적용 
    for d in route[0][1]:
    	# 반대쪽으로 나옴 
        tmp_x = (tmp_x + dx[d]) % n 
        tmp_y = (tmp_y + dy[d]) % m 
        # 공격할 포탑이면 
        if tmp_x == s_x and tmp_y == s_y:
            graph[s_x][s_y][0] -= graph[w_x][w_y][0]
            # 부서진 포탑 
            if graph[s_x][s_y][0] <= 0 :
                graph[s_x][s_y] = [] 
        # 경로에 있는 포탑이면 
        else :
        	# 영향 받은 포탑임을 표시 
            check[tmp_x][tmp_y] = True 
            graph[tmp_x][tmp_y][0] -= graph[w_x][w_y][0] // 2 
            # 부서진 포탑 
            if graph[tmp_x][tmp_y][0] <= 0 :
                graph[tmp_x][tmp_y] = []

2-2) 포탄 공격

n_dx = [1,-1,0,0,1,1,-1,-1]
n_dy = [0,0,1,-1,1,-1,1,-1]

def bomb(w_x,w_y,s_x,s_y):
	# 공격할 포탑에 포탄 던지기 
    graph[s_x][s_y][0] -= graph[w_x][w_y][0]
    if graph[s_x][s_y][0] <= 0 :
        graph[s_x][s_y] = [] 
    
    # 8개 방향 공격 
    for i in range(8):
        nx = (s_x + n_dx[i])%n
        ny = (s_y + n_dy[i])%m
        
        # 부서진 포탑이면 넘어감 
        if len(graph[nx][ny]) == 0  :
            continue 
        
        # 공격자거나 공격할 포탑이면 넘어감 
        if check[nx][ny] :
            continue 
        # 영향 받은 포탑임을 표시 
        check[nx][ny] = True
        graph[nx][ny][0] -= graph[w_x][w_y][0] // 2 
        # 부서진 포탑 
        if graph[nx][ny][0] <= 0 :
            graph[nx][ny] = []

3. 정비 

영향을 받은 포탑이 아니고, 포탑이 존재하면 + 1 

def maintain():
    for i in range(n):
        for j in range(m):
            if not check[i][j] and len(graph[i][j] ):
                graph[i][j][0] +=1

 

전체 코드는 다음과 같다. 

graph : 해당 위치에 있는 포탑의 ( 공격력, 최근 공격한 turn) 저장 

check : 해당 Turn에 영향을 받은 포탑인지 체크 

from collections import deque
import copy 

n, m, k = map(int,input().split())

graph = [ [[] for _ in range(m)] for _ in range(n)]

for i in range(n):
    data = list(map(int,input().split()))
    for j in range(m):
        if data[j]:
            graph[i][j] = [data[j],0]
            
check = [ [False] * m for _ in range(n)]  

def find_weak():
    s_x, s_y = -1,-1
    tower = []
    for i in range(n):
        for j in range(m):
            if len(graph[i][j]) > 0 :
                tower.append((graph[i][j][0], -graph[i][j][1], -(i+j), -j))

    tower.sort()
    s_x = -(tower[0][2] - tower[0][3])
    s_y = - tower[0][3]
    
    return s_x, s_y 

def find_strong():
    s_x, s_y = -1,-1
    tower = []
    for i in range(n):
        for j in range(m):
            if len(graph[i][j]) and not check[i][j] :
                tower.append((-graph[i][j][0], graph[i][j][1], (i+j), j))
    tower.sort()
    s_x = tower[0][2] - tower[0][3]
    s_y = tower[0][3]
    
    return s_x, s_y 

dx = [0,1,0,-1]
dy = [1,0,-1,0]

def laser(w_x,w_y, s_x,s_y):
    q = deque()
    visited = [[False] * m for _ in range(n)]
    
    q.append((w_x,w_y,[]))
    visited[w_x][w_y] = True 
    route = []
    
    while q :
        x, y, r= q.popleft()
        
        for i in range(4):
            nx = (x + dx[i])%n
            ny = (y + dy[i])%m
            
            if len(graph[nx][ny]) == 0  :
                continue 
        
            if not visited[nx][ny] :
                tmp = copy.deepcopy(r)
                if nx == s_x and ny == s_y :
                    route.append((len(tmp), tmp))
                tmp.append(i)
                visited[nx][ny] = True 
                q.append((nx,ny,tmp))
    
    if len(route) == 0 :
        bomb(w_x,w_y,s_x,s_y)
        return 

    route.sort()
    tmp_x = w_x
    tmp_y = w_y
    for d in route[0][1]:
        tmp_x = (tmp_x + dx[d]) % n 
        tmp_y = (tmp_y + dy[d]) % m 
        if tmp_x == s_x and tmp_y == s_y:
            graph[s_x][s_y][0] -= graph[w_x][w_y][0]
            if graph[s_x][s_y][0] <= 0 :
                graph[s_x][s_y] = [] 
        else :
            check[tmp_x][tmp_y] = True 
            graph[tmp_x][tmp_y][0] -= graph[w_x][w_y][0] // 2 
            if graph[tmp_x][tmp_y][0] <= 0 :
                graph[tmp_x][tmp_y] = []

n_dx = [1,-1,0,0,1,1,-1,-1]
n_dy = [0,0,1,-1,1,-1,1,-1]

def bomb(w_x,w_y,s_x,s_y):
    graph[s_x][s_y][0] -= graph[w_x][w_y][0]
    if graph[s_x][s_y][0] <= 0 :
        graph[s_x][s_y] = [] 
    
    for i in range(8):
        nx = (s_x + n_dx[i])%n
        ny = (s_y + n_dy[i])%m
        
        if len(graph[nx][ny]) == 0  :
            continue 
        
        if check[nx][ny] :
            continue 
        
        check[nx][ny] = True
        graph[nx][ny][0] -= graph[w_x][w_y][0] // 2 
        if graph[nx][ny][0] <= 0 :
            graph[nx][ny] = []
            
def maintain():
    for i in range(n):
        for j in range(m):
            if not check[i][j] and len(graph[i][j] ):
                graph[i][j][0] +=1 
                
for turn in range(1,k+1):
    cnt = 0 
    for i in range(n):
        for j in range(m):
            if len(graph[i][j]) :
                cnt +=1 
    
    if cnt <= 1 :
        break 
    
    w_x, w_y = find_weak()
    
    check[w_x][w_y] = True 
    graph[w_x][w_y][0] += n + m
    graph[w_x][w_y][1] = turn 
    
    s_x, s_y = find_strong()
    check[s_x][s_y] = True 
    laser(w_x,w_y,s_x,s_y)
    maintain()

    check = [ [False] * m for _ in range(n)]  

s_x, s_y = find_strong()
print(graph[s_x][s_y][0])