Skip to content

Commit

Permalink
Added update-in-place kernels. #130
Browse files Browse the repository at this point in the history
  • Loading branch information
eholk committed Feb 25, 2014
1 parent a945432 commit acc1cd8
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 45 deletions.
2 changes: 1 addition & 1 deletion external/elegant-weapons
16 changes: 16 additions & 0 deletions harlan/front/expand-macros.scm
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@
env)))
`(,a ((,(reify x env) ,e) ...)
,(reify b* env) ...)))))
((kernel-update!)
(match x
((,_ ((,x ,e) ,<- (,x* ,e*) ...) ,b* ...)
(let ((e (reify e env))
(e* (map (lambda (e)
(reify e env))
e*))
(env (cons (cons x (gensym (ident-symbol x)))
(append (map (lambda (x)
(cons x
(gensym (ident-symbol x))))
x*)
env))))
`(,a ((,(reify x env) ,e) ,(ident-symbol <-)
(,(reify x* env) ,e*) ...)
,(reify b* env) ...)))))
((let-region)
(match x
((,_ (,r ...) ,b ...)
Expand Down
7 changes: 5 additions & 2 deletions harlan/front/expand-primitives.scm
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
(export expand-primitives)
(import
(rnrs)
(only (chezscheme) pretty-print)
(elegant-weapons helpers)
(elegant-weapons compat))

Expand Down Expand Up @@ -155,10 +156,12 @@
`(if ,test ,conseq ,altern))
((if ,[test] ,[conseq])
`(if ,test ,conseq))
((kernel ,ktype (((,x ,t) (,[xs] ,ts)) ...) ,[body])
`(kernel ,ktype ,(gensym 'region) (((,x ,t) (,xs ,ts)) ...) ,body))
((kernel-r ,ktype ,r (((,x ,t) (,[xs] ,ts)) ...) ,[body])
`(kernel ,ktype ,r (((,x ,t) (,xs ,ts)) ...) ,body))
((kernel-update! ,t (((,x ,tx) (,[xs] ,ts)) ...) ,[body])
`(kernel-update! ,t (((,x ,tx) (,xs ,ts)) ...) ,body))
((kernel-update! . ,_)
(pretty-print `(kernel-update! . ,_)))
((let ((,x* ,t* ,[e*]) ...) ,[e])
`(let ((,x* ,t* ,e*) ...) ,e))
((begin ,[expand-prim-stmt -> s*] ... ,[e])
Expand Down
10 changes: 10 additions & 0 deletions harlan/front/parser.scm
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@
,(make-begin
`(,@(map (parse-stmt env) stmt*)
,((parse-expr env) e)))))))
((kernel-update! ((,x ,[e]) <- (,x* ,[e*]) ...) ,stmt* ... ,b)
(begin
(check-idents (cons x x*))
(let* ((x^ (gensym x))
(x*^ (map gensym x*))
(env (cons (cons x x^) (append (map cons x* x*^) env))))
`(kernel-update! ((,x^ ,e) <- (,x*^ ,e*) ...)
,(make-begin
`(,@(map (parse-stmt env) stmt*)
,((parse-expr env) b)))))))
((match ,[e]
((,tag ,x* ...) ,s* ... ,e*) ...)
(guard (and (andmap ident? tag)
Expand Down
37 changes: 37 additions & 0 deletions harlan/front/typecheck.scm
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,33 @@
(,e (vec . ,t*))) ...)
,b)
`(vec ,r ,t))))
((kernel-update! ((,x ,e) <- (,x* ,e*) ...) ,b)
(let ((tv (make-tvar (gensym 'kt))))
(do* (((e* t*) (let loop ((e e*))
(if (null? e)
(return '() '())
(let ((e* (cdr e))
(e (car e))
(t (make-tvar (gensym 'kt)))
(r (make-rvar (gensym 'rkt))))
(do* (((e* t*) (loop e*))
((e _)
(require-type e env `(vec ,r ,t))))
(return (cons e e*)
(cons (list r t) t*)))))))
((e t^) (let ((r (make-rvar (gensym 'rkt))))
(require-type e env `(vec ,r ,tv))))
((b t) (infer-expr b (append
(map (lambda (x t) (cons x (cadr t)))
(cons x x*)
(cons `(r ,tv) t*))
env))))
(return `(kernel-update! ,t^
(((,x ,tv) (,e ,t^))
((,x* ,(map cadr t*))
(,e* (vec . ,t*))) ...)
,b)
t^))))
((call ,f ,e* ...) (guard (ident? f))
(let ((t (make-tvar (gensym 'rt)))
(ft (lookup f env)))
Expand Down Expand Up @@ -774,6 +801,11 @@
,[b])
`(kernel-r ,t ,(region-name (walk r s))
(((,x ,ta*) (,e ,ta**)) ...) ,b))
((kernel-update! ,[ground-type -> t]
(((,x ,[ground-type -> ta*])
(,[e] ,[ground-type -> ta**])) ...)
,[b])
`(kernel-update! ,t (((,x ,ta*) (,e ,ta**)) ...) ,b))
((reduce ,[ground-type -> t] + ,[e]) `(reduce ,t + ,e))
((set! ,[x] ,[e]) `(set! ,x ,e))
((begin ,[e*] ...) `(begin ,e* ...))
Expand Down Expand Up @@ -823,6 +855,11 @@
(((,x ,[free-regions-type -> t*]) (,xs ,[free-regions-type -> ts*])) ...)
,[b])
(set-add (union b t (apply union (append t* ts*))) r))
((kernel-update! ,[free-regions-type -> t]
(((,x ,[free-regions-type -> t*])
(,xs ,[free-regions-type -> ts*])) ...)
,[b])
(union b t (apply union (append t* ts*))))
((reduce ,[free-regions-type -> t] ,op ,[e]) (union t e))
((set! ,[x] ,[e]) (union x e))
((begin ,[e*] ...) (apply union e*))
Expand Down
2 changes: 2 additions & 0 deletions harlan/middle/desugar-match.scm
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@
((vector ,t ,r ,[e] ...) `(vector ,t ,r ,e ...))
((kernel ,t ,r (((,x ,t*) (,[xs] ,ts)) ...) ,[e])
`(kernel ,t ,r (((,x ,t*) (,xs ,ts)) ...) ,e))
((kernel-update! ,t (((,x ,t*) (,[xs] ,ts)) ...) ,[e])
`(kernel-update! ,t (((,x ,t*) (,xs ,ts)) ...) ,e))
((let ((,x ,t ,[e]) ...) ,[b])
`(let ((,x ,t ,e) ...) ,b))
((make-vector ,t ,r ,[e]) `(make-vector ,t ,r ,e))
Expand Down
4 changes: 3 additions & 1 deletion harlan/middle/languages.scm
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@
(begin e e* ...)
(assert e)
(let (lbind* ...) e)
(let-region (r ...) e)
(kernel t r (((x0 t0) (e1 t1)) ...) e)
(kernel-update! t (((x0 t0) (e1 t1)) ...) e)
(iota-r r e)
(for (x e0 e1 e2) e)
(lambda t0 ((x t) ...) e)
Expand Down Expand Up @@ -270,6 +270,7 @@
(set! e1 e2)
(error x)
(begin stmt ...)
(let-region (r ...) stmt)
(if e stmt1 stmt2)
(if e stmt)
(while e stmt)
Expand Down Expand Up @@ -297,6 +298,7 @@
(Expr
(e)
(- (kernel t r (e* ...) (((x0 t0) (e1 t1) i*) ...) e)
(kernel-update! t (((x0 t0) (e1 t1)) ...) e)
(error x))
(+ (addressof e)
(deref e))))
Expand Down
3 changes: 3 additions & 0 deletions harlan/middle/optimize-fuse-kernels.scm
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@
`(if ,test ,conseq))
((kernel ,t ,r ,dims ,iters ,[body])
(make-2d-kernel t r dims iters body))
((kernel-update! ,t (((,x ,t^) (,[xs] ,ts)) ...) ,[e])
;; TODO: we can do some limited fusion here.
`(kernel-update! ,t (((,x ,t^) (,xs ,ts)) ...) ,e))
((let ((,x* ,t* ,[e*]) ...) ,[e])
`(let ((,x* ,t* ,e*) ...) ,e))
((begin ,[Stmt -> s*] ... ,[e])
Expand Down
6 changes: 6 additions & 0 deletions harlan/middle/remove-nested-kernels.scm
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@
`(kernel (vec ,r ,t) ,r ,dims
(((,x* ,t*) (,xs* ,ts*) ,d*) ...)
,e)))
((kernel-update! ,t (((,x ,t^) (,[xs] ,ts)) ...) ,[(Expr #t) -> e])
(if k
;; TODO: allow nested kernel-update!
(error 'remove-nested-kernels
"nesting kernel-update! is not yet supported.")
`(kernel-update! ,t (((,x ,t^) (,xs ,ts)) ...) ,e)))
((if ,[t] ,[c] ,[a])
`(if ,t ,c ,a))
((call ,[fn] ,[args] ...)
Expand Down
108 changes: 67 additions & 41 deletions harlan/middle/returnify-kernels.scm
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,46 @@
((field ,[e] ,x) `(field ,e ,x))
((empty-struct) '(empty-struct))
((kernel . ,body*)
(returnify-kernel `(kernel . ,body*))))
(returnify-kernel `(kernel . ,body*)))
((kernel-update! (vec ,r ,t)
(((,x ,t) (,[x^] (vec ,r ,t)))
((,x* ,t*) (,[xs*] ,ts*)) ...)
,body)
;; This is where we get ride of kernel-update!. After this point,
;; we'll just have regular kernels.
(make-danger-vector
(lambda (danger)
(let ((body ((set-retval t x danger) body)))
`(kernel (vec ,r ,t)
((length ,x^))
(((,x ,t) (,x^ (vec ,r ,t)) 0)
((,x* ,t*) (,xs* ,ts*) 0) ...)
,body)))
x^)))

(define (make-danger-vector make-body rval)
(let ((r (gensym 'danger-region)))
(let ((danger-vector (gensym 'danger_vector))
(danger-vec-t `(vec ,r ,danger-type))
(i (gensym 'i)))
`(begin
(let-region (,r)
(let ((,danger-vector
,danger-vec-t
(make-vector ,danger-type ,r ,num-danger)))
(begin
(for (,i (int 0) ,num-danger (int 1))
(set! (vector-ref
,danger-type
(var ,danger-vec-t ,danger-vector)
(var int ,i))
(bool #f)))
,(make-body (lambda (d)
`(vector-ref ,danger-type
(var ,danger-vec-t ,danger-vector)
,d)))
,(check-danger-vector danger-vector r num-danger))))
,rval))))

