diff --git a/external/elegant-weapons b/external/elegant-weapons index 4c33e67..70f4b74 160000 --- a/external/elegant-weapons +++ b/external/elegant-weapons @@ -1 +1 @@ -Subproject commit 4c33e677ff70917a0ea6ded5703e4fc161920f3f +Subproject commit 70f4b7473f760170acb6cda5864a8eaeb515e7f2 diff --git a/harlan/front/expand-macros.scm b/harlan/front/expand-macros.scm index f8b6b8d..57b182f 100644 --- a/harlan/front/expand-macros.scm +++ b/harlan/front/expand-macros.scm @@ -273,6 +273,22 @@ env))) `(,a ((,(reify x env) ,e) ...) ,(reify b* env) ...))))) + ((kernel-update!) + (match x + ((,_ ((,x ,e) ,<- (,x* ,e*) ...) ,b* ...) + (let ((e (reify e env)) + (e* (map (lambda (e) + (reify e env)) + e*)) + (env (cons (cons x (gensym (ident-symbol x))) + (append (map (lambda (x) + (cons x + (gensym (ident-symbol x)))) + x*) + env)))) + `(,a ((,(reify x env) ,e) ,(ident-symbol <-) + (,(reify x* env) ,e*) ...) + ,(reify b* env) ...))))) ((let-region) (match x ((,_ (,r ...) ,b ...) diff --git a/harlan/front/expand-primitives.scm b/harlan/front/expand-primitives.scm index ec302cb..38a7508 100644 --- a/harlan/front/expand-primitives.scm +++ b/harlan/front/expand-primitives.scm @@ -3,6 +3,7 @@ (export expand-primitives) (import (rnrs) + (only (chezscheme) pretty-print) (elegant-weapons helpers) (elegant-weapons compat)) @@ -155,10 +156,12 @@ `(if ,test ,conseq ,altern)) ((if ,[test] ,[conseq]) `(if ,test ,conseq)) - ((kernel ,ktype (((,x ,t) (,[xs] ,ts)) ...) ,[body]) - `(kernel ,ktype ,(gensym 'region) (((,x ,t) (,xs ,ts)) ...) ,body)) ((kernel-r ,ktype ,r (((,x ,t) (,[xs] ,ts)) ...) ,[body]) `(kernel ,ktype ,r (((,x ,t) (,xs ,ts)) ...) ,body)) + ((kernel-update! ,t (((,x ,tx) (,[xs] ,ts)) ...) ,[body]) + `(kernel-update! ,t (((,x ,tx) (,xs ,ts)) ...) ,body)) + ((kernel-update! . ,_) + (pretty-print `(kernel-update! . ,_))) ((let ((,x* ,t* ,[e*]) ...) ,[e]) `(let ((,x* ,t* ,e*) ...) ,e)) ((begin ,[expand-prim-stmt -> s*] ... ,[e]) diff --git a/harlan/front/parser.scm b/harlan/front/parser.scm index 0e0faa3..c596866 100644 --- a/harlan/front/parser.scm +++ b/harlan/front/parser.scm @@ -165,6 +165,16 @@ ,(make-begin `(,@(map (parse-stmt env) stmt*) ,((parse-expr env) e))))))) + ((kernel-update! ((,x ,[e]) <- (,x* ,[e*]) ...) ,stmt* ... ,b) + (begin + (check-idents (cons x x*)) + (let* ((x^ (gensym x)) + (x*^ (map gensym x*)) + (env (cons (cons x x^) (append (map cons x* x*^) env)))) + `(kernel-update! ((,x^ ,e) <- (,x*^ ,e*) ...) + ,(make-begin + `(,@(map (parse-stmt env) stmt*) + ,((parse-expr env) b))))))) ((match ,[e] ((,tag ,x* ...) ,s* ... ,e*) ...) (guard (and (andmap ident? tag) diff --git a/harlan/front/typecheck.scm b/harlan/front/typecheck.scm index 39331c1..efcf1cc 100644 --- a/harlan/front/typecheck.scm +++ b/harlan/front/typecheck.scm @@ -468,6 +468,33 @@ (,e (vec . ,t*))) ...) ,b) `(vec ,r ,t)))) + ((kernel-update! ((,x ,e) <- (,x* ,e*) ...) ,b) + (let ((tv (make-tvar (gensym 'kt)))) + (do* (((e* t*) (let loop ((e e*)) + (if (null? e) + (return '() '()) + (let ((e* (cdr e)) + (e (car e)) + (t (make-tvar (gensym 'kt))) + (r (make-rvar (gensym 'rkt)))) + (do* (((e* t*) (loop e*)) + ((e _) + (require-type e env `(vec ,r ,t)))) + (return (cons e e*) + (cons (list r t) t*))))))) + ((e t^) (let ((r (make-rvar (gensym 'rkt)))) + (require-type e env `(vec ,r ,tv)))) + ((b t) (infer-expr b (append + (map (lambda (x t) (cons x (cadr t))) + (cons x x*) + (cons `(r ,tv) t*)) + env)))) + (return `(kernel-update! ,t^ + (((,x ,tv) (,e ,t^)) + ((,x* ,(map cadr t*)) + (,e* (vec . ,t*))) ...) + ,b) + t^)))) ((call ,f ,e* ...) (guard (ident? f)) (let ((t (make-tvar (gensym 'rt))) (ft (lookup f env))) @@ -774,6 +801,11 @@ ,[b]) `(kernel-r ,t ,(region-name (walk r s)) (((,x ,ta*) (,e ,ta**)) ...) ,b)) + ((kernel-update! ,[ground-type -> t] + (((,x ,[ground-type -> ta*]) + (,[e] ,[ground-type -> ta**])) ...) + ,[b]) + `(kernel-update! ,t (((,x ,ta*) (,e ,ta**)) ...) ,b)) ((reduce ,[ground-type -> t] + ,[e]) `(reduce ,t + ,e)) ((set! ,[x] ,[e]) `(set! ,x ,e)) ((begin ,[e*] ...) `(begin ,e* ...)) @@ -823,6 +855,11 @@ (((,x ,[free-regions-type -> t*]) (,xs ,[free-regions-type -> ts*])) ...) ,[b]) (set-add (union b t (apply union (append t* ts*))) r)) + ((kernel-update! ,[free-regions-type -> t] + (((,x ,[free-regions-type -> t*]) + (,xs ,[free-regions-type -> ts*])) ...) + ,[b]) + (union b t (apply union (append t* ts*)))) ((reduce ,[free-regions-type -> t] ,op ,[e]) (union t e)) ((set! ,[x] ,[e]) (union x e)) ((begin ,[e*] ...) (apply union e*)) diff --git a/harlan/middle/desugar-match.scm b/harlan/middle/desugar-match.scm index 100c89f..c21f3f2 100644 --- a/harlan/middle/desugar-match.scm +++ b/harlan/middle/desugar-match.scm @@ -151,6 +151,8 @@ ((vector ,t ,r ,[e] ...) `(vector ,t ,r ,e ...)) ((kernel ,t ,r (((,x ,t*) (,[xs] ,ts)) ...) ,[e]) `(kernel ,t ,r (((,x ,t*) (,xs ,ts)) ...) ,e)) + ((kernel-update! ,t (((,x ,t*) (,[xs] ,ts)) ...) ,[e]) + `(kernel-update! ,t (((,x ,t*) (,xs ,ts)) ...) ,e)) ((let ((,x ,t ,[e]) ...) ,[b]) `(let ((,x ,t ,e) ...) ,b)) ((make-vector ,t ,r ,[e]) `(make-vector ,t ,r ,e)) diff --git a/harlan/middle/languages.scm b/harlan/middle/languages.scm index 8100dda..cf8cecb 100644 --- a/harlan/middle/languages.scm +++ b/harlan/middle/languages.scm @@ -107,8 +107,8 @@ (begin e e* ...) (assert e) (let (lbind* ...) e) - (let-region (r ...) e) (kernel t r (((x0 t0) (e1 t1)) ...) e) + (kernel-update! t (((x0 t0) (e1 t1)) ...) e) (iota-r r e) (for (x e0 e1 e2) e) (lambda t0 ((x t) ...) e) @@ -270,6 +270,7 @@ (set! e1 e2) (error x) (begin stmt ...) + (let-region (r ...) stmt) (if e stmt1 stmt2) (if e stmt) (while e stmt) @@ -297,6 +298,7 @@ (Expr (e) (- (kernel t r (e* ...) (((x0 t0) (e1 t1) i*) ...) e) + (kernel-update! t (((x0 t0) (e1 t1)) ...) e) (error x)) (+ (addressof e) (deref e)))) diff --git a/harlan/middle/optimize-fuse-kernels.scm b/harlan/middle/optimize-fuse-kernels.scm index ac63c6d..c5cdcdc 100644 --- a/harlan/middle/optimize-fuse-kernels.scm +++ b/harlan/middle/optimize-fuse-kernels.scm @@ -79,6 +79,9 @@ `(if ,test ,conseq)) ((kernel ,t ,r ,dims ,iters ,[body]) (make-2d-kernel t r dims iters body)) + ((kernel-update! ,t (((,x ,t^) (,[xs] ,ts)) ...) ,[e]) + ;; TODO: we can do some limited fusion here. + `(kernel-update! ,t (((,x ,t^) (,xs ,ts)) ...) ,e)) ((let ((,x* ,t* ,[e*]) ...) ,[e]) `(let ((,x* ,t* ,e*) ...) ,e)) ((begin ,[Stmt -> s*] ... ,[e]) diff --git a/harlan/middle/remove-nested-kernels.scm b/harlan/middle/remove-nested-kernels.scm index 3054625..7661a1e 100644 --- a/harlan/middle/remove-nested-kernels.scm +++ b/harlan/middle/remove-nested-kernels.scm @@ -68,6 +68,12 @@ `(kernel (vec ,r ,t) ,r ,dims (((,x* ,t*) (,xs* ,ts*) ,d*) ...) ,e))) + ((kernel-update! ,t (((,x ,t^) (,[xs] ,ts)) ...) ,[(Expr #t) -> e]) + (if k + ;; TODO: allow nested kernel-update! + (error 'remove-nested-kernels + "nesting kernel-update! is not yet supported.") + `(kernel-update! ,t (((,x ,t^) (,xs ,ts)) ...) ,e))) ((if ,[t] ,[c] ,[a]) `(if ,t ,c ,a)) ((call ,[fn] ,[args] ...) diff --git a/harlan/middle/returnify-kernels.scm b/harlan/middle/returnify-kernels.scm index 86723bb..ec8c0cc 100644 --- a/harlan/middle/returnify-kernels.scm +++ b/harlan/middle/returnify-kernels.scm @@ -80,7 +80,46 @@ ((field ,[e] ,x) `(field ,e ,x)) ((empty-struct) '(empty-struct)) ((kernel . ,body*) - (returnify-kernel `(kernel . ,body*)))) + (returnify-kernel `(kernel . ,body*))) + ((kernel-update! (vec ,r ,t) + (((,x ,t) (,[x^] (vec ,r ,t))) + ((,x* ,t*) (,[xs*] ,ts*)) ...) + ,body) + ;; This is where we get ride of kernel-update!. After this point, + ;; we'll just have regular kernels. + (make-danger-vector + (lambda (danger) + (let ((body ((set-retval t x danger) body))) + `(kernel (vec ,r ,t) + ((length ,x^)) + (((,x ,t) (,x^ (vec ,r ,t)) 0) + ((,x* ,t*) (,xs* ,ts*) 0) ...) + ,body))) + x^))) + +(define (make-danger-vector make-body rval) + (let ((r (gensym 'danger-region))) + (let ((danger-vector (gensym 'danger_vector)) + (danger-vec-t `(vec ,r ,danger-type)) + (i (gensym 'i))) + `(begin + (let-region (,r) + (let ((,danger-vector + ,danger-vec-t + (make-vector ,danger-type ,r ,num-danger))) + (begin + (for (,i (int 0) ,num-danger (int 1)) + (set! (vector-ref + ,danger-type + (var ,danger-vec-t ,danger-vector) + (var int ,i)) + (bool #f))) + ,(make-body (lambda (d) + `(vector-ref ,danger-type + (var ,danger-vec-t ,danger-vector) + ,d))) + ,(check-danger-vector danger-vector r num-danger)))) + ,rval)))) (define-match returnify-kernel ((kernel (vec ,r ,t) @@ -89,48 +128,35 @@ (((,x* ,tx*) (,[returnify-kernel-expr -> xe*] ,xet*) ,dim) ...) ,body) (let ((retvars (map (lambda (_) (gensym 'retval)) dims)) - (danger-vector (gensym 'danger_vector)) - (danger-vec-t `(vec ,r ,danger-type)) - (i (gensym 'i)) (vv (gensym 'vv)) + (i (gensym 'i)) (id (gensym 'kern))) - `(let ((,id (vec ,r ,t) (make-vector ,t ,r ,(car dims))) - (,danger-vector - ;; TODO: the danger vector should probably get its own region. - ,danger-vec-t - (make-vector ,danger-type ,r ,num-danger))) - (begin - (for (,i (int 0) ,num-danger (int 1)) - (set! (vector-ref - ,danger-type - (var ,danger-vec-t ,danger-vector) - (var int ,i)) - (bool #f))) - ,@(if (null? (cdr dims)) - `() - (match t - ((vec ,r^ ,t^) - `((for (,i (int 0) ,(car dims) (int 1)) - (let ((,vv (vec ,r^ ,t^) - (make-vector ,t^ ,r^ - ,(cadr dims)))) - (set! (vector-ref (vec ,r^ ,t^) - (var (vec ,r ,t) ,id) - (var int ,i)) - (var (vec ,r^ ,t^) ,vv)))))))) - (kernel - (vec ,r ,t) - ,dims - ,(insert-retvars r retvars (cons id retvars) 0 t - `(((,x* ,tx*) (,xe* ,xet*) ,dim) ...)) - ,((set-retval (shave-type (length dims) `(vec ,r ,t)) - (car (reverse retvars)) - (lambda (d) `(vector-ref ,danger-type - (var ,danger-vec-t ,danger-vector) - ,d))) - body)) - ,(check-danger-vector danger-vector r num-danger) - (var (vec ,r ,t) ,id)))))) + `(let ((,id (vec ,r ,t) (make-vector ,t ,r ,(car dims)))) + ,(make-danger-vector + (lambda (danger) + `(begin + ,@(if (null? (cdr dims)) + `() + (match t + ((vec ,r^ ,t^) + `((for (,i (int 0) ,(car dims) (int 1)) + (let ((,vv (vec ,r^ ,t^) + (make-vector ,t^ ,r^ + ,(cadr dims)))) + (set! (vector-ref (vec ,r^ ,t^) + (var (vec ,r ,t) ,id) + (var int ,i)) + (var (vec ,r^ ,t^) ,vv)))))))) + (kernel + (vec ,r ,t) + ,dims + ,(insert-retvars r retvars (cons id retvars) 0 t + `(((,x* ,tx*) (,xe* ,xet*) ,dim) ...)) + ,((set-retval (shave-type (length dims) `(vec ,r ,t)) + (car (reverse retvars)) + danger) + body)))) + `(var (vec ,r ,t) ,id)))))) (define (check-danger-vector danger-vector r len) (let ((i (gensym 'danger_i))