본문 바로가기

개발/algorithm

[코드트리] 산타의 선물 공장 - Python

문제 링크

https://www.codetree.ai/training-field/frequent-problems/problems/santa-gift-factory/description?page=1&pageSize=20 

 

코드트리 | 코딩테스트 준비를 위한 알고리즘 정석

국가대표가 만든 코딩 공부의 가이드북 코딩 왕초보부터 꿈의 직장 코테 합격까지, 국가대표가 엄선한 커리큘럼으로 준비해보세요.

www.codetree.ai

 

풀이 

해설 참고해서 풀었다. 

실제 queue로 구현해서 풀면 안되고, 숫자의 범위가 매우 크기 때문에 최적화가 필요한 문제이다. 

prv 딕셔너리 : id에 해당하는 상자의 prv id 관리 

nxt 딕셔너리 : id에 해당하는 상자의 nxt id 관리

weigt 딕셔너리 : id에 해당하는 상자 무게 

# id에 해당하는 상자의 prv, nxt 관리 
prv = defaultdict(int)
nxt = defaultdict(int)

# id에 해당하는 상자 무게 
weight = {}

head : 벨트별 head 상자 id 

tail: 벨트별 tail 상자 id 

broken : 부서진 벨트인지 확인 

belt_num : id에 해당하는 상자 별 벨트 번호 

# 벨트 별 head, tail, id 관리
head = [0] * MAX_M
tail = [0] * MAX_M

broken = [False] * MAX_M

# 물건 별 벨트 번호 
belt_num = defaultdict(lambda: -1)

 

1. 공장 설립 

def build(data):
    global n, m 
    
    n, m = data[1], data[2]
    ids, ws = data[3:3+n], data[3+n:3+n+n]
    
    # id 별 무게 
    for i in range(n):
        weight[ids[i]] = ws[i]
    
    s = n//m 
    for i in range(m):
    	# 벨트 별 head, tail 값 저장
        head[i] = ids[i*s]
        tail[i] = ids[(i+1)*s-1]
        
        for j in range(i*s, (i+1)*s):
            # 상자 별 벨트 번호
            belt_num[ids[j]] = i
            
            # prv, nxt 값 저장
            if j < (i+1)*s - 1 :
                nxt[ids[j]] = ids[j+1]
                prv[ids[j+1]] = ids[j]

2. 물건 하차 

def drop(data):
    w_max = data[1]
    ans = 0 
    
    for i in range(m):
        # 부서진 벨트라면 넘어감 
        if broken[i] :
            continue 
        
        # 원소가 존재한다면 
        if head[i] != 0 :
            id = head[i]
            w = weight[id]
            # w_max 이하 
            if w <= w_max :
                ans += w
                # 해당 박스 벨트에서 제거 
                remove_id(id, True)
            # 그렇지 않다면 
            # 하나만 있는 경우는 그대로 유지 
            # 아니면 맨 뒤로 이동
            elif nxt[id] != 0 :
                # 벨트에서 제거 후 
                remove_id(id, False)
                # 맨 뒤에 삽입 
                push_id(tail[i], id)
    
    print(ans)

위에서 사용된 상자 제거, 상자 삽입 함수는 각각 아래와 같다. 

def remove_id(r_id, remove_belt):
    b_num = belt_num[r_id]
    
    # 벨트 번호 제거 -> 상자가 이동한 것이 아니라, 아에 하차된 경우 
    if remove_belt :
        belt_num[r_id] = -1 
    
    # 하나 남은 원소이면 사라지고 끝 
    if head[b_num] == tail[b_num]:
        head[b_num] = tail[b_num] = 0
        
    # head 삭제 일 경우 
    elif r_id == head[b_num]:
        nid = nxt[r_id]
        head[b_num] = nid
        prv[nid] = 0 
    # tail 삭제일 경우 
    elif r_id == tail[b_num] :
        pid = prv[r_id]
        tail[b_num] = pid 
        nxt[pid] = 0 
    # 중간에 있는 id 삭제 
    else :
        pid, nid = prv[r_id], nxt[r_id]
        nxt[pid] = nid 
        prv[nid] = pid 
    
    # 원소 제거 
    nxt[r_id], prv[r_id] = 0, 0