(define-match returnify-kernel
((kernel (vec ,r ,t)
Expand All @@ -89,48 +128,35 @@
(((,x* ,tx*) (,[returnify-kernel-expr -> xe*] ,xet*) ,dim) ...)
,body)
(let ((retvars (map (lambda (_) (gensym 'retval)) dims))
(danger-vector (gensym 'danger_vector))
(danger-vec-t `(vec ,r ,danger-type))
(i (gensym 'i))
(vv (gensym 'vv))
(i (gensym 'i))
(id (gensym 'kern)))
`(let ((,id (vec ,r ,t) (make-vector ,t ,r ,(car dims)))
(,danger-vector
;; TODO: the danger vector should probably get its own region.
,danger-vec-t
(make-vector ,danger-type ,r ,num-danger)))
(begin
(for (,i (int 0) ,num-danger (int 1))
(set! (vector-ref
,danger-type
(var ,danger-vec-t ,danger-vector)
(var int ,i))
(bool #f)))
,@(if (null? (cdr dims))
`()
(match t
((vec ,r^ ,t^)
`((for (,i (int 0) ,(car dims) (int 1))
(let ((,vv (vec ,r^ ,t^)
(make-vector ,t^ ,r^
,(cadr dims))))
(set! (vector-ref (vec ,r^ ,t^)
(var (vec ,r ,t) ,id)
(var int ,i))
(var (vec ,r^ ,t^) ,vv))))))))
(kernel
(vec ,r ,t)
,dims
,(insert-retvars r retvars (cons id retvars) 0 t
`(((,x* ,tx*) (,xe* ,xet*) ,dim) ...))
,((set-retval (shave-type (length dims) `(vec ,r ,t))
(car (reverse retvars))
(lambda (d) `(vector-ref ,danger-type
(var ,danger-vec-t ,danger-vector)
,d)))
body))
,(check-danger-vector danger-vector r num-danger)
(var (vec ,r ,t) ,id))))))
`(let ((,id (vec ,r ,t) (make-vector ,t ,r ,(car dims))))
,(make-danger-vector
(lambda (danger)
`(begin
,@(if (null? (cdr dims))
`()
(match t
((vec ,r^ ,t^)
`((for (,i (int 0) ,(car dims) (int 1))
(let ((,vv (vec ,r^ ,t^)
(make-vector ,t^ ,r^
,(cadr dims))))
(set! (vector-ref (vec ,r^ ,t^)
(var (vec ,r ,t) ,id)
(var int ,i))
(var (vec ,r^ ,t^) ,vv))))))))
(kernel
(vec ,r ,t)
,dims
,(insert-retvars r retvars (cons id retvars) 0 t
`(((,x* ,tx*) (,xe* ,xet*) ,dim) ...))
,((set-retval (shave-type (length dims) `(vec ,r ,t))
(car (reverse retvars))
danger)
body))))
`(var (vec ,r ,t) ,id))))))

(define (check-danger-vector danger-vector r len)
(let ((i (gensym 'danger_i))
Expand Down

0 comments on commit acc1cd8

Please sign in to comment.