728x90

https://www.acmicpc.net/problem/1517

 

1517번: 버블 소트

첫째 줄에 N(1 ≤ N ≤ 500,000)이 주어진다. 다음 줄에는 N개의 정수로 A[1], A[2], …, A[N]이 주어진다. 각각의 A[i]는 0 ≤ |A[i]| ≤ 1,000,000,000의 범위에 들어있다.

www.acmicpc.net

문제

N개의 수로 이루어진 수열 A[1], A[2], …, A[N]이 있다. 이 수열에 대해서 버블 소트를 수행할 때, Swap이 총 몇 번 발생하는지 알아내는 프로그램을 작성하시오.

(버블 소트는 서로 인접해 있는 두 수를 바꿔가며 정렬하는 방법이다. 예를 들어 수열이 3 2 1 이었다고 하자. 이 경우에는 인접해 있는 3, 2가 바뀌어야 하므로 2 3 1 이 된다. 다음으로는 3, 1이 바뀌어야 하므로 2 1 3 이 된다. 다음에는 2, 1이 바뀌어야 하므로 1 2 3 이 된다. 그러면 더 이상 바꿔야 할 경우가 없으므로 정렬이 완료된다.)

 

Tip & 풀이 핵심

1. '시간제한 1초' & '첫째 줄에 N(1 ≤ N ≤ 500,000)이 주어진다.'

  • (시간제한이 1초이면) input 길이가 10,000이 넘어가면 O(N**2)의 알고리즘으로 절대 풀 수 없음!!!
    => Bubble sort, Insertion sort, Selection sort: O(N**2)
    => Merge sort: O(N* logN) / Quick sort
    => Counting sort

2. Merge sort에서 swap하는 과정이 있더라!

   => "어떻게 하면 Merge sort 로직에 result 구하는 코드 구현?"

 

 

모범 코드

모범 코드 보면서 효율적이고 정확하게 코드 짜는 방법을 많이 배웠다.

import sys
input = sys.stdin.readline
result = 0 #최종 결과값

N = int(input())
A = list(map(int, input().split()))
A.insert(0,0) # 0 index에 0을 추가 # index 처리 쉽게 하기 위함
tmp = [0] *(N+1) # A배열의 크기를 맞춤 # A를 바로 처리하기 부담스러우므로 복제

def merge_sort(s, e):
    global result
    if e-s < 1: return
        
    else:
        m = (e+s) // 2
        merge_sort(s, m)
        merge_sort(m+1, e)
        
        for i in range(s, e+1):
            tmp[i] = A[i]
        
        # two pointers
        k = s #현재 바꾸어야 하는 A의 index
        index1 = s
        index2 = m+1
        while index1 <= m and index2 <=e:
            if tmp[index1] > tmp[index2]:
                A[k] = tmp[index2] #swap
                result += (index2-k)
                k += 1
                index2 += 1
                # A[index2] = tmp[index1]
            
            else:
                A[k] = tmp[index1]
                k += 1
                index1 += 1
         
        while index1 <= m:
             A[k] = tmp[index1]
             k += 1
             index1 += 1
        
        while index2 <= e:
            A[k] = tmp[index2]
            k += 1
            index2 += 1
        

merge_sort(1, N) # 1번째 idx부터 N번째 idx까지 merge_sort 진행
print(result)
  • global 이해 필수: https://lets-hci-la-ai-withme.tistory.com/68
  • A.insert(0,0) : idx 처리 편하도록 0번째에 0value 추가
  • tmp 생성: input 리스트(A) for문으로 복제 (value 값 저장하여 안전하게)
  • 재귀식에서 base condition 처리, return만: if e-s<1: return
  • 함수에서 return이 없을 수 있다. 단 위 경우, 전역변수인 result에 계속 값을 저장.
  • 특히, Merge_sort에서 return 없이, 시작 idx(s)와 끝 idx(e)만을 입력 파라미터로 이용하여 처리 가능.
  • k는 현재 처리해야 하는 A에서의 idx. 따로 지정.

 

내가 작성한 시간 초과 코드:

더보기

- merge_sort

- left_에서 현재 값보다 큰 것 중 가장 작은 것을 search하는 binary search

def merge_sort(inputs):
    
    if len(inputs) < 2:
        return inputs, 0
    
    else:
        mid = len(inputs)//2
        left = inputs[:mid]
        right = inputs[mid:]
        left_ , l_count  = merge_sort(left)
        right_ , r_count = merge_sort(right)
        
        merged, count = merge(left_, right_, l_count, r_count)

        return merged, count

def merge(left_, right_, l_count, r_count):
    merged = []
    count = (l_count + r_count)
    l_ptr = 0
    r_ptr = 0
    
    while l_ptr < len(left_) and r_ptr < len(right_):
        if left_[l_ptr] > right_[r_ptr]:
            # left에서 right_[r_ptr]보다 큰 원소 중 최소인 것의 idx 찾기
            s, mid, e = 0, 0, len(left_)-1

            while s <= e:
                mid = (s + e)//2
                if right_[r_ptr] == left_[mid]:
                    break
                
                elif right_[r_ptr] < left_[mid]:
                    e = mid -1
                else:
                    s = mid+1
                        
            if left_[mid] <= right_[r_ptr] < left_[e]:
                count += len(left_) - (mid+1)

            elif right_[r_ptr] < left_[s]:
                count += len(left_) - s

            elif mid > 0 and mid <= len(left_)-2:
                idx = mid+1
                while right_[r_ptr] >= left_[idx]:
                    idx += 1
                count += len(left_) - idx
            else:
                print(-1)
            
            merged.append(right_[r_ptr])
            r_ptr += 1
        else:
            merged.append(left_[l_ptr])
            l_ptr +=1
    
    while r_ptr < len(right_):
        merged.append(right_[r_ptr])
        r_ptr += 1
        
    while l_ptr < len(left_):
        merged.append(left_[l_ptr])
        l_ptr +=1
    
    return merged, count

import sys

for t in range(2):
    if t == 0:
        n = int(sys.stdin.readline())
    else:
        inputs = list(map(int, sys.stdin.readline().split()))
        print(merge_sort(inputs)[1])
728x90

+ Recent posts