def push_id(t_id, id):
    nxt[t_id] = id 
    prv[id] = t_id 
    
    b_num = belt_num[t_id]
    # 마지막이었으면 tail 값 갱신
    if t_id == tail[b_num]:
        tail[b_num] = id

3. 물건 제거

# id에 해당하는 상자 삭제 
def remove(data):
    r_id = data[1]
    
    # 이미 삭제된 상자인 경우
    if belt_num[r_id] == -1 :
        print(-1)
        return 

    print(r_id)
    remove_id(r_id, True)

4. 물건 확인 

   def find(data):
    f_id = data[1]
    
    # 없는 상자인 경우 
    if belt_num[f_id] == -1:
        print(-1)
        return 

    b_num = belt_num[f_id]
    
    # 맨 앞 아닌 경우 맨 앞으로 이동
    if head[b_num] != f_id :
        orig_tail = tail[b_num]
        orig_head = head[b_num]
        
        now_tail = prv[f_id]
        tail[b_num] = now_tail 
        nxt[now_tail] = 0
        
        nxt[orig_tail] = orig_head 
        prv[orig_head] = orig_tail 
        
        head[b_num] = f_id 
        
    print(b_num +1)

5. 벨트 고장

def broken_belt(data):
    b_num = data[1] -1 
    
    # 이미 고장난 벨트인 경우 
    if broken[b_num]:
        print(-1)
        return 
    
    broken[b_num] = True 
    
    # 벨트에 아무 상자도 없는 경우 
    if head[b_num] == 0 :
        print(b_num + 1)
        return 

    next_num = b_num
    while True :
    	# 오른쪽으로 탐색 
        next_num = (next_num + 1)%m 
        
        # 부서지지 않으 벨트가 존재 
        if not broken[next_num]:
        	# 벨트에 아무것도 없는 경우 그대로 옮김 
            # head와 tail 만 갱신 
            if tail[next_num] == 0 :
                head[next_num] = head[b_num]
                tail[next_num] = tail[b_num]
            # 뒤에 추가 
            else :
                push_id(tail[next_num] , head[b_num])
                tail[next_num] = tail[b_num]
            
            # 상자 별 벨트 번호 갱신 
            id = head[b_num]
            while id != 0 :
                belt_num[id] = next_num
                id = nxt[id]
            
            # 기존 벨트 head, tail 삭제 
            head[b_num] = tail[b_num] = 0 
            break 
        
    print(b_num+1)

 

전체 코드는 아래와 같다.

from collections import defaultdict

MAX_M = 10
q = int(input())
n, m = -1, -1 
# id에 해당하는 상자의 prv, nxt 관리 
prv = defaultdict(int)
nxt = defaultdict(int)

# id에 해당하는 상자 무게 
weight = {}

# 벨트 별 head, tail, id 관리
head = [0] * MAX_M
tail = [0] * MAX_M

broken = [False] * MAX_M

# 물건 별 벨트 번호 
belt_num = defaultdict(lambda: -1)

def build(data):
    global n, m 
    
    n, m = data[1], data[2]
    ids, ws = data[3:3+n], data[3+n:3+n+n]
    
    # id 별 무게 
    for i in range(n):
        weight[ids[i]] = ws[i]
    
    s = n//m 
    for i in range(m):
        head[i] = ids[i*s]
        tail[i] = ids[(i+1)*s-1]
        
        for j in range(i*s, (i+1)*s):
            # 상자 별 벨트 번호
            belt_num[ids[j]] = i
            
            if j < (i+1)*s - 1 :
                nxt[ids[j]] = ids[j+1]
                prv[ids[j+1]] = ids[j]

