0001 """Unittests for heapq.""" 0002 0003 from heapq import heappush, heappop, heapify, heapreplace, nlargest, nsmallest 0004 import random 0005 import unittest 0006 from test import test_support 0007 import sys 0008 0009 0010 def heapiter(heap): 0011 # An iterator returning a heap's elements, smallest-first. 0012 try: 0013 while 1: 0014 yield heappop(heap) 0015 except IndexError: 0016 pass 0017 0018 class TestHeap(unittest.TestCase): 0019 0020 def test_push_pop(self): 0021 # 1) Push 256 random numbers and pop them off, verifying all's OK. 0022 heap = [] 0023 data = [] 0024 self.check_invariant(heap) 0025 for i in range(256): 0026 item = random.random() 0027 data.append(item) 0028 heappush(heap, item) 0029 self.check_invariant(heap) 0030 results = [] 0031 while heap: 0032 item = heappop(heap) 0033 self.check_invariant(heap) 0034 results.append(item) 0035 data_sorted = data[:] 0036 data_sorted.sort() 0037 self.assertEqual(data_sorted, results) 0038 # 2) Check that the invariant holds for a sorted array 0039 self.check_invariant(results) 0040 0041 self.assertRaises(TypeError, heappush, []) 0042 self.assertRaises(TypeError, heappush, None, None) 0043 self.assertRaises(TypeError, heappop, None) 0044 0045 def check_invariant(self, heap): 0046 # Check the heap invariant. 0047 for pos, item in enumerate(heap): 0048 if pos: # pos 0 has no parent 0049 parentpos = (pos-1) >> 1 0050 self.assert_(heap[parentpos] <= item) 0051 0052 def test_heapify(self): 0053 for size in range(30): 0054 heap = [random.random() for dummy in range(size)] 0055 heapify(heap) 0056 self.check_invariant(heap) 0057 0058 self.assertRaises(TypeError, heapify, None) 0059 0060 def test_naive_nbest(self): 0061 data = [random.randrange(2000) for i in range(1000)] 0062 heap = [] 0063 for item in data: 0064 heappush(heap, item) 0065 if len(heap) > 10: 0066 heappop(heap) 0067 heap.sort() 0068 self.assertEqual(heap, sorted(data)[-10:]) 0069 0070 def test_nbest(self): 0071 # Less-naive "N-best" algorithm, much faster (if len(data) is big 0072 # enough <wink>) than sorting all of data. However, if we had a max 0073 # heap instead of a min heap, it could go faster still via 0074 # heapify'ing all of data (linear time), then doing 10 heappops 0075 # (10 log-time steps). 0076 data = [random.randrange(2000) for i in range(1000)] 0077 heap = data[:10] 0078 heapify(heap) 0079 for item in data[10:]: 0080 if item > heap[0]: # this gets rarer the longer we run 0081 heapreplace(heap, item) 0082 self.assertEqual(list(heapiter(heap)), sorted(data)[-10:]) 0083 0084 self.assertRaises(TypeError, heapreplace, None) 0085 self.assertRaises(TypeError, heapreplace, None, None) 0086 self.assertRaises(IndexError, heapreplace, [], None) 0087 0088 def test_heapsort(self): 0089 # Exercise everything with repeated heapsort checks 0090 for trial in xrange(100): 0091 size = random.randrange(50) 0092 data = [random.randrange(25) for i in range(size)] 0093 if trial & 1: # Half of the time, use heapify 0094 heap = data[:] 0095 heapify(heap) 0096 else: # The rest of the time, use heappush 0097 heap = [] 0098 for item in data: 0099 heappush(heap, item) 0100 heap_sorted = [heappop(heap) for i in range(size)] 0101 self.assertEqual(heap_sorted, sorted(data)) 0102 0103 def test_nsmallest(self): 0104 data = [random.randrange(2000) for i in range(1000)] 0105 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): 0106 self.assertEqual(nsmallest(n, data), sorted(data)[:n]) 0107 0108 def test_largest(self): 0109 data = [random.randrange(2000) for i in range(1000)] 0110 for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100): 0111 self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n]) 0112 0113 0114 #============================================================================== 0115 0116 class LenOnly: 0117 "Dummy sequence class defining __len__ but not __getitem__." 0118 def __len__(self): 0119 return 10 0120 0121 class GetOnly: 0122 "Dummy sequence class defining __getitem__ but not __len__." 0123 def __getitem__(self, ndx): 0124 return 10 0125 0126 class CmpErr: 0127 "Dummy element that always raises an error during comparison" 0128 def __cmp__(self, other): 0129 raise ZeroDivisionError 0130 0131 def R(seqn): 0132 'Regular generator' 0133 for i in seqn: 0134 yield i 0135 0136 class G: 0137 'Sequence using __getitem__' 0138 def __init__(self, seqn): 0139 self.seqn = seqn 0140 def __getitem__(self, i): 0141 return self.seqn[i] 0142 0143 class I: 0144 'Sequence using iterator protocol' 0145 def __init__(self, seqn): 0146 self.seqn = seqn 0147 self.i = 0 0148 def __iter__(self): 0149 return self 0150 def next(self): 0151 if self.i >= len(self.seqn): raise StopIteration 0152 v = self.seqn[self.i] 0153 self.i += 1 0154 return v 0155 0156 class Ig: 0157 'Sequence using iterator protocol defined with a generator' 0158 def __init__(self, seqn): 0159 self.seqn = seqn 0160 self.i = 0 0161 def __iter__(self): 0162 for val in self.seqn: 0163 yield val 0164 0165 class X: 0166 'Missing __getitem__ and __iter__' 0167 def __init__(self, seqn): 0168 self.seqn = seqn 0169 self.i = 0 0170 def next(self): 0171 if self.i >= len(self.seqn): raise StopIteration 0172 v = self.seqn[self.i] 0173 self.i += 1 0174 return v 0175 0176 class N: 0177 'Iterator missing next()' 0178 def __init__(self, seqn): 0179 self.seqn = seqn 0180 self.i = 0 0181 def __iter__(self): 0182 return self 0183 0184 class E: 0185 'Test propagation of exceptions' 0186 def __init__(self, seqn): 0187 self.seqn = seqn 0188 self.i = 0 0189 def __iter__(self): 0190 return self 0191 def next(self): 0192 3 // 0 0193 0194 class S: 0195 'Test immediate stop' 0196 def __init__(self, seqn): 0197 pass 0198 def __iter__(self): 0199 return self 0200 def next(self): 0201 raise StopIteration 0202 0203 from itertools import chain, imap 0204 def L(seqn): 0205 'Test multiple tiers of iterators' 0206 return chain(imap(lambda x:x, R(Ig(G(seqn))))) 0207 0208 class TestErrorHandling(unittest.TestCase): 0209 0210 def test_non_sequence(self): 0211 for f in (heapify, heappop): 0212 self.assertRaises(TypeError, f, 10) 0213 for f in (heappush, heapreplace, nlargest, nsmallest): 0214 self.assertRaises(TypeError, f, 10, 10) 0215 0216 def test_len_only(self): 0217 for f in (heapify, heappop): 0218 self.assertRaises(TypeError, f, LenOnly()) 0219 for f in (heappush, heapreplace): 0220 self.assertRaises(TypeError, f, LenOnly(), 10) 0221 for f in (nlargest, nsmallest): 0222 self.assertRaises(TypeError, f, 2, LenOnly()) 0223 0224 def test_get_only(self): 0225 for f in (heapify, heappop): 0226 self.assertRaises(TypeError, f, GetOnly()) 0227 for f in (heappush, heapreplace): 0228 self.assertRaises(TypeError, f, GetOnly(), 10) 0229 for f in (nlargest, nsmallest): 0230 self.assertRaises(TypeError, f, 2, GetOnly()) 0231 0232 def test_get_only(self): 0233 seq = [CmpErr(), CmpErr(), CmpErr()] 0234 for f in (heapify, heappop): 0235 self.assertRaises(ZeroDivisionError, f, seq) 0236 for f in (heappush, heapreplace): 0237 self.assertRaises(ZeroDivisionError, f, seq, 10) 0238 for f in (nlargest, nsmallest): 0239 self.assertRaises(ZeroDivisionError, f, 2, seq) 0240 0241 def test_arg_parsing(self): 0242 for f in (heapify, heappop, heappush, heapreplace, nlargest, nsmallest): 0243 self.assertRaises(TypeError, f, 10) 0244 0245 def test_iterable_args(self): 0246 for f in (nlargest, nsmallest): 0247 for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)): 0248 for g in (G, I, Ig, L, R): 0249 self.assertEqual(f(2, g(s)), f(2,s)) 0250 self.assertEqual(f(2, S(s)), []) 0251 self.assertRaises(TypeError, f, 2, X(s)) 0252 self.assertRaises(TypeError, f, 2, N(s)) 0253 self.assertRaises(ZeroDivisionError, f, 2, E(s)) 0254 0255 #============================================================================== 0256 0257 0258 def test_main(verbose=None): 0259 from types import BuiltinFunctionType 0260 0261 test_classes = [TestHeap] 0262 if isinstance(heapify, BuiltinFunctionType): 0263 test_classes.append(TestErrorHandling) 0264 test_support.run_unittest(*test_classes) 0265 0266 # verify reference counting 0267 if verbose and hasattr(sys, "gettotalrefcount"): 0268 import gc 0269 counts = [None] * 5 0270 for i in xrange(len(counts)): 0271 test_support.run_unittest(*test_classes) 0272 gc.collect() 0273 counts[i] = sys.gettotalrefcount() 0274 print counts 0275 0276 if __name__ == "__main__": 0277 test_main(verbose=True) 0278
Generated by PyXR 0.9.4