From 73c269414f41dac87122963dc1fc456442de8b85 Mon Sep 17 00:00:00 2001 From: Damien George Date: Wed, 26 Jun 2019 14:26:20 +1000 Subject: [PATCH] tests/perf_bench: Add some miscellaneous performance benchmarks. misc_aes.py and misc_mandel.py are adapted from sources in this repository. misc_pystone.py is the standard Python pystone test. misc_raytrace.py is written from scratch. --- tests/perf_bench/misc_aes.py | 225 +++++++++++++++++++++++++++ tests/perf_bench/misc_mandel.py | 31 ++++ tests/perf_bench/misc_pystone.py | 226 ++++++++++++++++++++++++++++ tests/perf_bench/misc_raytrace.py | 242 ++++++++++++++++++++++++++++++ 4 files changed, 724 insertions(+) create mode 100644 tests/perf_bench/misc_aes.py create mode 100644 tests/perf_bench/misc_mandel.py create mode 100644 tests/perf_bench/misc_pystone.py create mode 100644 tests/perf_bench/misc_raytrace.py diff --git a/tests/perf_bench/misc_aes.py b/tests/perf_bench/misc_aes.py new file mode 100644 index 000000000..5413a06b1 --- /dev/null +++ b/tests/perf_bench/misc_aes.py @@ -0,0 +1,225 @@ +# Pure Python AES encryption routines. +# +# AES is integer based and inplace so doesn't use the heap. It is therefore +# a good test of raw performance and correctness of the VM/runtime. +# +# The AES code comes first (code originates from a C version authored by D.P.George) +# and then the test harness at the bottom. +# +# MIT license; Copyright (c) 2016 Damien P. George on behalf of Pycom Ltd + +################################################################## +# discrete arithmetic routines, mostly from a precomputed table + +# non-linear, invertible, substitution box +aes_s_box_table = bytes(( + 0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76, + 0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0, + 0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15, + 0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75, + 0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84, + 0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf, + 0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8, + 0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2, + 0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73, + 0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb, + 0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79, + 0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08, + 0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a, + 0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e, + 0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf, + 0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16, +)) + +# multiplication of polynomials modulo x^8 + x^4 + x^3 + x + 1 = 0x11b +def aes_gf8_mul_2(x): + if x & 0x80: + return (x << 1) ^ 0x11b + else: + return x << 1 + +def aes_gf8_mul_3(x): + return x ^ aes_gf8_mul_2(x) + +# non-linear, invertible, substitution box +def aes_s_box(a): + return aes_s_box_table[a & 0xff] + +# return 0x02^(a-1) in GF(2^8) +def aes_r_con(a): + ans = 1 + while a > 1: + ans <<= 1; + if ans & 0x100: + ans ^= 0x11b + a -= 1 + return ans + +################################################################## +# basic AES algorithm; see FIPS-197 + +# all inputs must be size 16 +def aes_add_round_key(state, w): + for i in range(16): + state[i] ^= w[i] + +# combined sub_bytes, shift_rows, mix_columns, add_round_key +# all inputs must be size 16 +def aes_sb_sr_mc_ark(state, w, w_idx, temp): + temp_idx = 0 + for i in range(4): + x0 = aes_s_box_table[state[i * 4]] + x1 = aes_s_box_table[state[1 + ((i + 1) & 3) * 4]] + x2 = aes_s_box_table[state[2 + ((i + 2) & 3) * 4]] + x3 = aes_s_box_table[state[3 + ((i + 3) & 3) * 4]] + temp[temp_idx] = aes_gf8_mul_2(x0) ^ aes_gf8_mul_3(x1) ^ x2 ^ x3 ^ w[w_idx] + temp[temp_idx + 1] = x0 ^ aes_gf8_mul_2(x1) ^ aes_gf8_mul_3(x2) ^ x3 ^ w[w_idx + 1] + temp[temp_idx + 2] = x0 ^ x1 ^ aes_gf8_mul_2(x2) ^ aes_gf8_mul_3(x3) ^ w[w_idx + 2] + temp[temp_idx + 3] = aes_gf8_mul_3(x0) ^ x1 ^ x2 ^ aes_gf8_mul_2(x3) ^ w[w_idx + 3] + w_idx += 4 + temp_idx += 4 + for i in range(16): + state[i] = temp[i] + +# combined sub_bytes, shift_rows, add_round_key +# all inputs must be size 16 +def aes_sb_sr_ark(state, w, w_idx, temp): + temp_idx = 0 + for i in range(4): + x0 = aes_s_box_table[state[i * 4]] + x1 = aes_s_box_table[state[1 + ((i + 1) & 3) * 4]] + x2 = aes_s_box_table[state[2 + ((i + 2) & 3) * 4]] + x3 = aes_s_box_table[state[3 + ((i + 3) & 3) * 4]] + temp[temp_idx] = x0 ^ w[w_idx] + temp[temp_idx + 1] = x1 ^ w[w_idx + 1] + temp[temp_idx + 2] = x2 ^ w[w_idx + 2] + temp[temp_idx + 3] = x3 ^ w[w_idx + 3] + w_idx += 4 + temp_idx += 4 + for i in range(16): + state[i] = temp[i] + +# take state as input and change it to the next state in the sequence +# state and temp have size 16, w has size 16 * (Nr + 1), Nr >= 1 +def aes_state(state, w, temp, nr): + aes_add_round_key(state, w) + w_idx = 16 + for i in range(nr - 1): + aes_sb_sr_mc_ark(state, w, w_idx, temp) + w_idx += 16 + aes_sb_sr_ark(state, w, w_idx, temp) + +# expand 'key' to 'w' for use with aes_state +# key has size 4 * Nk, w has size 16 * (Nr + 1), temp has size 16 +def aes_key_expansion(key, w, temp, nk, nr): + for i in range(4 * nk): + w[i] = key[i] + w_idx = 4 * nk - 4 + for i in range(nk, 4 * (nr + 1)): + t = temp + t_idx = 0 + if i % nk == 0: + t[0] = aes_s_box(w[w_idx + 1]) ^ aes_r_con(i // nk) + for j in range(1, 4): + t[j] = aes_s_box(w[w_idx + (j + 1) % 4]) + elif nk > 6 and i % nk == 4: + for j in range(0, 4): + t[j] = aes_s_box(w[w_idx + j]) + else: + t = w + t_idx = w_idx + w_idx += 4 + for j in range(4): + w[w_idx + j] = w[w_idx + j - 4 * nk] ^ t[t_idx + j] + +################################################################## +# simple use of AES algorithm, using output feedback (OFB) mode + +class AES: + def __init__(self, keysize): + if keysize == 128: + self.nk = 4 + self.nr = 10 + elif keysize == 192: + self.nk = 6 + self.nr = 12 + else: + assert keysize == 256 + self.nk = 8 + self.nr = 14 + + self.state = bytearray(16) + self.w = bytearray(16 * (self.nr + 1)) + self.temp = bytearray(16) + self.state_pos = 16 + + def set_key(self, key): + aes_key_expansion(key, self.w, self.temp, self.nk, self.nr) + self.state_pos = 16 + + def set_iv(self, iv): + for i in range(16): + self.state[i] = iv[i] + self.state_pos = 16; + + def get_some_state(self, n_needed): + if self.state_pos >= 16: + aes_state(self.state, self.w, self.temp, self.nr) + self.state_pos = 0 + n = 16 - self.state_pos + if n > n_needed: + n = n_needed + return n + + def apply_to(self, data): + idx = 0 + n = len(data) + while n > 0: + ln = self.get_some_state(n) + n -= ln + for i in range(ln): + data[idx + i] ^= self.state[self.state_pos + i] + idx += ln + self.state_pos += n + +########################################################################### +# Benchmark interface + +bm_params = { + (50, 25): (1, 16), + (100, 100): (1, 32), + (1000, 1000): (4, 256), + (5000, 1000): (20, 256), +} + +def bm_setup(params): + nloop, datalen = params + + aes = AES(256) + key = bytearray(256 // 8) + iv = bytearray(16) + data = bytearray(datalen) + # from now on we don't use the heap + + def run(): + for loop in range(nloop): + # encrypt + aes.set_key(key) + aes.set_iv(iv) + for i in range(2): + aes.apply_to(data) + + # decrypt + aes.set_key(key) + aes.set_iv(iv) + for i in range(2): + aes.apply_to(data) + + # verify + for i in range(len(data)): + assert data[i] == 0 + + def result(): + return params[0] * params[1], True + + return run, result diff --git a/tests/perf_bench/misc_mandel.py b/tests/perf_bench/misc_mandel.py new file mode 100644 index 000000000..a4b789136 --- /dev/null +++ b/tests/perf_bench/misc_mandel.py @@ -0,0 +1,31 @@ +# Compute the Mandelbrot set, to test complex numbers + +def mandelbrot(w, h): + def in_set(c): + z = 0 + for i in range(32): + z = z * z + c + if abs(z) > 100: + return i + return 0 + + img = bytearray(w * h) + + xscale = ((w - 1) / 2.4) + yscale = ((h - 1) / 3.2) + for v in range(h): + line = memoryview(img)[v * w:v * w + w] + for u in range(w): + c = in_set(complex(v / yscale - 2.3, u / xscale - 1.2)) + line[u] = c + + return img + +bm_params = { + (100, 100): (20, 20), + (1000, 1000): (80, 80), + (5000, 1000): (150, 150), +} + +def bm_setup(ps): + return lambda: mandelbrot(ps[0], ps[1]), lambda: (ps[0] * ps[1], None) diff --git a/tests/perf_bench/misc_pystone.py b/tests/perf_bench/misc_pystone.py new file mode 100644 index 000000000..88626774b --- /dev/null +++ b/tests/perf_bench/misc_pystone.py @@ -0,0 +1,226 @@ +""" +"PYSTONE" Benchmark Program + +Version: Python/1.2 (corresponds to C/1.1 plus 3 Pystone fixes) + +Author: Reinhold P. Weicker, CACM Vol 27, No 10, 10/84 pg. 1013. + + Translated from ADA to C by Rick Richardson. + Every method to preserve ADA-likeness has been used, + at the expense of C-ness. + + Translated from C to Python by Guido van Rossum. +""" + +__version__ = "1.2" + +[Ident1, Ident2, Ident3, Ident4, Ident5] = range(1, 6) + +class Record: + + def __init__(self, PtrComp = None, Discr = 0, EnumComp = 0, + IntComp = 0, StringComp = 0): + self.PtrComp = PtrComp + self.Discr = Discr + self.EnumComp = EnumComp + self.IntComp = IntComp + self.StringComp = StringComp + + def copy(self): + return Record(self.PtrComp, self.Discr, self.EnumComp, + self.IntComp, self.StringComp) + +TRUE = 1 +FALSE = 0 + +def Setup(): + global IntGlob + global BoolGlob + global Char1Glob + global Char2Glob + global Array1Glob + global Array2Glob + + IntGlob = 0 + BoolGlob = FALSE + Char1Glob = '\0' + Char2Glob = '\0' + Array1Glob = [0]*51 + Array2Glob = [x[:] for x in [Array1Glob]*51] + +def Proc0(loops): + global IntGlob + global BoolGlob + global Char1Glob + global Char2Glob + global Array1Glob + global Array2Glob + global PtrGlb + global PtrGlbNext + + PtrGlbNext = Record() + PtrGlb = Record() + PtrGlb.PtrComp = PtrGlbNext + PtrGlb.Discr = Ident1 + PtrGlb.EnumComp = Ident3 + PtrGlb.IntComp = 40 + PtrGlb.StringComp = "DHRYSTONE PROGRAM, SOME STRING" + String1Loc = "DHRYSTONE PROGRAM, 1'ST STRING" + Array2Glob[8][7] = 10 + + for i in range(loops): + Proc5() + Proc4() + IntLoc1 = 2 + IntLoc2 = 3 + String2Loc = "DHRYSTONE PROGRAM, 2'ND STRING" + EnumLoc = Ident2 + BoolGlob = not Func2(String1Loc, String2Loc) + while IntLoc1 < IntLoc2: + IntLoc3 = 5 * IntLoc1 - IntLoc2 + IntLoc3 = Proc7(IntLoc1, IntLoc2) + IntLoc1 = IntLoc1 + 1 + Proc8(Array1Glob, Array2Glob, IntLoc1, IntLoc3) + PtrGlb = Proc1(PtrGlb) + CharIndex = 'A' + while CharIndex <= Char2Glob: + if EnumLoc == Func1(CharIndex, 'C'): + EnumLoc = Proc6(Ident1) + CharIndex = chr(ord(CharIndex)+1) + IntLoc3 = IntLoc2 * IntLoc1 + IntLoc2 = IntLoc3 // IntLoc1 + IntLoc2 = 7 * (IntLoc3 - IntLoc2) - IntLoc1 + IntLoc1 = Proc2(IntLoc1) + +def Proc1(PtrParIn): + PtrParIn.PtrComp = NextRecord = PtrGlb.copy() + PtrParIn.IntComp = 5 + NextRecord.IntComp = PtrParIn.IntComp + NextRecord.PtrComp = PtrParIn.PtrComp + NextRecord.PtrComp = Proc3(NextRecord.PtrComp) + if NextRecord.Discr == Ident1: + NextRecord.IntComp = 6 + NextRecord.EnumComp = Proc6(PtrParIn.EnumComp) + NextRecord.PtrComp = PtrGlb.PtrComp + NextRecord.IntComp = Proc7(NextRecord.IntComp, 10) + else: + PtrParIn = NextRecord.copy() + NextRecord.PtrComp = None + return PtrParIn + +def Proc2(IntParIO): + IntLoc = IntParIO + 10 + while 1: + if Char1Glob == 'A': + IntLoc = IntLoc - 1 + IntParIO = IntLoc - IntGlob + EnumLoc = Ident1 + if EnumLoc == Ident1: + break + return IntParIO + +def Proc3(PtrParOut): + global IntGlob + + if PtrGlb is not None: + PtrParOut = PtrGlb.PtrComp + else: + IntGlob = 100 + PtrGlb.IntComp = Proc7(10, IntGlob) + return PtrParOut + +def Proc4(): + global Char2Glob + + BoolLoc = Char1Glob == 'A' + BoolLoc = BoolLoc or BoolGlob + Char2Glob = 'B' + +def Proc5(): + global Char1Glob + global BoolGlob + + Char1Glob = 'A' + BoolGlob = FALSE + +def Proc6(EnumParIn): + EnumParOut = EnumParIn + if not Func3(EnumParIn): + EnumParOut = Ident4 + if EnumParIn == Ident1: + EnumParOut = Ident1 + elif EnumParIn == Ident2: + if IntGlob > 100: + EnumParOut = Ident1 + else: + EnumParOut = Ident4 + elif EnumParIn == Ident3: + EnumParOut = Ident2 + elif EnumParIn == Ident4: + pass + elif EnumParIn == Ident5: + EnumParOut = Ident3 + return EnumParOut + +def Proc7(IntParI1, IntParI2): + IntLoc = IntParI1 + 2 + IntParOut = IntParI2 + IntLoc + return IntParOut + +def Proc8(Array1Par, Array2Par, IntParI1, IntParI2): + global IntGlob + + IntLoc = IntParI1 + 5 + Array1Par[IntLoc] = IntParI2 + Array1Par[IntLoc+1] = Array1Par[IntLoc] + Array1Par[IntLoc+30] = IntLoc + for IntIndex in range(IntLoc, IntLoc+2): + Array2Par[IntLoc][IntIndex] = IntLoc + Array2Par[IntLoc][IntLoc-1] = Array2Par[IntLoc][IntLoc-1] + 1 + Array2Par[IntLoc+20][IntLoc] = Array1Par[IntLoc] + IntGlob = 5 + +def Func1(CharPar1, CharPar2): + CharLoc1 = CharPar1 + CharLoc2 = CharLoc1 + if CharLoc2 != CharPar2: + return Ident1 + else: + return Ident2 + +def Func2(StrParI1, StrParI2): + IntLoc = 1 + while IntLoc <= 1: + if Func1(StrParI1[IntLoc], StrParI2[IntLoc+1]) == Ident1: + CharLoc = 'A' + IntLoc = IntLoc + 1 + if CharLoc >= 'W' and CharLoc <= 'Z': + IntLoc = 7 + if CharLoc == 'X': + return TRUE + else: + if StrParI1 > StrParI2: + IntLoc = IntLoc + 7 + return TRUE + else: + return FALSE + +def Func3(EnumParIn): + EnumLoc = EnumParIn + if EnumLoc == Ident3: return TRUE + return FALSE + + +########################################################################### +# Benchmark interface + +bm_params = { + (50, 10): (80,), + (100, 10): (300,), + (1000, 10): (4000,), + (5000, 10): (20000,), +} + +def bm_setup(params): + Setup() + return lambda: Proc0(params[0]), lambda: (params[0], 0) diff --git a/tests/perf_bench/misc_raytrace.py b/tests/perf_bench/misc_raytrace.py new file mode 100644 index 000000000..76d4194bc --- /dev/null +++ b/tests/perf_bench/misc_raytrace.py @@ -0,0 +1,242 @@ +# A simple ray tracer +# MIT license; Copyright (c) 2019 Damien P. George + +INF = 1e30 +EPS = 1e-6 + +class Vec: + def __init__(self, x, y, z): + self.x, self.y, self.z = x, y, z + + def __neg__(self): + return Vec(-self.x, -self.y, -self.z) + + def __add__(self, rhs): + return Vec(self.x + rhs.x, self.y + rhs.y, self.z + rhs.z) + + def __sub__(self, rhs): + return Vec(self.x - rhs.x, self.y - rhs.y, self.z - rhs.z) + + def __mul__(self, rhs): + return Vec(self.x * rhs, self.y * rhs, self.z * rhs) + + def length(self): + return (self.x ** 2 + self.y ** 2 + self.z ** 2) ** 0.5 + + def normalise(self): + l = self.length() + return Vec(self.x / l, self.y / l, self.z / l) + + def dot(self, rhs): + return self.x * rhs.x + self.y * rhs.y + self.z * rhs.z + +RGB = Vec + +class Ray: + def __init__(self, p, d): + self.p, self.d = p, d + +class View: + def __init__(self, width, height, depth, pos, xdir, ydir, zdir): + self.width = width + self.height = height + self.depth = depth + self.pos = pos + self.xdir = xdir + self.ydir = ydir + self.zdir = zdir + + def calc_dir(self, dx, dy): + return (self.xdir * dx + self.ydir * dy + self.zdir * self.depth).normalise() + +class Light: + def __init__(self, pos, colour, casts_shadows): + self.pos = pos + self.colour = colour + self.casts_shadows = casts_shadows + +class Surface: + def __init__(self, diffuse, specular, spec_idx, reflect, transp, colour): + self.diffuse = diffuse + self.specular = specular + self.spec_idx = spec_idx + self.reflect = reflect + self.transp = transp + self.colour = colour + + @staticmethod + def dull(colour): + return Surface(0.7, 0.0, 1, 0.0, 0.0, colour * 0.6) + + @staticmethod + def shiny(colour): + return Surface(0.2, 0.9, 32, 0.8, 0.0, colour * 0.3) + + @staticmethod + def transparent(colour): + return Surface(0.2, 0.9, 32, 0.0, 0.8, colour * 0.3) + +class Sphere: + def __init__(self, surface, centre, radius): + self.surface = surface + self.centre = centre + self.radsq = radius ** 2 + + def intersect(self, ray): + v = self.centre - ray.p + b = v.dot(ray.d) + det = b ** 2 - v.dot(v) + self.radsq + if det > 0: + det **= 0.5 + t1 = b - det + if t1 > EPS: + return t1 + t2 = b + det + if t2 > EPS: + return t2 + return INF + + def surface_at(self, v): + return self.surface, (v - self.centre).normalise() + +class Plane: + def __init__(self, surface, centre, normal): + self.surface = surface + self.normal = normal.normalise() + self.cdotn = centre.dot(normal) + + def intersect(self, ray): + ddotn = ray.d.dot(self.normal) + if abs(ddotn) > EPS: + t = (self.cdotn - ray.p.dot(self.normal)) / ddotn + if t > 0: + return t + return INF + + def surface_at(self, p): + return self.surface, self.normal + +class Scene: + def __init__(self, ambient, light, objs): + self.ambient = ambient + self.light = light + self.objs = objs + +def trace_scene(canvas, view, scene, max_depth): + for v in range(canvas.height): + y = (-v + 0.5 * (canvas.height - 1)) * view.height / canvas.height + for u in range(canvas.width): + x = (u - 0.5 * (canvas.width - 1)) * view.width / canvas.width + ray = Ray(view.pos, view.calc_dir(x, y)) + c = trace_ray(scene, ray, max_depth) + canvas.put_pix(u, v, c) + +def trace_ray(scene, ray, depth): + # Find closest intersecting object + hit_t = INF + hit_obj = None + for obj in scene.objs: + t = obj.intersect(ray) + if t < hit_t: + hit_t = t + hit_obj = obj + + # Check if any objects hit + if hit_obj is None: + return RGB(0, 0, 0) + + # Compute location of ray intersection + point = ray.p + ray.d * hit_t + surf, surf_norm = hit_obj.surface_at(point) + if ray.d.dot(surf_norm) > 0: + surf_norm = -surf_norm + + # Compute reflected ray + reflected = ray.d - surf_norm * (surf_norm.dot(ray.d) * 2) + + # Ambient light + col = surf.colour * scene.ambient + + # Diffuse, specular and shadow from light source + light_vec = scene.light.pos - point + light_dist = light_vec.length() + light_vec = light_vec.normalise() + ndotl = surf_norm.dot(light_vec) + ldotv = light_vec.dot(reflected) + if ndotl > 0 or ldotv > 0: + light_ray = Ray(point + light_vec * EPS, light_vec) + light_col = trace_to_light(scene, light_ray, light_dist) + if ndotl > 0: + col += light_col * surf.diffuse * ndotl + if ldotv > 0: + col += light_col * surf.specular * ldotv ** surf.spec_idx + + # Reflections + if depth > 0 and surf.reflect > 0: + col += trace_ray(scene, Ray(point + reflected * EPS, reflected), depth - 1) * surf.reflect + + # Transparency + if depth > 0 and surf.transp > 0: + col += trace_ray(scene, Ray(point + ray.d * EPS, ray.d), depth - 1) * surf.transp + + return col + +def trace_to_light(scene, ray, light_dist): + col = scene.light.colour + for obj in scene.objs: + t = obj.intersect(ray) + if t < light_dist: + col *= obj.surface.transp + return col + +class Canvas: + def __init__(self, width, height): + self.width = width + self.height = height + self.data = bytearray(3 * width * height) + + def put_pix(self, x, y, c): + off = 3 * (y * self.width + x) + self.data[off] = min(255, max(0, int(255 * c.x))) + self.data[off + 1] = min(255, max(0, int(255 * c.y))) + self.data[off + 2] = min(255, max(0, int(255 * c.z))) + + def write_ppm(self, filename): + with open(filename, 'wb') as f: + f.write(bytes('P6 %d %d 255\n' % (self.width, self.height), 'ascii')) + f.write(self.data) + +def main(w, h, d): + canvas = Canvas(w, h) + view = View(32, 32, 64, Vec(0, 0, 50), Vec(1, 0, 0), Vec(0, 1, 0), Vec(0, 0, -1)) + scene = Scene( + 0.5, + Light(Vec(0, 8, 0), RGB(1, 1, 1), True), + [ + Plane(Surface.dull(RGB(1, 0, 0)), Vec(-10, 0, 0), Vec(1, 0, 0)), + Plane(Surface.dull(RGB(0, 1, 0)), Vec(10, 0, 0), Vec(-1, 0, 0)), + Plane(Surface.dull(RGB(1, 1, 1)), Vec(0, 0, -10), Vec(0, 0, 1)), + Plane(Surface.dull(RGB(1, 1, 1)), Vec(0, -10, 0), Vec(0, 1, 0)), + Plane(Surface.dull(RGB(1, 1, 1)), Vec(0, 10, 0), Vec(0, -1, 0)), + Sphere(Surface.shiny(RGB(1, 1, 1)), Vec(-5, -4, 3), 4), + Sphere(Surface.dull(RGB(0, 0, 1)), Vec(4, -5, 0), 4), + Sphere(Surface.transparent(RGB(0.2, 0.2, 0.2)), Vec(6, -1, 8), 4), + ] + ) + trace_scene(canvas, view, scene, d) + return canvas + +# For testing +#main(256, 256, 4).write_ppm('rt.ppm') + +########################################################################### +# Benchmark interface + +bm_params = { + (100, 100): (5, 5, 2), + (1000, 100): (18, 18, 3), + (5000, 100): (40, 40, 3), +} + +def bm_setup(params): + return lambda: main(*params), lambda: (params[0] * params[1] * params[2], None)