def remove_id(r_id, remove_belt):
    b_num = belt_num[r_id]
    
    # 벨트 번호 제거 
    if remove_belt :
        belt_num[r_id] = -1 
    
    # 하나 남은 원소이면 사라지고 끝 
    if head[b_num] == tail[b_num]:
        head[b_num] = tail[b_num] = 0
    # head 삭제 일 경우 
    elif r_id == head[b_num]:
        nid = nxt[r_id]
        head[b_num] = nid
        prv[nid] = 0 
    # tail 삭제일 경우 
    elif r_id == tail[b_num] :
        pid = prv[r_id]
        tail[b_num] = pid 
        nxt[pid] = 0 
    # 중간에 있는 id 삭제 
    else :
        pid, nid = prv[r_id], nxt[r_id]
        nxt[pid] = nid 
        prv[nid] = pid 
    
    # 원소 제거 
    nxt[r_id], prv[r_id] = 0, 0
    
# id에 해당하는 상자 삭제 
def remove(data):
    r_id = data[1]
    
    if belt_num[r_id] == -1 :
        print(-1)
        return 

    print(r_id)
    remove_id(r_id, True)

def push_id(t_id, id):
    nxt[t_id] = id 
    prv[id] = t_id 
    
    b_num = belt_num[t_id]
    if t_id == tail[b_num]:
        tail[b_num] = id 
    
def drop(data):
    w_max = data[1]
    ans = 0 
    
    for i in range(m):
        # 부서진 벨트라면 넘어감 
        if broken[i] :
            continue 
        
        # 원소가 존재한다면 
        if head[i] != 0 :
            id = head[i]
            w = weight[id]
            # w_max 이하 
            if w <= w_max :
                ans += w
                # 해당 박스 벨트에서 제거 
                remove_id(id, True)
            # 그렇지 않다면 
            # 하나만 있는 경우는 그대로 유지 
            # 아니면 맨 뒤로 이동
            elif nxt[id] != 0 :
                # 벨트에서 제거 후 
                remove_id(id, False)
                # 맨 뒤에 삽입 
                push_id(tail[i], id)
    
    print(ans)
                
def find(data):
    f_id = data[1]
    
    if belt_num[f_id] == -1:
        print(-1)
        return 

    b_num = belt_num[f_id]
    
    # 맨 앞 아닌 경우 맨 앞으로 이동
    if head[b_num] != f_id :
        orig_tail = tail[b_num]
        orig_head = head[b_num]
        
        now_tail = prv[f_id]
        tail[b_num] = now_tail 
        nxt[now_tail] = 0
        
        nxt[orig_tail] = orig_head 
        prv[orig_head] = orig_tail 
        
        head[b_num] = f_id 
        # 이거 필요하지 않나 ? 
        #prv[f_id] = 0
        
    print(b_num +1)

def broken_belt(data):
    b_num = data[1] -1 
    
    if broken[b_num]:
        print(-1)
        return 
    
    broken[b_num] = True 
    
    if head[b_num] == 0 :
        print(b_num + 1)
        return 

    next_num = b_num
    while True :
        next_num = (next_num + 1)%m 
        
        if not broken[next_num]:
            if tail[next_num] == 0 :
                head[next_num] = head[b_num]
                tail[next_num] = tail[b_num]
            
            else :
                push_id(tail[next_num] , head[b_num])
                tail[next_num] = tail[b_num]
            
            id = head[b_num]
            while id != 0 :
                belt_num[id] = next_num
                id = nxt[id]
            
            head[b_num] = tail[b_num] = 0 
            break 
        
    print(b_num+1)

for _ in range(q):
    tmp = list(map(int,input().split()))
    q_type = tmp[0]
    
    if q_type == 100 :
        build(tmp)
    elif q_type == 200 :
        drop(tmp)
    elif q_type == 300:
        remove(tmp)
    elif q_type == 400:
        find(tmp)
    else :
        broken_belt(tmp)